This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5e7438f  [TIR][Schedule] Blockize and Tensorize (#9871)
5e7438f is described below

commit 5e7438feaabf5b843fdb0fac9d93934ffb313c5a
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Jan 26 03:00:07 2022 -0500

    [TIR][Schedule] Blockize and Tensorize (#9871)
    
    * WIP
    
    * WIP
    
    * WIP
    
    * test cases
    
    * add examples
    
    * lint
    
    * Amend co-authors information
    
    Co-authored-by: Siyuan Feng <[email protected]>
    Co-authored-by: Bohan Hou 
<[email protected]>
    Co-authored-by: Hongyi Jin <[email protected]>
    Co-authored-by: Ruihang Lai <[email protected]>
    Co-authored-by: Junru Shao <[email protected]>
    Co-authored-by: Xiyou Zhou <[email protected]>
    
    * WIP
    
    * address comments and changed tensorized comparator
    
    * update
    
    * nit
    
    * fix example
    
    * lint
    
    * lint
    
    * lint
    
    * remove unused
    
    * trigger ci
    
    * clang-format
    
    * fix
    
    * rebase
    
    Co-authored-by: Siyuan Feng <[email protected]>
    Co-authored-by: Bohan Hou 
<[email protected]>
    Co-authored-by: Hongyi Jin <[email protected]>
    Co-authored-by: Ruihang Lai <[email protected]>
    Co-authored-by: Junru Shao <[email protected]>
    Co-authored-by: Xiyou Zhou <[email protected]>
---
 include/tvm/arith/iter_affine_map.h                |   7 +
 include/tvm/tir/function.h                         |  52 ++
 include/tvm/tir/schedule/schedule.h                |  19 +
 python/tvm/tir/__init__.py                         |   2 +-
 python/tvm/tir/function.py                         |  48 ++
 python/tvm/tir/schedule/schedule.py                | 229 +++++++
 src/arith/int_set.cc                               |   2 +-
 src/tir/ir/function.cc                             |  53 ++
 src/tir/schedule/concrete_schedule.cc              |  23 +
 src/tir/schedule/concrete_schedule.h               |   3 +
 src/tir/schedule/ir_comparator.cc                  | 363 +++++++++++
 src/tir/schedule/ir_comparator.h                   | 116 ++++
 src/tir/schedule/primitive.h                       |  18 +
 src/tir/schedule/primitive/blockize_tensorize.cc   | 698 +++++++++++++++++++++
 src/tir/schedule/schedule.cc                       |  14 +
 src/tir/schedule/state.cc                          |   4 +-
 src/tir/schedule/traced_schedule.cc                |  31 +
 src/tir/schedule/traced_schedule.h                 |   3 +
 .../python/unittest/test_tir_schedule_blockize.py  | 210 +++++++
 .../python/unittest/test_tir_schedule_tensorize.py | 431 +++++++++++++
 20 files changed, 2321 insertions(+), 5 deletions(-)

diff --git a/include/tvm/arith/iter_affine_map.h 
b/include/tvm/arith/iter_affine_map.h
index 22b4cd5..eb69c18 100644
--- a/include/tvm/arith/iter_affine_map.h
+++ b/include/tvm/arith/iter_affine_map.h
@@ -350,6 +350,13 @@ Array<Array<IterMark>> SubspaceDivide(const 
Array<PrimExpr>& bindings,
                                       bool require_bijective, arith::Analyzer* 
analyzer,
                                       DiagnosticContext diag_ctx);
 
+/*!
+ * \brief Given an IterMapExpr, transform it to normal PrimExpr.
+ * \param expr The input IterMapExpr.
+ * \return The corresponding normal PrimExpr.
+ */
+PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr);
+
 }  // namespace arith
 }  // namespace tvm
 #endif  // TVM_ARITH_ITER_AFFINE_MAP_H_
diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
index e482a18..1ab911b 100644
--- a/include/tvm/tir/function.h
+++ b/include/tvm/tir/function.h
@@ -188,6 +188,58 @@ class LinkedParam : public ObjectRef {
 };
 
 /*!
+ * \brief Tensor intrinsics for tensorization
+ */
+class TensorIntrinNode : public Object {
+ public:
+  /*! \brief The function to describe the computation. */
+  PrimFunc desc;
+  /*! \brief The function of the implementation for the execution. */
+  PrimFunc impl;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("desc", &desc);
+    v->Visit("impl", &impl);
+  }
+
+  static constexpr const char* _type_key = "tir.TensorIntrin";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object);
+};
+
+/*!
+ * \brief Managed reference to TensorIntrinNode.
+ */
+class TensorIntrin : public ObjectRef {
+ public:
+  /*!
+   * \brief Constructor
+   * \param desc The function to describe the computation.
+   * \param impl The function of the implementation for the execution.
+   */
+  TVM_DLL explicit TensorIntrin(PrimFunc desc, PrimFunc impl);
+
+  /*!
+   * \brief Create and register a TensorIntrin. After registration, the 
TensorIntrin can be looked
+   * up with its name.
+   * \param name The name of the TensorIntrin to register
+   * \param intrin The TensorIntrin to register.
+   * \throws This method throws an exception if the TensorIntrin with the 
specified name already
+   *         exists.
+   */
+  TVM_DLL static void Register(String name, TensorIntrin intrin);
+
+  /*!
+   * \brief Look up TensorIntrin by name. Raises an exception if not found.
+   * \param name The name of the TensorIntrin.
+   * \return The TensorIntrin with the specified name.
+   * \throws This method throws an exception if the TensorIntrin does not 
exist.
+   */
+  TVM_DLL static TensorIntrin Get(String name);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode)
+};
+
+/*!
  * \brief Specialize parameters of PrimFunc.
  * \param func The PrimFunc to be specialized.
  * \param param_map The mapping from function params to the instance.
diff --git a/include/tvm/tir/schedule/schedule.h 
b/include/tvm/tir/schedule/schedule.h
index 43f2379..be06b44 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -473,6 +473,25 @@ class ScheduleNode : public runtime::Object {
    */
   virtual void SetScope(const BlockRV& block_rv, int buffer_index, const 
String& storage_scope) = 0;
   /******** Schedule: Blockize & Tensorize ********/
