This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 783a9bf3e3 [REFACTOR][S-TIR] Migrate more transform to s_tir (#18771)
783a9bf3e3 is described below

commit 783a9bf3e30f4c3aebe269e3942aa3ca9742be8b
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Feb 13 07:12:28 2026 -0500

    [REFACTOR][S-TIR] Migrate more transform to s_tir (#18771)
    
    This PR migrate more transformations into s_tir, to reduce the passes in
    tir namespace.
---
 include/tvm/s_tir/transform.h                      |  96 +++++++++++
 python/tvm/s_tir/backend/adreno/pipeline.py        |  36 ++--
 python/tvm/s_tir/pipeline.py                       |  36 ++--
 python/tvm/s_tir/transform/__init__.py             |   1 +
 python/tvm/s_tir/transform/transform.py            | 189 +++++++++++++++++++++
 python/tvm/tir/transform/transform.py              |   4 +-
 .../analysis/calculate_allocated_memory.cc         |   3 +-
 .../meta_schedule/postproc/verify_gpu_code.cc      |   2 +-
 src/{tir => s_tir}/transform/bound_checker.cc      |  12 +-
 src/{tir => s_tir}/transform/hoist_expression.cc   |  72 ++++----
 .../transform/inject_ptx_async_copy.cc             |  13 +-
 src/{tir => s_tir}/transform/inject_ptx_ldg32.cc   |  11 +-
 src/{tir => s_tir}/transform/lower_async_dma.cc    |  13 +-
 .../transform/lower_thread_allreduce.cc            |  21 +--
 src/{tir => s_tir}/transform/lower_vtcm_alloc.cc   |  13 +-
 .../transform/merge_shared_memory_allocations.cc   |  24 +--
 .../transform/profile_instrumentation.cc           |  31 ++--
 .../transform/renormalize_split_pattern.cc         |  11 +-
 .../transform/rewrite_unsafe_select.cc             |  11 +-
 src/{tir => s_tir}/transform/storage_access.cc     |  15 +-
 src/{tir => s_tir}/transform/storage_access.h      |  11 +-
 .../transform/tensorcore_infer_fragment.cc         |  29 ++--
 .../transform/thread_storage_sync.cc               |  22 +--
 .../test_s_tir_transform_hoist_expression.py}      |   7 +-
 .../transform/test_s_tir_transform_hoist_if.py}    |  72 ++++----
 .../test_s_tir_transform_inject_ptx_async_copy.py} |  11 +-
 .../test_s_tir_transform_inject_ptx_ldg32.py}      |   5 +-
 ...est_s_tir_transform_lower_thread_all_reduce.py} |  21 +--
 ...orm_merge_dynamic_shared_memory_allocations.py} |  12 +-
 .../test_s_tir_transform_profiling_instr.py}       |  24 +--
 ...t_s_tir_transform_renormalize_split_pattern.py} |   5 +-
 .../test_s_tir_transform_rewrite_unsafe_select.py} |   7 +-
 .../transform/test_s_tir_transform_thread_sync.py} |   8 +-
 33 files changed, 584 insertions(+), 264 deletions(-)

diff --git a/include/tvm/s_tir/transform.h b/include/tvm/s_tir/transform.h
index 9914c6e49a..55343fdce5 100644
--- a/include/tvm/s_tir/transform.h
+++ b/include/tvm/s_tir/transform.h
@@ -230,6 +230,102 @@ TVM_DLL Pass InjectVirtualThread();
  */
 TVM_DLL Pass InjectDoubleBuffer();
 
+/*!
+ * \brief Hoist loop-invariant IfThenElse nodes to
+ * outside the eligible loops.
+ *
+ * \param variant The variant of the pass.
+ *        variant can have any one of following values ["basic", ""(Default)].
+ * \return The pass.
+ */
+TVM_DLL Pass HoistIfThenElse(tvm::ffi::String variant = "");
+
+/*!
+ * \brief Hoist loop-invariant expressions to outside the eligible loops.
+ *
+ * Can hoist conditionals used in IfThenElse statements and
+ * expressions, bindings of variables in Let statements and
+ * expressions, or boolean expressions, configurable to enable/disable
+ * each hoistable type.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass HoistExpression();
+
+/*!
+ * \brief Renormalize the split pattern from floordiv(floormod()) to 
floormod(floordiv()).
+ * \return The pass.
+ */
+TVM_DLL Pass RenormalizeSplitPattern();
+
+/*!
+ * \brief Detect and rewrite unsafe select that contains memory access.
+ * \return The pass.
+ */
+TVM_DLL Pass RewriteUnsafeSelect();
+
+/*!
+ * \brief Instruments bound checkers.
+ * \return The pass.
+ */
+TVM_DLL Pass InstrumentBoundCheckers();
+
+/*!
+ * \brief Rewrite global to local memory copy on CUDA with ldg32 instruction.
+ * \param enable_inject Whether to enable injection.
+ * \return The pass.
+ */
+TVM_DLL Pass InjectPTXLDG32(bool enable_inject = true);
+
+/*!
+ * \brief Insert intrinsic calls to instrument function and loop level 
profiling.
+ * \return The pass.
+ */
+TVM_DLL Pass InstrumentProfileIntrinsics();
+
+/*!
+ * \brief Lower VTCM allocations.
+ * \return The pass.
+ */
+TVM_DLL Pass LowerVtcmAlloc();
+
+/*!
+ * \brief Insert sync between parallel read/write of shared buffers.
+ * \param storage_scope The storage scope considered.
+ * \return The pass.
+ */
+TVM_DLL Pass ThreadSync(tvm::ffi::String storage_scope);
+
+/*!
+ * \brief Infer the TensorCore fragment information using tensor intrinsics.
+ * \return The pass.
+ */
+TVM_DLL Pass InferFragment();
+
+/*!
+ * \brief Lower cross thread allreduce.
+ * \return The pass.
+ */
+TVM_DLL Pass LowerThreadAllreduce();
+
+/*!
+ * \brief Lower Async TIR primitives to DMA copy and wait builtins.
+ * \return The pass.
+ */
+TVM_DLL Pass LowerAsyncDMA();
+
+/*!
+ * \brief Rewrite global to shared memory copy on CUDA with asynchronous copy.
+ * \return The pass.
+ */
+TVM_DLL Pass InjectPTXAsyncCopy();
+
+/*!
+ * \brief Merge multiple TIR-level shared memory allocations into one.
+ * \return The pass.
+ */
+TVM_DLL Pass MergeSharedMemoryAllocations();
+
 }  // namespace transform
 }  // namespace s_tir
 }  // namespace tvm
diff --git a/python/tvm/s_tir/backend/adreno/pipeline.py 
b/python/tvm/s_tir/backend/adreno/pipeline.py
index e895025bc4..0237429e80 100644
--- a/python/tvm/s_tir/backend/adreno/pipeline.py
+++ b/python/tvm/s_tir/backend/adreno/pipeline.py
@@ -62,22 +62,22 @@ def default_tir_pipeline():
         if not bool(config.get("tir.disable_storage_rewrite", False)):
             passes.append(tir.transform.StorageRewrite())
         if config.get("tir.use_async_copy", False):
-            passes.append(tir.transform.LowerAsyncDMA())
+            passes.append(s_tir.transform.LowerAsyncDMA())
         passes.extend(
             [
-                tir.transform.HoistIfThenElse(),
+                s_tir.transform.HoistIfThenElse(),
                 tir.transform.UnrollLoop(),
-                tir.transform.RenormalizeSplitPattern(),
+                s_tir.transform.RenormalizeSplitPattern(),
                 tir.transform.Simplify(),
                 tir.transform.RemoveNoOp(),
-                tir.transform.RewriteUnsafeSelect(),
+                s_tir.transform.RewriteUnsafeSelect(),
             ]
         )
         # Additional passes based on configuration.
         if bool(config.get("tir.instrument_bound_checkers", False)):
-            passes.append(tir.transform.InstrumentBoundCheckers())
+            passes.append(s_tir.transform.InstrumentBoundCheckers())
         if bool(config.get("tir.ptx_ldg32", False)):
-            passes.append(tir.transform.InjectPTXLDG32(True))
+            passes.append(s_tir.transform.InjectPTXLDG32(True))
         passes.append(
             tir.transform.CommonSubexprElimTIR(
                 not bool(config.get("tir.disable_cse_tir", False)),
@@ -85,39 +85,39 @@ def default_tir_pipeline():
             )
         )
         if bool(config.get("tir.instrument_lwp", False)):
-            passes.append(tir.transform.InstrumentProfileIntrinsics())
+            passes.append(s_tir.transform.InstrumentProfileIntrinsics())
         passes.extend(
             [
                 # Bind the target first so that target-specific attributes are 
available.
                 tir.transform.FP8ComputeLegalize(),
                 # VerifyVTCMLimit must occur before LowerVtcmAlloc.
-                tir.transform.VerifyVTCMLimit(),
-                tir.transform.LowerVtcmAlloc(),
+                s_tir.transform.VerifyVTCMLimit(),
+                s_tir.transform.LowerVtcmAlloc(),
                 tir.transform.VerifyMemory(),
                 tir.transform.AnnotateEntryFunc(),
             ]
         )
         if bool(config.get("tir.detect_global_barrier", False)):
-            passes.append(tir.transform.ThreadSync("global"))
+            passes.append(s_tir.transform.ThreadSync("global"))
         passes.extend(
             [
-                tir.transform.ThreadSync("shared"),
-                tir.transform.ThreadSync("shared.dyn"),
-                tir.transform.ThreadSync("warp"),
-                tir.transform.InferFragment(),
-                tir.transform.LowerThreadAllreduce(),
+                s_tir.transform.ThreadSync("shared"),
+                s_tir.transform.ThreadSync("shared.dyn"),
+                s_tir.transform.ThreadSync("warp"),
+                s_tir.transform.InferFragment(),
+                s_tir.transform.LowerThreadAllreduce(),
             ]
         )
         if bool(config.get("tir.use_async_copy", False)):
-            passes.append(tir.transform.InjectPTXAsyncCopy())
+            passes.append(s_tir.transform.InjectPTXAsyncCopy())
         if bool(config.get("tir.ptx_ldg32", False)):
