This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 790c5d1e9e [Pass] Add DumpIR pass instrument to save IR snapshots
(#18511)
790c5d1e9e is described below
commit 790c5d1e9ed6f02040f46dc9ab75932063b796f7
Author: Siyuan Feng <[email protected]>
AuthorDate: Thu Nov 27 23:25:41 2025 +0800
[Pass] Add DumpIR pass instrument to save IR snapshots (#18511)
Add a new DumpIR pass instrument that automatically dumps the IR module
to files after each pass execution. This helps with debugging and
understanding pass transformations.
Features:
- Dumps IR to numbered files (e.g., 000_PassName.py, 001_PassName.py)
- Optional refresh parameter to clean dump directory before starting
- Safe directory removal that only deletes if directory contains dump
files
- Graceful error handling if IR script generation fails
Example usage:
```python
with tvm.transform.PassContext(instruments=[DumpIR("./dump",
refresh=True)]):
lib = tvm.compile(module, target="llvm")
```
Also includes minor cleanup:
- Rename RelayPassContextThreadLocalStore to PassContextThreadLocalStore
- Remove unused includes in transform.cc and unroll_loop.cc
- Add type hints to PrintAfterAll and PrintBeforeAll"
---------
Co-authored-by: gemini-code-assist[bot]
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
python/tvm/ir/instrument.py | 53 ++++++++++++++++++++++++++++++++++++++-
src/ir/transform.cc | 14 +++--------
src/tir/transforms/unroll_loop.cc | 2 --
3 files changed, 56 insertions(+), 13 deletions(-)
diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py
index 7b6749f113..0f1bcf3adf 100644
--- a/python/tvm/ir/instrument.py
+++ b/python/tvm/ir/instrument.py
@@ -16,10 +16,15 @@
# under the License.
# pylint: disable=invalid-name,unused-argument
"""Common pass instrumentation across IR variants."""
-import inspect
import functools
+import inspect
+import re
+import shutil
+from pathlib import Path
+from typing import Union
import tvm_ffi
+
import tvm.runtime
from . import _ffi_instrument_api
@@ -288,3 +293,49 @@ class PrintBeforeAll:
def run_before_pass(self, mod, info):
print(f"Before Running Pass: {info}")
print(mod)
+
+
+@pass_instrument
+class DumpIR:
+ """Dump the IR after the pass runs."""
+
+ def __init__(self, dump_dir: Union[Path, str], refresh: bool = False):
+ if isinstance(dump_dir, Path):
+ self.dump_dir = dump_dir
+ else:
+ self.dump_dir = Path(dump_dir)
+ self.counter = 0
+ if refresh and self.dump_dir.is_dir():
+ self._safe_remove_dump_dir()
+
+ def _safe_remove_dump_dir(self):
+ """Remove dump directory only if it contains only dumped IR files."""
+ # Pattern for dumped files: {counter:03d}_{pass_name}.py
+ dump_pattern = re.compile(r"^\d{3}_.*\.py$")
+
+ # Check all files in the directory
+ for item in self.dump_dir.iterdir():
+ # If there's a subdirectory or a file that doesn't match the
pattern, abort
+ if item.is_dir() or not dump_pattern.match(item.name):
+ print(
+ f"WARNING: Skipping removal of {self.dump_dir} as it
contains "
+ f"non-dumped files or directories. Please clean it
manually."
+ )
+ return
+
+ # Safe to remove - only contains dumped files
+ try:
+ shutil.rmtree(self.dump_dir)
+ except OSError as e:
+ print(f"WARNING: Failed to remove directory {self.dump_dir}: {e}")
+
+ def run_after_pass(self, mod, info):
+ self.dump_dir.mkdir(parents=True, exist_ok=True)
+ try:
+ sanitized_pass_name = re.sub(r'[<>:"/\\|?*]', "_", info.name)
+ with open(self.dump_dir /
f"{self.counter:03d}_{sanitized_pass_name}.py", "w") as f:
+ f.write(mod.script())
+ except Exception: # pylint: disable=broad-exception-caught
+ print(f"WARNING: Failed to dump IR for pass {info.name}")
+ finally:
+ self.counter += 1
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index d1e5950459..3cbf8a629f 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -31,19 +31,13 @@
#include <tvm/relax/expr.h>
#include <tvm/runtime/device_api.h>
-#include <chrono>
-#include <iomanip>
#include <stack>
-#include <unordered_set>
-
-#include "../runtime/regex.h"
namespace tvm {
namespace transform {
using tvm::ReprPrinter;
using tvm::ffi::Any;
-using tvm::ffi::PackedArgs;
TVM_REGISTER_PASS_CONFIG_OPTION("testing.immutable_module", Bool);
@@ -60,17 +54,17 @@ struct PassContextThreadLocalEntry {
};
/*! \brief Thread local store to hold the pass context. */
-typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry>
RelayPassContextThreadLocalStore;
+typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry>
PassContextThreadLocalStore;
void PassContext::EnterWithScope() {
InstrumentEnterPassContext();
- PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get();
+ PassContextThreadLocalEntry* entry = PassContextThreadLocalStore::Get();
entry->context_stack.push(*this);
}
void PassContext::ExitWithScope() {
- PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get();
+ PassContextThreadLocalEntry* entry = PassContextThreadLocalStore::Get();
ICHECK(!entry->context_stack.empty());
ICHECK(entry->context_stack.top().same_as(*this));
entry->context_stack.pop();
@@ -79,7 +73,7 @@ void PassContext::ExitWithScope() {
}
PassContext PassContext::Current() {
- PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get();
+ PassContextThreadLocalEntry* entry = PassContextThreadLocalStore::Get();
if (!entry->context_stack.empty()) {
return entry->context_stack.top();
} else {
diff --git a/src/tir/transforms/unroll_loop.cc
b/src/tir/transforms/unroll_loop.cc
index 74abea57ba..7b92bad12d 100644
--- a/src/tir/transforms/unroll_loop.cc
+++ b/src/tir/transforms/unroll_loop.cc
@@ -30,9 +30,7 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
-#include <unordered_map>
#include <unordered_set>
-#include <vector>
#include "../../runtime/thread_storage_scope.h"
#include "ir_utils.h"