+  /*!
+   * \brief Convert the subtree rooted at a specific loop into a block.
+   * \param loop_rv the root of the subtree
+   * \return the new block
+   */
+  virtual BlockRV Blockize(const LoopRV& loop_rv) = 0;
+  /*!
+   * \brief Tensorize the computation enclosed by loop with the tensor intrin.
+   * \param loop_rv The loop to be tensorized
+   * \param intrin Name of the tensor intrinsic
+   */
+  virtual void Tensorize(const LoopRV& loop_rv, const String& intrin) = 0;
+  /*!
+   * \brief Tensorize the computation enclosed by loop with the tensor intrin.
+   * \param block_rv The block to be tensorized
+   * \param intrin Name of the tensor intrinsic
+   */
+  virtual void Tensorize(const BlockRV& block_rv, const String& intrin) = 0;
+
   /******** Schedule: Annotation ********/
   /*!
    * \brief Annotate a loop with a key value pair
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 07ceb29..5854b93 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -33,7 +33,7 @@ from .stmt import ProducerRealize, SeqStmt
 from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
 from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize
 
-from .function import PrimFunc
+from .function import PrimFunc, TensorIntrin
 
 from .op import call_packed, call_intrin, call_pure_extern, call_extern
 from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, 
min_value, max_value, trace
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
index ecbcd83..bcebab9 100644
--- a/python/tvm/tir/function.py
+++ b/python/tvm/tir/function.py
@@ -162,3 +162,51 @@ class PrimFunc(BaseFunc):
         return tvm._ffi.get_global_func("script.AsTVMScript")(
             self, tir_prefix, show_meta
         )  # type: ignore
+
+
+@tvm._ffi.register_object("tir.TensorIntrin")
+class TensorIntrin(Object):
+    """A tensor intrinsic.
+
+    Parameters
+    ----------
+    desc : PrimFunc
+        The function to describe the computation.
+
+    impl : PrimFunc
+        The function of the implementation for the execution.
+    """
+
+    def __init__(self, desc, impl):
+        self.__init_handle_by_constructor__(_ffi_api.TensorIntrin, desc, impl)
+
+    @staticmethod
+    def register(name: str, desc: PrimFunc, impl: PrimFunc):
+        """Register a tensor intrinsic with its name.
+
+        Parameters
+        ----------
+        name : str
+            The name of the TensorIntrin to register.
+        desc : PrimFunc
+            The function to describe the computation.
+        impl : PrimFunc
+            The function of the implementation for the execution.
+        """
+        return _ffi_api.TensorIntrinRegister(name, TensorIntrin(desc, impl))  
# type: ignore
+
+    @staticmethod
+    def get(name: str):
+        """Look up a tensor intrinsic by its name.
+
+        Parameters
+        ----------
+        name : str
+            The name of the TensorIntrin to look up.
+
+        Returns
+        -------
+        result : TensorIntrin
+            The TensorIntrin with the specified name.
+        """
+        return _ffi_api.TensorIntrinGet(name)  # pylint: type: ignore
diff --git a/python/tvm/tir/schedule/schedule.py 
b/python/tvm/tir/schedule/schedule.py
index 7d352f1..96fa21f 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -1759,6 +1759,235 @@ class Schedule(Object):
 
     ########## Schedule: Blockize & Tensorize ##########
 
+    @type_checked
+    def blockize(self, loop: LoopRV) -> BlockRV:
+        """Convert the subtree rooted at a specific loop into a block.
+
+        Parameters
+        ----------
+        loop : LoopRV
+            The root of the subtree.
+
+        Returns
+        -------
+        result : BlockRV
+            The new block.
+
+        Examples
+        --------
+
+        Before blockize, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def before_blockize(
+                A: T.Buffer[(128, 128), "float32"],
+                B: T.Buffer[(128, 128), "float32"]
+            ) -> None:
+                for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16):
+                    with T.block("B"):
+                        vi = T.axis.spatial(128, i_0 * 16 + i_1)
+                        vj = T.axis.spatial(128, j_0 * 16 + j_1)
+                        T.reads(A[vi, vj])
+                        T.writes(B[vi, vj])
+                        B[vi, vj] = A[vi, vj] * T.float32(2)
+
+        Create the schedule and do set_scope:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_blockize)
+            B = sch.get_block("B")
+            _, _, i1, _ = sch.get_loops(B)
+            sch.blockize(i1)
+            print(sch.mod["main"].script())
+
+        After applying blockize, the IR becomes:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def after_blockize(
+                A: T.Buffer[(128, 128), "float32"],
+                B: T.Buffer[(128, 128), "float32"]
+            )-> None:
+                for i_0, j_0 in T.grid(8, 8):
+                    with T.block("B_o"):
+                        vio, vjo = T.axis.remap("SS", [i_0, j_0])
+                        T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 
16 + 16])
+                        T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 
16 + 16])
+                        for i_1, j_1 in T.grid(16, 16):
+                            with T.block("B"):
+                                vi, vj = T.axis.remap("SS", [i_1, j_1])
+                                T.reads(A[vio * 16 + vi, vjo * 16 + vj])
+                                T.writes(B[vio * 16 + vi, vjo * 16 + vj])
+                                B[vio * 16 + vi, vjo * 16 + vj] = A[vio * 16 + 
vi, vjo * 16 + vj] \
+                                                                  * 
T.float32(2)
+
+        Note
+        ----
+        blockize requires there is exactly one block under the given loop and 
the bindings of the
+        block are divisible by the subspace represented by the loops starting 
at the given loop.
+        """
+
+        return _ffi_api.ScheduleBlockize(self, loop)  # type: ignore # pylint: 
disable=no-member
+
+    @type_checked
+    def tensorize(self, block_or_loop: Union[BlockRV, LoopRV], tensor_intrin: 
str) -> None:
+        """Tensorize the computation enclosed by loop with the tensor 
intrinsic.
+
+        Parameters
+        ----------
+        block_or_loop : Union[BlockRV, LoopRV]
+            The loop to be tensorized.
+        tensor_intrin : str
+            The tensor intrin or the name of the tensor intrin.
+
+        Examples
+        --------
+
+        Before tensorize, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def before_tensorize(
+                A: T.Buffer[(128, 128), "float32"],
+                B: T.Buffer[(128, 128), "float32"],
+                C: T.Buffer[(128, 128), "float32"],
+            ) -> None:
+                # body
+                # with T.block("root")
+                for i_0, j_0, k_0, i_1, j_1, k_1 in T.grid(8, 8, 8, 16, 16, 
16):
+                    with T.block("update"):
+                        vi = T.axis.spatial(128, i_0 * 16 + i_1)
+                        vj = T.axis.spatial(128, j_0 * 16 + j_1)
+                        vk = T.axis.reduce(128, k_0 * 16 + k_1)
+                        T.reads(C[vi, vj], A[vi, vk], B[vj, vk])
+                        T.writes(C[vi, vj])
+                        with T.init():
+                            C[vi, vj] = T.float32(0)
+                        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+        Declare and register the tensor intrinsic:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+                A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
+                B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
+                C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
+
+                with T.block("root"):
+                    T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 
: 16])
+                    T.writes(C[0 : 16, 0 : 16])
+                    for i, j, k in T.grid(16, 16, 16):
+                        with T.block("update"):
+                            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+            @T.prim_func
+            def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+                A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
+                B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
+                C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
+
+                with T.block("root"):
+                    T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 
: 16])
+                    T.writes(C[0 : 16, 0 : 16])
+                    T.evaluate(
+                        T.tvm_mma_sync(
+                            C.data,
+                            C.elem_offset // 256,
+                            A.data,
+                            A.elem_offset // 256,
+                            B.data,
+                            B.elem_offset // 256,
+                            C.data,
+                            C.elem_offset // 256,
+                            dtype="handle",
+                        )
+                    )
+
+            tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin)
+
+        Create the schedule and do tensorize:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_tensorize)
+            update = sch.get_block("update")
+            _, _, _, i1, _, _ = sch.get_loops(update)
+            sch.tensorize(i1, "test_mma_intrin")
+            print(sch.mod["main"].script())
+
+        After applying tensorize, the IR becomes:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def after_tensorize(
+                A: T.Buffer[(128, 128), "float32"],
+                B: T.Buffer[(128, 128), "float32"],
+                C: T.Buffer[(128, 128), "float32"],
+            ) -> None:
+                # body
+                # with T.block("root")
+                for i_0, j_0, k_0 in T.grid(8, 8, 8):
+                    with T.block("update_o"):
+                        vio, vjo, vko = T.axis.remap("SSR", [i_0, j_0, k_0])
+                        T.reads(
+                            C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 
16],
+                            A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 + 
16],
+                            B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 + 
16],
+                        )
+                        T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 
16 + 16])
+                        A_1 = T.match_buffer(
+                            A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 + 
16],
+                            [16, 16],
+                            dtype="float32",
+                            offset_factor=1,
+                        )
+                        B_1 = T.match_buffer(
+                            B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 + 
16],
+                            [16, 16],
+                            dtype="float32",
+                            offset_factor=1,
+                        )
+                        C_1 = T.match_buffer(
+                            C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 
16],
+                            [16, 16],
+                            dtype="float32",
+                            offset_factor=1,
+                        )
+                        with T.init():
+                            for i_1, j_1 in T.grid(16, 16):
+                                with T.block("update_init"):
+                                    vi_init, vj_init = T.axis.remap("SS", 
[i_1, j_1])
+                                    T.reads()
+                                    T.writes(C[vio * 16 + vi_init, vjo * 16 + 
vj_init])
+                                    C[vio * 16 + vi_init, vjo * 16 + vj_init] 
= T.float32(0)
+                        T.evaluate(
+                            T.tvm_mma_sync(
+                                C_1.data,
+                                C_1.elem_offset // 256,
+                                A_1.data,
+                                A_1.elem_offset // 256,
+                                B_1.data,
+                                B_1.elem_offset // 256,
+                                C_1.data,
+                                C_1.elem_offset // 256,
+                                dtype="handle",
+                            )
+                        )
+        """
+        _ffi_api.ScheduleTensorize(  # type: ignore # pylint: disable=no-member
+            self, block_or_loop, tensor_intrin
+        )
+
     ########## Schedule: Annotation ##########
 
     @type_checked
diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index 55a1a5a..3d30eef 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -511,7 +511,7 @@ Range IntSet::CoverRange(Range max_range) const {
   const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
   ICHECK(s_int != nullptr);
   if (s_int->HasUpperBound() && s_int->HasLowerBound()) {
-    return Range::FromMinExtent(s_int->min_value,
+    return Range::FromMinExtent(analyzer.Simplify(s_int->min_value),
                                 analyzer.Simplify(s_int->max_value + 1 - 
s_int->min_value));
   }
   return max_range;
diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc
index 101d80a..1c34e34 100644
--- a/src/tir/ir/function.cc
+++ b/src/tir/ir/function.cc
@@ -64,6 +64,51 @@ FuncType PrimFuncNode::func_type_annotation() const {
 
 TVM_REGISTER_NODE_TYPE(PrimFuncNode);
 
+class TensorIntrinManager {
+ public:
+  Map<String, tir::TensorIntrin> reg;
+
+  static TensorIntrinManager* Global() {
+    static TensorIntrinManager* inst = new TensorIntrinManager();
+    return inst;
+  }
+};
+
+TensorIntrin::TensorIntrin(PrimFunc desc, PrimFunc impl) {
+  // Check the number of func var is equal
+  CHECK_EQ(desc->params.size(), impl->params.size())
+      << "ValueError: The number of parameters of the description and the 
implementation of the "
+         "tensor intrinsic doesn't match.";
+  for (size_t i = 0; i < desc->params.size(); i++) {
+    CHECK(desc->params[i]->dtype.is_handle()) << "ValueError: Parameters of 
the description of the "
+                                                 "tensor intrinsic should be 
handle only.";
+    CHECK(impl->params[i]->dtype.is_handle()) << "ValueError: Parameters of 
the implementation of "
+                                                 "the tensor intrinsic should 
be handle only.";
+  }
+  ICHECK_EQ(desc->buffer_map.size(), impl->buffer_map.size());
+
+  ObjectPtr<TensorIntrinNode> n = make_object<TensorIntrinNode>();
+  n->desc = std::move(desc);
+  n->impl = std::move(impl);
+  data_ = std::move(n);
+}
+
+void TensorIntrin::Register(String name, TensorIntrin intrin) {
+  TensorIntrinManager* manager = TensorIntrinManager::Global();
+  CHECK_EQ(manager->reg.count(name), 0)
+      << "ValueError: TensorIntrin '" << name << "' has already been 
registered";
+  manager->reg.Set(name, intrin);
+}
+
+TensorIntrin TensorIntrin::Get(String name) {
+  const TensorIntrinManager* manager = TensorIntrinManager::Global();
+  auto it = manager->reg.find(name);
+  CHECK(it != manager->reg.end()) << "ValueError: TensorIntrin '" << name << 
"' is not registered";
+  return manager->reg.at(name);
+}
+
+TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<PrimFuncNode>([](const ObjectRef& ref, ReprPrinter* p) {
       // TODO(tvm-team) redirect to Text printer once we have a good text 
format.
@@ -85,5 +130,13 @@ TVM_REGISTER_GLOBAL("tir.PrimFunc")
       return PrimFunc(params, body, ret_type, buffer_map, attrs, span);
     });
 
+TVM_REGISTER_GLOBAL("tir.TensorIntrin")
+    .set_body_typed([](PrimFunc desc_func, PrimFunc intrin_func) {
+      return TensorIntrin(desc_func, intrin_func);
+    });
+
+TVM_REGISTER_GLOBAL("tir.TensorIntrinRegister").set_body_typed(TensorIntrin::Register);
+TVM_REGISTER_GLOBAL("tir.TensorIntrinGet").set_body_typed(TensorIntrin::Get);
+
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/schedule/concrete_schedule.cc 
b/src/tir/schedule/concrete_schedule.cc
index 9f8dc6d..fc63f30 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -606,6 +606,29 @@ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& 
loop_rv, int factor_axis) {
 }
 
 /******** Schedule: Blockize & Tensorize ********/
+BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv) {
+  StmtSRef result{nullptr};
+  TVM_TIR_SCHEDULE_BEGIN();
+  result = tir::Blockize(state_, this->GetSRef(loop_rv));
+  this->state_->DebugVerify();
+  TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_);
+  return CreateRV<BlockRV>(result);
+}
+
+void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& 
intrin) {
+  TVM_TIR_SCHEDULE_BEGIN();
+  tir::Tensorize(state_, this->GetSRef(loop_rv), 
tir::TensorIntrin::Get(intrin));
+  this->state_->DebugVerify();
+  TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_);
+}
+
+void ConcreteScheduleNode::Tensorize(const BlockRV& block_rv, const String& 
intrin) {
+  TVM_TIR_SCHEDULE_BEGIN();
+  tir::Tensorize(state_, this->GetSRef(block_rv), 
tir::TensorIntrin::Get(intrin));
+  this->state_->DebugVerify();
+  TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_);
+}
+
 /******** Schedule: Annotation ********/
 
 ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& 
ann_val) {
diff --git a/src/tir/schedule/concrete_schedule.h 
b/src/tir/schedule/concrete_schedule.h
index 96cb0f7..5f10817 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -123,6 +123,9 @@ class ConcreteScheduleNode : public ScheduleNode {
                     int offset) override;
   void SetScope(const BlockRV& block_rv, int buffer_index, const String& 
storage_scope) override;
   /******** Schedule: Blockize & Tensorize ********/
+  BlockRV Blockize(const LoopRV& loop_rv) override;
+  void Tensorize(const BlockRV& loop_rv, const String& intrin) override;
+  void Tensorize(const LoopRV& loop_rv, const String& intrin) override;
   /******** Schedule: Annotation ********/
   void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& 
ann_val) override;
   void Unannotate(const LoopRV& loop_rv, const String& ann_key) override;
