This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 783a9bf3e3 [REFACTOR][S-TIR] Migrate more transform to s_tir (#18771)
783a9bf3e3 is described below
commit 783a9bf3e30f4c3aebe269e3942aa3ca9742be8b
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Feb 13 07:12:28 2026 -0500
[REFACTOR][S-TIR] Migrate more transform to s_tir (#18771)
This PR migrate more transformations into s_tir, to reduce the passes in
tir namespace.
---
include/tvm/s_tir/transform.h | 96 +++++++++++
python/tvm/s_tir/backend/adreno/pipeline.py | 36 ++--
python/tvm/s_tir/pipeline.py | 36 ++--
python/tvm/s_tir/transform/__init__.py | 1 +
python/tvm/s_tir/transform/transform.py | 189 +++++++++++++++++++++
python/tvm/tir/transform/transform.py | 4 +-
.../analysis/calculate_allocated_memory.cc | 3 +-
.../meta_schedule/postproc/verify_gpu_code.cc | 2 +-
src/{tir => s_tir}/transform/bound_checker.cc | 12 +-
src/{tir => s_tir}/transform/hoist_expression.cc | 72 ++++----
.../transform/inject_ptx_async_copy.cc | 13 +-
src/{tir => s_tir}/transform/inject_ptx_ldg32.cc | 11 +-
src/{tir => s_tir}/transform/lower_async_dma.cc | 13 +-
.../transform/lower_thread_allreduce.cc | 21 +--
src/{tir => s_tir}/transform/lower_vtcm_alloc.cc | 13 +-
.../transform/merge_shared_memory_allocations.cc | 24 +--
.../transform/profile_instrumentation.cc | 31 ++--
.../transform/renormalize_split_pattern.cc | 11 +-
.../transform/rewrite_unsafe_select.cc | 11 +-
src/{tir => s_tir}/transform/storage_access.cc | 15 +-
src/{tir => s_tir}/transform/storage_access.h | 11 +-
.../transform/tensorcore_infer_fragment.cc | 29 ++--
.../transform/thread_storage_sync.cc | 22 +--
.../test_s_tir_transform_hoist_expression.py} | 7 +-
.../transform/test_s_tir_transform_hoist_if.py} | 72 ++++----
.../test_s_tir_transform_inject_ptx_async_copy.py} | 11 +-
.../test_s_tir_transform_inject_ptx_ldg32.py} | 5 +-
...est_s_tir_transform_lower_thread_all_reduce.py} | 21 +--
...orm_merge_dynamic_shared_memory_allocations.py} | 12 +-
.../test_s_tir_transform_profiling_instr.py} | 24 +--
...t_s_tir_transform_renormalize_split_pattern.py} | 5 +-
.../test_s_tir_transform_rewrite_unsafe_select.py} | 7 +-
.../transform/test_s_tir_transform_thread_sync.py} | 8 +-
33 files changed, 584 insertions(+), 264 deletions(-)
diff --git a/include/tvm/s_tir/transform.h b/include/tvm/s_tir/transform.h
index 9914c6e49a..55343fdce5 100644
--- a/include/tvm/s_tir/transform.h
+++ b/include/tvm/s_tir/transform.h
@@ -230,6 +230,102 @@ TVM_DLL Pass InjectVirtualThread();
*/
TVM_DLL Pass InjectDoubleBuffer();
+/*!
+ * \brief Hoist loop-invariant IfThenElse nodes to
+ * outside the eligible loops.
+ *
+ * \param variant The variant of the pass.
+ * variant can have any one of following values ["basic", ""(Default)].
+ * \return The pass.
+ */
+TVM_DLL Pass HoistIfThenElse(tvm::ffi::String variant = "");
+
+/*!
+ * \brief Hoist loop-invariant expressions to outside the eligible loops.
+ *
+ * Can hoist conditionals used in IfThenElse statements and
+ * expressions, bindings of variables in Let statements and
+ * expressions, or boolean expressions, configurable to enable/disable
+ * each hoistable type.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass HoistExpression();
+
+/*!
+ * \brief Renormalize the split pattern from floordiv(floormod()) to
floormod(floordiv()).
+ * \return The pass.
+ */
+TVM_DLL Pass RenormalizeSplitPattern();
+
+/*!
+ * \brief Detect and rewrite unsafe select that contains memory access.
+ * \return The pass.
+ */
+TVM_DLL Pass RewriteUnsafeSelect();
+
+/*!
+ * \brief Instruments bound checkers.
+ * \return The pass.
+ */
+TVM_DLL Pass InstrumentBoundCheckers();
+
+/*!
+ * \brief Rewrite global to local memory copy on CUDA with ldg32 instruction.
+ * \param enable_inject Whether to enable injection.
+ * \return The pass.
+ */
+TVM_DLL Pass InjectPTXLDG32(bool enable_inject = true);
+
+/*!
+ * \brief Insert intrinsic calls to instrument function and loop level
profiling.
+ * \return The pass.
+ */
+TVM_DLL Pass InstrumentProfileIntrinsics();
+
+/*!
+ * \brief Lower VTCM allocations.
+ * \return The pass.
+ */
+TVM_DLL Pass LowerVtcmAlloc();
+
+/*!
+ * \brief Insert sync between parallel read/write of shared buffers.
+ * \param storage_scope The storage scope considered.
+ * \return The pass.
+ */
+TVM_DLL Pass ThreadSync(tvm::ffi::String storage_scope);
+
+/*!
+ * \brief Infer the TensorCore fragment information using tensor intrinsics.
+ * \return The pass.
+ */
+TVM_DLL Pass InferFragment();
+
+/*!
+ * \brief Lower cross thread allreduce.
+ * \return The pass.
+ */
+TVM_DLL Pass LowerThreadAllreduce();
+
+/*!
+ * \brief Lower Async TIR primitives to DMA copy and wait builtins.
+ * \return The pass.
+ */
+TVM_DLL Pass LowerAsyncDMA();
+
+/*!
+ * \brief Rewrite global to shared memory copy on CUDA with asynchronous copy.
+ * \return The pass.
+ */
+TVM_DLL Pass InjectPTXAsyncCopy();
+
+/*!
+ * \brief Merge multiple TIR-level shared memory allocations into one.
+ * \return The pass.
+ */
+TVM_DLL Pass MergeSharedMemoryAllocations();
+
} // namespace transform
} // namespace s_tir
} // namespace tvm
diff --git a/python/tvm/s_tir/backend/adreno/pipeline.py
b/python/tvm/s_tir/backend/adreno/pipeline.py
index e895025bc4..0237429e80 100644
--- a/python/tvm/s_tir/backend/adreno/pipeline.py
+++ b/python/tvm/s_tir/backend/adreno/pipeline.py
@@ -62,22 +62,22 @@ def default_tir_pipeline():
if not bool(config.get("tir.disable_storage_rewrite", False)):
passes.append(tir.transform.StorageRewrite())
if config.get("tir.use_async_copy", False):
- passes.append(tir.transform.LowerAsyncDMA())
+ passes.append(s_tir.transform.LowerAsyncDMA())
passes.extend(
[
- tir.transform.HoistIfThenElse(),
+ s_tir.transform.HoistIfThenElse(),
tir.transform.UnrollLoop(),
- tir.transform.RenormalizeSplitPattern(),
+ s_tir.transform.RenormalizeSplitPattern(),
tir.transform.Simplify(),
tir.transform.RemoveNoOp(),
- tir.transform.RewriteUnsafeSelect(),
+ s_tir.transform.RewriteUnsafeSelect(),
]
)
# Additional passes based on configuration.
if bool(config.get("tir.instrument_bound_checkers", False)):
- passes.append(tir.transform.InstrumentBoundCheckers())
+ passes.append(s_tir.transform.InstrumentBoundCheckers())
if bool(config.get("tir.ptx_ldg32", False)):
- passes.append(tir.transform.InjectPTXLDG32(True))
+ passes.append(s_tir.transform.InjectPTXLDG32(True))
passes.append(
tir.transform.CommonSubexprElimTIR(
not bool(config.get("tir.disable_cse_tir", False)),
@@ -85,39 +85,39 @@ def default_tir_pipeline():
)
)
if bool(config.get("tir.instrument_lwp", False)):
- passes.append(tir.transform.InstrumentProfileIntrinsics())
+ passes.append(s_tir.transform.InstrumentProfileIntrinsics())
passes.extend(
[
# Bind the target first so that target-specific attributes are
available.
tir.transform.FP8ComputeLegalize(),
# VerifyVTCMLimit must occur before LowerVtcmAlloc.
- tir.transform.VerifyVTCMLimit(),
- tir.transform.LowerVtcmAlloc(),
+ s_tir.transform.VerifyVTCMLimit(),
+ s_tir.transform.LowerVtcmAlloc(),
tir.transform.VerifyMemory(),
tir.transform.AnnotateEntryFunc(),
]
)
if bool(config.get("tir.detect_global_barrier", False)):
- passes.append(tir.transform.ThreadSync("global"))
+ passes.append(s_tir.transform.ThreadSync("global"))
passes.extend(
[
- tir.transform.ThreadSync("shared"),
- tir.transform.ThreadSync("shared.dyn"),
- tir.transform.ThreadSync("warp"),
- tir.transform.InferFragment(),
- tir.transform.LowerThreadAllreduce(),
+ s_tir.transform.ThreadSync("shared"),
+ s_tir.transform.ThreadSync("shared.dyn"),
+ s_tir.transform.ThreadSync("warp"),
+ s_tir.transform.InferFragment(),
+ s_tir.transform.LowerThreadAllreduce(),
]
)
if bool(config.get("tir.use_async_copy", False)):
- passes.append(tir.transform.InjectPTXAsyncCopy())
+ passes.append(s_tir.transform.InjectPTXAsyncCopy())
if bool(config.get("tir.ptx_ldg32", False)):
- passes.append(tir.transform.InjectPTXLDG32())
+ passes.append(s_tir.transform.InjectPTXLDG32())
passes.extend(
[
tir.transform.AnnotateDeviceRegions(),
tir.transform.SplitHostDevice(),
# MergeSharedMemoryAllocations must follow SplitHostDevice.
- tir.transform.MergeSharedMemoryAllocations(),
+ s_tir.transform.MergeSharedMemoryAllocations(),
tir.transform.MakePackedAPI(),
tir.transform.FP8StorageLegalize(),
tir.transform.BF16StorageLegalize(),
diff --git a/python/tvm/s_tir/pipeline.py b/python/tvm/s_tir/pipeline.py
index 59cec5b582..8c8ffb69af 100644
--- a/python/tvm/s_tir/pipeline.py
+++ b/python/tvm/s_tir/pipeline.py
@@ -60,22 +60,22 @@ def default_s_tir_pipeline():
if not bool(config.get("tir.disable_storage_rewrite", False)):
passes.append(tir.transform.StorageRewrite())
if config.get("tir.use_async_copy", False):
- passes.append(tir.transform.LowerAsyncDMA())
+ passes.append(s_tir.transform.LowerAsyncDMA())
passes.extend(
[
- tir.transform.HoistIfThenElse(),
+ s_tir.transform.HoistIfThenElse(),
tir.transform.UnrollLoop(),
- tir.transform.RenormalizeSplitPattern(),
+ s_tir.transform.RenormalizeSplitPattern(),
tir.transform.Simplify(),
tir.transform.RemoveNoOp(),
- tir.transform.RewriteUnsafeSelect(),
+ s_tir.transform.RewriteUnsafeSelect(),
]
)
# Additional passes based on configuration.
if bool(config.get("tir.instrument_bound_checkers", False)):
- passes.append(tir.transform.InstrumentBoundCheckers())
+ passes.append(s_tir.transform.InstrumentBoundCheckers())
if bool(config.get("tir.ptx_ldg32", False)):
- passes.append(tir.transform.InjectPTXLDG32(True))
+ passes.append(s_tir.transform.InjectPTXLDG32(True))
passes.append(
tir.transform.CommonSubexprElimTIR(
not bool(config.get("tir.disable_cse_tir", False)),
@@ -83,39 +83,39 @@ def default_s_tir_pipeline():
)
)
if bool(config.get("tir.instrument_lwp", False)):
- passes.append(tir.transform.InstrumentProfileIntrinsics())
+ passes.append(s_tir.transform.InstrumentProfileIntrinsics())
passes.extend(
[
# Bind the target first so that target-specific attributes are
available.
tir.transform.FP8ComputeLegalize(),
# VerifyVTCMLimit must occur before LowerVtcmAlloc.
- tir.transform.VerifyVTCMLimit(),
- tir.transform.LowerVtcmAlloc(),
+ s_tir.transform.VerifyVTCMLimit(),
+ s_tir.transform.LowerVtcmAlloc(),
tir.transform.VerifyMemory(),
tir.transform.AnnotateEntryFunc(),
]
)
if bool(config.get("tir.detect_global_barrier", False)):
- passes.append(tir.transform.ThreadSync("global"))
+ passes.append(s_tir.transform.ThreadSync("global"))
passes.extend(
[
- tir.transform.ThreadSync("shared"),
- tir.transform.ThreadSync("shared.dyn"),
- tir.transform.ThreadSync("warp"),
- tir.transform.InferFragment(),
- tir.transform.LowerThreadAllreduce(),
+ s_tir.transform.ThreadSync("shared"),
+ s_tir.transform.ThreadSync("shared.dyn"),
+ s_tir.transform.ThreadSync("warp"),
+ s_tir.transform.InferFragment(),
+ s_tir.transform.LowerThreadAllreduce(),
]
)
if bool(config.get("tir.use_async_copy", False)):
- passes.append(tir.transform.InjectPTXAsyncCopy())
+ passes.append(s_tir.transform.InjectPTXAsyncCopy())
if bool(config.get("tir.ptx_ldg32", False)):
- passes.append(tir.transform.InjectPTXLDG32())
+ passes.append(s_tir.transform.InjectPTXLDG32())
passes.extend(
[
tir.transform.AnnotateDeviceRegions(),
tir.transform.SplitHostDevice(),
# MergeSharedMemoryAllocations must follow SplitHostDevice.
- tir.transform.MergeSharedMemoryAllocations(),
+ s_tir.transform.MergeSharedMemoryAllocations(),
tir.transform.MakePackedAPI(),
tir.transform.FP8StorageLegalize(),
tir.transform.BF16StorageLegalize(),
diff --git a/python/tvm/s_tir/transform/__init__.py
b/python/tvm/s_tir/transform/__init__.py
index 4529684dc2..c669f3eaa7 100644
--- a/python/tvm/s_tir/transform/__init__.py
+++ b/python/tvm/s_tir/transform/__init__.py
@@ -18,3 +18,4 @@
# pylint: disable=wildcard-import, invalid-name
from .transform import *
+from ...tir.transform.transform import HoistedConditionals, HoistedLetBindings
diff --git a/python/tvm/s_tir/transform/transform.py
b/python/tvm/s_tir/transform/transform.py
index d4dbb8ee86..05d1a46746 100644
--- a/python/tvm/s_tir/transform/transform.py
+++ b/python/tvm/s_tir/transform/transform.py
@@ -253,3 +253,192 @@ def InjectDoubleBuffer():
The result pass
"""
return _ffi_api.InjectDoubleBuffer() # type: ignore
+
+
+def HoistIfThenElse(variant=None):
+ """Hoist loop-invariant IfThenElse nodes to outside the eligible loops.
+
+ Parameters
+ ----------
+ variant : Optional[String]
+ The variant of the pass.
+ variant can have any one of following values ["basic", None(Default)].
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ if variant == "basic":
+ return _ffi_api.HoistIfThenElseBasic() # type: ignore
+ elif variant is None:
+ return _ffi_api.HoistIfThenElse() # type: ignore
+ else:
+ raise ValueError("wrong variant of HoistIfThenElse, " + variant)
+
+
+def HoistExpression():
+ """Hoist loop-invariant expressions to outside the eligible loops.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.HoistExpression() # type: ignore
+
+
+def RenormalizeSplitPattern():
+ """Renormalize the split pattern from floordiv(floormod()) to
floormod(floordiv())
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.RenormalizeSplitPattern() # type: ignore
+
+
+def RewriteUnsafeSelect():
+ """Detect and rewrite unsafe select that contains memory access.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.RewriteUnsafeSelect() # type: ignore
+
+
+def InstrumentBoundCheckers():
+ """Instruments bound checkers.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InstrumentBoundCheckers() # type: ignore
+
+
+def InjectPTXLDG32(enable_inject_ptx_intrin=True):
+ """Inject ptx.ldg.32 intrinsics.
+
+ Parameters
+ ----------
+ enable_inject_ptx_intrin : bool
+ If True, inject ptx.ldg.32 intrinsics.
+ """
+ return _ffi_api.InjectPTXLDG32(enable_inject_ptx_intrin) # type: ignore
+
+
+def InstrumentProfileIntrinsics():
+ """Insert intrinsic calls to instrument function and loop level profiling.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InstrumentProfileIntrinsics() # type: ignore
+
+
+def VerifyVTCMLimit(default_target=None):
+ """Verify if the size of the allocated vtcm memory satisfies the limit.
+
+ The limit is determined from the "vtcm-capacity" attribute of the target.
+
+ Parameters
+ ----------
+ default_target : Optional[tvm.target.Target]
+ The default target to use if a PrimFunc does not have a target
attribute.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.VerifyVTCMLimit(default_target) # type: ignore
+
+
+def LowerVtcmAlloc():
+ """Lower vtcm allocation.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LowerVtcmAlloc() # type: ignore
+
+
+def ThreadSync(storage_scope):
+ """Insert sync between parallel read/write of shared buffers.
+
+ Parameters
+ ----------
+ storage_scope: str
+ The target storage scope.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.ThreadSync(storage_scope) # type: ignore
+
+
+def InferFragment():
+ """Infer the TensorCore fragment information using tensor intrinsics.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InferFragment() # type: ignore
+
+
+def LowerThreadAllreduce():
+ """Lower cross thread allreduce.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LowerThreadAllreduce() # type: ignore
+
+
+def LowerAsyncDMA():
+ """Lower async DMA to DMA.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LowerAsyncDMA() # type: ignore
+
+
+def InjectPTXAsyncCopy():
+ """Rewrite global to shared memory copy on CUDA with asynchronous copy.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InjectPTXAsyncCopy() # type: ignore
+
+
+def MergeSharedMemoryAllocations():
+ """This pass merges multiple TIR-level shared memory allocations
+ into one allocation.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.MergeSharedMemoryAllocations() # type: ignore
diff --git a/python/tvm/tir/transform/transform.py
b/python/tvm/tir/transform/transform.py
index 7de12d5301..a9439a77f3 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -585,7 +585,7 @@ def VerifyVTCMLimit(limit=None):
return _ffi_api.VerifyVTCMLimit(limit) # type: ignore
-@_ffi.register_object("tir.transform.HoistIfThenElseConfig")
+@_ffi.register_object("s_tir.transform.HoistIfThenElseConfig")
class HoistIfThenElseConfig(_ir.Attrs):
"""Config for hoist if then else pass"""
@@ -669,7 +669,7 @@ class HoistedLetBindings(enum.Flag):
""" Enable all hoisting of let bindings """
-@_ffi.register_object("tir.transform.HoistExpressionConfig")
+@_ffi.register_object("s_tir.transform.HoistExpressionConfig")
class HoistExpressionConfig(_ir.Attrs):
"""Config for hoist expression pass"""
diff --git a/src/tir/analysis/calculate_allocated_memory.cc
b/src/s_tir/analysis/calculate_allocated_memory.cc
similarity index 97%
rename from src/tir/analysis/calculate_allocated_memory.cc
rename to src/s_tir/analysis/calculate_allocated_memory.cc
index 1741eff937..ba7c7438bc 100644
--- a/src/tir/analysis/calculate_allocated_memory.cc
+++ b/src/s_tir/analysis/calculate_allocated_memory.cc
@@ -198,12 +198,13 @@ Pass VerifyVTCMLimit(ffi::Optional<Target>
default_target) {
}
return mod;
};
- return tvm::transform::CreateModulePass(pass_func, 0,
"tir.calculate_allocated_bytes", {});
+ return tvm::transform::CreateModulePass(pass_func, 0,
"s_tir.VerifyVTCMLimit", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tir.transform.VerifyVTCMLimit", VerifyVTCMLimit);
+ refl::GlobalDef().def("s_tir.transform.VerifyVTCMLimit", VerifyVTCMLimit);
}
} // namespace transform
diff --git a/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc
b/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc
index 553647f3e6..ef349416b4 100644
--- a/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc
+++ b/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc
@@ -179,7 +179,7 @@ class VerifyGPUCodeNode : public PostprocNode {
pass_list.push_back(s_tir::transform::InjectVirtualThread());
pass_list.push_back(s_tir::transform::InjectDoubleBuffer());
pass_list.push_back(tir::transform::StorageRewrite());
- pass_list.push_back(tir::transform::MergeSharedMemoryAllocations());
+
pass_list.push_back(s_tir::transform::MergeSharedMemoryAllocations());
pass_list.push_back(tir::transform::LowerIntrin());
// Convert Function to IRModule
tvm::transform::PassContext pass_ctx =
tvm::transform::PassContext::Current();
diff --git a/src/tir/transform/bound_checker.cc
b/src/s_tir/transform/bound_checker.cc
similarity index 96%
rename from src/tir/transform/bound_checker.cc
rename to src/s_tir/transform/bound_checker.cc
index 99d990ece6..2f2061d9c1 100644
--- a/src/tir/transform/bound_checker.cc
+++ b/src/s_tir/transform/bound_checker.cc
@@ -25,11 +25,11 @@
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
#include <unordered_map>
#include <utility>
@@ -38,7 +38,8 @@
#include "../../arith/unwrap_vector_expr.h"
namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
// TODO(Lunderberg): Move this pass to be before
// FlattenBuffer. That will simplify this pass,
@@ -254,15 +255,16 @@ Pass InstrumentBoundCheckers() {
n->body = BoundChecker(bound_collector.mem_to_shape)(std::move(n->body));
return f;
};
- return CreatePrimFuncPass(pass_func, 0, "tir.InstrumentBoundCheckers", {});
+ return CreatePrimFuncPass(pass_func, 0, "s_tir.InstrumentBoundCheckers", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.transform.InstrumentBoundCheckers",
InstrumentBoundCheckers);
+ refl::GlobalDef().def("s_tir.transform.InstrumentBoundCheckers",
+ static_cast<Pass (*)()>(InstrumentBoundCheckers));
}
} // namespace transform
-} // namespace tir
+} // namespace s_tir
} // namespace tvm
diff --git a/src/tir/transform/hoist_expression.cc
b/src/s_tir/transform/hoist_expression.cc
similarity index 91%
rename from src/tir/transform/hoist_expression.cc
rename to src/s_tir/transform/hoist_expression.cc
index ebd90583c9..add5b663bc 100644
--- a/src/tir/transform/hoist_expression.cc
+++ b/src/s_tir/transform/hoist_expression.cc
@@ -23,10 +23,10 @@
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
#include <queue>
#include <unordered_map>
@@ -36,10 +36,11 @@
#include "../../arith/interval_set.h"
#include "../../arith/ir_mutator_with_analyzer.h"
#include "../../runtime/thread_storage_scope.h"
-#include "ir_utils.h"
+#include "../../tir/transform/ir_utils.h"
namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
enum class HoistedConditionals : int {
kNone = 0,
@@ -81,7 +82,7 @@ struct HoistExpressionConfigNode : public
AttrsNodeReflAdapter<HoistExpressionCo
bool FlagSet(HoistedLetBindings flag) const {
return static_cast<int>(flag) & hoisted_let_bindings;
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.transform.HoistExpressionConfig",
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.transform.HoistExpressionConfig",
HoistExpressionConfigNode, Object);
};
@@ -99,7 +100,7 @@ class HoistExpressionConfig : public Attrs {
TVM_FFI_STATIC_INIT_BLOCK() { HoistExpressionConfigNode::RegisterReflection();
}
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistExpression", HoistExpressionConfig);
+TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.HoistExpression",
HoistExpressionConfig);
struct HoistIfThenElseConfigNode : public
AttrsNodeReflAdapter<HoistIfThenElseConfigNode> {
bool support_block_scope_hoisting;
@@ -110,7 +111,7 @@ struct HoistIfThenElseConfigNode : public
AttrsNodeReflAdapter<HoistIfThenElseCo
"support_block_scope_hoisting",
&HoistIfThenElseConfigNode::support_block_scope_hoisting,
"Hoist if cond with block scope variables", refl::DefaultValue(false));
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.transform.HoistIfThenElseConfig",
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.transform.HoistIfThenElseConfig",
HoistIfThenElseConfigNode, Object);
};
@@ -122,7 +123,7 @@ class HoistIfThenElseConfig : public Attrs {
TVM_FFI_STATIC_INIT_BLOCK() { HoistIfThenElseConfigNode::RegisterReflection();
}
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistIfThenElse", HoistIfThenElseConfig);
+TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.HoistIfThenElse",
HoistIfThenElseConfig);
class HoistInfoCollector : public StmtExprVisitor {
public:
@@ -541,7 +542,7 @@ namespace transform {
Pass HoistExpression() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
- auto cfg = ctx->GetConfig<HoistExpressionConfig>("tir.HoistExpression");
+ auto cfg = ctx->GetConfig<HoistExpressionConfig>("s_tir.HoistExpression");
if (!cfg.defined()) {
cfg = AttrsWithDefaultValues<HoistExpressionConfig>();
@@ -549,26 +550,26 @@ Pass HoistExpression() {
n->body = ExpressionHoister::Hoist(std::move(n->body), cfg.value());
return f;
};
- auto insertion_pass = CreatePrimFuncPass(pass_func, 0,
"tir.InsertHoistedExpression", {});
+ auto insertion_pass = CreatePrimFuncPass(pass_func, 0,
"s_tir.InsertHoistedExpression", {});
- return Sequential(
+ return tvm::transform::Sequential(
{
insertion_pass,
- Simplify(),
- RemoveNoOp(),
+ tir::transform::Simplify(),
+ tir::transform::RemoveNoOp(),
},
- "tir.HoistExpression");
+ "s_tir.HoistExpression");
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.transform.HoistExpression", HoistExpression);
+ refl::GlobalDef().def("s_tir.transform.HoistExpression", HoistExpression);
}
-Pass HoistIfThenElse() {
+static Pass HoistIfThenElseImpl() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
- auto cfg = ctx->GetConfig<HoistIfThenElseConfig>("tir.HoistIfThenElse");
+ auto cfg = ctx->GetConfig<HoistIfThenElseConfig>("s_tir.HoistIfThenElse");
auto flag = f->GetAttr<Integer>("tir.HoistIfThenElseExprWithBlock");
if (flag && flag.value().IntValue() == 1) {
HoistExpressionConfig
config(static_cast<int>(HoistedConditionals::kUsingBlockVar) |
@@ -588,22 +589,17 @@ Pass HoistIfThenElse() {
n->body = ExpressionHoister::Hoist(std::move(n->body), config);
return f;
};
- auto insertion_pass = CreatePrimFuncPass(pass_func, 0,
"tir.InsertHoistIfThenElse", {});
- return Sequential(
+ auto insertion_pass = CreatePrimFuncPass(pass_func, 0,
"s_tir.InsertHoistIfThenElse", {});
+ return tvm::transform::Sequential(
{
insertion_pass,
- Simplify(),
- RemoveNoOp(),
+ tir::transform::Simplify(),
+ tir::transform::RemoveNoOp(),
},
- "tir.HoistIfThenElse");
+ "s_tir.HoistIfThenElse");
}
-TVM_FFI_STATIC_INIT_BLOCK() {
- namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.transform.HoistIfThenElse", HoistIfThenElse);
-}
-
-Pass HoistIfThenElseBasic() {
+static Pass HoistIfThenElseBasicImpl() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
HoistExpressionConfig
config(static_cast<int>(HoistedConditionals::kIfElseStmt),
@@ -611,22 +607,30 @@ Pass HoistIfThenElseBasic() {
n->body = ExpressionHoister::Hoist(std::move(n->body), config);
return f;
};
- auto insertion_pass = CreatePrimFuncPass(pass_func, 0,
"tir.InsertHoistIfThenElseBasic", {});
- return Sequential(
+ auto insertion_pass = CreatePrimFuncPass(pass_func, 0,
"s_tir.InsertHoistIfThenElseBasic", {});
+ return tvm::transform::Sequential(
{
insertion_pass,
- Simplify(),
- RemoveNoOp(),
+ tir::transform::Simplify(),
+ tir::transform::RemoveNoOp(),
},
- "tir.HoistIfThenElseBasic");
+ "s_tir.HoistIfThenElseBasic");
+}
+
+Pass HoistIfThenElse(tvm::ffi::String variant) {
+ if (variant == "basic") {
+ return HoistIfThenElseBasicImpl();
+ }
+ return HoistIfThenElseImpl();
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.transform.HoistIfThenElseBasic",
HoistIfThenElseBasic);
+ refl::GlobalDef().def("s_tir.transform.HoistIfThenElse",
HoistIfThenElseImpl);
+ refl::GlobalDef().def("s_tir.transform.HoistIfThenElseBasic",
HoistIfThenElseBasicImpl);
}
} // namespace transform
-} // namespace tir
+} // namespace s_tir
} // namespace tvm
diff --git a/src/tir/transform/inject_ptx_async_copy.cc
b/src/s_tir/transform/inject_ptx_async_copy.cc
similarity index 96%
rename from src/tir/transform/inject_ptx_async_copy.cc
rename to src/s_tir/transform/inject_ptx_async_copy.cc
index 0e9820aa65..c9b1e42d7f 100644
--- a/src/tir/transform/inject_ptx_async_copy.cc
+++ b/src/s_tir/transform/inject_ptx_async_copy.cc
@@ -22,19 +22,20 @@
* \file inject_ptx_async_copy.cc
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
-#include "../ir/buffer_common.h"
+#include "../../tir/ir/buffer_common.h"
#include "storage_access.h"
#include "tvm/tir/stmt.h"
namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
class PTXAsyncCopyInjector : public StmtMutator {
public:
@@ -197,15 +198,15 @@ Pass InjectPTXAsyncCopy() {
n->body = PTXAsyncCopyInjector()(n->body);
return f;
};
- return CreatePrimFuncPass(pass_func, 0, "tir.InjectPTXAsyncCopy", {});
+ return CreatePrimFuncPass(pass_func, 0, "s_tir.InjectPTXAsyncCopy", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.transform.InjectPTXAsyncCopy",
InjectPTXAsyncCopy);
+ refl::GlobalDef().def("s_tir.transform.InjectPTXAsyncCopy",
InjectPTXAsyncCopy);
}
} // namespace transform
-} // namespace tir
+} // namespace s_tir
} // namespace tvm
diff --git a/src/tir/transform/inject_ptx_ldg32.cc
b/src/s_tir/transform/inject_ptx_ldg32.cc
similarity index 95%
rename from src/tir/transform/inject_ptx_ldg32.cc
rename to src/s_tir/transform/inject_ptx_ldg32.cc
index f52539fa77..5f10d25917 100644
--- a/src/tir/transform/inject_ptx_ldg32.cc
+++ b/src/s_tir/transform/inject_ptx_ldg32.cc
@@ -21,17 +21,18 @@
#include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
#include "../../arith/const_fold.h"
#include "../../arith/pattern_match.h"
namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
class PTXRewriter : public StmtMutator {
public:
@@ -145,16 +146,16 @@ Pass InjectPTXLDG32(bool enable_inject_ptx_intrin) {
}
return f;
};
- return CreatePrimFuncPass(pass_func, 0, "tir.InjectPTXLDG32", {});
+ return CreatePrimFuncPass(pass_func, 0, "s_tir.InjectPTXLDG32", {});
}
// The pass can now be invoked via the pass infrastructure, but we also add a
// Python binding for it
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.transform.InjectPTXLDG32", InjectPTXLDG32);
+ refl::GlobalDef().def("s_tir.transform.InjectPTXLDG32", InjectPTXLDG32);
}
} // namespace transform
-} // namespace tir
+} // namespace s_tir
} // namespace tvm
diff --git a/src/tir/transform/lower_async_dma.cc
b/src/s_tir/transform/lower_async_dma.cc
similarity index 95%
rename from src/tir/transform/lower_async_dma.cc
rename to src/s_tir/transform/lower_async_dma.cc
index 1b7bf14c38..3b6fd0480c 100644
--- a/src/tir/transform/lower_async_dma.cc
+++ b/src/s_tir/transform/lower_async_dma.cc
@@ -25,19 +25,20 @@
#include <tvm/arith/bound.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
#include <optional>
#include "../../arith/ir_mutator_with_analyzer.h"
-#include "ir_utils.h"
+#include "../../tir/transform/ir_utils.h"
namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer {
public:
@@ -174,14 +175,14 @@ Pass LowerAsyncDMA() {
fptr->body = AsyncDMALowerer(dma_bypass_cache,
&analyzer)(std::move(fptr->body));
return f;
};
- return CreatePrimFuncPass(pass_func, 0, "tir.LowerAsyncDMA", {});
+ return CreatePrimFuncPass(pass_func, 0, "s_tir.LowerAsyncDMA", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.transform.LowerAsyncDMA", LowerAsyncDMA);
+ refl::GlobalDef().def("s_tir.transform.LowerAsyncDMA", LowerAsyncDMA);
}
} // namespace transform
-} // namespace tir
+} // namespace s_tir
} // namespace tvm
diff --git a/src/tir/transform/lower_thread_allreduce.cc
b/src/s_tir/transform/lower_thread_allreduce.cc
similarity index 98%
rename from src/tir/transform/lower_thread_allreduce.cc
rename to src/s_tir/transform/lower_thread_allreduce.cc
index 4a0eb49cc3..5f1fc9afaa 100644
--- a/src/tir/transform/lower_thread_allreduce.cc
+++ b/src/s_tir/transform/lower_thread_allreduce.cc
@@ -24,20 +24,21 @@
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
#include <unordered_set>
#include "../../runtime/thread_storage_scope.h"
-#include "ir_utils.h"
-#include "update_pointer_storage_scope.h"
+#include "../../tir/transform/ir_utils.h"
+#include "../../tir/transform/update_pointer_storage_scope.h"
namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
class ThreadAllreduceBuilder final : public StmtExprMutator {
public:
@@ -47,12 +48,12 @@ class ThreadAllreduceBuilder final : public StmtExprMutator
{
max_num_threads_(target->GetAttr<Integer>("max_num_threads",
-1).value().IntValue()) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == attr::thread_extent) {
+ if (op->attr_key == tir::attr::thread_extent) {
thread_extents_.push_back(op);
Stmt ret = StmtExprMutator::VisitStmt_(op);
thread_extents_.pop_back();
return ret;
- } else if (op->attr_key == attr::reduce_scope) {
+ } else if (op->attr_key == tir::attr::reduce_scope) {
const CommReducerNode* combiner = op->node.as<CommReducerNode>();
ICHECK(combiner);
reduce_combiner_.push_back(combiner);
@@ -86,7 +87,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
if (buf.scope() == "shared") {
// Use volatile access to shared buffer.
- write_ptr->body = AttrStmt(buf->data, attr::volatile_scope, 1,
write_ptr->body);
+ write_ptr->body = AttrStmt(buf->data, tir::attr::volatile_scope, 1,
write_ptr->body);
}
}
return node;
@@ -807,14 +808,14 @@ Pass LowerThreadAllreduce() {
n->body = thread_all_reduce(n->body);
return f;
};
- return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {});
+ return CreatePrimFuncPass(pass_func, 0, "s_tir.LowerThreadAllreduce", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.transform.LowerThreadAllreduce",
LowerThreadAllreduce);
+ refl::GlobalDef().def("s_tir.transform.LowerThreadAllreduce",
LowerThreadAllreduce);
}
} // namespace transform
-} // namespace tir
+} // namespace s_tir
} // namespace tvm
diff --git a/src/tir/transform/lower_vtcm_alloc.cc
b/src/s_tir/transform/lower_vtcm_alloc.cc
similarity index 88%
rename from src/tir/transform/lower_vtcm_alloc.cc
rename to src/s_tir/transform/lower_vtcm_alloc.cc
index c3b03f8623..469f7c4655 100644
--- a/src/tir/transform/lower_vtcm_alloc.cc
+++ b/src/s_tir/transform/lower_vtcm_alloc.cc
@@ -18,14 +18,15 @@
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt.h>
-#include <tvm/tir/transform.h>
#include "../../arith/ir_visitor_with_analyzer.h"
namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
inline bool IsVtcmStorage(std::string scope) {
return scope.find("global.vtcm") != std::string::npos;
@@ -68,17 +69,17 @@ namespace transform {
Pass LowerVtcmAlloc() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
- return LowerVtcmAlloc(std::move(f));
+ return s_tir::LowerVtcmAlloc(std::move(f));
};
- return CreatePrimFuncPass(pass_func, 0, "tir.LowerVtcmAlloc", {});
+ return CreatePrimFuncPass(pass_func, 0, "s_tir.LowerVtcmAlloc", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.transform.LowerVtcmAlloc", LowerVtcmAlloc);
+ refl::GlobalDef().def("s_tir.transform.LowerVtcmAlloc", static_cast<Pass
(*)()>(LowerVtcmAlloc));
}
} // namespace transform
-} // namespace tir
+} // namespace s_tir
} // namespace tvm
diff --git a/src/tir/transform/merge_shared_memory_allocations.cc
b/src/s_tir/transform/merge_shared_memory_allocations.cc
similarity index 97%
rename from src/tir/transform/merge_shared_memory_allocations.cc
rename to src/s_tir/transform/merge_shared_memory_allocations.cc
index 4a2b8698d8..11fe15e95a 100644
--- a/src/tir/transform/merge_shared_memory_allocations.cc
+++ b/src/s_tir/transform/merge_shared_memory_allocations.cc
@@ -25,20 +25,21 @@
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "../../runtime/thread_storage_scope.h"
#include "../../support/arena.h"
-#include "ir_utils.h"
+#include "../../tir/transform/ir_utils.h"
namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
using runtime::StorageRank;
using runtime::StorageScope;
@@ -207,13 +208,13 @@ class SharedMemLinearAccessPatternFinder final : public
StmtExprVisitor {
void VisitStmt_(const AttrStmtNode* op) final {
// Only record the outer most thread extent.
- if (op->attr_key == attr::thread_extent && !in_thread_env_) {
+ if (op->attr_key == tir::attr::thread_extent && !in_thread_env_) {
in_thread_env_ = true;
VisitNewScope(op);
in_thread_env_ = false;
- } else if (op->attr_key == attr::extern_scope) {
+ } else if (op->attr_key == tir::attr::extern_scope) {
VisitNewScope(op);
- } else if (op->attr_key == attr::virtual_thread) {
+ } else if (op->attr_key == tir::attr::virtual_thread) {
VisitNewScope(op);
} else {
StmtExprVisitor::VisitStmt_(op);
@@ -273,7 +274,7 @@ class SharedMemoryRewriter : public StmtExprMutator {
private:
Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == attr::thread_extent && !allocated_) {
+ if (op->attr_key == tir::attr::thread_extent && !allocated_) {
// Allocate one dynamic shared memory allocation at the beginning of
thread scope
int max_layer_num = 0;
std::vector<const StorageEntry*> all_entry;
@@ -690,17 +691,18 @@ Pass MergeSharedMemoryAllocations() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem",
Bool(false)).value();
auto* n = f.CopyOnWrite();
- n->body = MergeSharedMemoryAllocations(std::move(n->body),
merge_static_smem);
+ n->body = s_tir::MergeSharedMemoryAllocations(std::move(n->body),
merge_static_smem);
return f;
};
- return CreatePrimFuncPass(pass_func, 0, "tir.MergeSharedMemoryAllocations",
{});
+ return CreatePrimFuncPass(pass_func, 0,
"s_tir.MergeSharedMemoryAllocations", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.transform.MergeSharedMemoryAllocations",
MergeSharedMemoryAllocations);
+ refl::GlobalDef().def("s_tir.transform.MergeSharedMemoryAllocations",
+ static_cast<Pass (*)()>(MergeSharedMemoryAllocations));
}
} // namespace transform
-} // namespace tir
+} // namespace s_tir
} // namespace tvm
diff --git a/src/tir/transform/profile_instrumentation.cc
b/src/s_tir/transform/profile_instrumentation.cc
similarity index 89%
rename from src/tir/transform/profile_instrumentation.cc
rename to src/s_tir/transform/profile_instrumentation.cc
index 513f0d730e..268190557c 100644
--- a/src/tir/transform/profile_instrumentation.cc
+++ b/src/s_tir/transform/profile_instrumentation.cc
@@ -25,21 +25,22 @@
// and can be used to capture profiling information such as processor cycles.
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
namespace lwp {
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.lwp_disable_func_prof", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.lwp_max_depth", Integer);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.lwp_min_height", Integer);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.instr_siblings", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.reset_start_id", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.lwp_disable_func_prof", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.lwp_max_depth", Integer);
+TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.lwp_min_height", Integer);
+TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.instr_siblings", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.reset_start_id", Bool);
static int32_t start_id = 0;
@@ -259,12 +260,12 @@ Pass InstrumentProfileIntrinsics() {
// In addition, loops with siblings are also instrumented provided
// their loop depth is >= min_instr_height. This is done to avoid
// instrumenting inner-most loops.
- auto max_instr_depth = ctx->GetConfig<Integer>("tir.lwp_max_depth",
Integer(0)).value();
- auto min_instr_height = ctx->GetConfig<Integer>("tir.lwp_min_height",
Integer(1)).value();
- bool instr_siblings = ctx->GetConfig<Bool>("tir.instr_siblings",
Bool(true)).value();
+ auto max_instr_depth = ctx->GetConfig<Integer>("s_tir.lwp_max_depth",
Integer(0)).value();
+ auto min_instr_height = ctx->GetConfig<Integer>("s_tir.lwp_min_height",
Integer(1)).value();
+ bool instr_siblings = ctx->GetConfig<Bool>("s_tir.instr_siblings",
Bool(true)).value();
bool disable_func_instrumentation =
- ctx->GetConfig<Bool>("tir.lwp_disable_func_prof", Bool(false)).value();
- bool reset_start_id = ctx->GetConfig<Bool>("tir.reset_start_id",
Bool(false)).value();
+ ctx->GetConfig<Bool>("s_tir.lwp_disable_func_prof",
Bool(false)).value();
+ bool reset_start_id = ctx->GetConfig<Bool>("s_tir.reset_start_id",
Bool(false)).value();
if (reset_start_id) lwp::start_id = 0;
std::vector<std::pair<GlobalVar, PrimFunc>> updates;
for (const auto& kv : mptr->functions) {
@@ -281,15 +282,15 @@ Pass InstrumentProfileIntrinsics() {
return m;
};
- return tvm::transform::CreateModulePass(pass_func, 0,
"tir.InstrumentProfileIntrinsics", {});
+ return tvm::transform::CreateModulePass(pass_func, 0,
"s_tir.InstrumentProfileIntrinsics", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.transform.InstrumentProfileIntrinsics",
InstrumentProfileIntrinsics);
+ refl::GlobalDef().def("s_tir.transform.InstrumentProfileIntrinsics",
InstrumentProfileIntrinsics);
}
} // namespace transform
-} // namespace tir
+} // namespace s_tir
} // namespace tvm
diff --git a/src/tir/transform/renormalize_split_pattern.cc
b/src/s_tir/transform/renormalize_split_pattern.cc
similarity index 96%
rename from src/tir/transform/renormalize_split_pattern.cc
rename to src/s_tir/transform/renormalize_split_pattern.cc
index 04dbcca510..abf127e2a2 100644
--- a/src/tir/transform/renormalize_split_pattern.cc
+++ b/src/s_tir/transform/renormalize_split_pattern.cc
@@ -23,17 +23,18 @@
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
#include "../../arith/ir_mutator_with_analyzer.h"
#include "../../arith/pattern_match.h"
namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
using namespace arith;
@@ -203,15 +204,15 @@ Pass RenormalizeSplitPattern() {
n->body = SplitPatternReNormalizer(&analyzer)(std::move(n->body));
return f;
};
- return CreatePrimFuncPass(pass_func, 0, "tir.RenormalizeSplitPattern", {});
+ return CreatePrimFuncPass(pass_func, 0, "s_tir.RenormalizeSplitPattern", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.transform.RenormalizeSplitPattern",
RenormalizeSplitPattern);
+ refl::GlobalDef().def("s_tir.transform.RenormalizeSplitPattern",
RenormalizeSplitPattern);
}
} // namespace transform
-} // namespace tir
+} // namespace s_tir
} // namespace tvm
diff --git a/src/tir/transform/rewrite_unsafe_select.cc
b/src/s_tir/transform/rewrite_unsafe_select.cc
similarity index 95%
rename from src/tir/transform/rewrite_unsafe_select.cc
rename to src/s_tir/transform/rewrite_unsafe_select.cc
index 3dfbcb9967..267d9b4d00 100644
--- a/src/tir/transform/rewrite_unsafe_select.cc
+++ b/src/s_tir/transform/rewrite_unsafe_select.cc
@@ -23,14 +23,15 @@
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
// For now, rewrite unsafe select expression to if_then_else
// TODO(tqchen) pattern matching to support masked load
@@ -137,15 +138,15 @@ Pass RewriteUnsafeSelect() {
n->body = UnsafeSelectRewriter()(std::move(n->body));
return f;
};
- return CreatePrimFuncPass(pass_func, 0, "tir.RewriteUnsafeSelect", {});
+ return CreatePrimFuncPass(pass_func, 0, "s_tir.RewriteUnsafeSelect", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.transform.RewriteUnsafeSelect",
RewriteUnsafeSelect);
+ refl::GlobalDef().def("s_tir.transform.RewriteUnsafeSelect",
RewriteUnsafeSelect);
}
} // namespace transform
-} // namespace tir
+} // namespace s_tir
} // namespace tvm
diff --git a/src/tir/transform/storage_access.cc
b/src/s_tir/transform/storage_access.cc
similarity index 96%
rename from src/tir/transform/storage_access.cc
rename to src/s_tir/transform/storage_access.cc
index 2a38e64cc7..7a5b6e1622 100644
--- a/src/tir/transform/storage_access.cc
+++ b/src/s_tir/transform/storage_access.cc
@@ -28,10 +28,11 @@
#include <string>
#include <utility>
-#include "ir_utils.h"
+#include "../../tir/transform/ir_utils.h"
namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
void StorageAccessVisitor::VisitExpr_(const BufferLoadNode* op) {
Var buf = op->buffer->data;
@@ -109,7 +110,7 @@ void StorageAccessVisitor::VisitStmt_(const LetStmtNode*
op) {
}
void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) {
- if (op->attr_key == attr::double_buffer_write) {
+ if (op->attr_key == tir::attr::double_buffer_write) {
ICHECK(double_buffer_write_ == nullptr);
double_buffer_write_ = op->node.as<VarNode>();
scope_.push_back(std::vector<StmtEntry>());
@@ -127,12 +128,12 @@ void StorageAccessVisitor::VisitStmt_(const AttrStmtNode*
op) {
scope_.back().emplace_back(std::move(s));
}
double_buffer_write_ = nullptr;
- } else if (op->attr_key == attr::coproc_scope) {
+ } else if (op->attr_key == tir::attr::coproc_scope) {
IterVar iv = Downcast<IterVar>(op->node);
env_threads_.push_back(iv);
StmtExprVisitor::VisitStmt_(op);
env_threads_.pop_back();
- } else if (op->attr_key == attr::thread_extent) {
+ } else if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
env_threads_.push_back(iv);
if (!in_device_env_) {
@@ -147,7 +148,7 @@ void StorageAccessVisitor::VisitStmt_(const AttrStmtNode*
op) {
StmtExprVisitor::VisitStmt_(op);
}
env_threads_.pop_back();
- } else if (op->attr_key == attr::hand_threaded) {
+ } else if (op->attr_key == tir::attr::hand_threaded) {
// skip this pass on blocks that were hand_threaded
// this avoids control flow and read/write conflicts
// between hand-threaded kernels and automatic threading
@@ -293,5 +294,5 @@ StorageScope StorageAccessVisitor::GetScope(Var buffer_var)
const {
return StorageScope(); // global by default
}
-} // namespace tir
+} // namespace s_tir
} // namespace tvm
diff --git a/src/tir/transform/storage_access.h
b/src/s_tir/transform/storage_access.h
similarity index 95%
rename from src/tir/transform/storage_access.h
rename to src/s_tir/transform/storage_access.h
index 7c96068229..848d8edcf5 100644
--- a/src/tir/transform/storage_access.h
+++ b/src/s_tir/transform/storage_access.h
@@ -21,8 +21,8 @@
* \file storage_access.h
* \brief Common data structure for storage access analysis.
*/
-#ifndef TVM_TIR_TRANSFORM_STORAGE_ACCESS_H_
-#define TVM_TIR_TRANSFORM_STORAGE_ACCESS_H_
+#ifndef TVM_S_TIR_TRANSFORM_STORAGE_ACCESS_H_
+#define TVM_S_TIR_TRANSFORM_STORAGE_ACCESS_H_
#include <tvm/arith/int_set.h>
#include <tvm/ir/attrs.h>
@@ -35,7 +35,8 @@
#include "../../runtime/thread_storage_scope.h"
namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tir;
using runtime::StorageRank;
using runtime::StorageScope;
@@ -140,6 +141,6 @@ class StorageAccessVisitor : public StmtExprVisitor {
// The involving threads
ffi::Array<IterVar> env_threads_;
};
-} // namespace tir
+} // namespace s_tir
} // namespace tvm
-#endif // TVM_TIR_TRANSFORM_STORAGE_ACCESS_H_
+#endif // TVM_S_TIR_TRANSFORM_STORAGE_ACCESS_H_
diff --git a/src/tir/transform/tensorcore_infer_fragment.cc
b/src/s_tir/transform/tensorcore_infer_fragment.cc
similarity index 91%
rename from src/tir/transform/tensorcore_infer_fragment.cc
rename to src/s_tir/transform/tensorcore_infer_fragment.cc
index 7c1b5b05d0..d1232e5164 100644
--- a/src/tir/transform/tensorcore_infer_fragment.cc
+++ b/src/s_tir/transform/tensorcore_infer_fragment.cc
@@ -23,19 +23,20 @@
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "../../runtime/thread_storage_scope.h"
-#include "ir_utils.h"
+#include "../../tir/transform/ir_utils.h"
#include "storage_access.h"
namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
// Get fragment information from tensor intrinsics
class FragmentGetter : public StmtExprVisitor {
@@ -113,11 +114,17 @@ class FragmentGetter : public StmtExprVisitor {
std::unordered_map<const VarNode*, FragmentInfo> fragments;
};
+} // namespace s_tir
+
+namespace tir {
std::unordered_map<const VarNode*, FragmentInfo>
GetTensorCoreFragmentInfo(const Stmt& stmt) {
- FragmentGetter getter;
+ s_tir::FragmentGetter getter;
getter(stmt);
return std::move(getter.fragments);
}
+} // namespace tir
+
+namespace s_tir {
// Check shape of fragment making sure it is a valid shape for tvm_mma_sync
class FragmentChecker : public StmtExprVisitor {
@@ -180,11 +187,11 @@ class InferFragmenter : public StmtMutator {
std::string shape =
std::to_string(info.m) + ", " + std::to_string(info.n) + ", " +
std::to_string(info.k);
PrimExpr shape_expr = StringImm(shape);
- Stmt shape_attr = AttrStmt(op->buffer_var, attr::fragment_shape,
shape_expr, stmt);
+ Stmt shape_attr = AttrStmt(op->buffer_var, tir::attr::fragment_shape,
shape_expr, stmt);
if (info.layout != "") {
// Add shape attribute to matrix_a and matrix_b
- Stmt layout_attr =
- AttrStmt(op->buffer_var, attr::fragment_layout,
StringImm(info.layout), shape_attr);
+ Stmt layout_attr = AttrStmt(op->buffer_var, tir::attr::fragment_layout,
+ StringImm(info.layout), shape_attr);
return layout_attr;
} else {
return shape_attr;
@@ -212,17 +219,17 @@ namespace transform {
Pass InferFragment() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
- n->body = InferFragment(std::move(n->body));
+ n->body = s_tir::InferFragment(std::move(n->body));
return f;
};
- return CreatePrimFuncPass(pass_func, 0, "tir.InferFragment", {});
+ return CreatePrimFuncPass(pass_func, 0, "s_tir.InferFragment", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.transform.InferFragment", InferFragment);
+ refl::GlobalDef().def("s_tir.transform.InferFragment", static_cast<Pass
(*)()>(InferFragment));
}
} // namespace transform
-} // namespace tir
+} // namespace s_tir
} // namespace tvm
diff --git a/src/tir/transform/thread_storage_sync.cc
b/src/s_tir/transform/thread_storage_sync.cc
similarity index 96%
rename from src/tir/transform/thread_storage_sync.cc
rename to src/s_tir/transform/thread_storage_sync.cc
index d41d474a08..57d8f25b51 100644
--- a/src/tir/transform/thread_storage_sync.cc
+++ b/src/s_tir/transform/thread_storage_sync.cc
@@ -22,21 +22,22 @@
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "../../runtime/thread_storage_scope.h"
-#include "ir_utils.h"
+#include "../../tir/transform/ir_utils.h"
#include "storage_access.h"
namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
class ThreadSyncPlanner : public StorageAccessVisitor {
public:
@@ -291,7 +292,7 @@ class ThreadSyncAfterWaitQueueInserter : public
StmtExprMutator {
explicit ThreadSyncAfterWaitQueueInserter(StorageScope sync_scope) :
sync_scope_(sync_scope) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == attr::async_wait_queue_scope) {
+ if (op->attr_key == tir::attr::async_wait_queue_scope) {
auto sync = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
{StringImm(sync_scope_.to_string())}));
auto inner = op->body.as<AttrStmtNode>();
@@ -346,7 +347,7 @@ class ThreadSyncInserter : public StmtExprMutator {
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == attr::thread_extent) {
+ if (op->attr_key == tir::attr::thread_extent) {
bool temp = true;
std::swap(temp, in_thread_env_);
thread_extents_.push_back(op);
@@ -407,7 +408,7 @@ class ThreadSyncInserter : public StmtExprMutator {
for (const auto& kv : rw_stats_) {
const auto& e = kv.second;
if (e.read_count != 0 && e.write_count != 0) {
- body = AttrStmt(kv.first, attr::volatile_scope, 1, body);
+ body = AttrStmt(kv.first, tir::attr::volatile_scope, 1, body);
}
}
rw_stats_.clear();
@@ -466,17 +467,18 @@ namespace transform {
Pass ThreadSync(ffi::String storage_scope) {
auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
- n->body = ThreadSync(std::move(n->body), storage_scope);
+ n->body = s_tir::ThreadSync(std::move(n->body), storage_scope);
return f;
};
- return CreatePrimFuncPass(pass_func, 0, "tir.ThreadSync", {});
+ return CreatePrimFuncPass(pass_func, 0, "s_tir.ThreadSync", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.transform.ThreadSync", ThreadSync);
+ refl::GlobalDef().def("s_tir.transform.ThreadSync",
+ static_cast<Pass (*)(ffi::String)>(ThreadSync));
}
} // namespace transform
-} // namespace tir
+} // namespace s_tir
} // namespace tvm
diff --git a/tests/python/tir-transform/test_tir_transform_hoist_expression.py
b/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py
similarity index 98%
rename from tests/python/tir-transform/test_tir_transform_hoist_expression.py
rename to tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py
index c973e86e18..25419dc88b 100644
--- a/tests/python/tir-transform/test_tir_transform_hoist_expression.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py
@@ -16,8 +16,9 @@
# under the License.
import tvm
import tvm.testing
+from tvm import s_tir
from tvm.script import tir as T
-from tvm.tir.transform import HoistedConditionals, HoistedLetBindings
+from tvm.s_tir.transform import HoistedConditionals, HoistedLetBindings
def _run_transform(before, hoisted_conditionals, hoisted_let_bindings):
@@ -25,14 +26,14 @@ def _run_transform(before, hoisted_conditionals,
hoisted_let_bindings):
before_mod = tvm.IRModule.from_expr(before)
config = {
- "tir.HoistExpression": {
+ "s_tir.HoistExpression": {
"hoisted_conditionals": hoisted_conditionals.value,
"hoisted_let_bindings": hoisted_let_bindings.value,
}
}
with tvm.transform.PassContext(config=config):
- after_mod = tvm.tir.transform.HoistExpression()(before_mod)
+ after_mod = tvm.s_tir.transform.HoistExpression()(before_mod)
return after_mod["main"]
diff --git a/tests/python/tir-transform/test_tir_transform_hoist_if.py
b/tests/python/s_tir/transform/test_s_tir_transform_hoist_if.py
similarity index 88%
rename from tests/python/tir-transform/test_tir_transform_hoist_if.py
rename to tests/python/s_tir/transform/test_s_tir_transform_hoist_if.py
index a941c6c019..5d0f48df4d 100644
--- a/tests/python/tir-transform/test_tir_transform_hoist_if.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_hoist_if.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
-from tvm import te
+from tvm import te, s_tir
from tvm.script import tir as T, ir as I
import numpy as np
import pytest
@@ -77,7 +77,7 @@ def test_hoist_top_for():
T.evaluate(T.call_extern("int32", "dummy", n))
mod = tvm.IRModule.from_expr(func)
- new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
expected_struct = {
("tir.For", "k"): (None,),
("tir.For", "j"): (("tir.For", "k"),),
@@ -99,7 +99,7 @@ def test_hoist_multi_var_if():
T.evaluate(T.call_extern("int32", "dummy", n))
mod = tvm.IRModule.from_expr(func)
- new_mod = tvm.tir.transform.HoistIfThenElse()(mod)
+ new_mod = tvm.s_tir.transform.HoistIfThenElse()(mod)
new_stmt = new_mod["main"].body
expected_struct = {
("tir.For", "k"): (None,),
@@ -124,7 +124,7 @@ def test_hoist_no_match_for():
T.evaluate(T.call_extern("int32", "dummy", n))
mod = tvm.IRModule.from_expr(func)
- new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
expected_struct = {
("tir.For", "k"): (None,),
("tir.IfThenElse", ("i",)): (("tir.For", "k"), ("tir.For", "k")),
@@ -144,7 +144,7 @@ def test_no_else():
T.evaluate(T.call_extern("int32", "dummy", m))
mod = tvm.IRModule.from_expr(func)
- new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
expected_struct = {
("tir.For", "k"): (None,),
("tir.For", "j"): (("tir.For", "k"),),
@@ -175,7 +175,7 @@ def test_attr_stmt():
)
mod = tvm.IRModule.from_expr(func)
- new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
expected_struct = {
("tir.For", "k"): (None,),
("tir.IfThenElse", ("i", "j")): (("tir.For", "k"), ("tir.For", "k")),
@@ -207,7 +207,7 @@ def test_nested_for():
] * T.float32(1.5)
mod = tvm.IRModule.from_expr(func)
- new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
expected_struct = {
("tir.For", "l"): (None,),
("tir.For", "k"): (("tir.For", "l"),),
@@ -248,7 +248,7 @@ def test_if_block():
if n >= 3:
data[i2 * 3 + j2 + k2] = data[i2 * 3 + j2 + k2] +
T.float32(0.6)
- new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
# Updated expected_struct with renamed second nest variables
expected_struct = {
("tir.IfThenElse", ("i", "j")): (None, None),
@@ -280,7 +280,7 @@ def test_multi_if():
] + T.float32(0.5)
mod = tvm.IRModule.from_expr(func)
- new_mod = tvm.tir.transform.HoistIfThenElse()(mod)
+ new_mod = tvm.s_tir.transform.HoistIfThenElse()(mod)
new_stmt = new_mod["main"].body
expected_struct = {
("tir.For", "k"): (None,),
@@ -306,13 +306,13 @@ def test_no_hoisting_1():
mod = tvm.IRModule.from_expr(func)
stmt = mod["main"].body
- new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
- config={"tir.HoistIfThenElse": {"support_block_scope_hoisting": True}}
+ config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting":
True}}
):
- new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
@@ -331,13 +331,13 @@ def test_no_hoisting_2():
mod = tvm.IRModule.from_expr(func)
stmt = mod["main"].body
- new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
- config={"tir.HoistIfThenElse": {"support_block_scope_hoisting": True}}
+ config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting":
True}}
):
- new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
@@ -370,13 +370,13 @@ def test_no_hoisting_4():
] + T.float32(1.3)
stmt = Module["main"].body
- new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
- config={"tir.HoistIfThenElse": {"support_block_scope_hoisting": True}}
+ config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting":
True}}
):
- new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
@@ -398,13 +398,13 @@ def test_no_hoisting_6():
data[bx * j + tx * j * k] = data[bx * j + tx * j *
k] + T.float32(1.3)
stmt = Module["main"].body
- new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
- config={"tir.HoistIfThenElse": {"support_block_scope_hoisting": True}}
+ config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting":
True}}
):
- new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
@@ -427,13 +427,13 @@ def test_no_hoisting_7():
)
stmt = Module["main"].body
- new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
- config={"tir.HoistIfThenElse": {"support_block_scope_hoisting": True}}
+ config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting":
True}}
):
- new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
@@ -469,13 +469,13 @@ def test_hoisting_block_scope_2():
mod = tvm.tir.transform.RemoveNoOp()(mod)
stmt = mod["main"].body
- new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
- config={"tir.HoistIfThenElse": {"support_block_scope_hoisting": True}}
+ config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting":
True}}
):
- new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
assert not tvm.ir.structural_equal(new_stmt, stmt)
@@ -493,16 +493,16 @@ def test_hoisting_block_scope_5():
data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k]
+ T.float32(1.3)
stmt = Module["main"].body
- new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
assert not tvm.ir.structural_equal(new_stmt, stmt)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], new_stmt))
stmt = new_stmt
with tvm.transform.PassContext(
- config={"tir.HoistIfThenElse": {"support_block_scope_hoisting": True}}
+ config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting":
True}}
):
- new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
@@ -524,13 +524,13 @@ def test_hoisting_block_scope_6():
data[bx * j + tx * j * k] = data[bx * j + tx * j *
k] + T.float32(1.3)
stmt = Module["main"].body
- new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
- config={"tir.HoistIfThenElse": {"support_block_scope_hoisting": True}}
+ config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting":
True}}
):
- new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
assert not tvm.ir.structural_equal(new_stmt, stmt)
@@ -552,13 +552,13 @@ def test_hoisting_block_scope_7():
data[bx * j + tx * j * k] = data[bx * j + tx * j *
k] + T.float32(1.3)
stmt = Module["main"].body
- new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
tvm.ir.assert_structural_equal(new_stmt, stmt)
with tvm.transform.PassContext(
- config={"tir.HoistIfThenElse": {"support_block_scope_hoisting": True}}
+ config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting":
True}}
):
- new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+ new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
assert not tvm.ir.structural_equal(new_stmt, stmt)
diff --git
a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py
similarity index 99%
rename from
tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
rename to
tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py
index 5ab0cfe6f0..37b6d02f02 100644
--- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py
@@ -18,6 +18,7 @@
import tvm
import tvm_ffi
import tvm.testing
+from tvm import s_tir
from tvm.script import tir as T, ir as I
import pytest
@@ -135,7 +136,7 @@ def test_inject_async_copy():
mod = tvm.tir.transform.FlattenBuffer()(mod)
if vec_size > 1:
mod = tvm.tir.transform.VectorizeLoop()(mod)
- mod = tvm.tir.transform.InjectPTXAsyncCopy()(mod)
+ mod = tvm.s_tir.transform.InjectPTXAsyncCopy()(mod)
assert count_cp_async(mod["main"].body) == 1
@@ -162,8 +163,8 @@ def test_inject_async_copy_shared_dyn():
mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod)
mod = tvm.tir.transform.FlattenBuffer()(mod)
mod = tvm.tir.transform.VectorizeLoop()(mod)
- mod = tvm.tir.transform.MergeSharedMemoryAllocations()(mod)
- mod = tvm.tir.transform.InjectPTXAsyncCopy()(mod)
+ mod = tvm.s_tir.transform.MergeSharedMemoryAllocations()(mod)
+ mod = tvm.s_tir.transform.InjectPTXAsyncCopy()(mod)
assert count_cp_async(mod["main"].body) == 2
@@ -223,7 +224,7 @@ def test_inject_async_copy_barrier():
mod = tvm.IRModule.from_expr(f)
mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod)
mod = tvm.tir.transform.FlattenBuffer()(mod)
- mod = tvm.tir.transform.InjectPTXAsyncCopy()(mod)
+ mod = tvm.s_tir.transform.InjectPTXAsyncCopy()(mod)
assert count_cp_async(mod["main"].body) == 1
@@ -981,7 +982,7 @@ def test_multiplication_nodes_are_inlined():
T.ptx_commit_group()
T.ptx_wait_group(0)
- After = tvm.tir.transform.InjectPTXAsyncCopy()(Before)
+ After = tvm.s_tir.transform.InjectPTXAsyncCopy()(Before)
tvm.ir.assert_structural_equal(After, Expected)
diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_ldg32.py
b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_ldg32.py
similarity index 95%
rename from tests/python/tir-transform/test_tir_transform_inject_ptx_ldg32.py
rename to tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_ldg32.py
index 55099f252c..c0aed0351e 100644
--- a/tests/python/tir-transform/test_tir_transform_inject_ptx_ldg32.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_ldg32.py
@@ -17,6 +17,7 @@
import tvm
import tvm.testing
+from tvm import s_tir
from tvm.script import tir as T
@@ -60,7 +61,7 @@ def test_inject_ptx_ldg32_inserts_alloc_for_no_alloc_func():
mod = tvm.IRModule.from_expr(where_no_alloc)
assert _count_alloc(mod["main"].body) == 0
- mod = tvm.tir.transform.InjectPTXLDG32()(mod)
+ mod = tvm.s_tir.transform.InjectPTXLDG32()(mod)
assert _count_alloc(mod["main"].body) > 0
assert _count_ptx_ldg32(mod["main"].body) == 1
@@ -71,7 +72,7 @@ def test_inject_ptx_ldg32_skip_non_cuda_target():
mod = tvm.IRModule({"main": mod["main"].with_attr("target", cpu_target)})
assert _count_alloc(mod["main"].body) == 0
- mod = tvm.tir.transform.InjectPTXLDG32()(mod)
+ mod = tvm.s_tir.transform.InjectPTXLDG32()(mod)
assert _count_alloc(mod["main"].body) == 0
assert _count_ptx_ldg32(mod["main"].body) == 0
diff --git
a/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py
b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
similarity index 98%
rename from
tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py
rename to
tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
index cfd81377c1..edffa0dcc5 100644
--- a/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py
+++
b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
@@ -17,11 +17,12 @@
import tvm
import tvm.testing
+from tvm import s_tir
from tvm.script import tir as T, ir as I
def test_basic():
- transform = tvm.tir.transform.LowerThreadAllreduce()
+ transform = tvm.s_tir.transform.LowerThreadAllreduce()
@I.ir_module
class Before:
@@ -97,7 +98,7 @@ def test_basic():
def test_basic_with_decl_buffer():
- transform = tvm.tir.transform.LowerThreadAllreduce()
+ transform = tvm.s_tir.transform.LowerThreadAllreduce()
@I.ir_module
class Before:
@@ -171,7 +172,7 @@ def test_basic_with_decl_buffer():
def test_reduce_summation():
- transform = tvm.tir.transform.LowerThreadAllreduce()
+ transform = tvm.s_tir.transform.LowerThreadAllreduce()
@I.ir_module
class Before:
@@ -261,7 +262,7 @@ def test_reduce_summation():
def test_multi_group_reduction():
- transform = tvm.tir.transform.LowerThreadAllreduce()
+ transform = tvm.s_tir.transform.LowerThreadAllreduce()
@I.ir_module
class Before:
@@ -330,7 +331,7 @@ def test_multi_group_reduction():
def test_multi_group_mask1():
- transform = tvm.tir.transform.LowerThreadAllreduce()
+ transform = tvm.s_tir.transform.LowerThreadAllreduce()
@I.ir_module
class Before:
@@ -393,7 +394,7 @@ def test_multi_group_mask1():
def test_multi_warp_reduce1():
- transform = tvm.tir.transform.LowerThreadAllreduce()
+ transform = tvm.s_tir.transform.LowerThreadAllreduce()
@I.ir_module
class Before:
@@ -478,7 +479,7 @@ def test_multi_warp_reduce1():
def test_multi_warp_reduce2():
- transform = tvm.tir.transform.LowerThreadAllreduce()
+ transform = tvm.s_tir.transform.LowerThreadAllreduce()
@I.ir_module
class Before:
@@ -563,7 +564,7 @@ def test_multi_warp_reduce2():
def test_multi_group_multi_warp_reduction():
- transform = tvm.tir.transform.LowerThreadAllreduce()
+ transform = tvm.s_tir.transform.LowerThreadAllreduce()
@I.ir_module
class Before:
@@ -648,7 +649,7 @@ def test_multi_group_multi_warp_reduction():
def test_multi_group_multi_warp_predicated_reduction():
- transform = tvm.tir.transform.LowerThreadAllreduce()
+ transform = tvm.s_tir.transform.LowerThreadAllreduce()
@I.ir_module
class Before:
@@ -743,7 +744,7 @@ def test_multi_group_multi_warp_predicated_reduction():
def test_metal_no_mask():
- transform = tvm.tir.transform.LowerThreadAllreduce()
+ transform = tvm.s_tir.transform.LowerThreadAllreduce()
@I.ir_module
class Before:
diff --git
a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py
b/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py
similarity index 97%
rename from
tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py
rename to
tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py
index eb08607dc8..26e8a6f2bb 100644
---
a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py
+++
b/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py
@@ -18,7 +18,7 @@ import numpy as np
import tvm
import tvm.testing
-from tvm import te
+from tvm import te, s_tir
from tvm.topi.math import cast
from tvm.script import ir as I, tir as T
@@ -30,7 +30,7 @@ def test_matmul_t_buffer():
test_matmul_dyn_shared, using `T.Buffer` (Allocate without
DeclBuffer) for the replaced allocations.
"""
- transform = tvm.tir.transform.MergeSharedMemoryAllocations()
+ transform = tvm.s_tir.transform.MergeSharedMemoryAllocations()
buffer_func = T.Buffer
@I.ir_module
@@ -145,7 +145,7 @@ def test_matmul_decl_buffer():
test_matmul_dyn_shared, using `T.decl_buffer` (Allocate followed by
DeclBuffer)
for the replaced allocations.
"""
- transform = tvm.tir.transform.MergeSharedMemoryAllocations()
+ transform = tvm.s_tir.transform.MergeSharedMemoryAllocations()
buffer_func = T.decl_buffer
@I.ir_module
@@ -255,7 +255,7 @@ def test_matmul_decl_buffer():
def test_simple_alloc_no_reuse():
"""Test alloc and free within the same scope."""
- transform = tvm.tir.transform.MergeSharedMemoryAllocations()
+ transform = tvm.s_tir.transform.MergeSharedMemoryAllocations()
@I.ir_module
class Before:
@@ -284,7 +284,7 @@ def test_simple_alloc_no_reuse():
def test_simple_alloc_reuse():
"""Test alloc and free within the same scope with a reuse chance."""
- transform = tvm.tir.transform.MergeSharedMemoryAllocations()
+ transform = tvm.s_tir.transform.MergeSharedMemoryAllocations()
@I.ir_module
class Before:
@@ -315,7 +315,7 @@ def test_simple_alloc_reuse():
def test_async_copy():
"""Test async copy in shared memory."""
- transform = tvm.tir.transform.MergeSharedMemoryAllocations()
+ transform = tvm.s_tir.transform.MergeSharedMemoryAllocations()
@I.ir_module
class Before:
diff --git a/tests/python/tir-transform/test_tir_transform_profiling_instr.py
b/tests/python/s_tir/transform/test_s_tir_transform_profiling_instr.py
similarity index 95%
rename from tests/python/tir-transform/test_tir_transform_profiling_instr.py
rename to tests/python/s_tir/transform/test_s_tir_transform_profiling_instr.py
index 139524fa83..f6c409aa8e 100644
--- a/tests/python/tir-transform/test_tir_transform_profiling_instr.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_profiling_instr.py
@@ -17,15 +17,15 @@
import tvm
import tvm.testing
-from tvm import te
+from tvm import te, s_tir
from tvm.ir.module import IRModule
from tvm.script import tir as T
import numpy
default_lwp_test_config = {
"tir.instrument_lwp": True,
- "tir.lwp_disable_func_prof": True,
- "tir.reset_start_id": True,
+ "s_tir.lwp_disable_func_prof": True,
+ "s_tir.reset_start_id": True,
}
@@ -278,7 +278,7 @@ def test6_expected_output(a: T.handle, b: T.handle, c:
T.handle, d: T.handle) ->
def test1():
with tvm.transform.PassContext(config=default_lwp_test_config):
mod = tvm.IRModule.from_expr(input1.with_attr("global_symbol", "main"))
- mod = tvm.tir.transform.InstrumentProfileIntrinsics()(mod)
+ mod = tvm.s_tir.transform.InstrumentProfileIntrinsics()(mod)
tvm.ir.assert_structural_equal(
mod["main"], test1_expected_output.with_attr("global_symbol", "main")
)
@@ -288,10 +288,10 @@ def test1():
# doesn't have any effect unless 'instr_siblings' is set to False (ex: test3).
def test2():
test2_config = default_lwp_test_config.copy()
- test2_config.update({"tir.lwp_max_depth": 3})
+ test2_config.update({"s_tir.lwp_max_depth": 3})
with tvm.transform.PassContext(config=test2_config):
mod = tvm.IRModule.from_expr(input1.with_attr("global_symbol", "main"))
- mod = tvm.tir.transform.InstrumentProfileIntrinsics()(mod)
+ mod = tvm.s_tir.transform.InstrumentProfileIntrinsics()(mod)
tvm.ir.assert_structural_equal(
mod["main"], test1_expected_output.with_attr("global_symbol", "main")
)
@@ -303,10 +303,10 @@ def test2():
# 'lwp_min_height' (ex: test5)
def test3():
test3_config = default_lwp_test_config.copy()
- test3_config.update({"tir.lwp_max_depth": 3, "tir.instr_siblings": False})
+ test3_config.update({"s_tir.lwp_max_depth": 3, "s_tir.instr_siblings":
False})
with tvm.transform.PassContext(config=test3_config):
mod = tvm.IRModule.from_expr(input1.with_attr("global_symbol", "main"))
- mod = tvm.tir.transform.InstrumentProfileIntrinsics()(mod)
+ mod = tvm.s_tir.transform.InstrumentProfileIntrinsics()(mod)
tvm.ir.assert_structural_equal(
mod["main"], test3_expected_output.with_attr("global_symbol", "main")
)
@@ -317,7 +317,7 @@ def test3():
def test4():
with tvm.transform.PassContext(config=default_lwp_test_config):
mod = tvm.IRModule.from_expr(input2.with_attr("global_symbol", "main"))
- mod = tvm.tir.transform.InstrumentProfileIntrinsics()(mod)
+ mod = tvm.s_tir.transform.InstrumentProfileIntrinsics()(mod)
tvm.ir.assert_structural_equal(
mod["main"], test4_expected_output.with_attr("global_symbol", "main")
)
@@ -328,11 +328,11 @@ def test4():
def test5():
test5_config = default_lwp_test_config.copy()
test5_config.update(
- {"tir.lwp_max_depth": 3, "tir.instr_siblings": False,
"tir.lwp_min_height": 2}
+ {"s_tir.lwp_max_depth": 3, "s_tir.instr_siblings": False,
"s_tir.lwp_min_height": 2}
)
with tvm.transform.PassContext(config=test5_config):
mod = tvm.IRModule.from_expr(input1.with_attr("global_symbol", "main"))
- mod = tvm.tir.transform.InstrumentProfileIntrinsics()(mod)
+ mod = tvm.s_tir.transform.InstrumentProfileIntrinsics()(mod)
tvm.ir.assert_structural_equal(
mod["main"], test5_expected_output.with_attr("global_symbol", "main")
)
@@ -342,7 +342,7 @@ def test5():
def test6():
with tvm.transform.PassContext(config=default_lwp_test_config):
mod = tvm.IRModule.from_expr(input3.with_attr("global_symbol", "main"))
- mod = tvm.tir.transform.InstrumentProfileIntrinsics()(mod)
+ mod = tvm.s_tir.transform.InstrumentProfileIntrinsics()(mod)
tvm.ir.assert_structural_equal(
mod["main"], test6_expected_output.with_attr("global_symbol", "main")
)
diff --git
a/tests/python/tir-transform/test_tir_transform_renormalize_split_pattern.py
b/tests/python/s_tir/transform/test_s_tir_transform_renormalize_split_pattern.py
similarity index 98%
rename from
tests/python/tir-transform/test_tir_transform_renormalize_split_pattern.py
rename to
tests/python/s_tir/transform/test_s_tir_transform_renormalize_split_pattern.py
index 057cfc42e4..96c9089c23 100644
--- a/tests/python/tir-transform/test_tir_transform_renormalize_split_pattern.py
+++
b/tests/python/s_tir/transform/test_s_tir_transform_renormalize_split_pattern.py
@@ -17,6 +17,7 @@
import tvm
import tvm.testing
+from tvm import s_tir
from tvm.script import tir as T
# fmt: off
@@ -119,7 +120,7 @@ class After_simplified:
def test_renormalize_split_pattern():
- after = tvm.tir.transform.RenormalizeSplitPattern()(Before)
+ after = tvm.s_tir.transform.RenormalizeSplitPattern()(Before)
tvm.ir.assert_structural_equal(after, After)
after = tvm.tir.transform.Simplify()(after)
tvm.ir.assert_structural_equal(after, After_simplified)
@@ -168,7 +169,7 @@ def
test_analyze_inside_integer_conditional(integer_condition):
# exception, as it rewrites the integer conditionals first. These
# tests are written using RenormalizeSplitPattern as it is the
# first case identified.
- transform = tvm.tir.transform.RenormalizeSplitPattern()
+ transform = tvm.s_tir.transform.RenormalizeSplitPattern()
# Issue would result in an error through while applying the transformation.
mod = tvm.IRModule.from_expr(integer_condition)
diff --git
a/tests/python/tir-transform/test_tir_transform_rewrite_unsafe_select.py
b/tests/python/s_tir/transform/test_s_tir_transform_rewrite_unsafe_select.py
similarity index 90%
rename from
tests/python/tir-transform/test_tir_transform_rewrite_unsafe_select.py
rename to
tests/python/s_tir/transform/test_s_tir_transform_rewrite_unsafe_select.py
index eee05315c9..8500aa8b33 100644
--- a/tests/python/tir-transform/test_tir_transform_rewrite_unsafe_select.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_rewrite_unsafe_select.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
+from tvm import s_tir
from tvm.script import ir as I, tir as T
@@ -27,7 +28,7 @@ def test_rewrite_Select():
A = T.Buffer(100, "float32", data=A_data)
T.evaluate(T.Select(i > 1, A[i - 1], T.float32(1.0)))
- yy =
tvm.tir.transform.RewriteUnsafeSelect()(ModuleY)["main"].body.body.value
+ yy =
tvm.s_tir.transform.RewriteUnsafeSelect()(ModuleY)["main"].body.body.value
@I.ir_module
class ModuleZ:
@@ -41,7 +42,7 @@ def test_rewrite_Select():
)
)
- zz =
tvm.tir.transform.RewriteUnsafeSelect()(ModuleZ)["main"].body.body.value
+ zz =
tvm.s_tir.transform.RewriteUnsafeSelect()(ModuleZ)["main"].body.body.value
@I.ir_module
class ModuleA:
@@ -62,7 +63,7 @@ def test_rewrite_Select():
)
)
- aa =
tvm.tir.transform.RewriteUnsafeSelect()(ModuleA)["main"].body.body.value
+ aa =
tvm.s_tir.transform.RewriteUnsafeSelect()(ModuleA)["main"].body.body.value
builtin_if_then_else = tvm.ir.Op.get("tir.if_then_else")
assert yy.op.same_as(builtin_if_then_else)
diff --git a/tests/python/tir-transform/test_tir_transform_thread_sync.py
b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py
similarity index 97%
rename from tests/python/tir-transform/test_tir_transform_thread_sync.py
rename to tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py
index 48de01a629..7d5720ea3f 100644
--- a/tests/python/tir-transform/test_tir_transform_thread_sync.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py
@@ -16,7 +16,7 @@
# under the License.
import tvm
import tvm.testing
-from tvm import te
+from tvm import te, s_tir
from tvm.script import tir as T
@@ -31,7 +31,7 @@ def run_passes(func: tvm.tir.PrimFunc):
mod = tvm.tir.transform.AnnotateDeviceRegions()(mod)
mod = tvm.tir.transform.SplitHostDevice()(mod)
- return tvm.tir.transform.ThreadSync("shared")(mod)
+ return tvm.s_tir.transform.ThreadSync("shared")(mod)
@tvm.testing.requires_cuda
@@ -94,7 +94,7 @@ def test_sync_shared_dyn():
E_1[threadIdx_x] = D_1_1[threadIdx_x]
mod = tvm.IRModule({"main": func})
- mod = tvm.tir.transform.ThreadSync("shared.dyn")(mod)
+ mod = tvm.s_tir.transform.ThreadSync("shared.dyn")(mod)
tvm.ir.assert_structural_equal(mod["main"], expected)
@@ -170,7 +170,7 @@ def test_sync_let_stmt():
)
mod = tvm.IRModule({"main": func})
- mod = tvm.tir.transform.ThreadSync("shared")(mod)
+ mod = tvm.s_tir.transform.ThreadSync("shared")(mod)
tvm.ir.assert_structural_equal(mod["main"], expected)