-            passes.append(tir.transform.InjectPTXLDG32())
+            passes.append(s_tir.transform.InjectPTXLDG32())
         passes.extend(
             [
                 tir.transform.AnnotateDeviceRegions(),
                 tir.transform.SplitHostDevice(),
                 # MergeSharedMemoryAllocations must follow SplitHostDevice.
-                tir.transform.MergeSharedMemoryAllocations(),
+                s_tir.transform.MergeSharedMemoryAllocations(),
                 tir.transform.MakePackedAPI(),
                 tir.transform.FP8StorageLegalize(),
                 tir.transform.BF16StorageLegalize(),
diff --git a/python/tvm/s_tir/pipeline.py b/python/tvm/s_tir/pipeline.py
index 59cec5b582..8c8ffb69af 100644
--- a/python/tvm/s_tir/pipeline.py
+++ b/python/tvm/s_tir/pipeline.py
@@ -60,22 +60,22 @@ def default_s_tir_pipeline():
         if not bool(config.get("tir.disable_storage_rewrite", False)):
             passes.append(tir.transform.StorageRewrite())
         if config.get("tir.use_async_copy", False):
-            passes.append(tir.transform.LowerAsyncDMA())
+            passes.append(s_tir.transform.LowerAsyncDMA())
         passes.extend(
             [
-                tir.transform.HoistIfThenElse(),
+                s_tir.transform.HoistIfThenElse(),
                 tir.transform.UnrollLoop(),
-                tir.transform.RenormalizeSplitPattern(),
+                s_tir.transform.RenormalizeSplitPattern(),
                 tir.transform.Simplify(),
                 tir.transform.RemoveNoOp(),
-                tir.transform.RewriteUnsafeSelect(),
+                s_tir.transform.RewriteUnsafeSelect(),
             ]
         )
         # Additional passes based on configuration.
         if bool(config.get("tir.instrument_bound_checkers", False)):
-            passes.append(tir.transform.InstrumentBoundCheckers())
+            passes.append(s_tir.transform.InstrumentBoundCheckers())
         if bool(config.get("tir.ptx_ldg32", False)):
-            passes.append(tir.transform.InjectPTXLDG32(True))
+            passes.append(s_tir.transform.InjectPTXLDG32(True))
         passes.append(
             tir.transform.CommonSubexprElimTIR(
                 not bool(config.get("tir.disable_cse_tir", False)),
@@ -83,39 +83,39 @@ def default_s_tir_pipeline():
             )
         )
         if bool(config.get("tir.instrument_lwp", False)):
-            passes.append(tir.transform.InstrumentProfileIntrinsics())
+            passes.append(s_tir.transform.InstrumentProfileIntrinsics())
         passes.extend(
             [
                 # Bind the target first so that target-specific attributes are 
available.
                 tir.transform.FP8ComputeLegalize(),
                 # VerifyVTCMLimit must occur before LowerVtcmAlloc.
-                tir.transform.VerifyVTCMLimit(),
-                tir.transform.LowerVtcmAlloc(),
+                s_tir.transform.VerifyVTCMLimit(),
+                s_tir.transform.LowerVtcmAlloc(),
                 tir.transform.VerifyMemory(),
                 tir.transform.AnnotateEntryFunc(),
             ]
         )
         if bool(config.get("tir.detect_global_barrier", False)):
-            passes.append(tir.transform.ThreadSync("global"))
+            passes.append(s_tir.transform.ThreadSync("global"))
         passes.extend(
             [
-                tir.transform.ThreadSync("shared"),
-                tir.transform.ThreadSync("shared.dyn"),
-                tir.transform.ThreadSync("warp"),
-                tir.transform.InferFragment(),
-                tir.transform.LowerThreadAllreduce(),
+                s_tir.transform.ThreadSync("shared"),
+                s_tir.transform.ThreadSync("shared.dyn"),
+                s_tir.transform.ThreadSync("warp"),
+                s_tir.transform.InferFragment(),
+                s_tir.transform.LowerThreadAllreduce(),
             ]
         )
         if bool(config.get("tir.use_async_copy", False)):
-            passes.append(tir.transform.InjectPTXAsyncCopy())
+            passes.append(s_tir.transform.InjectPTXAsyncCopy())
         if bool(config.get("tir.ptx_ldg32", False)):
-            passes.append(tir.transform.InjectPTXLDG32())
+            passes.append(s_tir.transform.InjectPTXLDG32())
         passes.extend(
             [
                 tir.transform.AnnotateDeviceRegions(),
                 tir.transform.SplitHostDevice(),
                 # MergeSharedMemoryAllocations must follow SplitHostDevice.
-                tir.transform.MergeSharedMemoryAllocations(),
+                s_tir.transform.MergeSharedMemoryAllocations(),
                 tir.transform.MakePackedAPI(),
                 tir.transform.FP8StorageLegalize(),
                 tir.transform.BF16StorageLegalize(),
diff --git a/python/tvm/s_tir/transform/__init__.py 
b/python/tvm/s_tir/transform/__init__.py
index 4529684dc2..c669f3eaa7 100644
--- a/python/tvm/s_tir/transform/__init__.py
+++ b/python/tvm/s_tir/transform/__init__.py
@@ -18,3 +18,4 @@
 # pylint: disable=wildcard-import, invalid-name
 
 from .transform import *
+from ...tir.transform.transform import HoistedConditionals, HoistedLetBindings
diff --git a/python/tvm/s_tir/transform/transform.py 
b/python/tvm/s_tir/transform/transform.py
index d4dbb8ee86..05d1a46746 100644
--- a/python/tvm/s_tir/transform/transform.py
+++ b/python/tvm/s_tir/transform/transform.py
@@ -253,3 +253,192 @@ def InjectDoubleBuffer():
         The result pass
     """
     return _ffi_api.InjectDoubleBuffer()  # type: ignore
+
+
+def HoistIfThenElse(variant=None):
+    """Hoist loop-invariant IfThenElse nodes to outside the eligible loops.
+
+    Parameters
+    ----------
+    variant : Optional[String]
+        The variant of the pass.
+        variant can have any one of following values ["basic", None(Default)].
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    if variant == "basic":
+        return _ffi_api.HoistIfThenElseBasic()  # type: ignore
+    elif variant is None:
+        return _ffi_api.HoistIfThenElse()  # type: ignore
+    else:
+        raise ValueError("wrong variant of HoistIfThenElse, " + variant)
+
+
+def HoistExpression():
+    """Hoist loop-invariant expressions to outside the eligible loops.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.HoistExpression()  # type: ignore
+
+
+def RenormalizeSplitPattern():
+    """Renormalize the split pattern from floordiv(floormod()) to 
floormod(floordiv())
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.RenormalizeSplitPattern()  # type: ignore
+
+
+def RewriteUnsafeSelect():
+    """Detect and rewrite unsafe select that contains memory access.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.RewriteUnsafeSelect()  # type: ignore
+
+
+def InstrumentBoundCheckers():
+    """Instruments bound checkers.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.InstrumentBoundCheckers()  # type: ignore
+
+
+def InjectPTXLDG32(enable_inject_ptx_intrin=True):
+    """Inject ptx.ldg.32 intrinsics.
+
+    Parameters
+    ----------
+    enable_inject_ptx_intrin : bool
+        If True, inject ptx.ldg.32 intrinsics.
+    """
+    return _ffi_api.InjectPTXLDG32(enable_inject_ptx_intrin)  # type: ignore
+
+
+def InstrumentProfileIntrinsics():
+    """Insert intrinsic calls to instrument function and loop level profiling.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.InstrumentProfileIntrinsics()  # type: ignore
+
+
+def VerifyVTCMLimit(default_target=None):
+    """Verify if the size of the allocated vtcm memory satisfies the limit.
+
+    The limit is determined from the "vtcm-capacity" attribute of the target.
+
+    Parameters
+    ----------
+    default_target : Optional[tvm.target.Target]
+        The default target to use if a PrimFunc does not have a target 
attribute.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.VerifyVTCMLimit(default_target)  # type: ignore
+
+
+def LowerVtcmAlloc():
+    """Lower vtcm allocation.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.LowerVtcmAlloc()  # type: ignore
+
+
+def ThreadSync(storage_scope):
+    """Insert sync between parallel read/write of shared buffers.
+
+    Parameters
+    ----------
+    storage_scope: str
+        The target storage scope.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.ThreadSync(storage_scope)  # type: ignore
+
+
+def InferFragment():
+    """Infer the TensorCore fragment information using tensor intrinsics.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.InferFragment()  # type: ignore
+
+
+def LowerThreadAllreduce():
+    """Lower cross thread allreduce.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.LowerThreadAllreduce()  # type: ignore
+
+
+def LowerAsyncDMA():
+    """Lower async DMA to DMA.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.LowerAsyncDMA()  # type: ignore
+
+
+def InjectPTXAsyncCopy():
+    """Rewrite global to shared memory copy on CUDA with asynchronous copy.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.InjectPTXAsyncCopy()  # type: ignore
+
+
+def MergeSharedMemoryAllocations():
+    """This pass merges multiple TIR-level shared memory allocations
+    into one allocation.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.MergeSharedMemoryAllocations()  # type: ignore
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index 7de12d5301..a9439a77f3 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -585,7 +585,7 @@ def VerifyVTCMLimit(limit=None):
     return _ffi_api.VerifyVTCMLimit(limit)  # type: ignore
 
 
-@_ffi.register_object("tir.transform.HoistIfThenElseConfig")
+@_ffi.register_object("s_tir.transform.HoistIfThenElseConfig")
 class HoistIfThenElseConfig(_ir.Attrs):
     """Config for hoist if then else pass"""
 
@@ -669,7 +669,7 @@ class HoistedLetBindings(enum.Flag):
     """ Enable all hoisting of let bindings """
 
 
-@_ffi.register_object("tir.transform.HoistExpressionConfig")
+@_ffi.register_object("s_tir.transform.HoistExpressionConfig")
 class HoistExpressionConfig(_ir.Attrs):
     """Config for hoist expression pass"""
 
diff --git a/src/tir/analysis/calculate_allocated_memory.cc 
b/src/s_tir/analysis/calculate_allocated_memory.cc
similarity index 97%
rename from src/tir/analysis/calculate_allocated_memory.cc
rename to src/s_tir/analysis/calculate_allocated_memory.cc
index 1741eff937..ba7c7438bc 100644
--- a/src/tir/analysis/calculate_allocated_memory.cc
+++ b/src/s_tir/analysis/calculate_allocated_memory.cc
@@ -198,12 +198,13 @@ Pass VerifyVTCMLimit(ffi::Optional<Target> 
default_target) {
     }
     return mod;
   };
-  return tvm::transform::CreateModulePass(pass_func, 0, 
"tir.calculate_allocated_bytes", {});
+  return tvm::transform::CreateModulePass(pass_func, 0, 
"s_tir.VerifyVTCMLimit", {});
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
   refl::GlobalDef().def("tir.transform.VerifyVTCMLimit", VerifyVTCMLimit);
+  refl::GlobalDef().def("s_tir.transform.VerifyVTCMLimit", VerifyVTCMLimit);
 }
 
 }  // namespace transform
diff --git a/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc 
b/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc
index 553647f3e6..ef349416b4 100644
--- a/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc
+++ b/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc
@@ -179,7 +179,7 @@ class VerifyGPUCodeNode : public PostprocNode {
           pass_list.push_back(s_tir::transform::InjectVirtualThread());
           pass_list.push_back(s_tir::transform::InjectDoubleBuffer());
           pass_list.push_back(tir::transform::StorageRewrite());
-          pass_list.push_back(tir::transform::MergeSharedMemoryAllocations());
+          
pass_list.push_back(s_tir::transform::MergeSharedMemoryAllocations());
           pass_list.push_back(tir::transform::LowerIntrin());
           // Convert Function to IRModule
           tvm::transform::PassContext pass_ctx = 
tvm::transform::PassContext::Current();
diff --git a/src/tir/transform/bound_checker.cc 
b/src/s_tir/transform/bound_checker.cc
similarity index 96%
rename from src/tir/transform/bound_checker.cc
rename to src/s_tir/transform/bound_checker.cc
index 99d990ece6..2f2061d9c1 100644
--- a/src/tir/transform/bound_checker.cc
+++ b/src/s_tir/transform/bound_checker.cc
@@ -25,11 +25,11 @@
 #include <tvm/arith/analyzer.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/op.h>
 #include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
 
 #include <unordered_map>
 #include <utility>
@@ -38,7 +38,8 @@
 #include "../../arith/unwrap_vector_expr.h"
 
 namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
 
 // TODO(Lunderberg): Move this pass to be before
 // FlattenBuffer.  That will simplify this pass,
@@ -254,15 +255,16 @@ Pass InstrumentBoundCheckers() {
     n->body = BoundChecker(bound_collector.mem_to_shape)(std::move(n->body));
     return f;
   };
-  return CreatePrimFuncPass(pass_func, 0, "tir.InstrumentBoundCheckers", {});
+  return CreatePrimFuncPass(pass_func, 0, "s_tir.InstrumentBoundCheckers", {});
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("tir.transform.InstrumentBoundCheckers", 
InstrumentBoundCheckers);
+  refl::GlobalDef().def("s_tir.transform.InstrumentBoundCheckers",
+                        static_cast<Pass (*)()>(InstrumentBoundCheckers));
 }
 
 }  // namespace transform
 
-}  // namespace tir
+}  // namespace s_tir
 }  // namespace tvm
diff --git a/src/tir/transform/hoist_expression.cc 
b/src/s_tir/transform/hoist_expression.cc
similarity index 91%
rename from src/tir/transform/hoist_expression.cc
rename to src/s_tir/transform/hoist_expression.cc
index ebd90583c9..add5b663bc 100644
--- a/src/tir/transform/hoist_expression.cc
+++ b/src/s_tir/transform/hoist_expression.cc
@@ -23,10 +23,10 @@
 #include <tvm/arith/analyzer.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
 #include <tvm/tir/analysis.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
 
 #include <queue>
 #include <unordered_map>
@@ -36,10 +36,11 @@
 #include "../../arith/interval_set.h"
 #include "../../arith/ir_mutator_with_analyzer.h"
 #include "../../runtime/thread_storage_scope.h"
-#include "ir_utils.h"
+#include "../../tir/transform/ir_utils.h"
 
 namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
 
 enum class HoistedConditionals : int {
   kNone = 0,
@@ -81,7 +82,7 @@ struct HoistExpressionConfigNode : public 
AttrsNodeReflAdapter<HoistExpressionCo
   bool FlagSet(HoistedLetBindings flag) const {
     return static_cast<int>(flag) & hoisted_let_bindings;
   }
-  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.transform.HoistExpressionConfig",
+  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.transform.HoistExpressionConfig",
                                     HoistExpressionConfigNode, Object);
 };
 
@@ -99,7 +100,7 @@ class HoistExpressionConfig : public Attrs {
 
 TVM_FFI_STATIC_INIT_BLOCK() { HoistExpressionConfigNode::RegisterReflection(); 
}
 
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistExpression", HoistExpressionConfig);
+TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.HoistExpression", 
HoistExpressionConfig);
 
 struct HoistIfThenElseConfigNode : public 
AttrsNodeReflAdapter<HoistIfThenElseConfigNode> {
   bool support_block_scope_hoisting;
@@ -110,7 +111,7 @@ struct HoistIfThenElseConfigNode : public 
AttrsNodeReflAdapter<HoistIfThenElseCo
         "support_block_scope_hoisting", 
&HoistIfThenElseConfigNode::support_block_scope_hoisting,
         "Hoist if cond with block scope variables", refl::DefaultValue(false));
   }
-  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.transform.HoistIfThenElseConfig",
+  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.transform.HoistIfThenElseConfig",
                                     HoistIfThenElseConfigNode, Object);
 };
 
@@ -122,7 +123,7 @@ class HoistIfThenElseConfig : public Attrs {
 
 TVM_FFI_STATIC_INIT_BLOCK() { HoistIfThenElseConfigNode::RegisterReflection(); 
}
 
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistIfThenElse", HoistIfThenElseConfig);
+TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.HoistIfThenElse", 
HoistIfThenElseConfig);
 
 class HoistInfoCollector : public StmtExprVisitor {
  public:
@@ -541,7 +542,7 @@ namespace transform {
 Pass HoistExpression() {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
     auto* n = f.CopyOnWrite();
-    auto cfg = ctx->GetConfig<HoistExpressionConfig>("tir.HoistExpression");
+    auto cfg = ctx->GetConfig<HoistExpressionConfig>("s_tir.HoistExpression");
 
     if (!cfg.defined()) {
       cfg = AttrsWithDefaultValues<HoistExpressionConfig>();
@@ -549,26 +550,26 @@ Pass HoistExpression() {
     n->body = ExpressionHoister::Hoist(std::move(n->body), cfg.value());
     return f;
   };
-  auto insertion_pass = CreatePrimFuncPass(pass_func, 0, 
"tir.InsertHoistedExpression", {});
+  auto insertion_pass = CreatePrimFuncPass(pass_func, 0, 
"s_tir.InsertHoistedExpression", {});
 
-  return Sequential(
+  return tvm::transform::Sequential(
       {
           insertion_pass,
-          Simplify(),
-          RemoveNoOp(),
+          tir::transform::Simplify(),
+          tir::transform::RemoveNoOp(),
       },
-      "tir.HoistExpression");
+      "s_tir.HoistExpression");
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("tir.transform.HoistExpression", HoistExpression);
+  refl::GlobalDef().def("s_tir.transform.HoistExpression", HoistExpression);
 }
 
-Pass HoistIfThenElse() {
+static Pass HoistIfThenElseImpl() {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
     auto* n = f.CopyOnWrite();
-    auto cfg = ctx->GetConfig<HoistIfThenElseConfig>("tir.HoistIfThenElse");
+    auto cfg = ctx->GetConfig<HoistIfThenElseConfig>("s_tir.HoistIfThenElse");
     auto flag = f->GetAttr<Integer>("tir.HoistIfThenElseExprWithBlock");
     if (flag && flag.value().IntValue() == 1) {
       HoistExpressionConfig 
config(static_cast<int>(HoistedConditionals::kUsingBlockVar) |
@@ -588,22 +589,17 @@ Pass HoistIfThenElse() {
     n->body = ExpressionHoister::Hoist(std::move(n->body), config);
     return f;
   };
-  auto insertion_pass = CreatePrimFuncPass(pass_func, 0, 
"tir.InsertHoistIfThenElse", {});
-  return Sequential(
+  auto insertion_pass = CreatePrimFuncPass(pass_func, 0, 
"s_tir.InsertHoistIfThenElse", {});
+  return tvm::transform::Sequential(
       {
           insertion_pass,
-          Simplify(),
-          RemoveNoOp(),
+          tir::transform::Simplify(),
+          tir::transform::RemoveNoOp(),
       },
-      "tir.HoistIfThenElse");
+      "s_tir.HoistIfThenElse");
 }
 
-TVM_FFI_STATIC_INIT_BLOCK() {
-  namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("tir.transform.HoistIfThenElse", HoistIfThenElse);
-}
-
-Pass HoistIfThenElseBasic() {
+static Pass HoistIfThenElseBasicImpl() {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
     auto* n = f.CopyOnWrite();
     HoistExpressionConfig 
config(static_cast<int>(HoistedConditionals::kIfElseStmt),
@@ -611,22 +607,30 @@ Pass HoistIfThenElseBasic() {
     n->body = ExpressionHoister::Hoist(std::move(n->body), config);
     return f;
   };
-  auto insertion_pass = CreatePrimFuncPass(pass_func, 0, 
"tir.InsertHoistIfThenElseBasic", {});
-  return Sequential(
+  auto insertion_pass = CreatePrimFuncPass(pass_func, 0, 
"s_tir.InsertHoistIfThenElseBasic", {});
+  return tvm::transform::Sequential(
       {
           insertion_pass,
-          Simplify(),
-          RemoveNoOp(),
+          tir::transform::Simplify(),
+          tir::transform::RemoveNoOp(),
       },
-      "tir.HoistIfThenElseBasic");
+      "s_tir.HoistIfThenElseBasic");
+}
+
+Pass HoistIfThenElse(tvm::ffi::String variant) {
+  if (variant == "basic") {
+    return HoistIfThenElseBasicImpl();
+  }
+  return HoistIfThenElseImpl();
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("tir.transform.HoistIfThenElseBasic", 
HoistIfThenElseBasic);
+  refl::GlobalDef().def("s_tir.transform.HoistIfThenElse", 
HoistIfThenElseImpl);
+  refl::GlobalDef().def("s_tir.transform.HoistIfThenElseBasic", 
HoistIfThenElseBasicImpl);
 }
 
 }  // namespace transform
 
-}  // namespace tir
+}  // namespace s_tir
 }  // namespace tvm
diff --git a/src/tir/transform/inject_ptx_async_copy.cc 
b/src/s_tir/transform/inject_ptx_async_copy.cc
similarity index 96%
rename from src/tir/transform/inject_ptx_async_copy.cc
rename to src/s_tir/transform/inject_ptx_async_copy.cc
index 0e9820aa65..c9b1e42d7f 100644
--- a/src/tir/transform/inject_ptx_async_copy.cc
+++ b/src/s_tir/transform/inject_ptx_async_copy.cc
@@ -22,19 +22,20 @@
  * \file inject_ptx_async_copy.cc
  */
 #include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
 #include <tvm/tir/analysis.h>
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/op.h>
 #include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
 
-#include "../ir/buffer_common.h"
+#include "../../tir/ir/buffer_common.h"
 #include "storage_access.h"
 #include "tvm/tir/stmt.h"
 
 namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
 
 class PTXAsyncCopyInjector : public StmtMutator {
  public:
@@ -197,15 +198,15 @@ Pass InjectPTXAsyncCopy() {
     n->body = PTXAsyncCopyInjector()(n->body);
     return f;
   };
-  return CreatePrimFuncPass(pass_func, 0, "tir.InjectPTXAsyncCopy", {});
+  return CreatePrimFuncPass(pass_func, 0, "s_tir.InjectPTXAsyncCopy", {});
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("tir.transform.InjectPTXAsyncCopy", 
InjectPTXAsyncCopy);
+  refl::GlobalDef().def("s_tir.transform.InjectPTXAsyncCopy", 
InjectPTXAsyncCopy);
 }
 
 }  // namespace transform
 
-}  // namespace tir
+}  // namespace s_tir
 }  // namespace tvm
diff --git a/src/tir/transform/inject_ptx_ldg32.cc 
b/src/s_tir/transform/inject_ptx_ldg32.cc
similarity index 95%
rename from src/tir/transform/inject_ptx_ldg32.cc
rename to src/s_tir/transform/inject_ptx_ldg32.cc
index f52539fa77..5f10d25917 100644
--- a/src/tir/transform/inject_ptx_ldg32.cc
+++ b/src/s_tir/transform/inject_ptx_ldg32.cc
@@ -21,17 +21,18 @@
 #include <tvm/arith/iter_affine_map.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
 #include <tvm/tir/analysis.h>
 #include <tvm/tir/op.h>
 #include <tvm/tir/stmt.h>
 #include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
 
 #include "../../arith/const_fold.h"
 #include "../../arith/pattern_match.h"
 
 namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
 
 class PTXRewriter : public StmtMutator {
  public:
@@ -145,16 +146,16 @@ Pass InjectPTXLDG32(bool enable_inject_ptx_intrin) {
     }
     return f;
   };
-  return CreatePrimFuncPass(pass_func, 0, "tir.InjectPTXLDG32", {});
+  return CreatePrimFuncPass(pass_func, 0, "s_tir.InjectPTXLDG32", {});
 }
 
 // The pass can now be invoked via the pass infrastructure, but we also add a
 // Python binding for it
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("tir.transform.InjectPTXLDG32", InjectPTXLDG32);
+  refl::GlobalDef().def("s_tir.transform.InjectPTXLDG32", InjectPTXLDG32);
 }
 
 }  // namespace transform
-}  // namespace tir
+}  // namespace s_tir
 }  // namespace tvm
diff --git a/src/tir/transform/lower_async_dma.cc 
b/src/s_tir/transform/lower_async_dma.cc
similarity index 95%
rename from src/tir/transform/lower_async_dma.cc
rename to src/s_tir/transform/lower_async_dma.cc
index 1b7bf14c38..3b6fd0480c 100644
--- a/src/tir/transform/lower_async_dma.cc
+++ b/src/s_tir/transform/lower_async_dma.cc
@@ -25,19 +25,20 @@
 #include <tvm/arith/bound.h>
 #include <tvm/arith/iter_affine_map.h>
 #include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
 #include <tvm/tir/analysis.h>
 #include <tvm/tir/buffer.h>
 #include <tvm/tir/stmt.h>
 #include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
 
 #include <optional>
 
 #include "../../arith/ir_mutator_with_analyzer.h"
-#include "ir_utils.h"
+#include "../../tir/transform/ir_utils.h"
 
 namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
 
 class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer {
  public:
@@ -174,14 +175,14 @@ Pass LowerAsyncDMA() {
     fptr->body = AsyncDMALowerer(dma_bypass_cache, 
&analyzer)(std::move(fptr->body));
     return f;
   };
-  return CreatePrimFuncPass(pass_func, 0, "tir.LowerAsyncDMA", {});
+  return CreatePrimFuncPass(pass_func, 0, "s_tir.LowerAsyncDMA", {});
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("tir.transform.LowerAsyncDMA", LowerAsyncDMA);
+  refl::GlobalDef().def("s_tir.transform.LowerAsyncDMA", LowerAsyncDMA);
 }
 }  // namespace transform
 
-}  // namespace tir
+}  // namespace s_tir
 }  // namespace tvm
diff --git a/src/tir/transform/lower_thread_allreduce.cc 
b/src/s_tir/transform/lower_thread_allreduce.cc
similarity index 98%
rename from src/tir/transform/lower_thread_allreduce.cc
rename to src/s_tir/transform/lower_thread_allreduce.cc
index 4a0eb49cc3..5f1fc9afaa 100644
--- a/src/tir/transform/lower_thread_allreduce.cc
+++ b/src/s_tir/transform/lower_thread_allreduce.cc
@@ -24,20 +24,21 @@
 #include <tvm/arith/analyzer.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
 #include <tvm/target/target.h>
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
 
 #include <unordered_set>
 
 #include "../../runtime/thread_storage_scope.h"
-#include "ir_utils.h"
-#include "update_pointer_storage_scope.h"
+#include "../../tir/transform/ir_utils.h"
+#include "../../tir/transform/update_pointer_storage_scope.h"
 
 namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
 
 class ThreadAllreduceBuilder final : public StmtExprMutator {
  public:
@@ -47,12 +48,12 @@ class ThreadAllreduceBuilder final : public StmtExprMutator 
{
         max_num_threads_(target->GetAttr<Integer>("max_num_threads", 
-1).value().IntValue()) {}
 
   Stmt VisitStmt_(const AttrStmtNode* op) final {
-    if (op->attr_key == attr::thread_extent) {
+    if (op->attr_key == tir::attr::thread_extent) {
       thread_extents_.push_back(op);
       Stmt ret = StmtExprMutator::VisitStmt_(op);
       thread_extents_.pop_back();
       return ret;
-    } else if (op->attr_key == attr::reduce_scope) {
+    } else if (op->attr_key == tir::attr::reduce_scope) {
       const CommReducerNode* combiner = op->node.as<CommReducerNode>();
       ICHECK(combiner);
       reduce_combiner_.push_back(combiner);
@@ -86,7 +87,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
 
       if (buf.scope() == "shared") {
         // Use volatile access to shared buffer.
-        write_ptr->body = AttrStmt(buf->data, attr::volatile_scope, 1, 
write_ptr->body);
+        write_ptr->body = AttrStmt(buf->data, tir::attr::volatile_scope, 1, 
write_ptr->body);
       }
     }
     return node;
@@ -807,14 +808,14 @@ Pass LowerThreadAllreduce() {
     n->body = thread_all_reduce(n->body);
     return f;
   };
-  return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {});
+  return CreatePrimFuncPass(pass_func, 0, "s_tir.LowerThreadAllreduce", {});
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("tir.transform.LowerThreadAllreduce", 
LowerThreadAllreduce);
+  refl::GlobalDef().def("s_tir.transform.LowerThreadAllreduce", 
LowerThreadAllreduce);
 }
 
 }  // namespace transform
-}  // namespace tir
+}  // namespace s_tir
 }  // namespace tvm
diff --git a/src/tir/transform/lower_vtcm_alloc.cc 
b/src/s_tir/transform/lower_vtcm_alloc.cc
similarity index 88%
rename from src/tir/transform/lower_vtcm_alloc.cc
rename to src/s_tir/transform/lower_vtcm_alloc.cc
index c3b03f8623..469f7c4655 100644
--- a/src/tir/transform/lower_vtcm_alloc.cc
+++ b/src/s_tir/transform/lower_vtcm_alloc.cc
@@ -18,14 +18,15 @@
  */
 
 #include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/stmt.h>
-#include <tvm/tir/transform.h>
 
 #include "../../arith/ir_visitor_with_analyzer.h"
 
 namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
 
 inline bool IsVtcmStorage(std::string scope) {
   return scope.find("global.vtcm") != std::string::npos;
@@ -68,17 +69,17 @@ namespace transform {
 
 Pass LowerVtcmAlloc() {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
-    return LowerVtcmAlloc(std::move(f));
+    return s_tir::LowerVtcmAlloc(std::move(f));
   };
-  return CreatePrimFuncPass(pass_func, 0, "tir.LowerVtcmAlloc", {});
+  return CreatePrimFuncPass(pass_func, 0, "s_tir.LowerVtcmAlloc", {});
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("tir.transform.LowerVtcmAlloc", LowerVtcmAlloc);
+  refl::GlobalDef().def("s_tir.transform.LowerVtcmAlloc", static_cast<Pass 
(*)()>(LowerVtcmAlloc));
 }
 
 }  // namespace transform
 
-}  // namespace tir
+}  // namespace s_tir
 }  // namespace tvm
diff --git a/src/tir/transform/merge_shared_memory_allocations.cc 
b/src/s_tir/transform/merge_shared_memory_allocations.cc
similarity index 97%
rename from src/tir/transform/merge_shared_memory_allocations.cc
rename to src/s_tir/transform/merge_shared_memory_allocations.cc
index 4a2b8698d8..11fe15e95a 100644
--- a/src/tir/transform/merge_shared_memory_allocations.cc
+++ b/src/s_tir/transform/merge_shared_memory_allocations.cc
@@ -25,20 +25,21 @@
  */
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/op.h>
 #include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
 
 #include <unordered_map>
 #include <unordered_set>
 
 #include "../../runtime/thread_storage_scope.h"
 #include "../../support/arena.h"
-#include "ir_utils.h"
+#include "../../tir/transform/ir_utils.h"
 
 namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
 
 using runtime::StorageRank;
 using runtime::StorageScope;
@@ -207,13 +208,13 @@ class SharedMemLinearAccessPatternFinder final : public 
StmtExprVisitor {
 
   void VisitStmt_(const AttrStmtNode* op) final {
     // Only record the outer most thread extent.
-    if (op->attr_key == attr::thread_extent && !in_thread_env_) {
+    if (op->attr_key == tir::attr::thread_extent && !in_thread_env_) {
       in_thread_env_ = true;
       VisitNewScope(op);
       in_thread_env_ = false;
-    } else if (op->attr_key == attr::extern_scope) {
+    } else if (op->attr_key == tir::attr::extern_scope) {
       VisitNewScope(op);
-    } else if (op->attr_key == attr::virtual_thread) {
+    } else if (op->attr_key == tir::attr::virtual_thread) {
       VisitNewScope(op);
     } else {
       StmtExprVisitor::VisitStmt_(op);
@@ -273,7 +274,7 @@ class SharedMemoryRewriter : public StmtExprMutator {
 
  private:
   Stmt VisitStmt_(const AttrStmtNode* op) final {
-    if (op->attr_key == attr::thread_extent && !allocated_) {
+    if (op->attr_key == tir::attr::thread_extent && !allocated_) {
       // Allocate one dynamic shared memory allocation at the beginning of 
thread scope
       int max_layer_num = 0;
       std::vector<const StorageEntry*> all_entry;
@@ -690,17 +691,18 @@ Pass MergeSharedMemoryAllocations() {
   auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
     bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", 
Bool(false)).value();
     auto* n = f.CopyOnWrite();
-    n->body = MergeSharedMemoryAllocations(std::move(n->body), 
merge_static_smem);
+    n->body = s_tir::MergeSharedMemoryAllocations(std::move(n->body), 
merge_static_smem);
     return f;
   };
-  return CreatePrimFuncPass(pass_func, 0, "tir.MergeSharedMemoryAllocations", 
{});
+  return CreatePrimFuncPass(pass_func, 0, 
"s_tir.MergeSharedMemoryAllocations", {});
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("tir.transform.MergeSharedMemoryAllocations", 
MergeSharedMemoryAllocations);
+  refl::GlobalDef().def("s_tir.transform.MergeSharedMemoryAllocations",
+                        static_cast<Pass (*)()>(MergeSharedMemoryAllocations));
 }
 
 }  // namespace transform
-}  // namespace tir
+}  // namespace s_tir
 }  // namespace tvm
diff --git a/src/tir/transform/profile_instrumentation.cc 
b/src/s_tir/transform/profile_instrumentation.cc
similarity index 89%
rename from src/tir/transform/profile_instrumentation.cc
rename to src/s_tir/transform/profile_instrumentation.cc
index 513f0d730e..268190557c 100644
--- a/src/tir/transform/profile_instrumentation.cc
+++ b/src/s_tir/transform/profile_instrumentation.cc
@@ -25,21 +25,22 @@
 // and can be used to capture profiling information such as processor cycles.
 
 #include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt.h>
 #include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
 
 namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
 namespace lwp {
 
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.lwp_disable_func_prof", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.lwp_max_depth", Integer);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.lwp_min_height", Integer);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.instr_siblings", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.reset_start_id", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.lwp_disable_func_prof", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.lwp_max_depth", Integer);
+TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.lwp_min_height", Integer);
+TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.instr_siblings", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.reset_start_id", Bool);
 
 static int32_t start_id = 0;
 
@@ -259,12 +260,12 @@ Pass InstrumentProfileIntrinsics() {
     // In addition, loops with siblings are also instrumented provided
     // their loop depth is >= min_instr_height. This is done to avoid
     // instrumenting inner-most loops.
-    auto max_instr_depth = ctx->GetConfig<Integer>("tir.lwp_max_depth", 
Integer(0)).value();
-    auto min_instr_height = ctx->GetConfig<Integer>("tir.lwp_min_height", 
Integer(1)).value();
-    bool instr_siblings = ctx->GetConfig<Bool>("tir.instr_siblings", 
Bool(true)).value();
+    auto max_instr_depth = ctx->GetConfig<Integer>("s_tir.lwp_max_depth", 
Integer(0)).value();
+    auto min_instr_height = ctx->GetConfig<Integer>("s_tir.lwp_min_height", 
Integer(1)).value();
+    bool instr_siblings = ctx->GetConfig<Bool>("s_tir.instr_siblings", 
Bool(true)).value();
     bool disable_func_instrumentation =
-        ctx->GetConfig<Bool>("tir.lwp_disable_func_prof", Bool(false)).value();
-    bool reset_start_id = ctx->GetConfig<Bool>("tir.reset_start_id", 
Bool(false)).value();
+        ctx->GetConfig<Bool>("s_tir.lwp_disable_func_prof", 
Bool(false)).value();
+    bool reset_start_id = ctx->GetConfig<Bool>("s_tir.reset_start_id", 
Bool(false)).value();
     if (reset_start_id) lwp::start_id = 0;
     std::vector<std::pair<GlobalVar, PrimFunc>> updates;
     for (const auto& kv : mptr->functions) {
@@ -281,15 +282,15 @@ Pass InstrumentProfileIntrinsics() {
     return m;
   };
 
-  return tvm::transform::CreateModulePass(pass_func, 0, 
"tir.InstrumentProfileIntrinsics", {});
+  return tvm::transform::CreateModulePass(pass_func, 0, 
"s_tir.InstrumentProfileIntrinsics", {});
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("tir.transform.InstrumentProfileIntrinsics", 
InstrumentProfileIntrinsics);
+  refl::GlobalDef().def("s_tir.transform.InstrumentProfileIntrinsics", 
InstrumentProfileIntrinsics);
 }
 
 }  // namespace transform
 
-}  // namespace tir
+}  // namespace s_tir
 }  // namespace tvm
diff --git a/src/tir/transform/renormalize_split_pattern.cc 
b/src/s_tir/transform/renormalize_split_pattern.cc
similarity index 96%
rename from src/tir/transform/renormalize_split_pattern.cc
rename to src/s_tir/transform/renormalize_split_pattern.cc
index 04dbcca510..abf127e2a2 100644
--- a/src/tir/transform/renormalize_split_pattern.cc
+++ b/src/s_tir/transform/renormalize_split_pattern.cc
@@ -23,17 +23,18 @@
  */
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
 #include <tvm/tir/analysis.h>
 #include <tvm/tir/op.h>
 #include <tvm/tir/stmt.h>
 #include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
 
 #include "../../arith/ir_mutator_with_analyzer.h"
 #include "../../arith/pattern_match.h"
 
 namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
 
 using namespace arith;
 
@@ -203,15 +204,15 @@ Pass RenormalizeSplitPattern() {
     n->body = SplitPatternReNormalizer(&analyzer)(std::move(n->body));
     return f;
   };
-  return CreatePrimFuncPass(pass_func, 0, "tir.RenormalizeSplitPattern", {});
+  return CreatePrimFuncPass(pass_func, 0, "s_tir.RenormalizeSplitPattern", {});
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("tir.transform.RenormalizeSplitPattern", 
RenormalizeSplitPattern);
+  refl::GlobalDef().def("s_tir.transform.RenormalizeSplitPattern", 
RenormalizeSplitPattern);
 }
 
 }  // namespace transform
 
-}  // namespace tir
+}  // namespace s_tir
 }  // namespace tvm
diff --git a/src/tir/transform/rewrite_unsafe_select.cc 
b/src/s_tir/transform/rewrite_unsafe_select.cc
similarity index 95%
rename from src/tir/transform/rewrite_unsafe_select.cc
rename to src/s_tir/transform/rewrite_unsafe_select.cc
index 3dfbcb9967..267d9b4d00 100644
--- a/src/tir/transform/rewrite_unsafe_select.cc
+++ b/src/s_tir/transform/rewrite_unsafe_select.cc
@@ -23,14 +23,15 @@
  */
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/op_attr_types.h>
 #include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
 
 namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
 
 // For now, rewrite unsafe select expression to if_then_else
 // TODO(tqchen) pattern matching to support masked load
@@ -137,15 +138,15 @@ Pass RewriteUnsafeSelect() {
     n->body = UnsafeSelectRewriter()(std::move(n->body));
     return f;
   };
-  return CreatePrimFuncPass(pass_func, 0, "tir.RewriteUnsafeSelect", {});
+  return CreatePrimFuncPass(pass_func, 0, "s_tir.RewriteUnsafeSelect", {});
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("tir.transform.RewriteUnsafeSelect", 
RewriteUnsafeSelect);
+  refl::GlobalDef().def("s_tir.transform.RewriteUnsafeSelect", 
RewriteUnsafeSelect);
 }
 
 }  // namespace transform
 
-}  // namespace tir
+}  // namespace s_tir
 }  // namespace tvm
diff --git a/src/tir/transform/storage_access.cc 
b/src/s_tir/transform/storage_access.cc
similarity index 96%
rename from src/tir/transform/storage_access.cc
rename to src/s_tir/transform/storage_access.cc
index 2a38e64cc7..7a5b6e1622 100644
--- a/src/tir/transform/storage_access.cc
+++ b/src/s_tir/transform/storage_access.cc
@@ -28,10 +28,11 @@
 #include <string>
 #include <utility>
 
-#include "ir_utils.h"
+#include "../../tir/transform/ir_utils.h"
 
 namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
 
 void StorageAccessVisitor::VisitExpr_(const BufferLoadNode* op) {
   Var buf = op->buffer->data;
@@ -109,7 +110,7 @@ void StorageAccessVisitor::VisitStmt_(const LetStmtNode* 
op) {
 }
 
 void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) {
-  if (op->attr_key == attr::double_buffer_write) {
+  if (op->attr_key == tir::attr::double_buffer_write) {
     ICHECK(double_buffer_write_ == nullptr);
     double_buffer_write_ = op->node.as<VarNode>();
     scope_.push_back(std::vector<StmtEntry>());
@@ -127,12 +128,12 @@ void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* 
op) {
       scope_.back().emplace_back(std::move(s));
     }
     double_buffer_write_ = nullptr;
-  } else if (op->attr_key == attr::coproc_scope) {
+  } else if (op->attr_key == tir::attr::coproc_scope) {
     IterVar iv = Downcast<IterVar>(op->node);
     env_threads_.push_back(iv);
     StmtExprVisitor::VisitStmt_(op);
     env_threads_.pop_back();
-  } else if (op->attr_key == attr::thread_extent) {
+  } else if (op->attr_key == tir::attr::thread_extent) {
     IterVar iv = Downcast<IterVar>(op->node);
     env_threads_.push_back(iv);
     if (!in_device_env_) {
@@ -147,7 +148,7 @@ void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* 
op) {
       StmtExprVisitor::VisitStmt_(op);
     }
     env_threads_.pop_back();
-  } else if (op->attr_key == attr::hand_threaded) {
+  } else if (op->attr_key == tir::attr::hand_threaded) {
     // skip this pass on blocks that were hand_threaded
     // this avoids control flow and read/write conflicts
     // between hand-threaded kernels and automatic threading
@@ -293,5 +294,5 @@ StorageScope StorageAccessVisitor::GetScope(Var buffer_var) 
const {
   return StorageScope();  // global by default
 }
 
-}  // namespace tir
+}  // namespace s_tir
 }  // namespace tvm
diff --git a/src/tir/transform/storage_access.h 
b/src/s_tir/transform/storage_access.h
similarity index 95%
rename from src/tir/transform/storage_access.h
rename to src/s_tir/transform/storage_access.h
index 7c96068229..848d8edcf5 100644
--- a/src/tir/transform/storage_access.h
+++ b/src/s_tir/transform/storage_access.h
@@ -21,8 +21,8 @@
  * \file storage_access.h
  * \brief Common data structure for storage access analysis.
  */
-#ifndef TVM_TIR_TRANSFORM_STORAGE_ACCESS_H_
-#define TVM_TIR_TRANSFORM_STORAGE_ACCESS_H_
+#ifndef TVM_S_TIR_TRANSFORM_STORAGE_ACCESS_H_
+#define TVM_S_TIR_TRANSFORM_STORAGE_ACCESS_H_
 
 #include <tvm/arith/int_set.h>
 #include <tvm/ir/attrs.h>
@@ -35,7 +35,8 @@
 #include "../../runtime/thread_storage_scope.h"
 
 namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tir;
 
 using runtime::StorageRank;
 using runtime::StorageScope;
@@ -140,6 +141,6 @@ class StorageAccessVisitor : public StmtExprVisitor {
   // The involving threads
   ffi::Array<IterVar> env_threads_;
 };
-}  // namespace tir
+}  // namespace s_tir
 }  // namespace tvm
-#endif  // TVM_TIR_TRANSFORM_STORAGE_ACCESS_H_
+#endif  // TVM_S_TIR_TRANSFORM_STORAGE_ACCESS_H_
diff --git a/src/tir/transform/tensorcore_infer_fragment.cc 
b/src/s_tir/transform/tensorcore_infer_fragment.cc
similarity index 91%
rename from src/tir/transform/tensorcore_infer_fragment.cc
rename to src/s_tir/transform/tensorcore_infer_fragment.cc
index 7c1b5b05d0..d1232e5164 100644
--- a/src/tir/transform/tensorcore_infer_fragment.cc
+++ b/src/s_tir/transform/tensorcore_infer_fragment.cc
@@ -23,19 +23,20 @@
  */
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
 
 #include <unordered_map>
 #include <unordered_set>
 
 #include "../../runtime/thread_storage_scope.h"
-#include "ir_utils.h"
+#include "../../tir/transform/ir_utils.h"
 #include "storage_access.h"
 
 namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
 
 // Get fragment information from tensor intrinsics
 class FragmentGetter : public StmtExprVisitor {
@@ -113,11 +114,17 @@ class FragmentGetter : public StmtExprVisitor {
   std::unordered_map<const VarNode*, FragmentInfo> fragments;
 };
 
+}  // namespace s_tir
+
+namespace tir {
 std::unordered_map<const VarNode*, FragmentInfo> 
GetTensorCoreFragmentInfo(const Stmt& stmt) {
-  FragmentGetter getter;
+  s_tir::FragmentGetter getter;
   getter(stmt);
   return std::move(getter.fragments);
 }
+}  // namespace tir
+
+namespace s_tir {
 
 // Check shape of fragment making sure it is a valid shape for tvm_mma_sync
 class FragmentChecker : public StmtExprVisitor {
@@ -180,11 +187,11 @@ class InferFragmenter : public StmtMutator {
       std::string shape =
           std::to_string(info.m) + ", " + std::to_string(info.n) + ", " + 
std::to_string(info.k);
       PrimExpr shape_expr = StringImm(shape);
-      Stmt shape_attr = AttrStmt(op->buffer_var, attr::fragment_shape, 
shape_expr, stmt);
+      Stmt shape_attr = AttrStmt(op->buffer_var, tir::attr::fragment_shape, 
shape_expr, stmt);
       if (info.layout != "") {
         // Add shape attribute to matrix_a and matrix_b
-        Stmt layout_attr =
-            AttrStmt(op->buffer_var, attr::fragment_layout, 
StringImm(info.layout), shape_attr);
+        Stmt layout_attr = AttrStmt(op->buffer_var, tir::attr::fragment_layout,
+                                    StringImm(info.layout), shape_attr);
         return layout_attr;
       } else {
         return shape_attr;
@@ -212,17 +219,17 @@ namespace transform {
 Pass InferFragment() {
   auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
     auto* n = f.CopyOnWrite();
-    n->body = InferFragment(std::move(n->body));
+    n->body = s_tir::InferFragment(std::move(n->body));
     return f;
   };
-  return CreatePrimFuncPass(pass_func, 0, "tir.InferFragment", {});
+  return CreatePrimFuncPass(pass_func, 0, "s_tir.InferFragment", {});
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("tir.transform.InferFragment", InferFragment);
+  refl::GlobalDef().def("s_tir.transform.InferFragment", static_cast<Pass 
(*)()>(InferFragment));
 }
 
 }  // namespace transform
-}  // namespace tir
+}  // namespace s_tir
 }  // namespace tvm
diff --git a/src/tir/transform/thread_storage_sync.cc 
b/src/s_tir/transform/thread_storage_sync.cc
similarity index 96%
rename from src/tir/transform/thread_storage_sync.cc
rename to src/s_tir/transform/thread_storage_sync.cc
index d41d474a08..57d8f25b51 100644
--- a/src/tir/transform/thread_storage_sync.cc
+++ b/src/s_tir/transform/thread_storage_sync.cc
@@ -22,21 +22,22 @@
  */
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
+#include <tvm/s_tir/transform.h>
 #include <tvm/tir/analysis.h>
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/transform.h>
 
 #include <unordered_map>
 #include <unordered_set>
 
 #include "../../runtime/thread_storage_scope.h"
-#include "ir_utils.h"
+#include "../../tir/transform/ir_utils.h"
 #include "storage_access.h"
 
 namespace tvm {
-namespace tir {
+namespace s_tir {
+using namespace tvm::tir;
 
 class ThreadSyncPlanner : public StorageAccessVisitor {
  public:
@@ -291,7 +292,7 @@ class ThreadSyncAfterWaitQueueInserter : public 
StmtExprMutator {
   explicit ThreadSyncAfterWaitQueueInserter(StorageScope sync_scope) : 
sync_scope_(sync_scope) {}
 
   Stmt VisitStmt_(const AttrStmtNode* op) final {
-    if (op->attr_key == attr::async_wait_queue_scope) {
+    if (op->attr_key == tir::attr::async_wait_queue_scope) {
       auto sync = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
                                 {StringImm(sync_scope_.to_string())}));
       auto inner = op->body.as<AttrStmtNode>();
@@ -346,7 +347,7 @@ class ThreadSyncInserter : public StmtExprMutator {
     return StmtExprMutator::VisitStmt_(op);
   }
   Stmt VisitStmt_(const AttrStmtNode* op) final {
-    if (op->attr_key == attr::thread_extent) {
+    if (op->attr_key == tir::attr::thread_extent) {
       bool temp = true;
       std::swap(temp, in_thread_env_);
       thread_extents_.push_back(op);
@@ -407,7 +408,7 @@ class ThreadSyncInserter : public StmtExprMutator {
     for (const auto& kv : rw_stats_) {
       const auto& e = kv.second;
       if (e.read_count != 0 && e.write_count != 0) {
-        body = AttrStmt(kv.first, attr::volatile_scope, 1, body);
+        body = AttrStmt(kv.first, tir::attr::volatile_scope, 1, body);
       }
     }
     rw_stats_.clear();
@@ -466,17 +467,18 @@ namespace transform {
 Pass ThreadSync(ffi::String storage_scope) {
   auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) {
     auto* n = f.CopyOnWrite();
-    n->body = ThreadSync(std::move(n->body), storage_scope);
+    n->body = s_tir::ThreadSync(std::move(n->body), storage_scope);
     return f;
   };
-  return CreatePrimFuncPass(pass_func, 0, "tir.ThreadSync", {});
+  return CreatePrimFuncPass(pass_func, 0, "s_tir.ThreadSync", {});
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("tir.transform.ThreadSync", ThreadSync);
+  refl::GlobalDef().def("s_tir.transform.ThreadSync",
+                        static_cast<Pass (*)(ffi::String)>(ThreadSync));
 }
 
 }  // namespace transform
-}  // namespace tir
+}  // namespace s_tir
 }  // namespace tvm
diff --git a/tests/python/tir-transform/test_tir_transform_hoist_expression.py 
b/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py
similarity index 98%
rename from tests/python/tir-transform/test_tir_transform_hoist_expression.py
rename to tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py
index c973e86e18..25419dc88b 100644
--- a/tests/python/tir-transform/test_tir_transform_hoist_expression.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py
@@ -16,8 +16,9 @@
 # under the License.
 import tvm
 import tvm.testing
+from tvm import s_tir
 from tvm.script import tir as T
-from tvm.tir.transform import HoistedConditionals, HoistedLetBindings
+from tvm.s_tir.transform import HoistedConditionals, HoistedLetBindings
 
 
 def _run_transform(before, hoisted_conditionals, hoisted_let_bindings):
@@ -25,14 +26,14 @@ def _run_transform(before, hoisted_conditionals, 
hoisted_let_bindings):
     before_mod = tvm.IRModule.from_expr(before)
 
     config = {
-        "tir.HoistExpression": {
+        "s_tir.HoistExpression": {
             "hoisted_conditionals": hoisted_conditionals.value,
             "hoisted_let_bindings": hoisted_let_bindings.value,
         }
     }
 
     with tvm.transform.PassContext(config=config):
-        after_mod = tvm.tir.transform.HoistExpression()(before_mod)
+        after_mod = tvm.s_tir.transform.HoistExpression()(before_mod)
 
     return after_mod["main"]
 
diff --git a/tests/python/tir-transform/test_tir_transform_hoist_if.py 
b/tests/python/s_tir/transform/test_s_tir_transform_hoist_if.py
similarity index 88%
rename from tests/python/tir-transform/test_tir_transform_hoist_if.py
rename to tests/python/s_tir/transform/test_s_tir_transform_hoist_if.py
index a941c6c019..5d0f48df4d 100644
--- a/tests/python/tir-transform/test_tir_transform_hoist_if.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_hoist_if.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
-from tvm import te
+from tvm import te, s_tir
 from tvm.script import tir as T, ir as I
 import numpy as np
 import pytest
@@ -77,7 +77,7 @@ def test_hoist_top_for():
                         T.evaluate(T.call_extern("int32", "dummy", n))
 
     mod = tvm.IRModule.from_expr(func)
-    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
     expected_struct = {
         ("tir.For", "k"): (None,),
         ("tir.For", "j"): (("tir.For", "k"),),
@@ -99,7 +99,7 @@ def test_hoist_multi_var_if():
                         T.evaluate(T.call_extern("int32", "dummy", n))
 
     mod = tvm.IRModule.from_expr(func)
-    new_mod = tvm.tir.transform.HoistIfThenElse()(mod)
+    new_mod = tvm.s_tir.transform.HoistIfThenElse()(mod)
     new_stmt = new_mod["main"].body
     expected_struct = {
         ("tir.For", "k"): (None,),
@@ -124,7 +124,7 @@ def test_hoist_no_match_for():
                         T.evaluate(T.call_extern("int32", "dummy", n))
 
     mod = tvm.IRModule.from_expr(func)
-    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
     expected_struct = {
         ("tir.For", "k"): (None,),
         ("tir.IfThenElse", ("i",)): (("tir.For", "k"), ("tir.For", "k")),
@@ -144,7 +144,7 @@ def test_no_else():
                         T.evaluate(T.call_extern("int32", "dummy", m))
 
     mod = tvm.IRModule.from_expr(func)
-    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
     expected_struct = {
         ("tir.For", "k"): (None,),
         ("tir.For", "j"): (("tir.For", "k"),),
@@ -175,7 +175,7 @@ def test_attr_stmt():
                         )
 
     mod = tvm.IRModule.from_expr(func)
-    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
     expected_struct = {
         ("tir.For", "k"): (None,),
         ("tir.IfThenElse", ("i", "j")): (("tir.For", "k"), ("tir.For", "k")),
@@ -207,7 +207,7 @@ def test_nested_for():
                                 ] * T.float32(1.5)
 
     mod = tvm.IRModule.from_expr(func)
-    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
     expected_struct = {
         ("tir.For", "l"): (None,),
         ("tir.For", "k"): (("tir.For", "l"),),
@@ -248,7 +248,7 @@ def test_if_block():
                         if n >= 3:
                             data[i2 * 3 + j2 + k2] = data[i2 * 3 + j2 + k2] + 
T.float32(0.6)
 
-    new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+    new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
     # Updated expected_struct with renamed second nest variables
     expected_struct = {
         ("tir.IfThenElse", ("i", "j")): (None, None),
@@ -280,7 +280,7 @@ def test_multi_if():
                             ] + T.float32(0.5)
 
     mod = tvm.IRModule.from_expr(func)
-    new_mod = tvm.tir.transform.HoistIfThenElse()(mod)
+    new_mod = tvm.s_tir.transform.HoistIfThenElse()(mod)
     new_stmt = new_mod["main"].body
     expected_struct = {
         ("tir.For", "k"): (None,),
@@ -306,13 +306,13 @@ def test_no_hoisting_1():
 
     mod = tvm.IRModule.from_expr(func)
     stmt = mod["main"].body
-    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
     with tvm.transform.PassContext(
-        config={"tir.HoistIfThenElse": {"support_block_scope_hoisting": True}}
+        config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting": 
True}}
     ):
-        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+        new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
 
@@ -331,13 +331,13 @@ def test_no_hoisting_2():
 
     mod = tvm.IRModule.from_expr(func)
     stmt = mod["main"].body
-    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
     with tvm.transform.PassContext(
-        config={"tir.HoistIfThenElse": {"support_block_scope_hoisting": True}}
+        config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting": 
True}}
     ):
-        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+        new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
 
@@ -370,13 +370,13 @@ def test_no_hoisting_4():
                             ] + T.float32(1.3)
 
     stmt = Module["main"].body
-    new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+    new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
     with tvm.transform.PassContext(
-        config={"tir.HoistIfThenElse": {"support_block_scope_hoisting": True}}
+        config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting": 
True}}
     ):
-        new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+        new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
 
@@ -398,13 +398,13 @@ def test_no_hoisting_6():
                             data[bx * j + tx * j * k] = data[bx * j + tx * j * 
k] + T.float32(1.3)
 
     stmt = Module["main"].body
-    new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+    new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
     with tvm.transform.PassContext(
-        config={"tir.HoistIfThenElse": {"support_block_scope_hoisting": True}}
+        config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting": 
True}}
     ):
-        new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+        new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
 
@@ -427,13 +427,13 @@ def test_no_hoisting_7():
                                 )
 
     stmt = Module["main"].body
-    new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+    new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
     with tvm.transform.PassContext(
-        config={"tir.HoistIfThenElse": {"support_block_scope_hoisting": True}}
+        config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting": 
True}}
     ):
-        new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+        new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
 
@@ -469,13 +469,13 @@ def test_hoisting_block_scope_2():
     mod = tvm.tir.transform.RemoveNoOp()(mod)
     stmt = mod["main"].body
 
-    new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+    new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
     with tvm.transform.PassContext(
-        config={"tir.HoistIfThenElse": {"support_block_scope_hoisting": True}}
+        config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting": 
True}}
     ):
-        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+        new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
     assert not tvm.ir.structural_equal(new_stmt, stmt)
 
 
@@ -493,16 +493,16 @@ def test_hoisting_block_scope_5():
                             data[9 * j + 3 * j * k] = data[9 * j + 3 * j * k] 
+ T.float32(1.3)
 
     stmt = Module["main"].body
-    new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+    new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
     assert not tvm.ir.structural_equal(new_stmt, stmt)
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], new_stmt))
     stmt = new_stmt
 
     with tvm.transform.PassContext(
-        config={"tir.HoistIfThenElse": {"support_block_scope_hoisting": True}}
+        config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting": 
True}}
     ):
-        new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body
+        new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
 
@@ -524,13 +524,13 @@ def test_hoisting_block_scope_6():
                             data[bx * j + tx * j * k] = data[bx * j + tx * j * 
k] + T.float32(1.3)
 
     stmt = Module["main"].body
-    new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+    new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
     with tvm.transform.PassContext(
-        config={"tir.HoistIfThenElse": {"support_block_scope_hoisting": True}}
+        config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting": 
True}}
     ):
-        new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+        new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
     assert not tvm.ir.structural_equal(new_stmt, stmt)
 
 
@@ -552,13 +552,13 @@ def test_hoisting_block_scope_7():
                             data[bx * j + tx * j * k] = data[bx * j + tx * j * 
k] + T.float32(1.3)
 
     stmt = Module["main"].body
-    new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+    new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
     tvm.ir.assert_structural_equal(new_stmt, stmt)
 
     with tvm.transform.PassContext(
-        config={"tir.HoistIfThenElse": {"support_block_scope_hoisting": True}}
+        config={"s_tir.HoistIfThenElse": {"support_block_scope_hoisting": 
True}}
     ):
-        new_stmt = tvm.tir.transform.HoistIfThenElse()(Module)["main"].body
+        new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body
     assert not tvm.ir.structural_equal(new_stmt, stmt)
 
 
diff --git 
a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py 
b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py
similarity index 99%
rename from 
tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
rename to 
tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py
index 5ab0cfe6f0..37b6d02f02 100644
--- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py
@@ -18,6 +18,7 @@
 import tvm
 import tvm_ffi
 import tvm.testing
+from tvm import s_tir
 from tvm.script import tir as T, ir as I
 
 import pytest
@@ -135,7 +136,7 @@ def test_inject_async_copy():
         mod = tvm.tir.transform.FlattenBuffer()(mod)
         if vec_size > 1:
             mod = tvm.tir.transform.VectorizeLoop()(mod)
-        mod = tvm.tir.transform.InjectPTXAsyncCopy()(mod)
+        mod = tvm.s_tir.transform.InjectPTXAsyncCopy()(mod)
 
         assert count_cp_async(mod["main"].body) == 1
 
@@ -162,8 +163,8 @@ def test_inject_async_copy_shared_dyn():
     mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod)
     mod = tvm.tir.transform.FlattenBuffer()(mod)
     mod = tvm.tir.transform.VectorizeLoop()(mod)
-    mod = tvm.tir.transform.MergeSharedMemoryAllocations()(mod)
-    mod = tvm.tir.transform.InjectPTXAsyncCopy()(mod)
+    mod = tvm.s_tir.transform.MergeSharedMemoryAllocations()(mod)
+    mod = tvm.s_tir.transform.InjectPTXAsyncCopy()(mod)
 
     assert count_cp_async(mod["main"].body) == 2
 
@@ -223,7 +224,7 @@ def test_inject_async_copy_barrier():
     mod = tvm.IRModule.from_expr(f)
     mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod)
     mod = tvm.tir.transform.FlattenBuffer()(mod)
-    mod = tvm.tir.transform.InjectPTXAsyncCopy()(mod)
+    mod = tvm.s_tir.transform.InjectPTXAsyncCopy()(mod)
 
     assert count_cp_async(mod["main"].body) == 1
 
@@ -981,7 +982,7 @@ def test_multiplication_nodes_are_inlined():
             T.ptx_commit_group()
             T.ptx_wait_group(0)
 
-    After = tvm.tir.transform.InjectPTXAsyncCopy()(Before)
+    After = tvm.s_tir.transform.InjectPTXAsyncCopy()(Before)
     tvm.ir.assert_structural_equal(After, Expected)
 
 
diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_ldg32.py 
b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_ldg32.py
similarity index 95%
rename from tests/python/tir-transform/test_tir_transform_inject_ptx_ldg32.py
rename to tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_ldg32.py
index 55099f252c..c0aed0351e 100644
--- a/tests/python/tir-transform/test_tir_transform_inject_ptx_ldg32.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_ldg32.py
@@ -17,6 +17,7 @@
 
 import tvm
 import tvm.testing
+from tvm import s_tir
 from tvm.script import tir as T
 
 
@@ -60,7 +61,7 @@ def test_inject_ptx_ldg32_inserts_alloc_for_no_alloc_func():
     mod = tvm.IRModule.from_expr(where_no_alloc)
     assert _count_alloc(mod["main"].body) == 0
 
-    mod = tvm.tir.transform.InjectPTXLDG32()(mod)
+    mod = tvm.s_tir.transform.InjectPTXLDG32()(mod)
     assert _count_alloc(mod["main"].body) > 0
     assert _count_ptx_ldg32(mod["main"].body) == 1
 
@@ -71,7 +72,7 @@ def test_inject_ptx_ldg32_skip_non_cuda_target():
     mod = tvm.IRModule({"main": mod["main"].with_attr("target", cpu_target)})
     assert _count_alloc(mod["main"].body) == 0
 
-    mod = tvm.tir.transform.InjectPTXLDG32()(mod)
+    mod = tvm.s_tir.transform.InjectPTXLDG32()(mod)
     assert _count_alloc(mod["main"].body) == 0
     assert _count_ptx_ldg32(mod["main"].body) == 0
 
diff --git 
a/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py 
b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
similarity index 98%
rename from 
tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py
rename to 
tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
index cfd81377c1..edffa0dcc5 100644
--- a/tests/python/tir-transform/test_tir_transform_lower_thread_all_reduce.py
+++ 
b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
@@ -17,11 +17,12 @@
 
 import tvm
 import tvm.testing
+from tvm import s_tir
 from tvm.script import tir as T, ir as I
 
 
 def test_basic():
-    transform = tvm.tir.transform.LowerThreadAllreduce()
+    transform = tvm.s_tir.transform.LowerThreadAllreduce()
 
     @I.ir_module
     class Before:
@@ -97,7 +98,7 @@ def test_basic():
 
 
 def test_basic_with_decl_buffer():
-    transform = tvm.tir.transform.LowerThreadAllreduce()
+    transform = tvm.s_tir.transform.LowerThreadAllreduce()
 
     @I.ir_module
     class Before:
@@ -171,7 +172,7 @@ def test_basic_with_decl_buffer():
 
 
 def test_reduce_summation():
-    transform = tvm.tir.transform.LowerThreadAllreduce()
+    transform = tvm.s_tir.transform.LowerThreadAllreduce()
 
     @I.ir_module
     class Before:
@@ -261,7 +262,7 @@ def test_reduce_summation():
 
 
 def test_multi_group_reduction():
-    transform = tvm.tir.transform.LowerThreadAllreduce()
+    transform = tvm.s_tir.transform.LowerThreadAllreduce()
 
     @I.ir_module
     class Before:
@@ -330,7 +331,7 @@ def test_multi_group_reduction():
 
 
 def test_multi_group_mask1():
-    transform = tvm.tir.transform.LowerThreadAllreduce()
+    transform = tvm.s_tir.transform.LowerThreadAllreduce()
 
     @I.ir_module
     class Before:
@@ -393,7 +394,7 @@ def test_multi_group_mask1():
 
 
 def test_multi_warp_reduce1():
-    transform = tvm.tir.transform.LowerThreadAllreduce()
+    transform = tvm.s_tir.transform.LowerThreadAllreduce()
 
     @I.ir_module
     class Before:
@@ -478,7 +479,7 @@ def test_multi_warp_reduce1():
 
 
 def test_multi_warp_reduce2():
-    transform = tvm.tir.transform.LowerThreadAllreduce()
+    transform = tvm.s_tir.transform.LowerThreadAllreduce()
 
     @I.ir_module
     class Before:
@@ -563,7 +564,7 @@ def test_multi_warp_reduce2():
 
 
 def test_multi_group_multi_warp_reduction():
-    transform = tvm.tir.transform.LowerThreadAllreduce()
+    transform = tvm.s_tir.transform.LowerThreadAllreduce()
 
     @I.ir_module
     class Before:
@@ -648,7 +649,7 @@ def test_multi_group_multi_warp_reduction():
 
 
 def test_multi_group_multi_warp_predicated_reduction():
-    transform = tvm.tir.transform.LowerThreadAllreduce()
+    transform = tvm.s_tir.transform.LowerThreadAllreduce()
 
     @I.ir_module
     class Before:
@@ -743,7 +744,7 @@ def test_multi_group_multi_warp_predicated_reduction():
 
 
 def test_metal_no_mask():
-    transform = tvm.tir.transform.LowerThreadAllreduce()
+    transform = tvm.s_tir.transform.LowerThreadAllreduce()
 
     @I.ir_module
     class Before:
diff --git 
a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py
 
b/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py
similarity index 97%
rename from 
tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py
rename to 
tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py
index eb08607dc8..26e8a6f2bb 100644
--- 
a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py
+++ 
b/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py
@@ -18,7 +18,7 @@ import numpy as np
 
 import tvm
 import tvm.testing
-from tvm import te
+from tvm import te, s_tir
 from tvm.topi.math import cast
 from tvm.script import ir as I, tir as T
 
@@ -30,7 +30,7 @@ def test_matmul_t_buffer():
     test_matmul_dyn_shared, using `T.Buffer` (Allocate without
     DeclBuffer) for the replaced allocations.
     """
-    transform = tvm.tir.transform.MergeSharedMemoryAllocations()
+    transform = tvm.s_tir.transform.MergeSharedMemoryAllocations()
     buffer_func = T.Buffer
 
     @I.ir_module
@@ -145,7 +145,7 @@ def test_matmul_decl_buffer():
     test_matmul_dyn_shared, using `T.decl_buffer` (Allocate followed by 
DeclBuffer)
     for the replaced allocations.
     """
-    transform = tvm.tir.transform.MergeSharedMemoryAllocations()
+    transform = tvm.s_tir.transform.MergeSharedMemoryAllocations()
     buffer_func = T.decl_buffer
 
     @I.ir_module
@@ -255,7 +255,7 @@ def test_matmul_decl_buffer():
 
 def test_simple_alloc_no_reuse():
     """Test alloc and free within the same scope."""
-    transform = tvm.tir.transform.MergeSharedMemoryAllocations()
+    transform = tvm.s_tir.transform.MergeSharedMemoryAllocations()
 
     @I.ir_module
     class Before:
@@ -284,7 +284,7 @@ def test_simple_alloc_no_reuse():
 
 def test_simple_alloc_reuse():
     """Test alloc and free within the same scope with a reuse chance."""
-    transform = tvm.tir.transform.MergeSharedMemoryAllocations()
+    transform = tvm.s_tir.transform.MergeSharedMemoryAllocations()
 
     @I.ir_module
     class Before:
@@ -315,7 +315,7 @@ def test_simple_alloc_reuse():
 
 def test_async_copy():
     """Test async copy in shared memory."""
-    transform = tvm.tir.transform.MergeSharedMemoryAllocations()
+    transform = tvm.s_tir.transform.MergeSharedMemoryAllocations()
 
     @I.ir_module
     class Before:
diff --git a/tests/python/tir-transform/test_tir_transform_profiling_instr.py 
b/tests/python/s_tir/transform/test_s_tir_transform_profiling_instr.py
similarity index 95%
rename from tests/python/tir-transform/test_tir_transform_profiling_instr.py
rename to tests/python/s_tir/transform/test_s_tir_transform_profiling_instr.py
index 139524fa83..f6c409aa8e 100644
--- a/tests/python/tir-transform/test_tir_transform_profiling_instr.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_profiling_instr.py
@@ -17,15 +17,15 @@
 
 import tvm
 import tvm.testing
-from tvm import te
+from tvm import te, s_tir
 from tvm.ir.module import IRModule
 from tvm.script import tir as T
 import numpy
 
 default_lwp_test_config = {
     "tir.instrument_lwp": True,
-    "tir.lwp_disable_func_prof": True,
-    "tir.reset_start_id": True,
+    "s_tir.lwp_disable_func_prof": True,
+    "s_tir.reset_start_id": True,
 }
 
 
@@ -278,7 +278,7 @@ def test6_expected_output(a: T.handle, b: T.handle, c: 
T.handle, d: T.handle) ->
 def test1():
     with tvm.transform.PassContext(config=default_lwp_test_config):
         mod = tvm.IRModule.from_expr(input1.with_attr("global_symbol", "main"))
-        mod = tvm.tir.transform.InstrumentProfileIntrinsics()(mod)
+        mod = tvm.s_tir.transform.InstrumentProfileIntrinsics()(mod)
     tvm.ir.assert_structural_equal(
         mod["main"], test1_expected_output.with_attr("global_symbol", "main")
     )
@@ -288,10 +288,10 @@ def test1():
 # doesn't have any effect unless 'instr_siblings' is set to False (ex: test3).
 def test2():
     test2_config = default_lwp_test_config.copy()
-    test2_config.update({"tir.lwp_max_depth": 3})
+    test2_config.update({"s_tir.lwp_max_depth": 3})
     with tvm.transform.PassContext(config=test2_config):
         mod = tvm.IRModule.from_expr(input1.with_attr("global_symbol", "main"))
-        mod = tvm.tir.transform.InstrumentProfileIntrinsics()(mod)
+        mod = tvm.s_tir.transform.InstrumentProfileIntrinsics()(mod)
     tvm.ir.assert_structural_equal(
         mod["main"], test1_expected_output.with_attr("global_symbol", "main")
     )
@@ -303,10 +303,10 @@ def test2():
 # 'lwp_min_height' (ex: test5)
 def test3():
     test3_config = default_lwp_test_config.copy()
-    test3_config.update({"tir.lwp_max_depth": 3, "tir.instr_siblings": False})
+    test3_config.update({"s_tir.lwp_max_depth": 3, "s_tir.instr_siblings": 
False})
     with tvm.transform.PassContext(config=test3_config):
         mod = tvm.IRModule.from_expr(input1.with_attr("global_symbol", "main"))
-        mod = tvm.tir.transform.InstrumentProfileIntrinsics()(mod)
+        mod = tvm.s_tir.transform.InstrumentProfileIntrinsics()(mod)
     tvm.ir.assert_structural_equal(
         mod["main"], test3_expected_output.with_attr("global_symbol", "main")
     )
@@ -317,7 +317,7 @@ def test3():
 def test4():
     with tvm.transform.PassContext(config=default_lwp_test_config):
         mod = tvm.IRModule.from_expr(input2.with_attr("global_symbol", "main"))
-        mod = tvm.tir.transform.InstrumentProfileIntrinsics()(mod)
+        mod = tvm.s_tir.transform.InstrumentProfileIntrinsics()(mod)
     tvm.ir.assert_structural_equal(
         mod["main"], test4_expected_output.with_attr("global_symbol", "main")
     )
@@ -328,11 +328,11 @@ def test4():
 def test5():
     test5_config = default_lwp_test_config.copy()
     test5_config.update(
-        {"tir.lwp_max_depth": 3, "tir.instr_siblings": False, 
"tir.lwp_min_height": 2}
+        {"s_tir.lwp_max_depth": 3, "s_tir.instr_siblings": False, 
"s_tir.lwp_min_height": 2}
     )
     with tvm.transform.PassContext(config=test5_config):
         mod = tvm.IRModule.from_expr(input1.with_attr("global_symbol", "main"))
-        mod = tvm.tir.transform.InstrumentProfileIntrinsics()(mod)
+        mod = tvm.s_tir.transform.InstrumentProfileIntrinsics()(mod)
     tvm.ir.assert_structural_equal(
         mod["main"], test5_expected_output.with_attr("global_symbol", "main")
     )
@@ -342,7 +342,7 @@ def test5():
 def test6():
     with tvm.transform.PassContext(config=default_lwp_test_config):
         mod = tvm.IRModule.from_expr(input3.with_attr("global_symbol", "main"))
-        mod = tvm.tir.transform.InstrumentProfileIntrinsics()(mod)
+        mod = tvm.s_tir.transform.InstrumentProfileIntrinsics()(mod)
     tvm.ir.assert_structural_equal(
         mod["main"], test6_expected_output.with_attr("global_symbol", "main")
     )
diff --git 
a/tests/python/tir-transform/test_tir_transform_renormalize_split_pattern.py 
b/tests/python/s_tir/transform/test_s_tir_transform_renormalize_split_pattern.py
similarity index 98%
rename from 
tests/python/tir-transform/test_tir_transform_renormalize_split_pattern.py
rename to 
tests/python/s_tir/transform/test_s_tir_transform_renormalize_split_pattern.py
index 057cfc42e4..96c9089c23 100644
--- a/tests/python/tir-transform/test_tir_transform_renormalize_split_pattern.py
+++ 
b/tests/python/s_tir/transform/test_s_tir_transform_renormalize_split_pattern.py
@@ -17,6 +17,7 @@
 
 import tvm
 import tvm.testing
+from tvm import s_tir
 from tvm.script import tir as T
 
 # fmt: off
@@ -119,7 +120,7 @@ class After_simplified:
 
 
 def test_renormalize_split_pattern():
-    after = tvm.tir.transform.RenormalizeSplitPattern()(Before)
+    after = tvm.s_tir.transform.RenormalizeSplitPattern()(Before)
     tvm.ir.assert_structural_equal(after, After)
     after = tvm.tir.transform.Simplify()(after)
     tvm.ir.assert_structural_equal(after, After_simplified)
@@ -168,7 +169,7 @@ def 
test_analyze_inside_integer_conditional(integer_condition):
     # exception, as it rewrites the integer conditionals first.  These
     # tests are written using RenormalizeSplitPattern as it is the
     # first case identified.
-    transform = tvm.tir.transform.RenormalizeSplitPattern()
+    transform = tvm.s_tir.transform.RenormalizeSplitPattern()
 
     # Issue would result in an error through while applying the transformation.
     mod = tvm.IRModule.from_expr(integer_condition)
diff --git 
a/tests/python/tir-transform/test_tir_transform_rewrite_unsafe_select.py 
b/tests/python/s_tir/transform/test_s_tir_transform_rewrite_unsafe_select.py
similarity index 90%
rename from 
tests/python/tir-transform/test_tir_transform_rewrite_unsafe_select.py
rename to 
tests/python/s_tir/transform/test_s_tir_transform_rewrite_unsafe_select.py
index eee05315c9..8500aa8b33 100644
--- a/tests/python/tir-transform/test_tir_transform_rewrite_unsafe_select.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_rewrite_unsafe_select.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
+from tvm import s_tir
 from tvm.script import ir as I, tir as T
 
 
@@ -27,7 +28,7 @@ def test_rewrite_Select():
             A = T.Buffer(100, "float32", data=A_data)
             T.evaluate(T.Select(i > 1, A[i - 1], T.float32(1.0)))
 
-    yy = 
tvm.tir.transform.RewriteUnsafeSelect()(ModuleY)["main"].body.body.value
+    yy = 
tvm.s_tir.transform.RewriteUnsafeSelect()(ModuleY)["main"].body.body.value
 
     @I.ir_module
     class ModuleZ:
@@ -41,7 +42,7 @@ def test_rewrite_Select():
                 )
             )
 
-    zz = 
tvm.tir.transform.RewriteUnsafeSelect()(ModuleZ)["main"].body.body.value
+    zz = 
tvm.s_tir.transform.RewriteUnsafeSelect()(ModuleZ)["main"].body.body.value
 
     @I.ir_module
     class ModuleA:
@@ -62,7 +63,7 @@ def test_rewrite_Select():
                 )
             )
 
-    aa = 
tvm.tir.transform.RewriteUnsafeSelect()(ModuleA)["main"].body.body.value
+    aa = 
tvm.s_tir.transform.RewriteUnsafeSelect()(ModuleA)["main"].body.body.value
     builtin_if_then_else = tvm.ir.Op.get("tir.if_then_else")
 
     assert yy.op.same_as(builtin_if_then_else)
diff --git a/tests/python/tir-transform/test_tir_transform_thread_sync.py 
b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py
similarity index 97%
rename from tests/python/tir-transform/test_tir_transform_thread_sync.py
rename to tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py
index 48de01a629..7d5720ea3f 100644
--- a/tests/python/tir-transform/test_tir_transform_thread_sync.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py
@@ -16,7 +16,7 @@
 # under the License.
 import tvm
 import tvm.testing
-from tvm import te
+from tvm import te, s_tir
 from tvm.script import tir as T
 
 
@@ -31,7 +31,7 @@ def run_passes(func: tvm.tir.PrimFunc):
 
     mod = tvm.tir.transform.AnnotateDeviceRegions()(mod)
     mod = tvm.tir.transform.SplitHostDevice()(mod)
-    return tvm.tir.transform.ThreadSync("shared")(mod)
+    return tvm.s_tir.transform.ThreadSync("shared")(mod)
 
 
 @tvm.testing.requires_cuda
@@ -94,7 +94,7 @@ def test_sync_shared_dyn():
         E_1[threadIdx_x] = D_1_1[threadIdx_x]
 
     mod = tvm.IRModule({"main": func})
-    mod = tvm.tir.transform.ThreadSync("shared.dyn")(mod)
+    mod = tvm.s_tir.transform.ThreadSync("shared.dyn")(mod)
     tvm.ir.assert_structural_equal(mod["main"], expected)
 
 
@@ -170,7 +170,7 @@ def test_sync_let_stmt():
         )
 
     mod = tvm.IRModule({"main": func})
-    mod = tvm.tir.transform.ThreadSync("shared")(mod)
+    mod = tvm.s_tir.transform.ThreadSync("shared")(mod)
     tvm.ir.assert_structural_equal(mod["main"], expected)
 
 


Reply via email to