diff --git a/src/tir/schedule/ir_comparator.cc 
b/src/tir/schedule/ir_comparator.cc
new file mode 100644
index 0000000..3e61e95
--- /dev/null
+++ b/src/tir/schedule/ir_comparator.cc
@@ -0,0 +1,363 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./ir_comparator.h"
+
+namespace tvm {
+
+namespace tir {
+
+/******** Tensorize Comparator ********/
+
+class TensorIntrinMismatchError : public ScheduleError {
+ public:
+  explicit TensorIntrinMismatchError(IRModule lhs_mod, Stmt lhs_stmt, Stmt 
rhs_stmt,
+                                     std::vector<std::string> error_messages)
+      : lhs_mod_(std::move(lhs_mod)),
+        lhs_stmt_(std::move(lhs_stmt)),
+        rhs_stmt_(std::move(rhs_stmt)),
+        error_messages_(std::move(error_messages)) {
+    ICHECK(lhs_stmt_->IsInstance<ForNode>() || 
lhs_stmt_->IsInstance<BlockNode>());
+  }
+
+  String FastErrorString() const final {
+    return "ScheduleError: The stmt doesn't match the tensor intrin.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The stmt {0} doesn't match the tensor intrin\n " << rhs_stmt_;
+    for (const auto& msg : error_messages_) {
+      os << msg << std::endl;
+    }
+    return os.str();
+  }
+
+  IRModule mod() const final { return lhs_mod_; }
+
+  Array<ObjectRef> LocationsOfInterest() const final { return {lhs_stmt_}; }
+
+ private:
+  IRModule lhs_mod_;
+  Stmt lhs_stmt_;
+  Stmt rhs_stmt_;
+  std::vector<std::string> error_messages_;
+};
+
+/* Override the dispatcher to make sure RHS is always valid */
+bool TensorizeComparator::VisitStmt(const Stmt& n, const Stmt& other) {
+  bool equal = n.same_as(other) ||
+               ((n->type_index() == other->type_index()) && 
StmtComparator::VisitStmt(n, other));
+  if (!equal && assert_mode_ && (n->IsInstance<ForNode>() || 
n->IsInstance<BlockNode>())) {
+    throw TensorIntrinMismatchError(lhs_mod_, n, other, 
std::move(error_messages_));
+  }
+  return equal;
+}
+
+bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) {
+  bool equal =
+      n.same_as(other) || ((n->type_index() == other->type_index()) && 
n->dtype == other->dtype &&
+                           ExprComparator::VisitExpr(n, other));
+  if (!equal && assert_mode_) {
+    std::ostringstream os;
+    os << "Expression mismatch: " << n << " vs " << other;
+    EmitError(os.str());
+  }
+  return equal;
+}
+
+bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) {
+  const auto* rhs = other.as<ForNode>();
+  if (!DefEqual(op->loop_var, rhs->loop_var)) return false;
+  if (!VisitExpr(op->min, rhs->min)) return false;
+  if (!VisitExpr(op->extent, rhs->extent)) return false;
+  if (op->thread_binding.defined() != rhs->thread_binding.defined()) return 
false;
+  if (op->thread_binding.defined() &&
+      !VisitExpr(op->thread_binding.value(), rhs->thread_binding.value())) {
+    return false;
+  }
+  if (op->kind != rhs->kind) return false;
+  if (!CompareAnnotationMap(op->annotations, rhs->annotations)) return false;
+  return VisitStmt(op->body, rhs->body);
+}
+
+bool TensorizeComparator::VisitStmt_(const SeqStmtNode* op, const Stmt& other) 
{
+  const auto* rhs = other.as<SeqStmtNode>();
+  return CompareArray(op->seq, rhs->seq, &TensorizeComparator::VisitStmt);
+}
+
+bool TensorizeComparator::VisitStmt_(const BufferStoreNode* op, const Stmt& 
other) {
+  const auto* rhs = other.as<BufferStoreNode>();
+  return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value);
+}
+
+bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt& 
other) {
+  const auto* rhs = other.as<BlockRealizeNode>();
+  if (!is_scope_block) {
+    if (!CompareArray(op->iter_values, rhs->iter_values, 
&TensorizeComparator::VisitExpr)) {
+      return false;
+    }
+  }
+  return VisitExpr(op->predicate, rhs->predicate) && VisitStmt(op->block, 
rhs->block);
+}
+
+bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) {
+  const auto* rhs = other.as<BlockNode>();
+  // Check block equality.
+  // All iter vars and buffer regions including the order should match.
+  // When checking iter vars, DefEqual is used to remap variables.
+  if (!is_scope_block) {
+    if (!CompareArray(op->iter_vars, rhs->iter_vars, 
&TensorizeComparator::CompareIterVar)) {
+      return false;
+    }
+    if (!CompareAnnotationMap(op->annotations, rhs->annotations)) {
+      return false;
+    }
+    if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers, 
&TensorizeComparator::CompareBuffer)) {
+      return false;
+    }
+  }
+  if (!CompareArray(op->writes, rhs->writes, 
&TensorizeComparator::CompareBufferRegion)) {
+    return false;
+  }
+  if (!CompareArray(op->reads, rhs->reads, 
&TensorizeComparator::CompareBufferRegion)) {
+    return false;
+  }
+  is_scope_block = false;
+  return VisitStmt(op->body, rhs->body);
+}
+
+// Exprs
+#define TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OpName)                         
   \
+  bool TensorizeComparator::VisitExpr_(const OpName* op, const PrimExpr& 
other) { \
+    const auto* rhs = other.as<OpName>();                                      
   \
+    return VisitExpr(op->a, rhs->a) && VisitExpr(op->b, rhs->b);               
   \
+  }
+
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AddNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(SubNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MulNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(DivNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(ModNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(EQNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(NENode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LTNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LENode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GTNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GENode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AndNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OrNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MinNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MaxNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorDivNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorModNode);
+
+bool TensorizeComparator::VisitExpr_(const IntImmNode* op, const PrimExpr& 
other) {
+  const auto* rhs = other.as<IntImmNode>();
+  return op->value == rhs->value;
+}
+
+bool TensorizeComparator::VisitExpr_(const FloatImmNode* op, const PrimExpr& 
other) {
+  const auto* rhs = other.as<FloatImmNode>();
+  return op->value == rhs->value;
+}
+
+bool TensorizeComparator::VisitExpr_(const CastNode* op, const PrimExpr& 
other) {
+  const auto* rhs = other.as<CastNode>();
+  return VisitExpr(op->value, rhs->value);
+}
+
+bool TensorizeComparator::VisitExpr_(const VarNode* op, const PrimExpr& other) 
{
+  const auto* rhs = other.as<VarNode>();
+  auto lhs = GetRef<Var>(op);
+  if (lhs.same_as(other)) return true;
+  if (op->dtype != rhs->dtype) return false;
+  auto it = equal_map_.find(lhs);
+  return it != equal_map_.end() && it->second.same_as(other);
+}
+
+bool TensorizeComparator::VisitExpr_(const BufferLoadNode* op, const PrimExpr& 
other) {
+  const auto* rhs = other.as<BufferLoadNode>();
+  return CompareBufferAccess(op, rhs);
+}
+
+bool TensorizeComparator::VisitExpr_(const SelectNode* op, const PrimExpr& 
other) {
+  const auto* rhs = other.as<SelectNode>();
+  return VisitExpr(op->condition, rhs->condition) && VisitExpr(op->true_value, 
rhs->true_value) &&
+         VisitExpr(op->false_value, rhs->false_value);
+}
+
+bool TensorizeComparator::DefEqual(const Var& lhs, const Var& rhs) {
+  if (lhs.same_as(rhs)) return true;
+  auto it = equal_map_.find(lhs);
+  // If there is already a mapping
+  if (it != equal_map_.end()) return it->second.same_as(rhs);
+  // Otherwise remap lhs to rhs
+  equal_map_[lhs] = rhs;
+  analyzer_.Bind(lhs, rhs);
+  return true;
+}
+
+bool TensorizeComparator::CompareAnnotation(const std::pair<String, 
ObjectRef>& lhs,
+                                            const std::pair<String, 
ObjectRef>& rhs) {
+  if (lhs.first != rhs.first) return false;
+  if (!lhs.second.same_as(rhs.second)) return false;
+  return VisitExpr(Downcast<PrimExpr>(lhs.second), 
Downcast<PrimExpr>(rhs.second));
+}
+
+bool TensorizeComparator::CompareAnnotationMap(const Map<String, ObjectRef>& 
lhs,
+                                               const Map<String, ObjectRef>& 
rhs) {
+  if (lhs.same_as(rhs)) return true;
+  if (lhs.size() != rhs.size()) return false;
+
+  auto sort_map =
+      [](const Map<String, ObjectRef>& map) -> std::vector<std::pair<String, 
ObjectRef>> {
+    std::vector<std::pair<String, ObjectRef>> ret(map.begin(), map.end());
+    sort(ret.begin(), ret.end());
+    return ret;
+  };
+
+  std::vector<std::pair<String, ObjectRef>> lhs_array = sort_map(lhs);
+  std::vector<std::pair<String, ObjectRef>> rhs_array = sort_map(rhs);
+
+  for (size_t i = 0; i < lhs.size(); ++i) {
+    if (!CompareAnnotation(lhs_array[i], rhs_array[i])) return false;
+  }
+  return true;
+}
+
+bool TensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& rhs) {
+  if (lhs.same_as(rhs)) return true;
+  auto it = rhs_buffer_map_.find(rhs);
+  bool equal;
+  if (it != rhs_buffer_map_.end()) {
+    equal = (*it).second.same_as(lhs);
+  } else {
+    // Remap both buffer itself and buffer data, skip buffer shape
+    equal =
+        DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype && 
lhs.scope() == rhs.scope();
+    if (equal) {
+      rhs_buffer_map_[rhs] = lhs;
+    }
+  }
+  return equal;
+}
+
+bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const 
BufferRegion& rhs) {
+  if (!CompareBuffer(lhs->buffer, rhs->buffer)) {
+    if (assert_mode_) {
+      std::ostringstream os;
+      os << "Buffer mismatch: " << lhs->buffer << " vs " << rhs->buffer;
+      EmitError(os.str());
+    }
+    return false;
+  }
+  int offset = static_cast<int>(lhs->region.size()) - 
static_cast<int>(rhs->region.size());
+  // Number of indices in RHS (desc of the tensor intrinsic) must be smaller 
than it in LHS
+  if (offset < 0) return false;
+
+  auto it = buffer_indices_.find(lhs->buffer);
+  if (it == buffer_indices_.end()) {
+    // Update base indices for the buffer, this can only happen if it is 
visiting the scope block.
+    ICHECK(is_scope_block);
+    std::vector<PrimExpr> indices_base;
+    indices_base.reserve(lhs->region.size());
+    for (int i = 0; i < offset; i++) {
+      // High-dim region must be element-wise
+      if (!is_one(lhs->region[i]->extent)) return false;
+      indices_base.emplace_back(lhs->region[i]->min);
+    }
+    for (size_t i = 0; i < rhs->region.size(); i++) {
+      // save base index
+      indices_base.emplace_back(lhs->region[i + offset]->min);
+      // check extent match
+      if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent, 
rhs->region[i]->extent)) {
+        return false;
+      }
+    }
+    buffer_indices_.emplace(lhs->buffer, std::move(indices_base));
+  } else {
+    // Check the base indices are consistent.
+    const std::vector<PrimExpr>& indices_base = it->second;
+    for (int i = 0; i < offset; i++) {
+      // High-dim region must be element-wise
+      if (!is_one(lhs->region[i]->extent)) return false;
+      if (!analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) 
return false;
+    }
+    for (size_t i = 0; i < rhs->region.size(); i++) {
+      // check extent match
+      if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent, 
rhs->region[i]->extent)) {
+        return false;
+      }
+      PrimExpr normalized_lhs_min = (lhs->region[i + offset]->min - 
indices_base[i + offset]);
+      if (!analyzer_.CanProveEqual(normalized_lhs_min, rhs->region[i]->min)) {
+        return false;
+      }
+    }
+  }
+  return true;
+}
+
+// Comparator for BufferStoreNode and BufferLoadNode
+template <typename T>
+bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) {
+  if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false;
+  int offset = static_cast<int>(lhs->indices.size()) - 
static_cast<int>(rhs->indices.size());
+  if (offset < 0) return false;
+  auto it = buffer_indices_.find(lhs->buffer);
+  ICHECK(it != buffer_indices_.end());
+  const std::vector<PrimExpr>& indices_base = (*it).second;
+  ICHECK_EQ(indices_base.size(), rhs->indices.size() + offset);
+  for (size_t i = 0; i < rhs->indices.size(); i++) {
+    PrimExpr normalized_lhs_index = lhs->indices[i + offset] - indices_base[i 
+ offset];
+    if (!analyzer_.CanProveEqual(normalized_lhs_index, rhs->indices[i])) {
+      if (assert_mode_) {
+        std::ostringstream os;
+        os << "Buffer indices mismatch: " << lhs->indices[i + offset] << " vs 
" << rhs->indices[i];
+        EmitError(os.str());
+      }
+      return false;
+    }
+  }
+  return true;
+}
+
+template <typename T, typename F>
+bool TensorizeComparator::CompareArray(const Array<T>& lhs, const Array<T>& 
rhs, F cmp) {
+  if (lhs.same_as(rhs)) return true;
+  if (lhs.size() != rhs.size()) return false;
+  for (size_t i = 0; i < lhs.size(); ++i) {
+    if (!(this->*cmp)(lhs[i], rhs[i])) return false;
+  }
+  return true;
+}
+
+bool TensorizeComparator::CompareRange(const Range& lhs, const Range& rhs) {
+  return VisitExpr(lhs->min, rhs->min) && VisitExpr(lhs->extent, rhs->extent);
+}
+
+bool TensorizeComparator::CompareIterVar(const IterVar& lhs, const IterVar& 
rhs) {
+  return DefEqual(lhs->var, rhs->var) && lhs->iter_type == rhs->iter_type;
+}
+
+void TensorizeComparator::EmitError(const std::string& error_message) {
+  error_messages_.push_back(error_message);
+}
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h
new file mode 100644
index 0000000..359677d
--- /dev/null
+++ b/src/tir/schedule/ir_comparator.h
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_SCHEDULE_IR_COMPARATOR_H_
+#define TVM_TIR_SCHEDULE_IR_COMPARATOR_H_
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+using ExprComparator = ExprFunctor<bool(const PrimExpr& n, const PrimExpr& 
other)>;
+using StmtComparator = StmtFunctor<bool(const Stmt& n, const Stmt& other)>;
+
+/*! \brief Deep comparison to check if two IR ASTs are equivalent for 
tensorization*/
+class TensorizeComparator : public ExprComparator, public StmtComparator {
+ public:
+  /*!
+   * \brief Constructor of TensorizeComparator
+   * \param assert_mode Whether to raise an error if the two IR ASTs do not 
match.
+   * \param lhs_mod The IRModule of the LHS. This is used for error reporting.
+   */
+  explicit TensorizeComparator(IRModule lhs_mod, bool assert_mode = true)
+      : lhs_mod_(std::move(lhs_mod)), assert_mode_(assert_mode) {}
+
+  bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override;
+  bool VisitStmt(const Stmt& n, const Stmt& other) override;
+
+  bool VisitStmt_(const ForNode* op, const Stmt& other) override;
+  bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override;
+  bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override;
+  bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) override;
+  bool VisitStmt_(const BlockNode* op, const Stmt& other) override;
+
+  bool VisitExpr_(const AddNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const SubNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const MulNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const DivNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const ModNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const EQNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const NENode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const LTNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const LENode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const GTNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const GENode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const AndNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const OrNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const MinNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const MaxNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const FloorDivNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const FloorModNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const IntImmNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const FloatImmNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const CastNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const VarNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const SelectNode* op, const PrimExpr& other) override;
+
+  /*! \brief Map from RHS buffer to LHS buffer */
+  std::unordered_map<Buffer, Buffer, ObjectHash, ObjectEqual> rhs_buffer_map_;
+  /*! \brief Base indices of the LHS buffer. */
+  std::unordered_map<Buffer, std::vector<PrimExpr>, ObjectPtrHash, 
ObjectPtrEqual> buffer_indices_;
+
+ protected:
+  bool DefEqual(const Var& lhs, const Var& rhs);
+  virtual bool CompareBuffer(const Buffer& lhs, const Buffer& rhs);
+  bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs);
+  bool CompareAnnotation(const std::pair<String, ObjectRef>& lhs,
+                         const std::pair<String, ObjectRef>& rhs);
+  bool CompareAnnotationMap(const Map<String, ObjectRef>& lhs, const 
Map<String, ObjectRef>& rhs);
+  template <typename T>
+  bool CompareBufferAccess(const T* lhs, const T* rhs);
+  template <typename T, typename F>
+  bool CompareArray(const Array<T>& lhs, const Array<T>& rhs, F cmp);
+  bool CompareRange(const Range& lhs, const Range& rhs);
+  bool CompareIterVar(const IterVar& lhs, const IterVar& rhs);
+  void EmitError(const std::string& error_message);
+
+  /*! \brief IRModule of the LHS stmt. */
+  IRModule lhs_mod_;
+  /*! \brief Whether assertion mode is enabled. */
+  bool assert_mode_;
+  /*! \brief Whether it is visiting the scope block (the outermost block). */
+  bool is_scope_block = true;
+  /*! \brief The arithmetic analyzer. */
+  arith::Analyzer analyzer_;
+  /*! \brief Additional error messages. Only used when assert_mode is true. */
+  std::vector<std::string> error_messages_;
+  // variable remap if any
+  std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> 
equal_map_;
+};
+
+}  // namespace tir
+}  // namespace tvm
+
+#endif  // TVM_TIR_SCHEDULE_IR_COMPARATOR_H_
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index f0b38af..2368411 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -378,6 +378,24 @@ TVM_DLL void SetScope(ScheduleState self, const StmtSRef& 
block_sref, int buffer
                       const String& storage_scope);
 
 /******** Schedule: Blockize & Tensorize ********/
