This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a60cd0fecf [TIR] Allow symbolic bounds in IndexMap analysis (#15264)
a60cd0fecf is described below

commit a60cd0fecfdbc56187c7632469c5f444f835e99f
Author: Junru Shao <[email protected]>
AuthorDate: Sat Jul 8 21:50:05 2023 -0700

    [TIR] Allow symbolic bounds in IndexMap analysis (#15264)
    
    This PR adds the bounds of shape variables to the arithmetic analyzer
    so that it is possible to simplify certain expressions.
---
 include/tvm/tir/index_map.h                        |  12 +--
 include/tvm/topi/transform.h                       |   9 +-
 pyproject.toml                                     |   3 +
 python/tvm/te/schedule.py                          |  11 +-
 python/tvm/tir/function.py                         |  20 ++--
 python/tvm/tir/schedule/schedule.py                |  40 ++++++--
 python/tvm/tir/schedule/testing.py                 |   6 +-
 src/arith/ir_mutator_with_analyzer.cc              |  20 ++++
 src/arith/ir_mutator_with_analyzer.h               |   6 ++
 src/arith/iter_affine_map.cc                       |   6 ++
 src/relay/backend/te_compiler_cache.cc             |   4 +-
 src/relay/op/tensor/transform.cc                   |   4 +-
 src/runtime/logging.cc                             |   3 +
 src/target/source/codegen_cuda.cc                  |   3 +-
 src/te/schedule/message_passing.cc                 |  14 ++-
 src/te/schedule/schedule_lang.cc                   |   6 +-
 src/tir/ir/index_map.cc                            |  64 ++++++------
 src/tir/schedule/analysis.h                        |  10 ++
 src/tir/schedule/analysis/analysis.cc              |  14 +++
 src/tir/schedule/primitive/cache_read_write.cc     |   7 +-
 src/tir/schedule/primitive/compute_at.cc           |  14 ---
 .../schedule/primitive/layout_transformation.cc    |  81 ++++++++-------
 src/tir/schedule/transform.cc                      |  22 ++--
 src/tir/transforms/flatten_buffer.cc               |  13 +--
 src/tir/transforms/storage_flatten.cc              |   7 +-
 src/tir/transforms/transform_mma_buffer_layout.cc  |   6 +-
 .../python/unittest/test_arith_iter_affine_map.py  |   4 +-
 .../test_meta_schedule_relay_integration.py        |  20 +++-
 ...meta_schedule_schedule_cuda_layout_transform.py |   6 +-
 .../test_meta_schedule_schedule_rule_mlt_tc.py     |  19 ++--
 .../unittest/test_meta_schedule_trace_apply.py     |  10 +-
 .../python/unittest/test_meta_schedule_tune_tir.py |   5 +
 ...sform_layout.py => test_te_transform_layout.py} |   0
 .../{test_index_map.py => test_tir_index_map.py}   |  11 +-
 .../unittest/test_tir_schedule_transform_layout.py | 112 +++++++++++++++++++--
 35 files changed, 411 insertions(+), 181 deletions(-)

diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h
index 796e04e74a..340d953ccf 100644
--- a/include/tvm/tir/index_map.h
+++ b/include/tvm/tir/index_map.h
@@ -102,8 +102,7 @@ class IndexMapNode : public Object {
    * \returns The indices in the output space.  Contains one value for
    * each expression in `final_indices`.
    */
-  Array<PrimExpr> MapIndices(const Array<PrimExpr>& indices,
-                             arith::Analyzer* analyzer = nullptr) const;
+  Array<PrimExpr> MapIndices(const Array<PrimExpr>& indices, arith::Analyzer* 
analyzer) const;
 
   /*! \brief Map a memory range to the output space
    *
@@ -121,7 +120,7 @@ class IndexMapNode : public Object {
    * \returns The ranges in the output space.  Contains one value for
    * each expression in `final_indices`.
    */
-  Array<Range> MapRanges(const Array<Range>& ranges, arith::Analyzer* analyzer 
= nullptr) const;
+  Array<Range> MapRanges(const Array<Range>& ranges, arith::Analyzer* 
analyzer) const;
 
   /*! \brief Map a buffer shape to the output space
    *
@@ -134,7 +133,7 @@ class IndexMapNode : public Object {
    * \returns The buffer shape in the output space.  Contains one
    * value for each expression in `final_indices`.
    */
-  Array<PrimExpr> MapShape(const Array<PrimExpr>& shape, arith::Analyzer* 
analyzer = nullptr) const;
+  Array<PrimExpr> MapShape(const Array<PrimExpr>& shape, arith::Analyzer* 
analyzer) const;
 
   /* \brief Map an NDArray according to this index map
    *
@@ -203,7 +202,7 @@ class IndexMap : public ObjectRef {
    * If the user has supplied an `inverse_index_map`, that map is
    * assumed to be correct and bijective, and is returned.
    */
-  IndexMap Inverse(Array<Range> initial_ranges) const;
+  IndexMap Inverse(Array<Range> initial_ranges, arith::Analyzer* analyzer) 
const;
 
   /*! \brief Rename the variables in the index map and ensure the names are 
unique.
    *
@@ -225,7 +224,8 @@ class IndexMap : public ObjectRef {
    * \return The inverted index map, along with the predicate for
    * which the inverse maps to a valid range.
    */
-  std::pair<IndexMap, PrimExpr> NonSurjectiveInverse(Array<Range> 
initial_ranges) const;
+  std::pair<IndexMap, PrimExpr> NonSurjectiveInverse(Array<Range> 
initial_ranges,
+                                                     arith::Analyzer* 
analyzer) const;
 
   TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode);
 };
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index cab3466765..d881c4f423 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -24,6 +24,7 @@
 #ifndef TVM_TOPI_TRANSFORM_H_
 #define TVM_TOPI_TRANSFORM_H_
 
+#include <tvm/arith/analyzer.h>
 #include <tvm/te/operation.h>
 #include <tvm/tir/data_layout.h>
 #include <tvm/tir/index_map.h>
@@ -1738,16 +1739,18 @@ inline Tensor auto_scheduler_layout_transform(const 
Tensor& src, const String& s
 inline Tensor meta_schedule_layout_transform(const Tensor& src, const 
tir::IndexMap& index_map,
                                              const String name = 
"T_meta_schedule_layout_trans",
                                              const String tag = kInjective) {
+  arith::Analyzer analyzer;
   Array<Range> iter_domain;
   iter_domain.reserve(src->shape.size());
   for (const PrimExpr& e : src->shape) {
     iter_domain.push_back(Range::FromMinExtent(make_zero(e->dtype), e));
   }
-  Array<PrimExpr> post_transform_shape = index_map->MapShape(src->shape);
+  Array<PrimExpr> post_transform_shape = index_map->MapShape(src->shape, 
&analyzer);
   return compute(
       post_transform_shape,
-      [src, inv = index_map.Inverse(iter_domain)](const Array<Var>& indices) 
-> PrimExpr {
-        return src(inv->MapIndices(Array<PrimExpr>{indices.begin(), 
indices.end()}));
+      [src, inv = index_map.Inverse(iter_domain, &analyzer),
+       &analyzer](const Array<Var>& indices) -> PrimExpr {
+        return src(inv->MapIndices(Array<PrimExpr>{indices.begin(), 
indices.end()}, &analyzer));
       },
       name, tag);
 }
diff --git a/pyproject.toml b/pyproject.toml
index 5cca711ddb..91740f2b4b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,6 +14,9 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+[tool.isort]
+profile = "black"
+src_paths = ["python", "tests/python"]
 
 [tool.black]
 line-length = 100
diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py
index 3dbf2cefe4..0b6abc2566 100644
--- a/python/tvm/te/schedule.py
+++ b/python/tvm/te/schedule.py
@@ -22,13 +22,12 @@ from typing import Callable, List
 
 import tvm._ffi
 from tvm._ffi.base import string_types
-
-from tvm.runtime import Object, convert
 from tvm.ir import container as _container
-from tvm.tir import IterVar, Buffer, Var, IndexMap
+from tvm.runtime import Object, convert
+from tvm.tir import Buffer, IndexMap, IterVar, Var
 
-from . import tensor as _tensor
 from . import _ffi_api
+from . import tensor as _tensor
 
 
 @tvm._ffi.register_object
@@ -600,7 +599,9 @@ class Stage(Object):
         """
 
         ndim = len(self.op.output(0).shape)
-        index_map, axis_separators = 
IndexMap.from_func_with_separators(mapping_function, ndim=ndim)
+        index_map, axis_separators = IndexMap.from_func_with_separators(
+            mapping_function, ndim=ndim, index_dtype="int32"
+        )
 
         new_iter_vars = _ffi_api.StageTransformLayout(
             self, index_map.initial_indices, index_map.final_indices
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
index 32ec347039..bd44e3f7c3 100644
--- a/python/tvm/tir/function.py
+++ b/python/tvm/tir/function.py
@@ -67,7 +67,6 @@ class PrimFunc(BaseFunc, Scriptable):
         attrs=None,
         span=None,
     ):
-
         param_list = []
         buffer_map = {} if buffer_map is None else buffer_map
         for x in params:
@@ -266,6 +265,8 @@ class IndexMap(Object):
         mapping_function: Callable,
         ndim: Optional[int] = None,
         inverse_index_map: Union[Callable, Optional["IndexMap"]] = None,
+        *,
+        index_dtype: str = "int64",
     ):
         """Create an index map from a function
 
@@ -302,7 +303,10 @@ class IndexMap(Object):
 
         """
         index_map, axis_separators = IndexMap.from_func_with_separators(
-            mapping_function, ndim, inverse_index_map
+            mapping_function,
+            ndim,
+            inverse_index_map,
+            index_dtype=index_dtype,
         )
         assert not axis_separators, (
             "The mapping_function provided to IndexMap.from_func "
@@ -316,6 +320,8 @@ class IndexMap(Object):
         mapping_function: Callable,
         ndim: Optional[int] = None,
         inverse_index_map: Union[Callable, Optional["IndexMap"]] = None,
+        *,
+        index_dtype: str = "int64",
     ):
         """Create an index map from a function
 
@@ -346,6 +352,9 @@ class IndexMap(Object):
             It is the user's responsibility to ensure the correctness of the 
pre-defined inverse
             index map.
 
+        index_dtype : str
+            The default index dtype to use for input iters in the mapping 
function.
+
         Returns
         -------
         ret: Tuple[IndexMap, List[int]]
@@ -361,20 +370,19 @@ class IndexMap(Object):
         args = []
         var_arg_name = None
         kwargs = collections.OrderedDict()
-        default_index_dtype = "int32"
 
         for name, param in params.items():
             if param.kind in [
                 inspect.Parameter.POSITIONAL_ONLY,
                 inspect.Parameter.POSITIONAL_OR_KEYWORD,
             ]:
-                args.append(tvm.tir.Var(name, default_index_dtype))
+                args.append(tvm.tir.Var(name, index_dtype))
 
             elif param.kind == inspect.Parameter.VAR_POSITIONAL:
                 var_arg_name = name
 
             elif param.kind == inspect.Parameter.KEYWORD_ONLY:
-                kwargs[name] = tvm.tir.Var(name, default_index_dtype)
+                kwargs[name] = tvm.tir.Var(name, index_dtype)
 
             else:
                 raise ValueError("transform_layout mapping may not have *args")
@@ -386,7 +394,7 @@ class IndexMap(Object):
             assert ndim is not None, "ndim must be specified when *args is 
used"
             num_var_args = ndim - len(args) - len(kwargs)
             for i in range(num_var_args):
-                args.append(tvm.tir.Var(f"{var_arg_name}_{i}", 
default_index_dtype))
+                args.append(tvm.tir.Var(f"{var_arg_name}_{i}", index_dtype))
 
         mapping = mapping_function(*args, **kwargs)
 
diff --git a/python/tvm/tir/schedule/schedule.py 
b/python/tvm/tir/schedule/schedule.py
index 4f717474e0..6c42f15a2f 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -93,6 +93,15 @@ def _parse_seed(seed: Optional[int]) -> int:
     return seed
 
 
+def _get_block_default_dtype(block: Block) -> str:
+    for i in block.iter_vars:
+        return i.var.dtype
+    for buffer_region in list(block.reads) + list(block.writes):
+        for dom in buffer_region.region:
+            return dom.min.dtype
+    return "int64"
+
+
 @_register_object("tir.Schedule")
 class Schedule(Object):
     """The user-facing schedule class
@@ -1492,7 +1501,10 @@ class Schedule(Object):
         block = self._normalize_block_arg(block)
 
         if callable(index_map):
-            index_map = IndexMap.from_func(index_map)
+            index_map = IndexMap.from_func(
+                index_map,
+                index_dtype=_get_block_default_dtype(self.get(block)),
+            )
         return _ffi_api.ScheduleReindexCacheRead(  # type: ignore # pylint: 
disable=no-member
             self, block, read_buffer_index, storage_scope, index_map
         )
@@ -1589,7 +1601,10 @@ class Schedule(Object):
         block = self._normalize_block_arg(block)
 
         if callable(index_map):
-            index_map = IndexMap.from_func(index_map)
+            index_map = IndexMap.from_func(
+                index_map,
+                index_dtype=_get_block_default_dtype(self.get(block)),
+            )
         return _ffi_api.ScheduleReindexCacheWrite(  # type: ignore # pylint: 
disable=no-member
             self, block, write_buffer_index, storage_scope, index_map
         )
@@ -3246,14 +3261,22 @@ class Schedule(Object):
 
         ndim = len(buffer_obj.shape)
         if callable(index_map):
-            index_map, axis_separators = 
IndexMap.from_func_with_separators(index_map, ndim=ndim)
+            index_map, axis_separators = IndexMap.from_func_with_separators(
+                index_map,
+                ndim=ndim,
+                index_dtype=_get_block_default_dtype(self.get(block)),
+            )
         else:
             axis_separators = []
 
         if pad_value is None:
             pass
         elif callable(pad_value):
-            pad_value = IndexMap.from_func(pad_value, 
ndim=len(index_map.final_indices))
+            pad_value = IndexMap.from_func(
+                pad_value,
+                ndim=len(index_map.final_indices),
+                index_dtype=_get_block_default_dtype(self.get(block)),
+            )
         elif not isinstance(pad_value, IndexMap):
             # Explicitly convert python int/float arguments to the
             # buffer's type.  If the default `tvm.runtime.convert`
@@ -3264,7 +3287,9 @@ class Schedule(Object):
             elif "float" in buffer_obj.dtype and isinstance(pad_value, float):
                 pad_value = FloatImm(buffer_obj.dtype, pad_value)
             pad_value = IndexMap.from_func(
-                lambda *indices: pad_value, ndim=len(index_map.final_indices)
+                lambda *indices: pad_value,
+                ndim=len(index_map.final_indices),
+                index_dtype=_get_block_default_dtype(self.get(block)),
             )
 
         buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
@@ -3337,7 +3362,10 @@ class Schedule(Object):
         """
         block = self._normalize_block_arg(block)
         if callable(index_map):
-            index_map = IndexMap.from_func(index_map)
+            index_map = IndexMap.from_func(
+                index_map,
+                index_dtype=_get_block_default_dtype(self.get(block)),
+            )
         _ffi_api.ScheduleTransformBlockLayout(  # type: ignore # pylint: 
disable=no-member
             self, block, index_map
         )
diff --git a/python/tvm/tir/schedule/testing.py 
b/python/tvm/tir/schedule/testing.py
index f38a657123..a293b54b46 100644
--- a/python/tvm/tir/schedule/testing.py
+++ b/python/tvm/tir/schedule/testing.py
@@ -51,6 +51,8 @@ def verify_trace_roundtrip(
         The text format or formats whose round-trip behavior should be
         validated.  If a single string, validate round-trips through
     """
+    from tvm.script import tir as T  # pylint: disable=import-outside-toplevel
+
     if not isinstance(text_format, str):
         for opt in text_format:
             new_sch = verify_trace_roundtrip(sch, mod, debug_mask=debug_mask, 
text_format=opt)
@@ -66,7 +68,9 @@ def verify_trace_roundtrip(
         Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch)
     elif text_format == "python":
         py_trace = "\n".join(trace.as_python())
-        exec(py_trace, tvm.tir.__dict__, {"sch": new_sch})  # pylint: 
disable=exec-used
+        vars_dict = {"T": T}
+        vars_dict.update(tvm.tir.__dict__)
+        exec(py_trace, vars_dict, {"sch": new_sch})  # pylint: 
disable=exec-used
     else:
         assert text_format in ("json", "python"), f"Unknown text format: 
{text_format}"
 
diff --git a/src/arith/ir_mutator_with_analyzer.cc 
b/src/arith/ir_mutator_with_analyzer.cc
index 1f087d9934..2ee427beb8 100644
--- a/src/arith/ir_mutator_with_analyzer.cc
+++ b/src/arith/ir_mutator_with_analyzer.cc
@@ -22,6 +22,7 @@
  */
 #include "ir_mutator_with_analyzer.h"
 
+#include <tvm/arith/iter_affine_map.h>
 #include <tvm/tir/analysis.h>
 #include <tvm/tir/op.h>
 
@@ -39,6 +40,25 @@ void IRMutatorWithAnalyzer::MarkBufferMapShapes(const 
tir::PrimFunc& func) {
   }
 }
 
+Array<PrimExpr> IRMutatorWithAnalyzer::IterMapSimplifyWithContext(const 
Array<PrimExpr>& indices,
+                                                                  bool 
non_trivial_only) {
+  PrimExpr pred = const_true();
+  for (PrimExpr val : iter_predicates_) {
+    pred = pred && val;
+  }
+  int n = indices.size();
+  Array<PrimExpr> simplified = arith::IterMapSimplify(
+      indices, this->iter_vars_, pred, arith::IterMapLevel::Surjective, 
this->analyzer_);
+  if (non_trivial_only) {
+    for (int i = 0; i < n; ++i) {
+      if (simplified[i]->IsInstance<IntImmNode>() && 
indices[i]->IsInstance<VarNode>()) {
+        simplified.Set(i, indices[i]);
+      }
+    }
+  }
+  return simplified;
+}
+
 Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) {
   // record the loop variable as iterators
   Range dom = Range::FromMinExtent(op->min, op->extent);
diff --git a/src/arith/ir_mutator_with_analyzer.h 
b/src/arith/ir_mutator_with_analyzer.h
index f04b40e7ae..fb01fd19ce 100644
--- a/src/arith/ir_mutator_with_analyzer.h
+++ b/src/arith/ir_mutator_with_analyzer.h
@@ -70,6 +70,12 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
    */
   void MarkBufferMapShapes(const tir::PrimFunc& func);
 
+  /*!
+   * \brief Use internal bound information to perform inter map simplification 
of indices.
+   * \note Only do this during layout remapping
+   */
+  Array<PrimExpr> IterMapSimplifyWithContext(const Array<PrimExpr>& indices, 
bool non_trivial_only);
+
   /*! \brief internal analyzer field. */
   Analyzer* analyzer_;
   // the following two fields are useful in case we want
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index f8a36daf53..af1128aa27 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -2061,6 +2061,12 @@ class IterMapToExprNormalizer : public ExprMutator {
     if (analyzer_->CanProve(expr->extent == expr->source->extent) && 
is_one(expr->lower_factor)) {
       return source * expr->scale;
     } else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor 
* expr->extent)) {
+      // Simplify if `expr` is always 0. The 2nd condition guarantess that we 
do not aggressively
+      // simplify trivial iters like `vi \in [0, 1)`, which can be useful for 
subsequent analysis
+      // like tensorization.
+      if (is_one(expr->extent) && !is_one(expr->source->extent)) {
+        return make_const(expr->extent->dtype, 0);
+      }
       return floordiv(source, expr->lower_factor) * expr->scale;
     } else {
       return floordiv(floormod(source, expr->lower_factor * expr->extent), 
expr->lower_factor) *
diff --git a/src/relay/backend/te_compiler_cache.cc 
b/src/relay/backend/te_compiler_cache.cc
index 275d1b6bf7..b747855bff 100644
--- a/src/relay/backend/te_compiler_cache.cc
+++ b/src/relay/backend/te_compiler_cache.cc
@@ -19,6 +19,7 @@
 
 #include "./te_compiler_cache.h"
 
+#include <tvm/arith/analyzer.h>
 #include <tvm/driver/driver_api.h>
 #include <tvm/ir/name_supply.h>
 #include <tvm/ir/type_functor.h>
@@ -594,7 +595,8 @@ class ScheduleBuilder : public ExprVisitor {
                       src_size_1d *= c->shape[i];
                       
orig_shape.push_back(PrimExpr(static_cast<int>((c->shape[i]))));
                     }
-                    auto dst_shape = index_map->MapShape(orig_shape);
+                    arith::Analyzer analyzer;
+                    auto dst_shape = index_map->MapShape(orig_shape, 
&analyzer);
                     std::vector<int64_t> dst_shape_int;
                     size_t dst_size_1d = 1;
                     for (size_t i = 0; i < dst_shape.size(); ++i) {
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index a0111ff7cd..fde6daa4d8 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -23,6 +23,7 @@
  */
 #include "transform.h"
 
+#include <tvm/arith/analyzer.h>
 #include <tvm/relay/attrs/transform.h>
 #include <tvm/relay/error.h>
 #include <tvm/relay/expr.h>
@@ -3434,9 +3435,10 @@ Array<te::Tensor> 
MetaScheduleLayoutTransformCompute(const Attrs& attrs,
 bool MetaScheduleLayoutTransformRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
                                     const TypeReporter& reporter) {
   TensorType data_type = Downcast<TensorType>(types[0]);
+  arith::Analyzer analyzer;
   const MetaScheduleLayoutTransformAttrs* params = 
attrs.as<MetaScheduleLayoutTransformAttrs>();
   ICHECK(params);
-  Array<PrimExpr> new_shape = params->index_map->MapShape(data_type->shape);
+  Array<PrimExpr> new_shape = params->index_map->MapShape(data_type->shape, 
&analyzer);
   reporter->Assign(types[1], TensorType(new_shape, data_type->dtype));
   return true;
 }
diff --git a/src/runtime/logging.cc b/src/runtime/logging.cc
index 5e7431e510..04b25f764c 100644
--- a/src/runtime/logging.cc
+++ b/src/runtime/logging.cc
@@ -130,6 +130,9 @@ int BacktraceFullCallback(void* data, uintptr_t pc, const 
char* filename, int li
     backtrace_syminfo(_bt_state, pc, BacktraceSyminfoCallback, 
BacktraceErrorCallback,
                       symbol_str.get());
   }
+  if (filename == nullptr && strstr(symbol_str.get()->data(), "ffi_call_")) {
+    return 0;
+  }
   s << *symbol_str;
 
   if (filename != nullptr) {
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index cd0ec0e34f..22103f7b0f 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -898,8 +898,9 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, 
std::ostream& os) {
         
runtime::Registry::Get("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout");
     ICHECK(index_map_func);
 
+    arith::Analyzer analyzer;
     auto inverse_index_map =
-        IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0, 
n)});
+        IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0, 
n)}, &analyzer);
     auto indices_16x16 = inverse_index_map->final_indices;
 
     // "//" and "%" in the index map are translated to FloorDiv/Mod, but the 
plain Div/Mod are fine.
diff --git a/src/te/schedule/message_passing.cc 
b/src/te/schedule/message_passing.cc
index 7041b751c5..233663feac 100644
--- a/src/te/schedule/message_passing.cc
+++ b/src/te/schedule/message_passing.cc
@@ -193,7 +193,7 @@ void PassDownDomain(const Stage& stage, 
std::unordered_map<IterVar, Range>* p_st
       for (const auto& iter_var : s->original_variables) {
         original_ranges.push_back(state[iter_var]);
       }
-      Array<Range> updated_ranges = 
s->forward_transformation->MapRanges(original_ranges);
+      Array<Range> updated_ranges = 
s->forward_transformation->MapRanges(original_ranges, actx);
 
       ICHECK_EQ(updated_ranges.size(), s->transformed_variables.size());
       for (size_t i = 0; i < updated_ranges.size(); i++) {
@@ -269,6 +269,7 @@ void PassUpIndex(const Stage& stage, const Map<IterVar, 
Range>& dom_map,
       }
     } else if (rel.as<SingletonNode>()) {
     } else if (const TransformNode* s = rel.as<TransformNode>()) {
+      arith::Analyzer analyzer;
       bool missing_transformed = false;
       for (const auto& iter_var : s->transformed_variables) {
         if (!state.count(iter_var)) {
@@ -284,7 +285,8 @@ void PassUpIndex(const Stage& stage, const Map<IterVar, 
Range>& dom_map,
       for (const auto& iter_var : s->transformed_variables) {
         transformed_indices.push_back(state[iter_var]);
       }
-      Array<PrimExpr> original_indices = 
s->inverse_transformation->MapIndices(transformed_indices);
+      Array<PrimExpr> original_indices =
+          s->inverse_transformation->MapIndices(transformed_indices, 
&analyzer);
 
       ICHECK_EQ(original_indices.size(), s->original_variables.size());
       for (size_t i = 0; i < original_indices.size(); i++) {
@@ -352,7 +354,9 @@ void PassDownIndex(const Stage& stage, const Map<IterVar, 
Range>& dom_map,
       for (const auto& iter_var : s->original_variables) {
         original_indices.push_back(state[iter_var]);
       }
-      Array<PrimExpr> transformed_indices = 
s->forward_transformation->MapIndices(original_indices);
+      arith::Analyzer analyzer;
+      Array<PrimExpr> transformed_indices =
+          s->forward_transformation->MapIndices(original_indices, &analyzer);
 
       ICHECK_EQ(transformed_indices.size(), s->transformed_variables.size());
       for (size_t i = 0; i < transformed_indices.size(); i++) {
@@ -449,7 +453,9 @@ Array<IntSet> PassUpDomain(const TransformNode* s,
     transformed_indices.push_back(iter_var->var);
   }
 
-  Array<PrimExpr> transformed_exprs = 
s->inverse_transformation->MapIndices(transformed_indices);
+  arith::Analyzer analyzer;
+  Array<PrimExpr> transformed_exprs =
+      s->inverse_transformation->MapIndices(transformed_indices, &analyzer);
 
   ICHECK_EQ(transformed_exprs.size(), s->original_variables.size());
   for (size_t i = 0; i < transformed_exprs.size(); i++) {
diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc
index 56fe0cfc65..44e742eee4 100644
--- a/src/te/schedule/schedule_lang.cc
+++ b/src/te/schedule/schedule_lang.cc
@@ -21,6 +21,7 @@
  * \file schedule_lang.cc
  */
 #include <dmlc/thread_local.h>
+#include <tvm/arith/analyzer.h>
 #include <tvm/ir/transform.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
@@ -491,10 +492,11 @@ Stage& Stage::transform_layout(const Array<Var>& 
initial_indices,
   for (const auto& iter_var : compute->axis) {
     initial_ranges.push_back(iter_var->dom);
   }
-  Array<Range> final_ranges = map->MapRanges(initial_ranges);
+  arith::Analyzer analyzer;
+  Array<Range> final_ranges = map->MapRanges(initial_ranges, &analyzer);
 
   // Make IterVar objects to represent the new iterations.
-  auto inverse = map.Inverse(initial_ranges);
+  auto inverse = map.Inverse(initial_ranges, &analyzer);
   Array<IterVar> final_indices_iter;
   ICHECK_EQ(inverse->initial_indices.size(), final_ranges.size());
   for (size_t i = 0; i < inverse->initial_indices.size(); i++) {
diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc
index a39149ceba..149e4cecd4 100644
--- a/src/tir/ir/index_map.cc
+++ b/src/tir/ir/index_map.cc
@@ -55,7 +55,9 @@ IndexMap IndexMap::FromFunc(int ndim, 
runtime::TypedPackedFunc<Array<PrimExpr>(A
 
 std::pair<IndexMap, PrimExpr> IndexMapInverseImpl(const IndexMap& self,
                                                   const Array<Range>& 
initial_ranges,
-                                                  arith::IterMapLevel 
check_level) {
+                                                  arith::IterMapLevel 
check_level,
+                                                  arith::Analyzer* analyzer) {
+  ICHECK(analyzer != nullptr);
   if (self->inverse_index_map.defined()) {
     // return the pre-defined inverse index map if exists.  In this
     // case, the user-defined inverse is assumed to be correct and
@@ -88,9 +90,8 @@ std::pair<IndexMap, PrimExpr> IndexMapInverseImpl(const 
IndexMap& self,
 
   // Unpack the output indices into linear combinations of the initial
   // indices.
-  arith::Analyzer analyzer;
-  auto padded_iter_map = DetectIterMap(self->final_indices, input_iters, /* 
predicate = */ 1,
-                                       /*check_level=*/check_level, &analyzer,
+  auto padded_iter_map = DetectIterMap(self->final_indices, input_iters, 
/*predicate=*/1,
+                                       /*check_level=*/check_level, analyzer,
                                        /*simplify_trivial_iterators=*/false);
   CHECK(padded_iter_map->errors.empty()) << "Could not parse mapping as sum of 
iterators.  "
                                          << "Error: " << 
padded_iter_map->errors[0];
@@ -110,15 +111,15 @@ std::pair<IndexMap, PrimExpr> IndexMapInverseImpl(const 
IndexMap& self,
     } else {
       expr = inverse_exprs_map.at(index);
     }
-    inverse_exprs.push_back(analyzer.Simplify(expr));
+    inverse_exprs.push_back(analyzer->Simplify(expr));
   }
 
   PrimExpr padding_predicate = padded_iter_map->padding_predicate;
   padding_predicate = arith::NormalizeIterMapToExpr(padding_predicate);
   padding_predicate = Substitute(padding_predicate, inverse_exprs_map);
 
+  auto output_ranges = self->MapRanges(initial_ranges, analyzer);
   {
-    auto output_ranges = self->MapRanges(initial_ranges);
     ICHECK_EQ(output_ranges.size(), output_vars.size());
 
     arith::Analyzer analyzer;
@@ -133,15 +134,17 @@ std::pair<IndexMap, PrimExpr> IndexMapInverseImpl(const 
IndexMap& self,
   return {IndexMap(output_vars, inverse_exprs), padding_predicate};
 }
 
-std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> 
initial_ranges) const {
-  return IndexMapInverseImpl(*this, initial_ranges, 
arith::IterMapLevel::NoCheck);
+std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> 
initial_ranges,
+                                                             arith::Analyzer* 
analyzer) const {
+  ICHECK(analyzer != nullptr);
+  return IndexMapInverseImpl(*this, initial_ranges, 
arith::IterMapLevel::NoCheck, analyzer);
 }
 
-IndexMap IndexMap::Inverse(Array<Range> initial_ranges) const {
+IndexMap IndexMap::Inverse(Array<Range> initial_ranges, arith::Analyzer* 
analyzer) const {
+  ICHECK(analyzer != nullptr);
   auto [inverse, padding_predicate] =
-      IndexMapInverseImpl(*this, initial_ranges, 
arith::IterMapLevel::Bijective);
-  arith::Analyzer analyzer;
-  CHECK(analyzer.CanProve(!padding_predicate))
+      IndexMapInverseImpl(*this, initial_ranges, 
arith::IterMapLevel::Bijective, analyzer);
+  CHECK(analyzer->CanProve(!padding_predicate))
       << "Bijective inverse should not contain padding, but inverse of " << 
*this << " over range "
       << initial_ranges << " resulted in a padding predicate of " << 
padding_predicate;
   return inverse;
@@ -149,6 +152,7 @@ IndexMap IndexMap::Inverse(Array<Range> initial_ranges) 
const {
 
 Array<PrimExpr> IndexMapNode::MapIndices(const Array<PrimExpr>& indices,
                                          arith::Analyzer* analyzer) const {
+  ICHECK(analyzer != nullptr);
   ICHECK_EQ(indices.size(), initial_indices.size());
 
   Map<Var, PrimExpr> vmap;
@@ -157,11 +161,6 @@ Array<PrimExpr> IndexMapNode::MapIndices(const 
Array<PrimExpr>& indices,
     vmap.Set(initial_indices[i], indices[i]);
   }
 
-  arith::Analyzer local_analyzer;
-  if (!analyzer) {
-    analyzer = &local_analyzer;
-  }
-
   Array<PrimExpr> output = final_indices.Map([&](PrimExpr index) {
     PrimExpr result = SubstituteWithDataTypeLegalization(
         std::move(index), [&](const Var& var) { return vmap.Get(var); });
@@ -171,18 +170,13 @@ Array<PrimExpr> IndexMapNode::MapIndices(const 
Array<PrimExpr>& indices,
 }
 
 Array<Range> IndexMapNode::MapRanges(const Array<Range>& ranges, 
arith::Analyzer* analyzer) const {
+  ICHECK(analyzer != nullptr);
   ICHECK_EQ(ranges.size(), initial_indices.size());
 
   Map<Var, Range> input_iters;
   for (size_t i = 0; i < initial_indices.size(); i++) {
     input_iters.Set(initial_indices[i], ranges[i]);
   }
-
-  arith::Analyzer local_analyzer;
-  if (!analyzer) {
-    analyzer = &local_analyzer;
-  }
-
   auto iter_map = DetectIterMap(final_indices, input_iters, /* predicate = */ 
1,
                                 /*check_level=*/arith::IterMapLevel::NoCheck, 
analyzer,
                                 /*simplify_trivial_iterators=*/false);
@@ -240,6 +234,7 @@ Array<Range> IndexMapNode::MapRanges(const Array<Range>& 
ranges, arith::Analyzer
 
 Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape,
                                        arith::Analyzer* analyzer) const {
+  ICHECK(analyzer != nullptr);
   ICHECK_EQ(shape.size(), initial_indices.size());
 
   Array<Range> ranges;
@@ -258,6 +253,7 @@ Array<PrimExpr> IndexMapNode::MapShape(const 
Array<PrimExpr>& shape,
 }
 
 runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray arr_src) const {
+  arith::Analyzer analyzer;
   auto shape = arr_src.Shape();
   ICHECK(shape.size() == initial_indices.size())
       << "The rank of the input array should be " << initial_indices.size() << 
" but got "
@@ -268,7 +264,7 @@ runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray 
arr_src) const {
     size_1d *= shape[i];
     orig_shape.push_back(PrimExpr(static_cast<int>((shape[i]))));
   }
-  auto dst_shape = MapShape(orig_shape);
+  auto dst_shape = MapShape(orig_shape, &analyzer);
 
   std::vector<int64_t> dst_shape_int;
   for (size_t i = 0; i < dst_shape.size(); ++i) {
@@ -292,7 +288,7 @@ runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray 
arr_src) const {
       src_indices.push_back(PrimExpr(static_cast<int>((src_linear_index / 
div_factor))));
       src_linear_index %= div_factor;
     }
-    auto dst_indices = MapIndices(src_indices);
+    auto dst_indices = MapIndices(src_indices, &analyzer);
 
     // Convert an N-d coordinate to a linear coordinate
     // (z, y, x) -> z * height * width + y * width + x
@@ -430,19 +426,29 @@ TVM_REGISTER_GLOBAL("tir.IndexMap")
     });
 
 TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices")
-    .set_body_typed([](IndexMap map, Array<PrimExpr> indices) { return 
map->MapIndices(indices); });
+    .set_body_typed([](IndexMap map, Array<PrimExpr> indices) {
+      arith::Analyzer analyzer;
+      return map->MapIndices(indices, &analyzer);
+    });
 
 TVM_REGISTER_GLOBAL("tir.IndexMapMapShape").set_body_typed([](IndexMap map, 
Array<PrimExpr> shape) {
-  return map->MapShape(shape);
+  arith::Analyzer analyzer;
+  return map->MapShape(shape, &analyzer);
 });
-TVM_REGISTER_GLOBAL("tir.IndexMapInverse").set_body_method(&IndexMap::Inverse);
+
+TVM_REGISTER_GLOBAL("tir.IndexMapInverse")
+    .set_body_typed([](IndexMap map, Array<Range> initial_ranges) {
+      arith::Analyzer analyzer;
+      return map.Inverse(initial_ranges, &analyzer);
+    });
 
 TVM_REGISTER_GLOBAL("tir.IndexMapMapNDArray")
     .set_body_typed([](IndexMap map, runtime::NDArray arr) { return 
map->MapNDArray(arr); });
 
 TVM_REGISTER_GLOBAL("tir.IndexMapNonSurjectiveInverse")
     .set_body_typed([](IndexMap forward, Array<Range> initial_ranges) {
-      auto result = forward.NonSurjectiveInverse(initial_ranges);
+      arith::Analyzer analyzer;
+      auto result = forward.NonSurjectiveInverse(initial_ranges, &analyzer);
       return Array<ObjectRef>{result.first, result.second};
     });
 
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 6e361cbe05..868fbe8563 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -72,6 +72,16 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, 
const StmtNode* root_bl
  */
 StmtSRef GetSRefTreeRoot(const StmtSRef& sref);
 
+/*!
+ * \brief Given an arbitrary sref, bind the shape var info of the PrimFunc it 
belongs to to the
+ * given analyzer
+ * \param state The schedule state
+ * \param sref The given sref
+ * \param analyzer The analyzer to be bound
+ */
+void AddShapeVarBounds(const ScheduleState& state, const StmtSRefNode* sref,
+                       arith::Analyzer* analyzer);
+
 /******** Scope ********/
 /*!
  * \brief Checks if scope the specified sref is in is a stage-pipeline and 
return it
diff --git a/src/tir/schedule/analysis/analysis.cc 
b/src/tir/schedule/analysis/analysis.cc
index a765eb4d9f..17cc39d5cb 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -1269,6 +1269,20 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref) {
   return GetRef<StmtSRef>(p);
 }
 
+void AddShapeVarBounds(const ScheduleState& state, const StmtSRefNode* sref,
+                       arith::Analyzer* analyzer) {
+  while (sref->parent != nullptr) {
+    sref = sref->parent;
+  }
+  const PrimFuncNode* f = GetRootPrimFunc(state->mod, sref->stmt, nullptr);
+  for (const auto& kv : f->buffer_map) {
+    const Buffer& buffer = kv.second;
+    for (const PrimExpr& e : buffer->shape) {
+      analyzer->MarkGlobalNonNegValue(e);
+    }
+  }
+}
+
 /******** Misc ********/
 
 bool HasOp(const Stmt& stmt, const Array<Op>& ops) {
diff --git a/src/tir/schedule/primitive/cache_read_write.cc 
b/src/tir/schedule/primitive/cache_read_write.cc
index 6f9aa11275..3fbdf856b5 100644
--- a/src/tir/schedule/primitive/cache_read_write.cc
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -1909,13 +1909,14 @@ void CollectReindexCacheStageInfoAndCreateBuffer(
     ReindexCacheStageInfo* info, const IRModule& mod, const StmtSRef& 
block_sref,
     const String& storage_scope, const IndexMap& index_map, const Block& block,
     const BlockRealize& realize, const Buffer& old_buffer, const BufferRegion& 
cache_region) {
+  arith::Analyzer analyzer;
   Array<PrimExpr> block_iter_vars, block_shape;
   for (const IterVar& iter_var : block->iter_vars) {
     block_iter_vars.push_back(iter_var);
     block_shape.push_back(iter_var->dom->extent);
   }
-  Array<PrimExpr> new_indices = index_map->MapIndices(block_iter_vars);
-  Array<PrimExpr> new_shape = index_map->MapShape(block_shape);
+  Array<PrimExpr> new_indices = index_map->MapIndices(block_iter_vars, 
&analyzer);
+  Array<PrimExpr> new_shape = index_map->MapShape(block_shape, &analyzer);
   info->indices = new_indices;
 
   // Step 5. Update CacheTouchedInfo
@@ -1926,8 +1927,6 @@ void CollectReindexCacheStageInfoAndCreateBuffer(
     old_indices.push_back(range->min);
   }
 
-  arith::Analyzer analyzer;
-
   VarUseDefAnalyzer collector_new(/*defined_vars=*/{});
   for (const PrimExpr& idx : new_indices) {
     collector_new(idx);
diff --git a/src/tir/schedule/primitive/compute_at.cc 
b/src/tir/schedule/primitive/compute_at.cc
index 45d0c81050..fc388b0048 100644
--- a/src/tir/schedule/primitive/compute_at.cc
+++ b/src/tir/schedule/primitive/compute_at.cc
@@ -680,20 +680,6 @@ void CalculateProvidedRequiredRegions(
 
 /******** Main Implementation ********/
 
-void AddShapeVarBounds(const ScheduleState& state, const StmtSRefNode* sref,
-                       arith::Analyzer* analyzer) {
-  while (sref->parent != nullptr) {
-    sref = sref->parent;
-  }
-  const PrimFuncNode* f = GetRootPrimFunc(state->mod, sref->stmt, nullptr);
-  for (const auto& kv : f->buffer_map) {
-    const Buffer& buffer = kv.second;
-    for (const PrimExpr& e : buffer->shape) {
-      analyzer->MarkGlobalNonNegValue(e);
-    }
-  }
-}
-
 template <bool is_compute_at>
 void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& 
block_sref,
                                      const StmtSRef& loop_sref, bool 
preserve_unit_loops,
diff --git a/src/tir/schedule/primitive/layout_transformation.cc 
b/src/tir/schedule/primitive/layout_transformation.cc
index bb2abc559d..d9a9f3cfed 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -17,6 +17,7 @@
  * under the License.
  */
 
+#include <tvm/arith/analyzer.h>
 #include <tvm/node/node.h>
 
 #include <optional>
@@ -93,12 +94,12 @@ class TransformLayoutPlanner : private StmtExprVisitor {
 
   static TransformPlan Plan(Block block, Buffer old_buffer, Buffer new_buffer, 
IndexMap index_map,
                             IndexMap inverse, PrimExpr padding_predicate,
-                            Optional<IndexMap> pad_value) {
+                            Optional<IndexMap> pad_value, arith::Analyzer* 
analyzer) {
     ICHECK(!pad_value.defined() || pad_value.value()->final_indices.size() == 
1)
         << "Internal error: Should be caught by ScheduleError checks prior to 
this point";
     TransformLayoutPlanner visitor(old_buffer);
     visitor(block);
-    return visitor.Finalize(new_buffer, index_map, inverse, padding_predicate, 
pad_value);
+    return visitor.Finalize(new_buffer, index_map, inverse, padding_predicate, 
pad_value, analyzer);
   }
 
  private:
@@ -220,14 +221,15 @@ class TransformLayoutPlanner : private StmtExprVisitor {
    public:
     BufferStoreReplacer(const WriteInfo& info, const Buffer& new_buffer, 
PrimExpr padding_predicate,
                         const IndexMap& inverse, const Optional<IndexMap>& 
pad_value,
-                        Map<Block, Block>* new_block_to_old)
+                        Map<Block, Block>* new_block_to_old, arith::Analyzer* 
analyzer)
         : info(info),
           new_buffer(new_buffer),
           new_indices(inverse->initial_indices),
           padding_predicate(padding_predicate),
           inverse(inverse),
           pad_value(pad_value),
-          new_block_to_old(*new_block_to_old) {
+          new_block_to_old(*new_block_to_old),
+          analyzer(analyzer) {
       ICHECK_EQ(info.dependent_loopnest.size(), inverse->final_indices.size());
       for (size_t i = 0; i < info.dependent_loopnest.size(); i++) {
         Var var = info.dependent_loopnest[i]->loop_var;
@@ -353,7 +355,7 @@ class TransformLayoutPlanner : private StmtExprVisitor {
       if (can_replace) {
         Array<PrimExpr> new_index_exprs =
             new_indices.Map([](const auto& var) -> PrimExpr { return var; });
-        PrimExpr pad_value_at_index = 
pad_value.value()->MapIndices(new_index_exprs)[0];
+        PrimExpr pad_value_at_index = 
pad_value.value()->MapIndices(new_index_exprs, analyzer)[0];
         store =
             BufferStore(new_buffer, if_then_else(padding_predicate, 
pad_value_at_index, op->value),
                         new_index_exprs);
@@ -429,22 +431,24 @@ class TransformLayoutPlanner : private StmtExprVisitor {
     const Optional<IndexMap>& pad_value;
     Map<Block, Block>& new_block_to_old;
     bool all_stores_replaced{true};
+    arith::Analyzer* analyzer;
 
     Map<Var, PrimExpr> var_remap;
   };
 
   TransformPlan Finalize(Buffer new_buffer, IndexMap index_map, IndexMap 
inverse,
-                         PrimExpr padding_predicate, Optional<IndexMap> 
pad_value) const {
-    if (auto prologue_plan =
-            FinalizeProloguePlan(new_buffer, index_map, inverse, 
padding_predicate, pad_value);
+                         PrimExpr padding_predicate, Optional<IndexMap> 
pad_value,
+                         arith::Analyzer* analyzer) const {
+    if (auto prologue_plan = FinalizeProloguePlan(new_buffer, index_map, 
inverse, padding_predicate,
+                                                  pad_value, analyzer);
         prologue_plan.has_value()) {
       return prologue_plan.value();
-    } else if (auto replacement_plan = FinalizeReplacementPlan(new_buffer, 
index_map, inverse,
-                                                               
padding_predicate, pad_value);
+    } else if (auto replacement_plan = FinalizeReplacementPlan(
+                   new_buffer, index_map, inverse, padding_predicate, 
pad_value, analyzer);
                replacement_plan.has_value()) {
       return replacement_plan.value();
     } else if (auto epilogue_plan = FinalizeEpiloguePlan(new_buffer, 
index_map, inverse,
-                                                         padding_predicate, 
pad_value);
+                                                         padding_predicate, 
pad_value, analyzer);
                epilogue_plan.has_value()) {
       return epilogue_plan.value();
     } else {
@@ -454,7 +458,8 @@ class TransformLayoutPlanner : private StmtExprVisitor {
 
   std::optional<ProloguePlan> FinalizeProloguePlan(Buffer new_buffer, IndexMap 
index_map,
                                                    IndexMap inverse, PrimExpr 
padding_predicate,
-                                                   Optional<IndexMap> 
pad_value) const {
+                                                   Optional<IndexMap> 
pad_value,
+                                                   arith::Analyzer* analyzer) 
const {
     if (write_info_.size() || is_zero(padding_predicate) || 
!pad_value.defined()) {
       return std::nullopt;
     }
@@ -476,7 +481,7 @@ class TransformLayoutPlanner : private StmtExprVisitor {
     }
     padding_predicate = Substitute(std::move(padding_predicate), 
loop_indices_to_block_indices);
 
-    PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices)[0];
+    PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices, 
analyzer)[0];
     PrimExpr expr = (!padding_predicate) || (BufferLoad(new_buffer, indices) 
== pad_value_at_index);
     Stmt stmt = Evaluate(Call(DataType::Bool(), builtin::assume(), {expr}));
 
@@ -498,7 +503,8 @@ class TransformLayoutPlanner : private StmtExprVisitor {
   std::optional<ReplacementPlan> FinalizeReplacementPlan(Buffer new_buffer, 
IndexMap index_map,
                                                          IndexMap inverse,
                                                          PrimExpr 
padding_predicate,
-                                                         Optional<IndexMap> 
pad_value) const {
+                                                         Optional<IndexMap> 
pad_value,
+                                                         arith::Analyzer* 
analyzer) const {
     if (write_info_.empty() || is_zero(padding_predicate) || 
!pad_value.defined()) {
       return std::nullopt;
     }
@@ -511,7 +517,7 @@ class TransformLayoutPlanner : private StmtExprVisitor {
       }
 
       BufferStoreReplacer replacer(info, new_buffer, padding_predicate, 
inverse, pad_value,
-                                   &new_block_to_old);
+                                   &new_block_to_old, analyzer);
       Stmt stmt = replacer(info.dependent_loopnest.back()->body);
       if (!replacer.is_all_stores_replaced()) {
         return NullOpt;
@@ -547,7 +553,8 @@ class TransformLayoutPlanner : private StmtExprVisitor {
 
   std::optional<EpiloguePlan> FinalizeEpiloguePlan(Buffer new_buffer, IndexMap 
index_map,
                                                    IndexMap inverse, PrimExpr 
padding_predicate,
-                                                   Optional<IndexMap> 
pad_value) const {
+                                                   Optional<IndexMap> 
pad_value,
+                                                   arith::Analyzer* analyzer) 
const {
     if (write_info_.empty() || is_zero(padding_predicate) || 
!pad_value.defined()) {
       return std::nullopt;
     }
@@ -566,7 +573,7 @@ class TransformLayoutPlanner : private StmtExprVisitor {
       iter_values.push_back(loop_var);
     }
 
-    PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices)[0];
+    PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices, 
analyzer)[0];
     Stmt stmt = BufferStore(new_buffer, pad_value_at_index, indices);
 
     std::stringstream block_name;
@@ -757,12 +764,13 @@ class TransformLayoutRewriter : private 
arith::IRMutatorWithAnalyzer {
       const Block& scope_stmt, const Buffer& old_buffer, const Buffer& 
new_buffer,
       const IndexMap& index_map, const Optional<IndexMap>& opt_inverse,
       const PrimExpr& padding_predicate, const Optional<IndexMap>& pad_value) {
-    auto plan = pad_value.defined() ? TransformLayoutPlanner::Plan(
-                                          scope_stmt, old_buffer, new_buffer, 
index_map,
-                                          opt_inverse.value(), 
padding_predicate, pad_value)
-                                    : 
TransformLayoutPlanner::NoPaddingRequired();
-
     arith::Analyzer analyzer;
+    auto plan = pad_value.defined()
+                    ? TransformLayoutPlanner::Plan(scope_stmt, old_buffer, 
new_buffer, index_map,
+                                                   opt_inverse.value(), 
padding_predicate,
+                                                   pad_value, &analyzer)
+                    : TransformLayoutPlanner::NoPaddingRequired();
+
     TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan, 
&analyzer);
     Block result = Downcast<Block>(rewriter(scope_stmt));
     if (auto plan_ptr = 
std::get_if<TransformLayoutPlanner::ProloguePlan>(&plan)) {
@@ -794,9 +802,8 @@ class TransformLayoutRewriter : private 
arith::IRMutatorWithAnalyzer {
 
   void RewriteBufferAccess(Buffer* buffer, Array<PrimExpr>* indices) {
     *buffer = new_buffer_;
-    *indices = index_map_->MapIndices(*indices);
-    (*indices).MutateByApply(
-        [&](const PrimExpr& e) { return SimplifyNonTrivialExpr(e, analyzer_); 
});
+    *indices = index_map_->MapIndices(*indices, &index_simplifier_);
+    *indices = this->IterMapSimplifyWithContext(*indices, true);
   }
 
   using Parent = arith::IRMutatorWithAnalyzer;
@@ -913,6 +920,7 @@ class TransformLayoutRewriter : private 
arith::IRMutatorWithAnalyzer {
   const TransformLayoutPlanner::TransformPlan& plan_;
   Map<Var, Buffer> buffer_data_to_buffer_;
   Map<Block, Block> new_block_to_old_;
+  arith::Analyzer index_simplifier_;
 };
 
 class BufferIsSubregionError : public ScheduleError {
@@ -1069,7 +1077,8 @@ class TransformationIntroducesPaddingError : public 
ScheduleError {
   }
 
   String DetailRenderTemplate() const final {
-    auto new_shape = index_map_->MapShape(buffer_->shape);
+    arith::Analyzer analyzer;
+    auto new_shape = index_map_->MapShape(buffer_->shape, &analyzer);
     std::ostringstream os;
     os << "The transformation " << index_map_ << " applied on buffer " << 
buffer_->name
        << " of shape " << buffer_->shape << " would result in shape " << 
new_shape
@@ -1138,6 +1147,8 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, 
const Array<PrimExpr>&
 void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int 
buffer_index,
                      BufferIndexType buffer_index_type, const IndexMap& 
index_map_orig,
                      const Optional<IndexMap>& pad_value, bool 
assume_injective_transform) {
+  arith::Analyzer analyzer;
+  AddShapeVarBounds(self, block_sref.get(), &analyzer);
   // Step 1: Input handling and error checking
   const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref);
   Buffer old_buffer =
@@ -1173,7 +1184,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& 
block_sref, int buffer_
       for (const auto& dim : old_buffer->shape) {
         region.push_back(Range::FromMinExtent(make_zero(dim.dtype()), dim));
       }
-      return index_map.NonSurjectiveInverse(region);
+      return index_map.NonSurjectiveInverse(region, &analyzer);
     }();
   }
 
@@ -1184,7 +1195,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& 
block_sref, int buffer_
 
   // Step 2: Infer the shape of the new buffer
   Buffer new_buffer = old_buffer;
-  new_buffer.CopyOnWrite()->shape = index_map->MapShape(old_buffer->shape);
+  new_buffer.CopyOnWrite()->shape = index_map->MapShape(old_buffer->shape, 
&analyzer);
 
   // Step 3: Rewrite BufferLoad/BufferStore access indices, block read/write 
regions, and block
   // alloc_buffers.
@@ -1336,6 +1347,7 @@ void TransformBlockLayout(ScheduleState self, const 
StmtSRef& block_sref,
   const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref);
   const Block& block = GetRef<Block>(block_ptr);
   arith::Analyzer analyzer;
+  AddShapeVarBounds(self, block_sref.get(), &analyzer);
 
   // Step 1: Collect outer loops and loop vars
   Array<StmtSRef> loops = GetLoops(block_sref);  // outer loops of the block
@@ -1375,8 +1387,8 @@ void TransformBlockLayout(ScheduleState self, const 
StmtSRef& block_sref,
 
   // Step 4: Apply the IndexMap to block iters.
   IndexMapNotApplicableToBlockIterError::Check(self->mod, block, index_map);
-  Array<PrimExpr> transformed_block_iters = index_map->MapIndices(block_vars);
-  Array<PrimExpr> new_block_iter_range = 
index_map->MapShape(block_iter_range_array);
+  Array<PrimExpr> transformed_block_iters = index_map->MapIndices(block_vars, 
&analyzer);
+  Array<PrimExpr> new_block_iter_range = 
index_map->MapShape(block_iter_range_array, &analyzer);
 
   // Step 5: Create the new block after transformation.
 
@@ -1408,14 +1420,13 @@ void TransformBlockLayout(ScheduleState self, const 
StmtSRef& block_sref,
     }
     IndexMap inverse_index_map{nullptr};
     try {
-      inverse_index_map = index_map.Inverse(initial_ranges);
+      inverse_index_map = index_map.Inverse(initial_ranges, &analyzer);
     } catch (...) {
       throw NotBijectiveAffineIndexMapError(self->mod, index_map);
     }
-
-    Array<PrimExpr> inversed_new_block_vars = inverse_index_map->MapIndices(
-        new_block_vars);  // old block vars written in terms of new block vars
-
+    // old block vars written in terms of new block vars
+    Array<PrimExpr> inversed_new_block_vars =
+        inverse_index_map->MapIndices(new_block_vars, &analyzer);
     for (int i = 0, n = block_vars.size(); i < n; ++i) {
       inverse_subst_map.Set(Downcast<Var>(block_vars[i]), 
inversed_new_block_vars[i]);
     }
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index e4047273b6..62662a4ce1 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -388,16 +388,17 @@ 
TVM_REGISTER_GLOBAL("tir.schedule.TileWithTensorIntrin").set_body_typed(TileWith
 /******** BlockBufferAccessSimplifier ********/
 void BlockBufferAccessSimplifier::SimplifyAccessRegion(Array<BufferRegion>* 
old_access_regions) {
   auto fmutate = [this](const BufferRegion& buffer_region) {
-    std::vector<Range> new_buffer_region;
+    Array<Range> new_buffer_region;
+    Array<PrimExpr> simplified_min;
     for (const auto& range : buffer_region->region) {
-      if (is_one(range->extent) && range->min->IsInstance<VarNode>()) {
-        new_buffer_region.push_back(Range::FromMinExtent(
-            SimplifyNonTrivialExpr(range->min, analyzer_), 
make_const(range->min.dtype(), 1)));
-      } else {
-        new_buffer_region.push_back(
-            Range::FromMinExtent(SimplifyNonTrivialExpr(range->min, analyzer_),
-                                 SimplifyNonTrivialExpr(range->extent, 
analyzer_)));
-      }
+      simplified_min.push_back(range->min);
+    }
+    simplified_min = this->IterMapSimplifyWithContext(simplified_min, true);
+    int n = buffer_region->region.size();
+    for (int i = 0; i < n; ++i) {
+      PrimExpr min = simplified_min[i];
+      PrimExpr extent = analyzer_->Simplify(buffer_region->region[i]->extent);
+      new_buffer_region.push_back(Range::FromMinExtent(min, extent));
     }
     return BufferRegion(buffer_region->buffer, new_buffer_region);
   };
@@ -405,8 +406,7 @@ void 
BlockBufferAccessSimplifier::SimplifyAccessRegion(Array<BufferRegion>* old_
 }
 
 void BlockBufferAccessSimplifier::SimplifyBufferIndices(Array<PrimExpr>* 
indices) {
-  (*indices).MutateByApply(
-      [this](const PrimExpr& expr) { return SimplifyNonTrivialExpr(expr, 
analyzer_); });
+  *indices = this->IterMapSimplifyWithContext(*indices, true);
 }
 
 Stmt BlockBufferAccessSimplifier::VisitStmt_(const BlockNode* op) {
diff --git a/src/tir/transforms/flatten_buffer.cc 
b/src/tir/transforms/flatten_buffer.cc
index f37c21593f..c04e12b839 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -220,18 +220,7 @@ class BufferFlattener : public 
arith::IRMutatorWithAnalyzer {
 
   Array<PrimExpr> GetSimplifiedElemOffset(const Buffer& buffer, const 
Array<PrimExpr>& indices) {
     auto flattened_indices = buffer->ElemOffset(indices);
-    // Use IterMapSimplify to enable constant fold of fused indices
-    // IterMapSimplify is more powerful and time-consuming than normal
-    // simplify as it tries to deal with symbolic fusion
-    //
-    // Only use to handle indices during layout transformations
-    // So we restrict the use to here
-    PrimExpr pred = const_true();
-    for (PrimExpr val : iter_predicates_) {
-      pred = pred && val;
-    }
-    return arith::IterMapSimplify(flattened_indices, this->iter_vars_, pred,
-                                  arith::IterMapLevel::Surjective, 
this->analyzer_);
+    return this->IterMapSimplifyWithContext(flattened_indices, false);
   }
 
   template <typename Node>
diff --git a/src/tir/transforms/storage_flatten.cc 
b/src/tir/transforms/storage_flatten.cc
index 8c409fba5e..9c12448381 100644
--- a/src/tir/transforms/storage_flatten.cc
+++ b/src/tir/transforms/storage_flatten.cc
@@ -1265,7 +1265,7 @@ class ApplyLayoutTransforms : public StmtExprMutator {
 
       Array<IndexMap> transforms = lookup.value();
       for (const auto& transform : transforms) {
-        write_ptr->bounds = transform->MapRanges(realize->bounds);
+        write_ptr->bounds = transform->MapRanges(realize->bounds, &analyzer);
       }
     }
 
@@ -1292,7 +1292,7 @@ class ApplyLayoutTransforms : public StmtExprMutator {
 
       Array<IndexMap> transforms = lookup.value();
       for (const auto& transform : transforms) {
-        write_ptr->indices = transform->MapIndices(node->indices);
+        write_ptr->indices = transform->MapIndices(node->indices, &analyzer);
       }
     }
     return node;
@@ -1315,7 +1315,7 @@ class ApplyLayoutTransforms : public StmtExprMutator {
 
       auto write_ptr = buf.CopyOnWrite();
       for (const auto& transform : transforms) {
-        write_ptr->shape = transform->MapShape(buf->shape);
+        write_ptr->shape = transform->MapShape(buf->shape, &analyzer);
       }
     }
 
@@ -1326,6 +1326,7 @@ class ApplyLayoutTransforms : public StmtExprMutator {
   std::unordered_map<const BufferNode*, Buffer> buf_map_;
 
   Map<Buffer, Array<IndexMap>> layout_transforms_;
+  arith::Analyzer analyzer;
 };
 
 class StorageFlattener : public StmtExprMutator {
diff --git a/src/tir/transforms/transform_mma_buffer_layout.cc 
b/src/tir/transforms/transform_mma_buffer_layout.cc
index 82fd6cfa9a..abe0bc3a3d 100644
--- a/src/tir/transforms/transform_mma_buffer_layout.cc
+++ b/src/tir/transforms/transform_mma_buffer_layout.cc
@@ -130,7 +130,7 @@ class MmaBufferLayoutTransformer : public StmtExprMutator {
         const auto* index_map_func = 
runtime::Registry::Get("tir.index_map_m16n8k8.matrixC");
         ICHECK(index_map_func);
         auto index_map = IndexMap::FromFunc(2, *index_map_func);
-        auto new_indices = index_map->MapIndices(store->indices);
+        auto new_indices = index_map->MapIndices(store->indices, &analyzer);
         n->buffer = buffer_map_[store->buffer];
         n->indices = std::move(new_indices);
       } else if (store->buffer.scope() == "m16n8k8.matrixA" ||
@@ -149,7 +149,7 @@ class MmaBufferLayoutTransformer : public StmtExprMutator {
         const auto* index_map_func = 
runtime::Registry::Get("tir.index_map_m16n8k8.matrixC");
         ICHECK(index_map_func);
         auto index_map = IndexMap::FromFunc(2, *index_map_func);
-        auto new_indices = index_map->MapIndices(load->indices);
+        auto new_indices = index_map->MapIndices(load->indices, &analyzer);
         n->buffer = buffer_map_[load->buffer];
         n->indices = std::move(new_indices);
       } else if (load->buffer.scope() == "m16n8k8.matrixA" ||
@@ -179,7 +179,7 @@ Pass TransformMmaBufferLayout() {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
     auto* n = f.CopyOnWrite();
     n->body = MmaBufferLayoutTransformer()(std::move(n->body));
-    return std::move(f);
+    return f;
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.TransformMmaBufferLayout", {});
 }
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py 
b/tests/python/unittest/test_arith_iter_affine_map.py
index 594dec73ea..cee9922e86 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -16,7 +16,7 @@
 # under the License.
 import tvm
 import tvm.testing
-from tvm.tir import floormod, floordiv
+from tvm.tir import floordiv, floormod
 
 
 def ifuse(inputs, pred_extent=None):
@@ -1211,7 +1211,7 @@ def test_iter_map_simplify_unit_loop_order():
     # When we have iterators that have same scale but one of them come
     # with unit extent, we should prioritize unit extent
     assert_iter_map_simplify(
-        {x // 128 + y + z: y + x // 128 + z},
+        {x // 128 + y + z: y + z},
         var_dom([(x, 128), (y, 128), (z, 1)]),
         simplify_trivial_iterators=False,
     )
diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py 
b/tests/python/unittest/test_meta_schedule_relay_integration.py
index 162dec6271..3a2ca69cba 100644
--- a/tests/python/unittest/test_meta_schedule_relay_integration.py
+++ b/tests/python/unittest/test_meta_schedule_relay_integration.py
@@ -15,8 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 """Integration test for MetaSchedule"""
-import tempfile
 import platform
+import tempfile
 from typing import List
 
 import numpy as np
@@ -56,6 +56,7 @@ class MockModule:
 # pylint: 
enable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument
 
 
[email protected]("Integration tests")
 def test_meta_schedule_dynamic_loop_extent():
     a = relay.var("a", shape=(1, 8, 8, 512), dtype="float32")
     b = relay.nn.adaptive_avg_pool2d(a, (7, 7), "NHWC")
@@ -64,6 +65,7 @@ def test_meta_schedule_dynamic_loop_extent():
     assert not extracted_tasks
 
 
[email protected]("Integration tests")
 @pytest.mark.skipif(
     platform.machine() == "aarch64",
     reason="Currently torch.jit.trace fails on AArch64",
@@ -104,6 +106,7 @@ def test_meta_schedule_integration_extract_from_resnet():
         assert t.task_name in expected_task_names, t.task_name
 
 
[email protected]("Integration tests")
 @pytest.mark.skipif(
     platform.machine() == "aarch64",
     reason="Currently torch.jit.trace fails on AArch64",
@@ -126,6 +129,7 @@ def test_task_extraction_winograd_tensorcore():
     assert len([t for t in extracted_tasks if "winograd" in t.task_name]) == 4
 
 
[email protected]("Integration tests")
 @pytest.mark.skipif(
     platform.machine() == "aarch64",
     reason="Currently torch.jit.trace fails on AArch64",
@@ -165,6 +169,7 @@ def test_task_extraction_anchor_block():
         assert t.task_name in expected_task_names, t.task_name
 
 
[email protected]("Integration tests")
 @tvm.testing.requires_package("torch")
 def test_meta_schedule_integration_extract_from_bert_base():
     pytest.importorskip(
@@ -263,6 +268,7 @@ def test_meta_schedule_integration_extract_from_bert_base():
         assert expected_shape == shape, t.task_name
 
 
[email protected]("Integration tests")
 @pytest.mark.skipif(
     platform.machine() == "aarch64",
     reason="Currently torch.jit.trace fails on AArch64",
@@ -374,6 +380,7 @@ def extract_task_qbert_avx512():
     extract_task_qbert("llvm -mcpu=skylake-avx512", "avx512")
 
 
[email protected]("Integration tests")
 @tvm.testing.skip_if_32bit(reason="Apparently the LLVM version on i386 image 
is too old")
 def test_extract_task_arm_conv2d_nchwc():
     data_shape = (1, 64, 128, 128)
@@ -419,6 +426,7 @@ def test_extract_task_arm_conv2d_nchwc():
     assert list(out_type.shape) == [1, 8, 130, 130, 4]
 
 
[email protected]("Integration tests")
 def test_meta_schedule_te2primfunc_argument_order_and_lowering():
     # pylint: 
disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
     # fmt: off
@@ -581,7 +589,9 @@ def 
test_meta_schedule_te2primfunc_argument_order_and_lowering():
         dev,
     )
 
-    with target, _create_verification_database(), PassContext(  # pylint: 
disable=not-context-manager
+    with (
+        target
+    ), _create_verification_database(), PassContext(  # pylint: 
disable=not-context-manager
         opt_level=3,
         config={
             "relay.backend.use_meta_schedule": True,
@@ -607,6 +617,7 @@ def 
test_meta_schedule_te2primfunc_argument_order_and_lowering():
     assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4)
 
 
[email protected]("Integration tests")
 def test_rewrite_layout_link_params():
     I, O, H, W = 64, 64, 56, 56
     kH = kW = 3
@@ -685,6 +696,7 @@ def test_rewrite_layout_link_params():
         np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4)
 
 
[email protected]("Integration tests")
 def test_module_equality_ignore_ndarray():
     target = "llvm --num-cores=4"
 
@@ -800,6 +812,7 @@ def _test_anchor_tuning(target, space):
     np.testing.assert_allclose(ref, out, atol=1e-3)
 
 
[email protected]("Integration tests")
 @pytest.mark.parametrize(
     "space",
     [
@@ -811,6 +824,7 @@ def test_anchor_tuning_cpu(space):
     _test_anchor_tuning("llvm --num-cores=4", space)
 
 
[email protected]("Integration tests")
 def test_anchor_tuning_cpu_link_params():
     data_shape = (128, 128)
     weight_shape1 = (128, 128)
@@ -863,6 +877,7 @@ def test_anchor_tuning_cpu_link_params():
     np.testing.assert_allclose(ref, out, atol=1e-3)
 
 
[email protected]("Integration tests")
 @pytest.mark.xfail(raises=tvm.error.TVMError)
 def test_disabled_pass_param():
     """
@@ -908,6 +923,7 @@ def test_disabled_pass_param():
     pytest.fail("'disabled_pass' argument does not work")
 
 
[email protected]("Integration tests")
 def test_rewrite_layout_link_params_1x1_conv2d():
     I, O, H, W = 32, 16, 256, 256
     kH = kW = 1
diff --git 
a/tests/python/unittest/test_meta_schedule_schedule_cuda_layout_transform.py 
b/tests/python/unittest/test_meta_schedule_schedule_cuda_layout_transform.py
index d1ba84d836..437aae9e6b 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_cuda_layout_transform.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_cuda_layout_transform.py
@@ -21,11 +21,14 @@ import tempfile
 from typing import Callable, Dict, List, Optional, Tuple, Union
 
 import numpy as np
+import pytest
 
 import tvm
 import tvm.testing
 from tvm import meta_schedule, relay
-from tvm.meta_schedule.schedule.cuda.layout_transform import 
cuda_layout_transform_schedule_rule
+from tvm.meta_schedule.schedule.cuda.layout_transform import (
+    cuda_layout_transform_schedule_rule,
+)
 from tvm.relay.op import OpPattern
 from tvm.script import ir as I
 from tvm.script import tir as T
@@ -170,6 +173,7 @@ def run_primfunc(
     lib(*input_tensors)
 
 
[email protected]("Integration test")
 class TestRandomRelayE2ECorrectness:
     """Tests E2E correctness of layout transform schedule.
 
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py 
b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
index 7cf06b54ca..d1f4b6bdce 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
@@ -14,9 +14,10 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: 
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+# pylint: 
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring,line-too-long,invalid-name,too-many-locals,too-many-statements,too-many-nested-blocks,too-many-branches,too-many-lines,chained-comparison
 
 import pytest
+
 import tvm
 import tvm.testing
 from tvm import meta_schedule as ms
@@ -413,10 +414,10 @@ def test_conv2d(shared_scope):
                             with T.block("PadInput_reindex_shared.dyn"):
                                 v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused 
* 16 + ax0_ax1_fused // 288)
                                 v1 = T.axis.spatial(288, ax0_ax1_fused % 288)
-                                T.reads(PadInput[v0 // 256, v1 // 96 + v0 // 
16, v1 % 96 // 32 + v0 % 16, v1 % 32])
+                                T.reads(PadInput[0, v0 // 16 + v1 // 96, v0 % 
16 + v1 % 96 // 32, v1 % 32])
                                 T.writes(PadInput_reindex_shared_dyn[v0, v1])
                                 T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch": 2})
-                                PadInput_reindex_shared_dyn[v0, v1] = 
PadInput[v0 // 256, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32]
+                                PadInput_reindex_shared_dyn[v0, v1] = 
PadInput[0, v0 // 16 + v1 // 96, v0 % 16 + v1 % 96 // 32, v1 % 32]
                         for ax0_ax1_fused in range(4608):
                             with T.block("weight_reindex_shared.dyn"):
                                 v0 = T.axis.spatial(288, ax0_ax1_fused // 16)
@@ -497,9 +498,9 @@ def test_conv2d(shared_scope):
                             v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
// 16)
                             v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 16)
                             T.reads(conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, 
v3, v4, v5])
-                            T.writes(conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + 
v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16])
+                            T.writes(conv2d_nhwc[0, (v4 + v0 * 16) // 16, (v4 
+ v0 * 16) % 16, v5 + v1 * 16])
                             T.block_attr({"meta_schedule.cooperative_fetch": 
3})
-                            conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) 
// 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared_dyn[v0, 
v1, v2, v3, v4, v5]
+                            conv2d_nhwc[0, (v4 + v0 * 16) // 16, (v4 + v0 * 
16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, v4, v5]
     # fmt: on
     decision_0 = [
         ("SamplePerfectTile", [1, 16, 1, 1, 1]),
@@ -915,10 +916,10 @@ def test_conv_1x1():
                             with T.block("PadInput_reindex_shared"):
                                 v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused 
// 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0_ax1_fused // 64)
                                 v1 = T.axis.spatial(64, ax0_ax1_fused % 64)
-                                T.reads(inputs[v0 // 256, v0 // 16, v0 % 16, 
v1])
+                                T.reads(inputs[0, v0 // 16, v0 % 16, v1])
                                 T.writes(PadInput_reindex_shared[v0, v1])
                                 T.block_attr({"buffer_dim_align": [[0, 0, 32, 
8]], "meta_schedule.cooperative_fetch": 1})
-                                PadInput_reindex_shared[v0, v1] = inputs[v0 // 
256, v0 // 16, v0 % 16, v1]
+                                PadInput_reindex_shared[v0, v1] = inputs[0, v0 
// 16, v0 % 16, v1]
                         for ax0_ax1_ax2_ax3_fused in range(2048):
                             with T.block("weight_reindex_shared"):
                                 v0 = T.axis.spatial(1, 0)
@@ -1007,9 +1008,9 @@ def test_conv_1x1():
                             v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 256 // 16)
                             v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused 
% 16)
                             T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, 
v4, v5])
-                            T.writes(conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + 
v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16])
+                            T.writes(conv2d_nhwc[0, (v4 + v0 * 16) // 16, (v4 
+ v0 * 16) % 16, v5 + v1 * 16])
                             T.block_attr({"meta_schedule.cooperative_fetch": 
2})
-                            conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) 
// 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared[v0, v1, 
v2, v3, v4, v5]
+                            conv2d_nhwc[0, (v4 + v0 * 16) // 16, (v4 + v0 * 
16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]
     # fmt: on
 
     decision_0 = [
diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py 
b/tests/python/unittest/test_meta_schedule_trace_apply.py
index 78b2fdbf3d..8562205753 100644
--- a/tests/python/unittest/test_meta_schedule_trace_apply.py
+++ b/tests/python/unittest/test_meta_schedule_trace_apply.py
@@ -15,16 +15,14 @@
 # specific language governing permissions and limitations
 # under the License.
 import pytest
-
 import tvm
-import tvm.testing
 import tvm.meta_schedule as ms
+import tvm.testing
 from tvm.script import tir as T
-from tvm.tir import Schedule, floormod, floordiv
-from tvm.tir.tensor_intrin.cuda import *
 from tvm.target import Target
 from tvm.target.codegen import llvm_lookup_intrinsic_id
-
+from tvm.tir import Schedule, floordiv, floormod
+from tvm.tir.tensor_intrin.cuda import *
 from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
 
 
@@ -1885,6 +1883,7 @@ def test_dense_add_cpu():
                     ((i0 * 64) + i2),
                     i1,
                 ),
+                index_dtype="int32",
             ),
             pad_value=None,
         )
@@ -1950,6 +1949,7 @@ def test_dense_add_cpu_no_write_cache():
                     ((i1 * 32) + i3),
                     ((i0 * 16) + i2),
                 ),
+                index_dtype="int32",
             ),
             pad_value=None,
         )
diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py 
b/tests/python/unittest/test_meta_schedule_tune_tir.py
index aa45120c23..c8fc4a73f5 100644
--- a/tests/python/unittest/test_meta_schedule_tune_tir.py
+++ b/tests/python/unittest/test_meta_schedule_tune_tir.py
@@ -20,6 +20,7 @@ import tempfile
 
 import numpy as np
 import pytest
+
 import tvm
 import tvm.testing
 from tvm import meta_schedule as ms
@@ -61,6 +62,7 @@ def two_step(a: T.handle, c: T.handle) -> None:
             C[vi, vj] = B[vi, vj] + 3.0
 
 
[email protected]("Integration test")
 @tvm.testing.requires_llvm
 def test_tune_matmul_cpu():
     with tempfile.TemporaryDirectory() as work_dir:
@@ -80,6 +82,7 @@ def test_tune_matmul_cpu():
             sch.trace.show()
 
 
[email protected]("Integration test")
 @tvm.testing.requires_cuda
 def test_tune_matmul_cuda():
     with tempfile.TemporaryDirectory() as work_dir:
@@ -99,6 +102,7 @@ def test_tune_matmul_cuda():
             sch.trace.show()
 
 
[email protected]("Integration test")
 def test_tune_run_module_via_rpc():
     target = tvm.target.Target("llvm")
     rt_mod = tvm.build(matmul, target)
@@ -141,6 +145,7 @@ def test_tune_run_module_via_rpc():
         tvm.testing.assert_allclose(result.numpy(), c_np, rtol=1e-3)
 
 
[email protected]("Integration test")
 def test_tune_block_cpu():
     @ms.derived_object
     class RemoveBlock(ms.schedule_rule.PyScheduleRule):
diff --git a/tests/python/unittest/test_transform_layout.py 
b/tests/python/unittest/test_te_transform_layout.py
similarity index 100%
rename from tests/python/unittest/test_transform_layout.py
rename to tests/python/unittest/test_te_transform_layout.py
diff --git a/tests/python/unittest/test_index_map.py 
b/tests/python/unittest/test_tir_index_map.py
similarity index 97%
rename from tests/python/unittest/test_index_map.py
rename to tests/python/unittest/test_tir_index_map.py
index 5eb31cd378..e893ed897d 100644
--- a/tests/python/unittest/test_index_map.py
+++ b/tests/python/unittest/test_tir_index_map.py
@@ -15,17 +15,16 @@
 # specific language governing permissions and limitations
 # under the License.
 import numpy as np
-
 import pytest
+
 import tvm
 import tvm.testing
 from tvm.ir import assert_structural_equal
-from tvm.tir import IndexMap, IntImm, floordiv, floormod
 from tvm.runtime import const
+from tvm.tir import IndexMap, IntImm, floordiv, floormod
 
 
 def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None:
-
     iters_1 = map1.map_indices(map2.initial_indices)
     iters_2 = map2.final_indices
     assert len(iters_1) == len(iters_2)
@@ -36,7 +35,7 @@ def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> 
None:
 
 
 def test_index_mapping():
-    index_map = IndexMap.from_func(lambda i: [i // 4, i % 4])
+    index_map = IndexMap.from_func(lambda i: [i // 4, i % 4], 
index_dtype="int32")
 
     assert_structural_equal(index_map.map_indices([0]), [0, 0])
     assert_structural_equal(index_map.map_indices([3]), [0, 3])
@@ -48,7 +47,7 @@ def test_index_mapping():
 
 
 def test_shape_mapping():
-    index_map = IndexMap.from_func(lambda i: [i // 4, i % 4])
+    index_map = IndexMap.from_func(lambda i: [i // 4, i % 4], 
index_dtype="int32")
 
     assert_structural_equal(index_map.map_shape([4]), [1, 4])
     assert_structural_equal(index_map.map_shape([16]), [4, 4])
@@ -184,7 +183,7 @@ padding_test_case = tvm.testing.parameter(
 
 
 def test_nonsurjective_inverse(padding_test_case):
-    index_map = IndexMap.from_func(padding_test_case["forward"])
+    index_map = IndexMap.from_func(padding_test_case["forward"], 
index_dtype="int32")
 
     inverse, padding_predicate = 
index_map.non_surjective_inverse(padding_test_case["pre_shape"])
     expected_inverse = IndexMap.from_func(padding_test_case["inverse"])
diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py 
b/tests/python/unittest/test_tir_schedule_transform_layout.py
index 8de11d8bd5..04bd00111e 100644
--- a/tests/python/unittest/test_tir_schedule_transform_layout.py
+++ b/tests/python/unittest/test_tir_schedule_transform_layout.py
@@ -154,11 +154,9 @@ def conv2d_nhwc_transformed(
     for ax0, ax1, ax2 in T.grid(12544, 64, 147):
         with T.block("conv2d_nhwc"):
             v0, v1, v2 = T.axis.remap("SSR", [ax0, ax1, ax2])
-            T.reads(PadInput[v0 // 12544, v0 // 112 * 2 + v2 // 21, v0 % 112 * 
2 + v2 % 21 // 3, v2 % 3], Weight[v2 // 21, v2 % 21 // 3, v2 % 3, v1])
-            T.writes(Conv2d_nhwc[v0 // 12544, v0 // 112, v0 % 112, v1])
             with T.init():
-                Conv2d_nhwc[v0 // 12544, v0 // 112, v0 % 112, v1] = 
T.float32(0)
-            Conv2d_nhwc[v0 // 12544, v0 // 112, v0 % 112, v1] = Conv2d_nhwc[v0 
// 12544, v0 // 112, v0 % 112, v1] + PadInput[v0 // 12544, v0 // 112 * 2 + v2 
// 21, v0 % 112 * 2 + v2 % 21 // 3, v2 % 3] * Weight[v2 // 21, v2 % 21 // 3, v2 
% 3, v1]
+                Conv2d_nhwc[0, v0 // 112, v0 % 112, v1] = T.float32(0)
+            Conv2d_nhwc[0, v0 // 112, v0 % 112, v1] = Conv2d_nhwc[0, v0 // 
112, v0 % 112, v1] + PadInput[0, v0 // 112 * 2 + v2 // 21, v0 % 112 * 2 + v2 % 
21 // 3, v2 % 3] * Weight[v2 // 21, v2 % 21 // 3, v2 % 3, v1]
 
 
 @T.prim_func
@@ -461,11 +459,6 @@ def 
test_transform_block_layout_int64_extent(use_block_name):
     sch = tir.Schedule(elementwise_int64_extent, debug_mask="all")
     block = "B" if use_block_name else sch.get_block("B")
     sch.transform_block_layout(block, lambda i, j: (i * 128 + j,))
-    print(
-        tvm.ir.base.get_first_structural_mismatch(
-            elementwise_int64_extent_transformed, sch.mod["main"]
-        )
-    )
     tvm.ir.assert_structural_equal(elementwise_int64_extent_transformed, 
sch.mod["main"])
     verify_trace_roundtrip(sch=sch, mod=elementwise_int64_extent)
 
@@ -1085,5 +1078,106 @@ def test_index_map_dtype_legalize_with_constant():
     sch.transform_layout(block="block", buffer="A", index_map=func, 
pad_value=0)
 
 
+def test_transform_layout_with_symbolic_bound():
+    # fmt: off
+    # pylint: disable=invalid-name,line-too-long,too-many-locals
+    @T.prim_func
+    def before(a: T.handle, b: T.handle, c: T.handle):
+        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+        n = T.int64()
+        A = T.match_buffer(a, (T.int64(1), T.int64(32), T.int64(1), 
T.int64(128)), "float16")
+        B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)), 
"float16")
+        C = T.match_buffer(c, (T.int64(1), T.int64(32), T.int64(1), n), 
"float16")
+        for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), 
n, T.int64(128)):
+            with T.block("NT_matmul"):
+                v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, 
i2, i3, k])
+                T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_i3, v_k])
+                T.writes(C[v_i0, v_i1, v_i2, v_i3])
+                with T.init():
+                    C[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
+                C[v_i0, v_i1, v_i2, v_i3] = C[v_i0, v_i1, v_i2, v_i3] + 
A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_i3, v_k]
+
+    @T.prim_func
+    def after(a: T.handle, b: T.handle, c: T.handle):
+        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+        n = T.int64()
+        A = T.match_buffer(a, (T.int64(1), T.int64(32), T.int64(1), 
T.int64(128)), "float16")
+        B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)), 
"float16")
+        C = T.match_buffer(c, (n * T.int64(32),), "float16")
+        for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), 
n, T.int64(128)):
+            with T.block("NT_matmul"):
+                v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, 
i2, i3, k])
+                T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_i3, v_k])
+                T.writes(C[v_i1 * n + v_i3])
+                with T.init():
+                    C[v_i1 * n + v_i3] = T.float16(0)
+                C[v_i1 * n + v_i3] = C[v_i1 * n + v_i3] + A[v_i0, v_i1, v_i2, 
v_k] * B[v_i0, v_i1, v_i3, v_k]
+    # pylint: enable=invalid-name,line-too-long,too-many-locals
+    # fmt: on
+    # pylint: disable=invalid-name
+    _, _, n, _ = before.buffer_map[before.params[1]].shape
+    sch = tvm.tir.Schedule(before)
+    block = sch.get_block("NT_matmul")
+    sch.transform_layout(
+        block,
+        ("write", 0),
+        lambda x, y, z, w: x * 32 * n + y * n + z * n + w,
+        assume_injective_transform=True,
+    )
+    # pylint: enable=invalid-name
+    tvm.ir.assert_structural_equal(after, sch.mod["main"])
+
+
+def test_transform_block_layout_with_symbolic_bound():
+    # fmt: off
+    # pylint: disable=invalid-name,line-too-long,too-many-locals
+    @T.prim_func
+    def before(a: T.handle, b: T.handle, c: T.handle):
+        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+        n = T.int64()
+        A = T.match_buffer(a, (T.int64(1), T.int64(32), T.int64(1), 
T.int64(128)), "float16")
+        B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)), 
"float16")
+        C = T.match_buffer(c, (n * T.int64(32),), "float16")
+        for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), 
n, T.int64(128)):
+            with T.block("NT_matmul"):
+                v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, 
i2, i3, k])
+                T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_i3, v_k])
+                T.writes(C[v_i1 * n + v_i3])
+                with T.init():
+                    C[v_i1 * n + v_i3] = T.float16(0)
+                C[v_i1 * n + v_i3] = C[v_i1 * n + v_i3] + A[v_i0, v_i1, v_i2, 
v_k] * B[v_i0, v_i1, v_i3, v_k]
+
+    @T.prim_func
+    def after(a: T.handle, b: T.handle, c: T.handle):
+        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+        n = T.int64()
+        A = T.match_buffer(a, (T.int64(1), T.int64(32), T.int64(1), 
T.int64(128)), "float16")
+        B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)), 
"float16")
+        C = T.match_buffer(c, (n * T.int64(32),), "float16")
+        for ax0, ax1 in T.grid(n * T.int64(32), T.int64(128)):
+            with T.block("NT_matmul"):
+                v0, v1 = T.axis.remap("SR", [ax0, ax1])
+                T.reads(A[T.int64(0), v0 // n, T.int64(0), v1], B[T.int64(0), 
v0 // n, v0 % n, v1])
+                T.writes(C[v0])
+                with T.init():
+                    C[v0] = T.float16(0)
+                C[v0] = C[v0] + A[T.int64(0), v0 // n, T.int64(0), v1] * 
B[T.int64(0), v0 // n, v0 % n, v1]
+    # pylint: enable=invalid-name,line-too-long,too-many-locals
+    # fmt: on
+    # pylint: disable=invalid-name
+    _, _, n, _ = before.buffer_map[before.params[1]].shape
+    sch = tvm.tir.Schedule(before)
+    block = sch.get_block("NT_matmul")
+    sch.transform_block_layout(
+        block,
+        lambda x, y, z, w, k: (
+            x * 32 * n + y * n + z * n + w,
+            k,
+        ),
+    )
+    # pylint: enable=invalid-name
+    tvm.ir.assert_structural_equal(after, sch.mod["main"])
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to