+
+/*!
+ * \brief Convert the subtree rooted at a specific loop into a block.
+ * \param self The state of the schedule
+ * \param loop_sref The root of the subtree
+ * \return The new block
+ */
+TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref);
+
+/*!
+ * \brief Tensorize the computation enclosed by loop with the tensor intrinsic.
+ * \param self The state of the schedule
+ * \param block_or_loop_sref The block or loop to be tensorized.
+ * \param intrin The tensor intrinsic.
+ */
+TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref,
+                       const TensorIntrin& intrin);
+
 /******** Schedule: Annotation ********/
 /*!
  * \brief Annotate a block/loop with a key value pair
diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc 
b/src/tir/schedule/primitive/blockize_tensorize.cc
new file mode 100644
index 0000000..bbeb9ca
--- /dev/null
+++ b/src/tir/schedule/primitive/blockize_tensorize.cc
@@ -0,0 +1,698 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <functional>
+
+#include "../ir_comparator.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief ScheduleError that the bindings of the inner block are not divisible 
by the subspace
+ * represented by the outer loops.
+ */
+class SubspaceNotDivisibleError : public ScheduleError {
+ public:
+  explicit SubspaceNotDivisibleError(IRModule mod, For scope_loop, Block 
inner_block)
+      : mod_(std::move(mod)),
+        scope_loop_(std::move(scope_loop)),
+        inner_block_(std::move(inner_block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The bindings of the inner block can not be 
blockized.";
+  }
+
+  String DetailRenderTemplate() const final {
+    return "ScheduleError: The bindings of the inner block {0} can not be 
blockized by the loops "
+           "starting at {1}.";
+  }
+
+  IRModule mod() const final { return mod_; }
+
+  Array<ObjectRef> LocationsOfInterest() const final { return {inner_block_, 
scope_loop_}; }
+
+ private:
+  IRModule mod_;
+  For scope_loop_;
+  Block inner_block_;
+};
+
+/*!
+ * \brief Detect if bindings are a trivial case of the subspace division where 
we can divide the
+ * block iter bindings into two categories:
+ *   1. The binding covers no inner loop vars.
+ *   2. The binding covers only inner loop vars.
+ *
+ * The bindings are not required to be quasi-affine.
+ *
+ * \param iter_vars The input iterators
+ * \param bindings The values of iter_vars
+ * \param outer_loops Iterators outside the subspace.
+ * \param inner_loops Iterators of the subspace
+ * \param predicate The predicate constraint on the input iterators.
+ * \return The result of the subspace division.
+ */
+Array<Array<arith::IterMark>> TrivialSubspaceDivision(const Array<IterVar>& 
iter_vars,
+                                                      const Array<PrimExpr>& 
bindings,
+                                                      const Array<Var>& 
outer_iters,
+                                                      const Array<Var>& 
inner_iters,
+                                                      const PrimExpr& 
predicate) {
+  if (!is_one(predicate)) return {};
+  Array<Array<arith::IterMark>> res;
+  std::unordered_set<const VarNode*> outer_loop_vars;
+  std::unordered_set<const VarNode*> inner_loop_vars;
+
+  auto make_uses_var = [](const Array<Var>& vars) -> std::function<bool(const 
PrimExpr& expr)> {
+    std::unordered_set<const VarNode*> var_set;
+    var_set.reserve(vars.size());
+    for (const Var& var : vars) {
+      var_set.insert(var.get());
+    }
+    return [var_set = std::move(var_set)](const PrimExpr& expr) -> bool {
+      return UsesVar(expr, [&var_set](const VarNode* var) {
+        return var_set.count(var);  //
+      });
+    };
+  };
+  auto use_outer_loop_vars = make_uses_var(outer_iters);
+  auto use_inner_loop_vars = make_uses_var(inner_iters);
+  arith::IterMark unit_iter_mark(arith::IterSumExpr({}, 0), 1);
+
+  for (size_t i = 0; i < bindings.size(); ++i) {
+    bool outer = use_outer_loop_vars(bindings[i]);
+    bool inner = use_inner_loop_vars(bindings[i]);
+    arith::IterMark iter_mark;
+    if (bindings[i]->IsInstance<VarNode>()) {
+      iter_mark = arith::IterMark(
+          arith::IterSplitExpr(arith::IterMark(bindings[i], 
iter_vars[i]->dom->extent)),
+          iter_vars[i]->dom->extent);
+    } else {
+      iter_mark = arith::IterMark(arith::IterSumExpr({}, bindings[i]), 
iter_vars[i]->dom->extent);
+    }
+    if (outer && !inner) {
+      res.push_back({/*outer_iter=*/iter_mark, /*inner_iter=*/unit_iter_mark});
+    } else if (inner && !outer) {
+      res.push_back({/*outer_iter=*/unit_iter_mark, /*inner_iter=*/iter_mark});
+    } else if (!outer && !inner) {
+      res.push_back({/*outer_iter=*/unit_iter_mark, 
/*inner_iter=*/unit_iter_mark});
+    } else {
+      return {};
+    }
+  }
+  res.push_back({arith::IterMark(arith::IterSumExpr({}, 0), Bool(true)),
+                 arith::IterMark(arith::IterSumExpr({}, 0), Bool(true))});
+  return res;
+}
+
+/*!
+ * \brief Generate the blockized init block.
+ * \param block The original block with init.
+ * \param inner_block_realize The block realize of the inner block after 
blockize.
+ * \param inner_loops The inner loops after blockize.
+ * \return The subtree of the init block and its outer loops.
+ */
+Stmt GenerateBlockizedInit(const Block& block, const BlockRealize& 
inner_block_realize,
+                           const std::vector<const ForNode*>& inner_loops) {
+  Array<IterVar> init_block_iters;
+  Array<PrimExpr> init_bindings;
+  const Block& inner_block = inner_block_realize->block;
+
+  // Step 1: Collect data-parallel block iters
+  for (size_t i = 0; i < inner_block->iter_vars.size(); i++) {
+    const IterVar& iter_var = inner_block->iter_vars[i];
+    const PrimExpr& binding = inner_block_realize->iter_values[i];
+    if (iter_var->iter_type == IterVarType::kDataPar &&
+        UsesVar(block->init.value(),
+                [tgt_var = iter_var->var.get()](const VarNode* var) { return 
var == tgt_var; })) {
+      init_block_iters.push_back(iter_var);
+      init_bindings.push_back(binding);
+    }
+  }
+
+  // Step 2: Collect loops related to iters of the init block
+  std::vector<const ForNode*> init_loops;
+  for (const ForNode* inner_loop : inner_loops) {
+    for (const PrimExpr& init_binding : init_bindings) {
+      if (UsesVar(init_binding, [tgt_var = inner_loop->loop_var.get()](const 
VarNode* var) {
+            return var == tgt_var;
+          })) {
+        init_loops.push_back(inner_loop);
+        break;
+      }
+    }
+  }
+
+  // Step 3: Create new block iters for the init block
+  Map<Var, PrimExpr> subst_map;
+  for (size_t i = 0; i < init_block_iters.size(); i++) {
+    IterVar new_iter_var = init_block_iters[i];
+    Var old_var = new_iter_var->var;
+    Var new_var = old_var.copy_with_suffix("_init");
+    new_iter_var.CopyOnWrite()->var = new_var;
+    subst_map.Set(old_var, new_var);
+    init_block_iters.Set(i, std::move(new_iter_var));
+  }
+
+  // Step 4: Generate loop nests and the init block
+  Stmt new_init = BlockRealize(
+      /*iter_values=*/init_bindings,
+      /*predicate=*/inner_block_realize->predicate,
+      /*block=*/
+      Block{/*iter_vars=*/init_block_iters,
+            /*reads=*/{},
+            /*writes=*/block->writes,
+            /*name_hint=*/block->name_hint + "_init",
+            /*body=*/block->init.value(),
+            /*init=*/NullOpt});
+
+  // Step 5: Generate the parent loops for the init block
+  for (const ForNode* init_loop : init_loops) {
+    ObjectPtr<ForNode> new_loop = make_object<ForNode>(*init_loop);
+    new_loop->loop_var = init_loop->loop_var.copy_with_suffix("");
+    subst_map.Set(init_loop->loop_var, new_loop->loop_var);
+    new_loop->body = std::move(new_init);
+    new_init = For(new_loop);
+  }
+
+  // Step 6: Substitute with new loop variables and block iters to prevent 
duplication of
+  // variables in the outer block.
+  new_init = Substitute(new_init, subst_map);
+
+  return new_init;
+}
+
+/*!
+ * \brief A helper to collect the parent loops of the block. The loops are 
divided into two groups,
+ * 'outer_loops', and 'inner_loops', by a specified loop as the separator. 
'outer_loops' are the
+ * ancestor loops of the separator loop. 'inner_loops' include the separator 
loop itself, and its
+ * successor loops. It is possible that 'outer_loops' is empty.
+ */
+class LoopSubspaceCollector {
+ public:
+  /*!
+   * \brief Collect the parent loops of the block and store the result in the 
corresponding fields.
+   * \param block_sref The sref to the target block.
+   * \param loop_sref The sref to the separator loop. The loop itself is 
counted as an inner loop.
+   */
+  void Collect(const StmtSRef& block_sref, const StmtSRef& loop_sref) {
+    bool inner = true;
+    for (StmtSRefNode* current_sref = block_sref->parent;
+         current_sref && current_sref->stmt->IsInstance<ForNode>();
+         current_sref = current_sref->parent) {
+      const auto* current_loop = current_sref->StmtAs<ForNode>();
+      ICHECK(current_loop);
+      if (inner) {
+        inner_loops.push_back(current_loop);
+        inner_loop_vars.push_back(current_loop->loop_var);
+      } else {
+        outer_loops.push_back(current_loop);
+        outer_loop_vars.push_back(current_loop->loop_var);
+      }
+      loop_var_domain.Set(current_loop->loop_var,
+                          Range::FromMinExtent(current_loop->min, 
current_loop->extent));
+      if (current_sref == loop_sref.get()) inner = false;
+    }
+  }
+  /*! \brief Outer loops which are ancestors of the separator. */
+  std::vector<const ForNode*> outer_loops;
+  /*! \brief Inner loops which are the separator itself or its successors. */
+  std::vector<const ForNode*> inner_loops;
+  /*! \brief Loop variables of the outer loops. */
+  Array<Var> outer_loop_vars;
+  /*! \brief Loop variables of the inner loops. */
+  Array<Var> inner_loop_vars;
+  /*! \brief Domain of the loop variables. */
+  Map<Var, Range> loop_var_domain;
+};
+
+/*!
+ * \brief Check the bindings of the block iters can be divided by a subspace 
collected by the
+ * collector.
+ * \param mod The current IR module.
+ * \param block_realize The block realize to be checked.
+ * \param collector The collector which has collected the loops of the block.
+ * \param analyzer The arithmetic analyzer.
+ * \return The result of the subspace division.
+ * \throws ScheduleError If the bindings are not divisible by the subspace.
+ */
+Array<Array<arith::IterMark>> CheckSubspaceDivisible(const IRModule& mod,
+                                                     const BlockRealize& 
block_realize,
+                                                     const 
LoopSubspaceCollector& collector,
+                                                     arith::Analyzer* 
analyzer) {
+  const Block& block = block_realize->block;
+  DiagnosticContext diag_ctx(DiagnosticContext::Default(mod));
+
+  Array<Array<arith::IterMark>> division =
+      arith::SubspaceDivide(block_realize->iter_values, 
collector.loop_var_domain,
+                            collector.inner_loop_vars, 
block_realize->predicate,
+                            /*require_bijective=*/false, analyzer, diag_ctx);
+
+  if (division.empty()) {
+    // If we can't do perfect subspace division, check if it is a trivial case 
of subspace division.
+    // In this case, we can still blockize.
+    division = TrivialSubspaceDivision(block->iter_vars, 
block_realize->iter_values,
+                                       collector.outer_loop_vars, 
collector.inner_loop_vars,
+                                       block_realize->predicate);
+  }
+  if (division.empty()) {
+    throw SubspaceNotDivisibleError(mod, 
GetRef<For>(collector.inner_loops.back()), block);
+  }
+  return division;
+}
+
+/*!
+ * \brief The binding extractor to compute the bindings of the outer and the 
inner blocks after
+ * blockize.
+ */
+class BlockizedBindingExtractor {
+ public:
+  /*!
+   * \brief Extract bindings for blockize.
+   * \param iter_vars The iter vars of the original inner block.
+   * \param division The result of the subspace division.
+   */
+  void ExtractBindings(const Array<IterVar>& iter_vars,
+                       const Array<Array<arith::IterMark>>& division, 
arith::Analyzer* analyzer) {
+    ICHECK_EQ(iter_vars.size() + 1, division.size());
+    for (size_t i = 0; i < iter_vars.size(); ++i) {
+      const IterVar& iter_var = iter_vars[i];
+      arith::IterMark outer_mark = division[i][0];
+      arith::IterMark inner_mark = division[i][1];
+      const auto* outer_binding =
+          TVM_TYPE_AS(outer_binding, outer_mark->source, 
arith::IterMapExprNode);
+      const auto* inner_binding =
+          TVM_TYPE_AS(inner_binding, inner_mark->source, 
arith::IterMapExprNode);
+
+      // After computing the subspace division, bindings[i] can be written as
+      // outer_binding * inner_binding->extent + inner_binding
+      // The outer block will have binding: iter_outer -> outer_binding
+      // The inner block will have binding: iter_inner -> inner_binding
+      // The iter in the original block will be substituted with base + 
iter_inner where
+      // base == iter_outer * iter_inner_extent
+
+      if (is_one(division[i][1]->extent)) {  // IsOuter
+        // extract this iter var to outer block directly
+        outer_bindings.push_back(
+            
arith::NormalizeIterMapToExpr(GetRef<arith::IterMapExpr>(outer_binding)));
+        outer_iter_vars.push_back(iter_var);
+      } else {
+        // create iter var for the outer block
+        const IterVar outer_var(/*dom=*/Range::FromMinExtent(0, 
division[i][0]->extent),
+                                /*var=*/iter_var->var.copy_with_suffix("_o"),
+                                /*iter_type=*/iter_var->iter_type);
+        outer_bindings.push_back(
+            
arith::NormalizeIterMapToExpr(GetRef<arith::IterMapExpr>(outer_binding)));
+        outer_iter_vars.push_back(outer_var);
+        PrimExpr base = is_one(division[i][0]->extent) ? 0 : outer_var * 
division[i][1]->extent;
+        // create iter var for the inner block
+        IterVar new_iter = iter_var;
+        auto* new_iter_node = new_iter.CopyOnWrite();
+        new_iter_node->dom = Range::FromMinExtent(0, division[i][1]->extent);
+        inner_iter_dom_map.Set(new_iter->var, 
arith::IntSet::FromRange(new_iter->dom));
+        analyzer->Bind(new_iter->var, new_iter->dom);
+        inner_iter_vars.push_back(new_iter);
+        inner_bindings.push_back(
+            
arith::NormalizeIterMapToExpr(GetRef<arith::IterMapExpr>(inner_binding)));
+        inner_iter_subst_map.Set(iter_var->var, base + new_iter->var);
+      }
+    }
+  }
+  Map<Var, PrimExpr> inner_iter_subst_map;
+  /*! \brief Iters of the outer block. */
+  Array<IterVar> outer_iter_vars;
+  /*! \brief Iters of the outer block. */
+  Array<IterVar> inner_iter_vars;
+  /*! \brief Binding values of the outer block. */
+  Array<PrimExpr> outer_bindings;
+  /*! \brief Binding values of the inner block. */
+  Array<PrimExpr> inner_bindings;
+  /*! \brief The domain of the inner block iters. */
+  Map<Var, arith::IntSet> inner_iter_dom_map;
+};
+
+/*!
+ * \brief Replacer for the inner block after blockize. Inner block iters will 
be replaced with
+ * base + inner_iter and the expressions after substituion will be simplified 
if possible.
+ */
+class InnerIterReplacer : public StmtExprMutator {
+ public:
+  /*!
+   * \brief The constructor
+   * \param subst_map The substitution map of the inner block iters.
+   * \param analyzer The arithmetic analyzer.
+   * \param block_sref_reuse The map to save the block reuse information.
+   */
+  InnerIterReplacer(Map<Var, PrimExpr> subst_map, arith::Analyzer* analyzer,
+                    Map<Block, Block>* block_sref_reuse)
+      : subst_map_(std::move(subst_map)),
+        analyzer_(analyzer),
+        block_sref_reuse_(block_sref_reuse) {}
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = subst_map_.find(GetRef<Var>(op));
+    if (it != subst_map_.end()) {
+      return (*it).second;
+    }
+    return StmtExprMutator::VisitExpr_(op);
+  }
+
+  PrimExpr VisitExpr(const PrimExpr& op) final {
+    PrimExpr result = StmtExprMutator::VisitExpr(op);
+    if (!result.same_as(op)) {
+      return analyzer_->Simplify(result);
+    }
+    return result;
+  }
+
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Stmt result = StmtExprMutator::VisitStmt_(op);
+    if (!result.same_as(GetRef<Stmt>(op))) {
+      block_sref_reuse_->Set(GetRef<Block>(op), Downcast<Block>(result));
+    }
+    return result;
+  }
+
+ private:
+  Map<Var, PrimExpr> subst_map_;
+  arith::Analyzer* analyzer_;
+  Map<Block, Block>* block_sref_reuse_;
+};
+
+/*!
+ * \brief Compute the access region of the outer block by relaxing the inner 
loops.
+ * \param buffer_region The original buffer region.
+ * \param The range of the inner loops.
+ * \return The new buffer region.
+ */
+BufferRegion RelaxBlockizedInnerIters(const BufferRegion& buffer_region,
+                                      const Map<Var, arith::IntSet>& 
inner_iter_relaxed_range) {
+  Array<Range> new_region;
+  new_region.reserve(buffer_region->region.size());
+  Array<arith::IntSet> relaxed_int_set =
+      arith::EvalSet(buffer_region->region, inner_iter_relaxed_range);
+  ICHECK(buffer_region->region.size() == buffer_region->buffer->shape.size());
+  for (size_t i = 0; i < buffer_region->region.size(); i++) {
+    Range max_range = Range::FromMinExtent(0, buffer_region->buffer->shape[i]);
+    new_region.push_back(relaxed_int_set[i].CoverRange(max_range));
+  }
+  return BufferRegion(buffer_region->buffer, std::move(new_region));
+}
+
+/*!
+ * \brief Generate the outer block after blockize.
+ * \param extractor The binding extractor which has extracted the blockized 
bindings.
+ * \param block The original inner block.
+ * \param inner_block_realize The block realize of the inner block after 
blockize.
+ * \param inner_loops The inner loops after blockize.
+ * \param predicate The outer predicate of the subspace division.
+ * \return The block realize of the outer block after blockize.
+ */
+BlockRealize GenerateBlockizedOuterBlock(const BlockizedBindingExtractor& 
extractor,
+                                         const Block& block, BlockRealize 
inner_block_realize,
+                                         const std::vector<const ForNode*>& 
inner_loops,
+                                         PrimExpr predicate) {
+  // Step 1: Generate the init block if needed
+  Optional<Stmt> new_init = NullOpt;
+  if (block->init.defined()) {
+    new_init = GenerateBlockizedInit(block, inner_block_realize, inner_loops);
+  }
+
+  // Step 2: Compute the access regions of the outer block by relaxing the 
inner loops
+  Array<BufferRegion> new_reads = block->reads;
+  Array<BufferRegion> new_writes = block->writes;
+
+  auto f_mutate = [&](const BufferRegion& buffer_region) {
+    return RelaxBlockizedInnerIters(buffer_region, 
extractor.inner_iter_dom_map);
+  };
+  new_reads.MutateByApply(f_mutate);
+  new_writes.MutateByApply(f_mutate);
+
+  // Step 3: Generate the body of the outer block. The body of the outer block 
is the inner block
+  // realize and its surrounding loops.
+  Stmt outer_block_body = inner_block_realize;
+  for (const ForNode* loop : inner_loops) {
+    ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop);
+    new_loop->body = std::move(outer_block_body);
+    outer_block_body = For(new_loop);
+  }
+
+  // Step 4: Generate the outer block and block realize.
+  return BlockRealize(/*iter_values=*/std::move(extractor.outer_bindings),
+                      /*predicate=*/std::move(predicate),
+                      /*block=*/
+                      
Block(/*iter_vars=*/std::move(extractor.outer_iter_vars),  //
+                            /*reads=*/std::move(new_reads),                    
  //
+                            /*writes=*/std::move(new_writes),                  
  //
+                            /*name_hint=*/block->name_hint + "_o",             
  //
+                            /*body=*/std::move(outer_block_body),              
  //
+                            /*init=*/std::move(new_init)));
+}
+
+StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) {
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  arith::Analyzer analyzer;
+
+  // Step 1: Check the loop has a single child BlockRealize on the sref tree.
+  BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, 
loop_sref);
+  Block block = block_realize->block;
+  StmtSRef block_sref = self->stmt2ref.at(block.get());
+
+  // Step 2: Collect loops inside and outside loop_sref.
+  LoopSubspaceCollector collector;
+  collector.Collect(block_sref, loop_sref);
+
+  // Step 3: Calculate subspace division for the inner loops.
+  Array<Array<arith::IterMark>> division =
+      CheckSubspaceDivisible(self->mod, block_realize, collector, &analyzer);
+
+  // Step 4: Generate bindings for the outer block and the inner block based 
on the result of
+  // the subspace division.
+  BlockizedBindingExtractor extractor;
+  extractor.ExtractBindings(block->iter_vars, division, &analyzer);
+  const PrimExpr& outer_pred = division.back()[0]->extent;
+  const PrimExpr& inner_pred = division.back()[1]->extent;
+
+  // Step 5: Substitute the iter vars in the original block with the inner 
iters after the subspace
+  // division
+  Map<Block, Block> block_sref_reuse;
+  InnerIterReplacer replacer(std::move(extractor.inner_iter_subst_map), 
&analyzer,
+                             &block_sref_reuse);
+  Block new_block = Downcast<Block>(replacer(block));
+
+  // Step 6: Generate the inner block.
+  BlockRealizeNode* inner_block_realize = block_realize.CopyOnWrite();
+  inner_block_realize->iter_values = extractor.inner_bindings;
+  inner_block_realize->predicate = inner_pred;
+  inner_block_realize->block = new_block;
+  BlockNode* inner_block = inner_block_realize->block.CopyOnWrite();
+  inner_block->iter_vars = extractor.inner_iter_vars;
+  inner_block->init = NullOpt;
+  block_sref_reuse.Set(block, inner_block_realize->block);
+
+  // Step 6: Generate the outer block.
+  BlockRealize outer_realize =
+      GenerateBlockizedOuterBlock(extractor, new_block, 
GetRef<BlockRealize>(inner_block_realize),
+                                  collector.inner_loops, outer_pred);
+  // Step 7: Do the actual replacement
+  self->Replace(loop_sref, outer_realize, block_sref_reuse);
+
+  // Step 8: Update the cached flags
+  StmtSRef outer_block_sref = self->stmt2ref.at(outer_realize->block.get());
+  StmtSRef scope_root = tir::GetScopeRoot(self, outer_block_sref, 
/*require_stage_pipeline=*/false,
+                                          
/*require_subtree_compact_dataflow=*/false);
+  bool scope_block_affine_binding = self->IsAffineBlockBinding(scope_root);
+  self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root));
+  self->block_info[scope_root].affine_binding = scope_block_affine_binding;
+  return outer_block_sref;
+}
+
+/*!
+ * \brief Update the map from the buffers in the desc to the impl of the tensor
+ * intrinsic.
+ * \param intrinsic The tensor intrinsic.
+ * \param buffer_map The map to be updated.
+ */
+void RemapTensorIntrinBuffers(
+    const TensorIntrin& intrinsic,
+    std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>* 
buffer_map) {
+  ICHECK_EQ(intrinsic->desc->params.size(), intrinsic->impl->params.size());
+  for (size_t i = 0; i < intrinsic->desc->params.size(); ++i) {
+    const Var& lhs_var = intrinsic->desc->params[i];
+    const Buffer& lhs_buffer = intrinsic->desc->buffer_map[lhs_var];
+    const Var& rhs_var = intrinsic->impl->params[i];
+    const Buffer& rhs_buffer = intrinsic->impl->buffer_map[rhs_var];
+    (*buffer_map)[rhs_buffer] = lhs_buffer;
+  }
+}
+
+void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref,
+               const TensorIntrin& intrinsic) {
+  /*!
+   * Check:
+   *   - Check buffer binding, including type, alignment, shape and etc.
+   *   - Check the sub AST is equal to the desc function.
+   *
+   * Mutate:
+   *   - Blockize the sub AST (please refer blockize for details)
+   *   - Bind buffers
+   *   - Mutate the impl of the tensor intrinsic by replacing its buffers with 
new
+   *     buffers created via match buffer region.
+   *   - Replace the sub tree with the mutated function.
+   */
+  const BlockRealize& desc_block_realize = 
Downcast<BlockRealize>(intrinsic->desc->body);
+  const BlockRealize& impl_block_realize = 
Downcast<BlockRealize>(intrinsic->impl->body);
+  Block impl_block = impl_block_realize->block;
+
+  // Step 1: Blockize the subtree rooted at the given loop if needed
+  StmtSRef block_sref{nullptr};
+  if (block_or_loop_sref->StmtAs<ForNode>()) {
+    block_sref = Blockize(self, block_or_loop_sref);
+  } else {
+    ICHECK(block_or_loop_sref->StmtAs<BlockNode>());
+    block_sref = block_or_loop_sref;
+  }
+  const BlockRealize& block_realize = GetBlockRealize(self, block_sref);
+
+  // Step 2: Compare the block with the desc of the tensor intrinsic, find the 
correspondence
+  // between buffers in the block and the desc.
+  TensorizeComparator comparator(self->mod, /*assert_mode=*/true);
+  comparator.VisitStmt(block_realize, desc_block_realize);
+
+  // Step 3: Find the correspondence between buffers in the current AST and 
the impl of
+  // the tensor intrinsic
+  // Step 3.1: Map from intrinsic func buffer to desc func buffer
+  std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> 
intrin_buffer_map;
+  RemapTensorIntrinBuffers(intrinsic, &intrin_buffer_map);
+  // Step 3.2: Map form intrinsic func buffer to current AST buffer
+  std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map;
+  for (const auto& pair : intrin_buffer_map) {
+    auto it = comparator.rhs_buffer_map_.find(pair.second);
+    ICHECK(it != comparator.rhs_buffer_map_.end()) << pair.second;
+    buffer_map[pair.first] = it->second;
+  }
+
+  // Step 4: Create MatchBufferRegion for the params of the impl function of 
the tensor
+  // intrin to make them subregions of the buffer in the original IR.
+  std::unordered_map<Buffer, Array<Range>, ObjectPtrHash, ObjectPtrEqual> 
buffer_region_map;
+  for (const BufferRegion& read : impl_block->reads) {
+    buffer_region_map.emplace(read->buffer, read->region);
+  }
+  for (const BufferRegion& write : impl_block->writes) {
+    buffer_region_map.emplace(write->buffer, write->region);
+  }
+  Array<MatchBufferRegion> match_buffer_regions;
+  match_buffer_regions.reserve(intrinsic->impl->params.size());
+  for (size_t i = 0; i < intrinsic->impl->params.size(); ++i) {
+    const auto& param = intrinsic->impl->params[i];
+    const auto& buffer = intrinsic->impl->buffer_map.at(param);
+    const auto& source = buffer_map.at(buffer);
+    // add the detected base indices to each buffer access region of the 
tensor intrinsic
+    Region old_region = buffer_region_map.at(buffer);
+    const auto& indices_base = comparator.buffer_indices_.at(source);
+    int offset = static_cast<int>(indices_base.size()) - 
static_cast<int>(old_region.size());
+    ICHECK(offset >= 0);
+    Region new_region;
+    new_region.reserve(source->shape.size());
+    for (int i = 0; i < offset; i++) {
+      new_region.push_back(Range::FromMinExtent(indices_base[i], 1));
+    }
+    for (int i = 0; i < static_cast<int>(old_region.size()); i++) {
+      new_region.push_back(Range::FromMinExtent(indices_base[i + offset], 
old_region[i]->extent));
+    }
+    match_buffer_regions.push_back(MatchBufferRegion(buffer, 
BufferRegion(source, new_region)));
+  }
+
+  // Step 5: Replace the subtree in the original IR with the tensor intrin 
impl.
+  ObjectPtr<BlockNode> new_block_ptr = 
make_object<BlockNode>(*block_realize->block.get());
+  new_block_ptr->body = impl_block->body;
+  ICHECK(new_block_ptr->match_buffers.empty());
+  new_block_ptr->match_buffers = std::move(match_buffer_regions);
+  Block new_block(new_block_ptr);
+
+  self->Replace(block_sref, new_block, {{block_realize->block, new_block}});
+
+  // Step 6: Update the cached flags.
+  StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, 
/*require_stage_pipeline=*/false,
+                                          
/*require_subtree_compact_dataflow=*/false);
+  self->UpdateScopeBlockInfo(static_cast<const 
BlockNode*>(scope_root->stmt)->body);
+}
+
+/******** InstructionKind Registration ********/
+
+struct BlockizeTraits : public UnpackedInstTraits<BlockizeTraits> {
+  static constexpr const char* kName = "Blockize";
+  static constexpr bool kIsPure = false;
+
+ private:
+  static constexpr size_t kNumInputs = 1;
+  static constexpr size_t kNumAttrs = 0;
+  static constexpr size_t kNumDecisions = 0;
+
+  static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) {
+    return sch->Blockize(loop_rv);
+  }
+
+  static String UnpackedAsPython(Array<String> outputs, String loop_rv) {
+    PythonAPICall py("blockize");
+    py.Input("loop", loop_rv);
+    py.SingleOutput(outputs);
+    return py.Str();
+  }
+
+  template <typename>
+  friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
+struct TensorizeTraits : public UnpackedInstTraits<TensorizeTraits> {
+  static constexpr const char* kName = "Tensorize";
+  static constexpr bool kIsPure = false;
+
+ private:
+  static constexpr size_t kNumInputs = 1;
+  static constexpr size_t kNumAttrs = 1;
+  static constexpr size_t kNumDecisions = 0;
+
+  static void UnpackedApplyToSchedule(Schedule sch, ObjectRef 
block_or_loop_rv, String intrin) {
+    if (const auto* block = block_or_loop_rv.as<BlockRVNode>()) {
+      sch->Tensorize(GetRef<BlockRV>(block), intrin);
+    } else if (const auto* loop = block_or_loop_rv.as<LoopRVNode>()) {
+      sch->Tensorize(GetRef<LoopRV>(loop), intrin);
+    } else {
+      LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: "
+                 << block_or_loop_rv->GetTypeKey();
+    }
+  }
+
+  static String UnpackedAsPython(Array<String> outputs, String 
block_or_loop_rv, String intrin) {
+    PythonAPICall py("tensorize");
+    py.Input("block_or_loop", block_or_loop_rv);
+    py.Input("intrin", intrin);
+    return py.Str();
+  }
+
+  template <typename>
+  friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
+TVM_REGISTER_INST_KIND_TRAITS(BlockizeTraits);
+TVM_REGISTER_INST_KIND_TRAITS(TensorizeTraits);
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 6e33862..b466843 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -185,6 +185,20 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign")
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope")
     .set_body_method<Schedule>(&ScheduleNode::SetScope);
 /******** (FFI) Blockize & Tensorize ********/
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize")
+    .set_body_method<Schedule>(&ScheduleNode::Blockize);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize")
+    .set_body_typed([](Schedule self, ObjectRef rv, String intrin) {
+      if (const auto* block_rv = rv.as<BlockRVNode>()) {
+        self->Tensorize(GetRef<BlockRV>(block_rv), intrin);
+      } else if (const auto* loop_rv = rv.as<LoopRVNode>()) {
+        self->Tensorize(GetRef<LoopRV>(loop_rv), intrin);
+      } else {
+        LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: 
" << rv->GetTypeKey()
+                   << ". Its value is: " << rv;
+      }
+    });
+
 /******** (FFI) Annotation ********/
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate")
     .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key,
diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
index 04b7dd5..3a37f81 100644
--- a/src/tir/schedule/state.cc
+++ b/src/tir/schedule/state.cc
@@ -201,9 +201,7 @@ class BlockInfoCollector : private StmtVisitor {
     bool is_root_block = srefs_.empty();
     // Calculate `BlockInfo::scope`
     Array<StmtSRef> child_block_srefs = std::move(block_frames_.back());
-    BlockInfo& info =
-        self_->block_info.emplace(scope_root, 
BlockInfo(BlockScope(child_block_srefs)))
-            .first->second;
+    BlockInfo& info = self_->block_info[scope_root] = 
BlockInfo(BlockScope(child_block_srefs));
     // Set `affine_binding`
     if (is_root_block) {
       // If the block doesn't have outer loops and BlockRealize,
diff --git a/src/tir/schedule/traced_schedule.cc 
b/src/tir/schedule/traced_schedule.cc
index da7a264..1e2e57e 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -356,6 +356,37 @@ void TracedScheduleNode::SetScope(const BlockRV& block_rv, 
int buffer_index,
 
 /******** Schedule: Blockize & Tensorize ********/
 
+BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv) {
+  BlockRV new_block = ConcreteScheduleNode::Blockize(loop_rv);
+  static const InstructionKind& kind = InstructionKind::Get("Blockize");
+  trace_->Append(/*inst=*/Instruction(
+      /*kind=*/kind,
+      /*inputs=*/{loop_rv},
+      /*attrs=*/{},
+      /*outputs=*/{new_block}));
+  return new_block;
+}
+
+void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& 
intrin) {
+  ConcreteScheduleNode::Tensorize(loop_rv, intrin);
+  static const InstructionKind& kind = InstructionKind::Get("Tensorize");
+  trace_->Append(/*inst=*/Instruction(
+      /*kind=*/kind,
+      /*inputs=*/{loop_rv},
+      /*attrs=*/{intrin},
+      /*outputs=*/{}));
+}
+
+void TracedScheduleNode::Tensorize(const BlockRV& block_rv, const String& 
intrin) {
+  ConcreteScheduleNode::Tensorize(block_rv, intrin);
+  static const InstructionKind& kind = InstructionKind::Get("Tensorize");
+  trace_->Append(/*inst=*/Instruction(
+      /*kind=*/kind,
+      /*inputs=*/{block_rv},
+      /*attrs=*/{intrin},
+      /*outputs=*/{}));
+}
+
 /******** Schedule: Annotation ********/
 
 void TracedScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key,
diff --git a/src/tir/schedule/traced_schedule.h 
b/src/tir/schedule/traced_schedule.h
index b35f1b6..3a88e86 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -87,6 +87,9 @@ class TracedScheduleNode : public ConcreteScheduleNode {
                     int offset) final;
   void SetScope(const BlockRV& block_rv, int buffer_index, const String& 
storage_scope) final;
   /******** Schedule: Blockize & Tensorize ********/
+  BlockRV Blockize(const LoopRV& loop_rv) final;
+  void Tensorize(const BlockRV& block_rv, const String& intrin) final;
+  void Tensorize(const LoopRV& loop_rv, const String& intrin) final;
   /******** Schedule: Annotation ********/
   void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& 
ann_val) override;
   void Unannotate(const LoopRV& loop_rv, const String& ann_key) override;
diff --git a/tests/python/unittest/test_tir_schedule_blockize.py 
b/tests/python/unittest/test_tir_schedule_blockize.py
new file mode 100644
index 0000000..b4a16a8
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_blockize.py
@@ -0,0 +1,210 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+import pytest
+import tvm
+from tvm.script import tir as T
+from tvm import tir
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+
+# fmt: off
+# pylint: 
disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
+
[email protected]_func
+def single_elementwise(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 
128), "float32"]):
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B[vi, vj] = A[vi, vj] * 2.0
+
+
[email protected]_func
+def single_elementwise_blockized1(
+    A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]
+) -> None:
+    with T.block("blockized_B"):
+        vio = T.axis.spatial(1, 0)
+        vjo = T.axis.spatial(1, 0)
+        T.reads(A[0:128, 0:128])
+        T.writes(B[0:128, 0:128])
+        for i, j in T.grid(128, 128):
+            with T.block("B"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                T.reads(A[vi, vj])
+                T.writes(B[vi, vj])
+                B[vi, vj] = A[vi, vj] * T.float32(2)
+
+
[email protected]_func
+def single_elementwise_blockized2(
+    A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]
+) -> None:
+    for i in T.serial(128):
+        with T.block("blockized_B"):
+            vi = T.axis.spatial(128, i)
+            vjo = T.axis.spatial(1, 0)
+            T.reads(A[vi, 0:128])
+            T.writes(B[vi, 0:128])
+            for j in T.serial(128):
+                with T.block("B"):
+                    vj = T.axis.remap("S", [j])
+                    T.reads(A[vi, vj])
+                    T.writes(B[vi, vj])
+                    B[vi, vj] = A[vi, vj] * T.float32(2)
+
+
[email protected]_func
+def two_elementwise(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 
128), "float32"]) -> None:
+    B = T.alloc_buffer([128, 128], dtype="float32")
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            T.reads(A[vi, vj])
+            T.writes(B[vi, vj])
+            B[vi, vj] = A[vi, vj] * T.float32(2)
+    for i, j in T.grid(128, 128):
+        with T.block("C"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            T.reads(B[vi, vj])
+            T.writes(C[vi, vj])
+            C[vi, vj] = B[vi, vj] + T.float32(1)
+
+
[email protected]_func
+def two_elementwise_blockized(
+    A: T.Buffer[(128, 128), "float32"],
+    C: T.Buffer[(128, 128), "float32"]
+) -> None:
+    B = T.alloc_buffer([128, 128], dtype="float32")
+    for i_0, j_0 in T.grid(8, 8):
+        with T.block("blockized_B"):
+            vio, vjo = T.axis.remap("SS", [i_0, j_0])
+            T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
+            T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
+            for i_1, j_1 in T.grid(16, 16):
+                with T.block("B"):
+                    vi, vj = T.axis.remap("SS", [i_1, j_1])
+                    T.reads(A[vio * 16 + vi, vjo * 16 + vj])
+                    T.writes(B[vio * 16 + vi, vjo * 16 + vj])
+                    B[vio * 16 + vi, vjo * 16 + vj] = A[vio * 16 + vi, vjo * 
16 + vj] * T.float32(2)
+        with T.block("blockized_C"):
+            vio, vjo = T.axis.remap("SS", [i_0, j_0])
+            T.reads(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
+            T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
+            for ax0, ax1 in T.grid(16, 16):
+                with T.block("C"):
+                    vi, vj = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(B[vio * 16 + vi, vjo * 16 + vj])
+                    T.writes(C[vio * 16 + vi, vjo * 16 + vj])
+                    C[vio * 16 + vi, vjo * 16 + vj] = B[vio * 16 + vi, vjo * 
16 + vj] + T.float32(1)
+
+
[email protected]_func
+def rowsum(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"]) 
-> None:
+    for k, i in T.grid(128, 128):
+        with T.block("B"):
+            vk, vi = T.axis.remap("RS", [k, i])
+            with T.init():
+                B[vi] = 0.0
+            B[vi] = B[vi] + A[vi, vk]
+
+
[email protected]_func
+def rowsum_blockized(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), 
"float32"]) -> None:
+    with T.block("blockized_B"):
+        vko = T.axis.R(1, 0)
+        vio = T.axis.S(1, 0)
+        with T.init():
+            for i1 in T.serial(0, 128):
+                with T.block("B_init"):
+                    vi_init = T.axis.S(128, i1)
+                    B[vi_init] = T.float32(0)
+        for i0, i1_1 in T.grid(128, 128):
+            with T.block("B"):
+                vk, vi = T.axis.remap("RS", [i0, i1_1])
+                B[vi] = B[vi] + A[vi, vk]
+
+
+# fmt: off
+# pylint: 
disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
+
+def test_blockize_outer():
+    func = single_elementwise
+    # schedule
+    s = tir.Schedule(func, debug_mask="all")
+    B = s.get_block("B")
+    x, y = s.get_loops(B)
+    s.blockize(x)
+    print(s.mod['main'].script())
+    tvm.ir.assert_structural_equal(s.mod["main"], 
single_elementwise_blockized1)
+    verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_blockize_inner():
+    func = single_elementwise
+    # schedule
+    s = tir.Schedule(func, debug_mask="all")
+    B = s.get_block("B")
+    x, y = s.get_loops(B)
+    s.blockize(y)
+    tvm.ir.assert_structural_equal(s.mod["main"], 
single_elementwise_blockized2)
+    verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_two_elementwise_blockize_reverse_compute_at():
+    func = two_elementwise
+    s = tir.Schedule(func, debug_mask="all")
+    B = s.get_block("B")
+    C = s.get_block("C")
+    x, y = s.get_loops(B)
+    xo, xi = s.split(x, factors=[None, 16])
+    yo, yi = s.split(y, factors=[None, 16])
+    s.reorder(xo, yo, xi, yi)
+    s.blockize(xi)
+    s.reverse_compute_at(C, yo)
+    s.blockize(s.get_loops(C)[-2])
+    tvm.ir.assert_structural_equal(s.mod["main"], two_elementwise_blockized)
+    verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_two_elementwise_blockize_compute_at():
+    func = two_elementwise
+    s = tir.Schedule(func, debug_mask="all")
+    B = s.get_block("B")
+    C = s.get_block("C")
+    x, y = s.get_loops(C)
+    xo, xi = s.split(x, factors=[None, 16])
+    yo, yi = s.split(y, factors=[None, 16])
+    s.reorder(xo, yo, xi, yi)
+    s.blockize(xi)
+    s.compute_at(B, yo)
+    s.blockize(s.get_loops(B)[-2])
+    tvm.ir.assert_structural_equal(s.mod["main"], two_elementwise_blockized)
+    verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_blockize_init_loops():
+    s = tir.Schedule(rowsum, debug_mask="all")
+    k, _ = s.get_loops(s.get_block("B"))
+    s.blockize(k)
+    tvm.ir.assert_structural_equal(s.mod["main"], rowsum_blockized)
+    verify_trace_roundtrip(sch=s, mod=rowsum)
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main([__file__] + sys.argv[1:]))
diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py 
b/tests/python/unittest/test_tir_schedule_tensorize.py
new file mode 100644
index 0000000..401a39f
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_tensorize.py
@@ -0,0 +1,431 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+import pytest
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm.script import tir as T
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+
+# fmt: off
+# pylint: 
disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
+
[email protected]_func
+def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
+    B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
+    C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
+
+    with T.block("root"):
+        T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16])
+        T.writes(C[0 : 16, 0 : 16])
+        for i, j, k in T.grid(16, 16, 16):
+            with T.block("update"):
+                vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
+                C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk]
+
+
[email protected]_func
+def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
+    B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
+    C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
+
+    with T.block("root"):
+        T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16])
+        T.writes(C[0 : 16, 0 : 16])
+        T.evaluate(
+            T.tvm_mma_sync(
+                C.data,
+                C.elem_offset // 256,
+                A.data,
+                A.elem_offset // 256,
+                B.data,
+                B.elem_offset // 256,
+                C.data,
+                C.elem_offset // 256,
+                dtype="handle",
+            )
+        )
+
+
[email protected]_func
+def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (4,))
+    B = T.match_buffer(b, (4,))
+    C = T.match_buffer(c, ())
+
+    with T.block("root"):
+        T.reads(C[()], A[0 : 4], B[0 : 4])
+        T.writes(C[()])
+        for i in range(0, 4):
+            with T.block("update"):
+                vi = T.axis.remap("R", [i])
+                C[()] = C[()] + A[vi] * B[vi]
+
+
[email protected]_func
+def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (4,), offset_factor=1)
+    B = T.match_buffer(b, (4,), offset_factor=1)
+    C = T.match_buffer(c, (), offset_factor=1)
+
+    with T.block("root"):
+        T.reads(C[()], A[0 : 4], B[0 : 4])
+        T.writes(C[()])
+        T.evaluate(
+            T.call_extern(
+                "vec4add",
+                C.data,
+                C.elem_offset,
+                A.data,
+                A.elem_offset,
+                B.data,
+                B.elem_offset,
+                dtype="int32",
+            )
+        )
+
+
[email protected]_func
+def outer_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (16, 1), offset_factor=1)
+    B = T.match_buffer(b, (16, 1), offset_factor=1)
+    C = T.match_buffer(c, (16, 16), offset_factor=1)
+
+    with T.block("root"):
+        T.reads(
+            C[0 : 16, 0 : 16],
+            A[0 : 16, 0 : 1],
+            B[0 : 16, 0 : 1],
+        )
+        T.writes(C[0 : 16, 0 : 16])
+        for i, j in T.grid(16, 16):
+            with T.block("update"):
+                vii, vjj = T.axis.remap("SS", [i, j])
+                C[vii, vjj] = C[vii, vjj] + A[vii, 0] * B[vjj, 0]
+
+
[email protected]_func
+def outer_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (16, 1), offset_factor=1)
+    B = T.match_buffer(b, (16, 1), offset_factor=1)
+    C = T.match_buffer(c, (16, 16), offset_factor=1)
+
+    with T.block("root"):
+        T.reads(
+            C[0 : 16, 0 : 16],
+            A[0 : 16, 0 : 1],
+            B[0 : 16, 0 : 1],
+        )
+        T.writes(C[0 : 16, 0 : 16])
+        T.evaluate(
+            T.call_extern(
+                "outer_product",
+                C.data,
+                C.elem_offset,
+                A.data,
+                A.elem_offset,
+                B.data,
+                B.elem_offset,
+                dtype="int32",
+            )
+        )
+
+
[email protected]_func
+def matmul(
+    A: T.Buffer[(128, 128), "float32"],
+    B: T.Buffer[(128, 128), "float32"],
+    C: T.Buffer[(128, 128), "float32"],
+) -> None:
+    for i, j, k in T.grid(128, 128, 128):
+        with T.block("update"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            with T.init():
+                C[vi, vj] = T.float32(0)
+            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
[email protected]_func
+def tensorized_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
+    C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, 
offset_factor=1)
+    B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, 
offset_factor=1)
+    A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, 
offset_factor=1)
+
+    for i_outer, j_outer in T.grid(8, 8):
+        for i_inner_init, j_inner_init in T.grid(16, 16):
+            with T.block("init"):
+                vi_init = T.axis.S(128, ((i_outer * 16) + i_inner_init))
+                vj_init = T.axis.S(128, ((j_outer * 16) + j_inner_init))
+                C[vi_init, vj_init] = T.float32(0)
+        for k_outer in T.grid(8):
+            with T.block("update"):
+                vi, vj, vk = T.axis.remap("SSR", [i_outer, j_outer, k_outer])
+                T.reads(
+                    [
+                        C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
+                        A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
+                        B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
+                    ]
+                )
+                T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+                A_elem_offset = T.var("int32")
+                B_elem_offset = T.var("int32")
+                C_elem_offset = T.var("int32")
+                A_sub = T.match_buffer(
+                    A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
+                    [16, 16],
+                    elem_offset=A_elem_offset,
+                )
+                B_sub = T.match_buffer(
+                    B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
+                    [16, 16],
+                    elem_offset=B_elem_offset,
+                )
+                C_sub = T.match_buffer(
+                    C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
+                    [16, 16],
+                    elem_offset=C_elem_offset,
+                )
+                T.evaluate(
+                    T.tvm_mma_sync(
+                        C_sub.data,
+                        T.floordiv(C_sub.elem_offset, 256),
+                        A_sub.data,
+                        T.floordiv(A_sub.elem_offset, 256),
+                        B_sub.data,
+                        T.floordiv(B_sub.elem_offset, 256),
+                        C_sub.data,
+                        T.floordiv(C_sub.elem_offset, 256),
+                        dtype="handle",
+                    )
+                )
+
+
[email protected]_func
+def batch_matmul(
+    A: T.Buffer[(16, 128, 128), "float32"],
+    B: T.Buffer[(16, 128, 128), "float32"],
+    C: T.Buffer[(16, 128, 128), "float32"],
+) -> None:
+    for n, i, j in T.grid(16, 128, 128):
+        with T.block("init"):
+            vn, vi, vj = T.axis.remap("SSS", [n, i, j])
+            C[vn, vi, vj] = T.float32(0)
+
+    for n, i, j, k in T.grid(16, 128, 128, 128):
+        with T.block("update"):
+            vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
+            C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk]
+
+
[email protected]_func
+def tensorized_batch_matmul_mma(
+    A: T.Buffer[(16, 128, 128), "float32"],
+    B: T.Buffer[(16, 128, 128), "float32"],
+    C: T.Buffer[(16, 128, 128), "float32"],
+) -> None:
+    for n, i, j in T.grid(16, 128, 128):
+        with T.block("init"):
+            vn, vi, vj = T.axis.remap("SSS", [n, i, j])
+            T.reads()
+            T.writes(C[vn, vi, vj])
+            C[vn, vi, vj] = T.float32(0)
+    for n in range(0, 16):
+        for i, j, k in T.grid(8, 8, 8):
+            with T.block("update"):
+                vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
+                T.reads(
+                    C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 
16],
+                    A[vn : vn + 1, vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 
16],
+                    B[vn : vn + 1, vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 
16],
+                )
+                T.writes(C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 
16 + 16])
+                A_elem_offset = T.var("int32")
+                B_elem_offset = T.var("int32")
+                C_elem_offset = T.var("int32")
+                A_sub = T.match_buffer(
+                    A[vn : vn + 1, vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 
16],
+                    (16, 16),
+                    elem_offset=A_elem_offset,
+                )
+                B_sub = T.match_buffer(
+                    B[vn : vn + 1, vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 
16],
+                    (16, 16),
+                    elem_offset=B_elem_offset,
+                )
+                C_sub = T.match_buffer(
+                    C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 
16],
+                    (16, 16),
+                    elem_offset=C_elem_offset,
+                )
+                T.evaluate(
+                    T.tvm_mma_sync(
+                        C_sub.data,
+                        T.floordiv(C_sub.elem_offset, 256),
+                        A_sub.data,
+                        T.floordiv(A_sub.elem_offset, 256),
+                        B_sub.data,
+                        T.floordiv(B_sub.elem_offset, 256),
+                        C_sub.data,
+                        T.floordiv(C_sub.elem_offset, 256),
+                        dtype="handle",
+                    )
+                )
+
+
[email protected]_func
+def tensorized_batch_matmul_dot_product(
+    A: T.Buffer[(16, 128, 128), "float32"],
+    B: T.Buffer[(16, 128, 128), "float32"],
+    C: T.Buffer[(16, 128, 128), "float32"],
+) -> None:
+    for n, i, j in T.grid(16, 128, 128):
+        with T.block("init"):
+            vn, vi, vj = T.axis.remap("SSS", [n, i, j])
+            T.reads()
+            T.writes(C[vn, vi, vj])
+            C[vn, vi, vj] = T.float32(0)
+    for n, i, j, k_0 in T.grid(16, 128, 128, 32):
+        with T.block("blockized_update"):
+            vn, vi, vj, vko = T.axis.remap("SSSR", [n, i, j, k_0])
+            T.reads(
+                C[vn, vi, vj], A[vn, vi, vko * 4 : vko * 4 + 4], B[vn, vj, vko 
* 4 : vko * 4 + 4]
+            )
+            T.writes(C[vn, vi, vj])
+            A_1 = T.match_buffer(
+                A[vn, vi, vko * 4 : vko * 4 + 4], [4], dtype="float32", 
offset_factor=1
+            )
+            B_1 = T.match_buffer(
+                B[vn, vj, vko * 4 : vko * 4 + 4], [4], dtype="float32", 
offset_factor=1
+            )
+            C_1 = T.match_buffer(C[vn, vi, vj], [], dtype="float32", 
offset_factor=1)
+            T.evaluate(
+                T.call_extern(
+                    "vec4add",
+                    C_1.data,
+                    C_1.elem_offset,
+                    A_1.data,
+                    A_1.elem_offset,
+                    B_1.data,
+                    B_1.elem_offset,
+                    dtype="int32",
+                )
+            )
+
+
[email protected]_func
+def tensorized_batch_matmul_outer_product(
+    A: T.Buffer[(16, 128, 128), "float32"],
+    B: T.Buffer[(16, 128, 128), "float32"],
+    C: T.Buffer[(16, 128, 128), "float32"],
+) -> None:
+    for n, i, j in T.grid(16, 128, 128):
+        with T.block("init"):
+            vn, vi, vj = T.axis.remap("SSS", [n, i, j])
+            T.reads()
+            T.writes(C[vn, vi, vj])
+            C[vn, vi, vj] = T.float32(0)
+    for n, i_0, j_0, k in T.grid(16, 8, 8, 128):
+        with T.block("blockized_update"):
+            vn, vio, vjo, vk = T.axis.remap("SSSR", [n, i_0, j_0, k])
+            T.reads(
+                C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16],
+                A[vn, vio * 16 : vio * 16 + 16, vk],
+                B[vn, vjo * 16 : vjo * 16 + 16, vk],
+            )
+            T.writes(C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
+            A_1 = T.match_buffer(A[vn, vio * 16 : vio * 16 + 16, vk], [16, 1], 
dtype="float32", offset_factor=1)
+            B_1 = T.match_buffer(B[vn, vjo * 16 : vjo * 16 + 16, vk], [16, 1], 
dtype="float32", offset_factor=1
+            )
+            C_1 = T.match_buffer(
+                C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], 
[16, 16], dtype="float32", offset_factor=1
+            )
+            T.evaluate(
+                T.call_extern("outer_product", C_1.data, C_1.elem_offset, 
A_1.data, A_1.elem_offset,
+                              B_1.data, B_1.elem_offset, dtype="int32"
+                )
+            )
+
+
+# fmt: off
+# pylint: 
disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
+
+tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin)
+tir.TensorIntrin.register("test_dot_product_intrin", dot_product_desc, 
dot_product_intrin)
+tir.TensorIntrin.register("test_outer_product_intrin", outer_product_desc, 
outer_product_intrin)
+
+
+def test_tensorize_matmul():
+    func = matmul
+    # schedule
+    s = tir.Schedule(func, debug_mask="all")
+    update = s.get_block("update")
+    i, j, k = s.get_loops(update)
+    io, ii = s.split(i, factors=[None, 16])
+    jo, ji = s.split(j, factors=[None, 16])
+    ko, ki = s.split(k, factors=[None, 16])
+    s.reorder(io, jo, ko, ii, ji, ki)
+    s.decompose_reduction(update, ko)
+    s.tensorize(ii, "test_mma_intrin")
+    tvm.ir.assert_structural_equal(tensorized_matmul, s.mod["main"])
+    verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_tensorize_batch_matmul():
+    func = batch_matmul
+    s = tir.Schedule(func, debug_mask="all")
+    update = s.get_block("update")
+    _, i, j, k = s.get_loops(update)
+    io, ii = s.split(i, factors=[None, 16])
+    jo, ji = s.split(j, factors=[None, 16])
+    ko, ki = s.split(k, factors=[None, 16])
+    s.reorder(io, jo, ko, ii, ji, ki)
+    s.tensorize(ii, "test_mma_intrin")
+    tvm.ir.assert_structural_equal(tensorized_batch_matmul_mma, s.mod["main"])
+    verify_trace_roundtrip(sch=s, mod=batch_matmul)
+
+
+def test_tensorize_dot_product():
+    func = batch_matmul
+    s = tir.Schedule(func, debug_mask="all")
+    C = s.get_block("update")
+    _, _, _, k = s.get_loops(C)
+    _, ki = s.split(k, factors=[None, 4])
+    s.tensorize(ki, "test_dot_product_intrin")
+    tvm.ir.assert_structural_equal(tensorized_batch_matmul_dot_product, 
s.mod["main"])
+    verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_tensorize_outer_product():
+    func = batch_matmul
+    s = tir.Schedule(func, debug_mask="all")
+    C = s.get_block("update")
+    _, i, j, k = s.get_loops(C)
+    io, ii = s.split(i, factors=[None, 16])
+    jo, ji = s.split(j, factors=[None, 16])
+    s.reorder(io, jo, k, ii, ji)
+    s.tensorize(ii, "test_outer_product_intrin")
+    tvm.ir.assert_structural_equal(tensorized_batch_matmul_outer_product, 
s.mod["main"])
+    verify_trace_roundtrip(sch=s, mod=func)
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main([__file__] + sys.argv[1:]))

Reply via email to