junrushao1994 commented on a change in pull request #9871:
URL: https://github.com/apache/tvm/pull/9871#discussion_r780587280



##########
File path: include/tvm/tir/function.h
##########
@@ -187,6 +187,56 @@ class LinkedParam : public ObjectRef {
   TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
 };
 
+/*!
+ * \brief Tensor intrinsics for tensorization
+ */
+class TensorIntrinNode : public Object {
+ public:
+  /*! \brief The function to describe the computation. */
+  PrimFunc description;
+  /*! \brief The intrinsic function for lower-level implementation. */
+  PrimFunc implementation;

Review comment:
       Which one do you prefer, shortened name or full name? Like, `desc` and 
`impl` vs `description` and `implementation`

##########
File path: include/tvm/tir/function.h
##########
@@ -187,6 +187,56 @@ class LinkedParam : public ObjectRef {
   TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
 };
 
+/*!
+ * \brief Tensor intrinsics for tensorization
+ */
+class TensorIntrinNode : public Object {
+ public:
+  /*! \brief The function to describe the computation. */
+  PrimFunc description;
+  /*! \brief The intrinsic function for lower-level implementation. */
+  PrimFunc implementation;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("description", &description);
+    v->Visit("implementation", &implementation);
+  }
+
+  static constexpr const char* _type_key = "tir.TensorIntrin";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object);
+};
+
+/*!
+ * \brief Managed reference to TensorIntrinNode.
+ */
+class TensorIntrin : public ObjectRef {
+ public:
+  /*!
+   * \brief Constructor
+   * \param desc_func The function to describe the computation.
+   * \param intrin_func The intrinsic function for lower-level implementation.
+   */
+  TVM_DLL explicit TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func);
+
+  /*!
+   * \brief Create and register a TensorIntrin. After registration, the 
TensorIntrin can be looked
+   * up with its name.
+   * \param name The name of the TensorIntrin to register
+   * \param desc_func The function to describe the computation.
+   * \param intrin_func The intrinsic function for lower-level implementation.
+   * \return The created TensorIntrin.
+   */
+  TVM_DLL static TensorIntrin Register(String name, PrimFunc desc_func, 
PrimFunc intrin_func);

Review comment:
       Let's be more idiomatic in registry design
   
   ```suggestion
     TVM_DLL static TensorIntrin Register(String name, TensorIntrin intrin);
   ```

##########
File path: include/tvm/tir/schedule/schedule.h
##########
@@ -465,6 +465,25 @@ class ScheduleNode : public runtime::Object {
    */
   virtual void SetScope(const BlockRV& block_rv, int buffer_index, const 
String& storage_scope) = 0;
   /******** Schedule: Blockize & Tensorize ********/
+  /*!
+   * \brief Convert the subtree rooted at a specific loop into a block.
+   * \param loop_rv the root of the subtree
+   * \return the new block
+   */
+  virtual BlockRV Blockize(const LoopRV& loop_rv) = 0;
+  /*!
+   * \brief Tensorize the computation enclosed by loop with tensor_intrin
+   * \param loop_rv the loop to be tensorized
+   * \param intrin the tensor intrinsic
+   */
+  virtual void Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) = 
0;

Review comment:
       Shall we hide this method to internal impl given:
   - TracedSchedule is the default go-to
   - This method isn't well supported by TracedSchedule

##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -1733,6 +1733,258 @@ def after_set_scope(
 
     ########## Schedule: Blockize & Tensorize ##########
 
+    @type_checked
+    def blockize(self, loop: LoopRV) -> BlockRV:
+        """Convert the subtree rooted at a specific loop into a block.
+
+        Parameters
+        ----------
+        loop : LoopRV
+            The root of the subtree.
+
+        Returns
+        -------
+        result : BlockRV
+            The new block.
+
+        Examples
+        --------
+
+        Before blockize, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def before_blockize(
+                A: T.Buffer[(128, 128), "float32"],
+                B: T.Buffer[(128, 128), "float32"]
+            ) -> None:
+                for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16):
+                    with T.block("B"):
+                        vi = T.axis.spatial(128, i_0 * 16 + i_1)
+                        vj = T.axis.spatial(128, j_0 * 16 + j_1)
+                        T.reads(A[vi, vj])
+                        T.writes(B[vi, vj])
+                        B[vi, vj] = A[vi, vj] * T.float32(2)
+
+        Create the schedule and do set_scope:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_blockize)
+            B = sch.get_block("B")
+            _, _, i1, _ = sch.get_loops(B)
+            sch.blockize(i1)
+            print(sch.mod["main"].script())
+
+        After applying blockize, the IR becomes:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def after_blockize(
+                A: T.Buffer[(128, 128), "float32"],
+                B: T.Buffer[(128, 128), "float32"]
+            )-> None:
+                for i_0, j_0 in T.grid(8, 8):
+                    with T.block("blockized_B"):
+                        vio, vjo = T.axis.remap("SS", [i_0, j_0])
+                        T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 
16 + 16])
+                        T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 
16 + 16])
+                        for i_1, j_1 in T.grid(16, 16):
+                            with T.block("B"):
+                                vi = T.axis.spatial(128, vio * 16 + i_1)
+                                vj = T.axis.spatial(128, vjo * 16 + j_1)
+                                T.reads(A[vi, vj])
+                                T.writes(B[vi, vj])
+                                B[vi, vj] = A[vi, vj] * T.float32(2)
+
+        Note
+        ----
+        blockize requires there is exactly one block under the given loop and 
the bindings of the
+        block are divisible by the subspace represented by the loops starting 
at the given loop.
+
+        """
+
+        return _ffi_api.ScheduleBlockize(self, loop)  # type: ignore # pylint: 
disable=no-member
+
+    @type_checked
+    def tensorize(self, loop: LoopRV, tensor_intrin: Union[str, TensorIntrin]) 
-> None:
+        """Tensorize the computation enclosed by loop with the tensor 
intrinsic.
+
+        Parameters
+        ----------
+        loop : LoopRV
+            The loop to be tensorized.
+        tensor_intrin : Union[str, TensorIntrin]
+            The tensor intrin or the name of the tensor intrin.
+
+        Examples
+        --------
+
+        Before tensorize, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def before_tensorize(
+                A: T.Buffer[(128, 128), "float32"],
+                B: T.Buffer[(128, 128), "float32"],
+                C: T.Buffer[(128, 128), "float32"],
+            ) -> None:
+                # body
+                # with T.block("root")
+                for i_0, j_0 in T.grid(8, 8):
+                    for i_1_init, j_1_init in T.grid(16, 16):
+                        with T.block("init"):
+                            vi = T.axis.spatial(128, i_0 * 16 + i_1_init)
+                            vj = T.axis.spatial(128, j_0 * 16 + j_1_init)
+                            T.reads()
+                            T.writes(C[vi, vj])
+                            C[vi, vj] = T.float32(0)
+                    for k_0, i_1, j_1, k_1 in T.grid(8, 16, 16, 16):
+                        with T.block("update"):
+                            vi = T.axis.spatial(128, i_0 * 16 + i_1)
+                            vj = T.axis.spatial(128, j_0 * 16 + j_1)
+                            vk = T.axis.reduce(128, k_0 * 16 + k_1)
+                            T.reads(C[vi, vj], A[vi, vk], B[vj, vk])
+                            T.writes(C[vi, vj])
+                            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+        Declare and register the tensor intrinsic:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+                A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
+                B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
+                C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
+
+                with T.block("root"):
+                    vi = T.axis.S(16, 0)
+                    vj = T.axis.S(16, 0)
+                    vk = T.axis.R(16, 0)
+                    for i, j, k in T.grid(16, 16, 16):
+                        with T.block("update"):
+                            vii = T.axis.S(16, vi + i)
+                            vjj = T.axis.S(16, vj + j)
+                            vkk = T.axis.R(16, vk + k)
+                            C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, 
vkk]
+
+
+            @T.prim_func
+            def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+                A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
+                B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
+                C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
+
+                with T.block("root"):
+                    vi = T.axis.S(16, 0)
+                    vj = T.axis.S(16, 0)
+                    vk = T.axis.R(16, 0)
+                    T.reads(
+                        [
+                            C[vi : vi + 16, vj : vj + 16],
+                            A[vi : vi + 16, vk : vk + 16],
+                            B[vj : vj + 16, vk : vk + 16],
+                        ]
+                    )
+                    T.writes(C[vi : vi + 16, vj : vj + 16])
+                    T.evaluate(
+                        T.tvm_mma_sync(
+                            C.data,
+                            C.elem_offset // 256,
+                            A.data,
+                            A.elem_offset // 256,
+                            B.data,
+                            B.elem_offset // 256,
+                            C.data,
+                            C.elem_offset // 256,
+                            dtype="handle",
+                        )
+                    )
+
+            tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin)
+
+        Create the schedule and do tensorize:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_tensorize)
+            update = s.get_block("update")
+            _, _, _, i1, _, _ = s.get_loops(update)
+            s.tensorize(ii, "test_mma_intrin")
+            print(sch.mod["main"].script())
+
+        After applying tensorize, the IR becomes:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def after_tensoirze(
+                A: T.Buffer[(128, 128), "float32"],
+                B: T.Buffer[(128, 128), "float32"],
+                C: T.Buffer[(128, 128), "float32"],
+            ) -> None:
+                # body
+                # with T.block("root")
+                for i_0, j_0 in T.grid(8, 8):
+                    for i_1_init, j_1_init in T.grid(16, 16):
+                        with T.block("init"):
+                            vi = T.axis.spatial(128, i_0 * 16 + i_1_init)
+                            vj = T.axis.spatial(128, j_0 * 16 + j_1_init)
+                            T.reads()
+                            T.writes(C[vi, vj])
+                            C[vi, vj] = T.float32(0)
+                    for k_0 in T.serial(8):
+                        with T.block("blockized_update"):
+                            vio, vjo, vko = T.axis.remap("SSR", [i_0, j_0, 
k_0])
+                            T.reads(
+                                C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 
16 + 16],
+                                A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 
16 + 16],
+                                B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 
16 + 16],
+                            )
+                            T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : 
vjo * 16 + 16])
+                            A_1 = T.match_buffer(
+                                A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 
16 + 16],
+                                [16, 16],
+                                dtype="float32",
+                                offset_factor=1,
+                            )
+                            B_1 = T.match_buffer(
+                                B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 
16 + 16],
+                                [16, 16],
+                                dtype="float32",
+                                offset_factor=1,
+                            )
+                            C_1 = T.match_buffer(
+                                C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 
16 + 16],
+                                [16, 16],
+                                dtype="float32",
+                                offset_factor=1,
+                            )
+                            T.evaluate(
+                                T.tvm_mma_sync(
+                                    C_1.data,
+                                    C_1.elem_offset // 256,
+                                    A_1.data,
+                                    A_1.elem_offset // 256,
+                                    B_1.data,
+                                    B_1.elem_offset // 256,
+                                    C_1.data,
+                                    C_1.elem_offset // 256,
+                                    dtype="handle",
+                                )
+                            )
+
+        """
+        if isinstance(tensor_intrin, str):
+            tensor_intrin = String(tensor_intrin)

Review comment:
       Do we need this conversion? I suppose `PackedFuncValueConverter` is able 
to handle this :-)

##########
File path: src/tir/ir/function.cc
##########
@@ -64,6 +64,67 @@ FuncType PrimFuncNode::func_type_annotation() const {
 
 TVM_REGISTER_NODE_TYPE(PrimFuncNode);
 
+class TensorIntrinManager {
+ public:
+  Map<String, tir::TensorIntrin> reg;
+
+  static TensorIntrinManager* Global() {
+    static TensorIntrinManager* inst = new TensorIntrinManager();
+    return inst;
+  }
+};
+
+TensorIntrin::TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func) {
+  // check the number of func var is equal
+  CHECK_EQ(desc_func->params.size(), intrin_func->params.size());
+  CHECK_EQ(desc_func->buffer_map.size(), intrin_func->buffer_map.size());

Review comment:
       Let's write some informative error message here given it's user-facing 
CHECK

##########
File path: src/tir/ir/function.cc
##########
@@ -64,6 +64,67 @@ FuncType PrimFuncNode::func_type_annotation() const {
 
 TVM_REGISTER_NODE_TYPE(PrimFuncNode);
 
+class TensorIntrinManager {
+ public:
+  Map<String, tir::TensorIntrin> reg;
+
+  static TensorIntrinManager* Global() {
+    static TensorIntrinManager* inst = new TensorIntrinManager();
+    return inst;
+  }
+};
+
+TensorIntrin::TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func) {
+  // check the number of func var is equal
+  CHECK_EQ(desc_func->params.size(), intrin_func->params.size());
+  CHECK_EQ(desc_func->buffer_map.size(), intrin_func->buffer_map.size());
+
+  // check both functions' bodies are directly block
+  const auto* desc_realize =
+      
Downcast<BlockRealize>(desc_func->body)->block->body.as<BlockRealizeNode>();
+  const auto* intrin_realize =
+      
Downcast<BlockRealize>(intrin_func->body)->block->body.as<BlockRealizeNode>();
+  CHECK(desc_realize != nullptr) << "description function's body expect a 
directly block";
+  CHECK(intrin_realize != nullptr) << "intrinsic function's body expect a 
directly block";
+
+  const Block& desc_block = desc_realize->block;
+  const Block& intrin_block = intrin_realize->block;
+
+  // check block var number and iter type
+  CHECK_EQ(desc_block->iter_vars.size(), intrin_block->iter_vars.size())
+      << "Two blocks should have the same number of block vars";
+  for (size_t i = 0; i < desc_block->iter_vars.size(); i++) {
+    const IterVar& desc_var = desc_block->iter_vars[i];
+    const IterVar& intrin_var = intrin_block->iter_vars[i];
+    CHECK(desc_var->iter_type == intrin_var->iter_type)
+        << "Block iter_type mismatch between " << desc_var->iter_type << " and 
"
+        << intrin_var->iter_type;

Review comment:
       Let's use `IterVarType2String`

##########
File path: src/tir/ir/function.cc
##########
@@ -64,6 +64,67 @@ FuncType PrimFuncNode::func_type_annotation() const {
 
 TVM_REGISTER_NODE_TYPE(PrimFuncNode);
 
+class TensorIntrinManager {
+ public:
+  Map<String, tir::TensorIntrin> reg;
+
+  static TensorIntrinManager* Global() {
+    static TensorIntrinManager* inst = new TensorIntrinManager();
+    return inst;
+  }
+};
+
+TensorIntrin::TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func) {
+  // check the number of func var is equal
+  CHECK_EQ(desc_func->params.size(), intrin_func->params.size());
+  CHECK_EQ(desc_func->buffer_map.size(), intrin_func->buffer_map.size());
+
+  // check both functions' bodies are directly block
+  const auto* desc_realize =
+      
Downcast<BlockRealize>(desc_func->body)->block->body.as<BlockRealizeNode>();
+  const auto* intrin_realize =
+      
Downcast<BlockRealize>(intrin_func->body)->block->body.as<BlockRealizeNode>();
+  CHECK(desc_realize != nullptr) << "description function's body expect a 
directly block";
+  CHECK(intrin_realize != nullptr) << "intrinsic function's body expect a 
directly block";
+
+  const Block& desc_block = desc_realize->block;
+  const Block& intrin_block = intrin_realize->block;
+
+  // check block var number and iter type
+  CHECK_EQ(desc_block->iter_vars.size(), intrin_block->iter_vars.size())
+      << "Two blocks should have the same number of block vars";
+  for (size_t i = 0; i < desc_block->iter_vars.size(); i++) {
+    const IterVar& desc_var = desc_block->iter_vars[i];
+    const IterVar& intrin_var = intrin_block->iter_vars[i];
+    CHECK(desc_var->iter_type == intrin_var->iter_type)
+        << "Block iter_type mismatch between " << desc_var->iter_type << " and 
"
+        << intrin_var->iter_type;
+  }
+
+  auto n = make_object<TensorIntrinNode>();

Review comment:
       nit
   
   ```suggestion
     ObjectPtr<TensorIntrinNode> n = make_object<TensorIntrinNode>();
   ```

##########
File path: src/tir/ir/function.cc
##########
@@ -64,6 +64,67 @@ FuncType PrimFuncNode::func_type_annotation() const {
 
 TVM_REGISTER_NODE_TYPE(PrimFuncNode);
 
+class TensorIntrinManager {
+ public:
+  Map<String, tir::TensorIntrin> reg;
+
+  static TensorIntrinManager* Global() {
+    static TensorIntrinManager* inst = new TensorIntrinManager();
+    return inst;
+  }
+};
+
+TensorIntrin::TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func) {
+  // check the number of func var is equal
+  CHECK_EQ(desc_func->params.size(), intrin_func->params.size());
+  CHECK_EQ(desc_func->buffer_map.size(), intrin_func->buffer_map.size());
+
+  // check both functions' bodies are directly block
+  const auto* desc_realize =
+      
Downcast<BlockRealize>(desc_func->body)->block->body.as<BlockRealizeNode>();
+  const auto* intrin_realize =
+      
Downcast<BlockRealize>(intrin_func->body)->block->body.as<BlockRealizeNode>();
+  CHECK(desc_realize != nullptr) << "description function's body expect a 
directly block";
+  CHECK(intrin_realize != nullptr) << "intrinsic function's body expect a 
directly block";
+
+  const Block& desc_block = desc_realize->block;
+  const Block& intrin_block = intrin_realize->block;
+
+  // check block var number and iter type
+  CHECK_EQ(desc_block->iter_vars.size(), intrin_block->iter_vars.size())
+      << "Two blocks should have the same number of block vars";
+  for (size_t i = 0; i < desc_block->iter_vars.size(); i++) {
+    const IterVar& desc_var = desc_block->iter_vars[i];
+    const IterVar& intrin_var = intrin_block->iter_vars[i];
+    CHECK(desc_var->iter_type == intrin_var->iter_type)
+        << "Block iter_type mismatch between " << desc_var->iter_type << " and 
"
+        << intrin_var->iter_type;
+  }
+
+  auto n = make_object<TensorIntrinNode>();
+  n->description = std::move(desc_func);
+  n->implementation = std::move(intrin_func);
+  data_ = std::move(n);
+}
+
+TensorIntrin TensorIntrin::Register(String name, PrimFunc desc_func, PrimFunc 
intrin_func) {
+  TensorIntrinManager* manager = TensorIntrinManager::Global();
+  ICHECK_EQ(manager->reg.count(name), 0)
+      << "ValueError: TensorIntrin '" << name << "' has already been 
registered";

Review comment:
       Should be user-facing API
   
   ```suggestion
     CHECK_EQ(manager->reg.count(name), 0)
         << "ValueError: TensorIntrin '" << name << "' has already been 
registered";
   ```

##########
File path: src/tir/ir/function.cc
##########
@@ -64,6 +64,67 @@ FuncType PrimFuncNode::func_type_annotation() const {
 
 TVM_REGISTER_NODE_TYPE(PrimFuncNode);
 
+class TensorIntrinManager {
+ public:
+  Map<String, tir::TensorIntrin> reg;
+
+  static TensorIntrinManager* Global() {
+    static TensorIntrinManager* inst = new TensorIntrinManager();
+    return inst;
+  }
+};
+
+TensorIntrin::TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func) {
+  // check the number of func var is equal
+  CHECK_EQ(desc_func->params.size(), intrin_func->params.size());
+  CHECK_EQ(desc_func->buffer_map.size(), intrin_func->buffer_map.size());
+
+  // check both functions' bodies are directly block
+  const auto* desc_realize =
+      
Downcast<BlockRealize>(desc_func->body)->block->body.as<BlockRealizeNode>();
+  const auto* intrin_realize =
+      
Downcast<BlockRealize>(intrin_func->body)->block->body.as<BlockRealizeNode>();
+  CHECK(desc_realize != nullptr) << "description function's body expect a 
directly block";
+  CHECK(intrin_realize != nullptr) << "intrinsic function's body expect a 
directly block";
+
+  const Block& desc_block = desc_realize->block;
+  const Block& intrin_block = intrin_realize->block;
+
+  // check block var number and iter type
+  CHECK_EQ(desc_block->iter_vars.size(), intrin_block->iter_vars.size())
+      << "Two blocks should have the same number of block vars";
+  for (size_t i = 0; i < desc_block->iter_vars.size(); i++) {
+    const IterVar& desc_var = desc_block->iter_vars[i];
+    const IterVar& intrin_var = intrin_block->iter_vars[i];
+    CHECK(desc_var->iter_type == intrin_var->iter_type)
+        << "Block iter_type mismatch between " << desc_var->iter_type << " and 
"
+        << intrin_var->iter_type;
+  }
+
+  auto n = make_object<TensorIntrinNode>();
+  n->description = std::move(desc_func);
+  n->implementation = std::move(intrin_func);
+  data_ = std::move(n);
+}
+
+TensorIntrin TensorIntrin::Register(String name, PrimFunc desc_func, PrimFunc 
intrin_func) {
+  TensorIntrinManager* manager = TensorIntrinManager::Global();
+  ICHECK_EQ(manager->reg.count(name), 0)
+      << "ValueError: TensorIntrin '" << name << "' has already been 
registered";
+  TensorIntrin intrin(desc_func, intrin_func);
+  manager->reg.Set(name, intrin);
+  return intrin;
+}
+
+TensorIntrin TensorIntrin::Get(String name) {
+  const TensorIntrinManager* manager = TensorIntrinManager::Global();
+  ICHECK_EQ(manager->reg.count(name), 1)
+      << "ValueError: TensorIntrin '" << name << "' is not registered";

Review comment:
       ditto
   
   ```suggestion
     CHECK_EQ(manager->reg.count(name), 1)
         << "ValueError: TensorIntrin '" << name << "' is not registered";
   ```

##########
File path: src/tir/ir/function.cc
##########
@@ -64,6 +64,67 @@ FuncType PrimFuncNode::func_type_annotation() const {
 
 TVM_REGISTER_NODE_TYPE(PrimFuncNode);
 
+class TensorIntrinManager {
+ public:
+  Map<String, tir::TensorIntrin> reg;
+
+  static TensorIntrinManager* Global() {
+    static TensorIntrinManager* inst = new TensorIntrinManager();
+    return inst;
+  }
+};
+
+TensorIntrin::TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func) {
+  // check the number of func var is equal
+  CHECK_EQ(desc_func->params.size(), intrin_func->params.size());
+  CHECK_EQ(desc_func->buffer_map.size(), intrin_func->buffer_map.size());
+
+  // check both functions' bodies are directly block
+  const auto* desc_realize =
+      
Downcast<BlockRealize>(desc_func->body)->block->body.as<BlockRealizeNode>();
+  const auto* intrin_realize =
+      
Downcast<BlockRealize>(intrin_func->body)->block->body.as<BlockRealizeNode>();
+  CHECK(desc_realize != nullptr) << "description function's body expect a 
directly block";
+  CHECK(intrin_realize != nullptr) << "intrinsic function's body expect a 
directly block";
+
+  const Block& desc_block = desc_realize->block;
+  const Block& intrin_block = intrin_realize->block;
+
+  // check block var number and iter type
+  CHECK_EQ(desc_block->iter_vars.size(), intrin_block->iter_vars.size())
+      << "Two blocks should have the same number of block vars";
+  for (size_t i = 0; i < desc_block->iter_vars.size(); i++) {
+    const IterVar& desc_var = desc_block->iter_vars[i];
+    const IterVar& intrin_var = intrin_block->iter_vars[i];
+    CHECK(desc_var->iter_type == intrin_var->iter_type)
+        << "Block iter_type mismatch between " << desc_var->iter_type << " and 
"
+        << intrin_var->iter_type;
+  }
+
+  auto n = make_object<TensorIntrinNode>();
+  n->description = std::move(desc_func);
+  n->implementation = std::move(intrin_func);
+  data_ = std::move(n);
+}
+
+TensorIntrin TensorIntrin::Register(String name, PrimFunc desc_func, PrimFunc 
intrin_func) {
+  TensorIntrinManager* manager = TensorIntrinManager::Global();
+  ICHECK_EQ(manager->reg.count(name), 0)
+      << "ValueError: TensorIntrin '" << name << "' has already been 
registered";
+  TensorIntrin intrin(desc_func, intrin_func);
+  manager->reg.Set(name, intrin);
+  return intrin;
+}
+
+TensorIntrin TensorIntrin::Get(String name) {
+  const TensorIntrinManager* manager = TensorIntrinManager::Global();
+  ICHECK_EQ(manager->reg.count(name), 1)
+      << "ValueError: TensorIntrin '" << name << "' is not registered";
+  return manager->reg.at(name);

Review comment:
       Use manager->reg.find to save one lookup :-)

##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -1733,6 +1733,258 @@ def after_set_scope(
 
     ########## Schedule: Blockize & Tensorize ##########
 
+    @type_checked
+    def blockize(self, loop: LoopRV) -> BlockRV:
+        """Convert the subtree rooted at a specific loop into a block.
+
+        Parameters
+        ----------
+        loop : LoopRV
+            The root of the subtree.
+
+        Returns
+        -------
+        result : BlockRV
+            The new block.
+
+        Examples
+        --------
+
+        Before blockize, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def before_blockize(
+                A: T.Buffer[(128, 128), "float32"],
+                B: T.Buffer[(128, 128), "float32"]
+            ) -> None:
+                for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16):
+                    with T.block("B"):
+                        vi = T.axis.spatial(128, i_0 * 16 + i_1)
+                        vj = T.axis.spatial(128, j_0 * 16 + j_1)
+                        T.reads(A[vi, vj])
+                        T.writes(B[vi, vj])
+                        B[vi, vj] = A[vi, vj] * T.float32(2)
+
+        Create the schedule and do set_scope:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_blockize)
+            B = sch.get_block("B")
+            _, _, i1, _ = sch.get_loops(B)
+            sch.blockize(i1)
+            print(sch.mod["main"].script())
+
+        After applying blockize, the IR becomes:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def after_blockize(
+                A: T.Buffer[(128, 128), "float32"],
+                B: T.Buffer[(128, 128), "float32"]
+            )-> None:
+                for i_0, j_0 in T.grid(8, 8):
+                    with T.block("blockized_B"):
+                        vio, vjo = T.axis.remap("SS", [i_0, j_0])
+                        T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 
16 + 16])
+                        T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 
16 + 16])
+                        for i_1, j_1 in T.grid(16, 16):
+                            with T.block("B"):
+                                vi = T.axis.spatial(128, vio * 16 + i_1)
+                                vj = T.axis.spatial(128, vjo * 16 + j_1)
+                                T.reads(A[vi, vj])
+                                T.writes(B[vi, vj])
+                                B[vi, vj] = A[vi, vj] * T.float32(2)
+
+        Note
+        ----
+        blockize requires there is exactly one block under the given loop and 
the bindings of the
+        block are divisible by the subspace represented by the loops starting 
at the given loop.
+
+        """
+
+        return _ffi_api.ScheduleBlockize(self, loop)  # type: ignore # pylint: 
disable=no-member
+
+    @type_checked
+    def tensorize(self, loop: LoopRV, tensor_intrin: Union[str, TensorIntrin]) 
-> None:
+        """Tensorize the computation enclosed by loop with the tensor 
intrinsic.
+
+        Parameters
+        ----------
+        loop : LoopRV
+            The loop to be tensorized.
+        tensor_intrin : Union[str, TensorIntrin]
+            The tensor intrin or the name of the tensor intrin.
+
+        Examples
+        --------
+
+        Before tensorize, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def before_tensorize(
+                A: T.Buffer[(128, 128), "float32"],
+                B: T.Buffer[(128, 128), "float32"],
+                C: T.Buffer[(128, 128), "float32"],
+            ) -> None:
+                # body
+                # with T.block("root")
+                for i_0, j_0 in T.grid(8, 8):
+                    for i_1_init, j_1_init in T.grid(16, 16):
+                        with T.block("init"):
+                            vi = T.axis.spatial(128, i_0 * 16 + i_1_init)
+                            vj = T.axis.spatial(128, j_0 * 16 + j_1_init)
+                            T.reads()
+                            T.writes(C[vi, vj])
+                            C[vi, vj] = T.float32(0)
+                    for k_0, i_1, j_1, k_1 in T.grid(8, 16, 16, 16):
+                        with T.block("update"):
+                            vi = T.axis.spatial(128, i_0 * 16 + i_1)
+                            vj = T.axis.spatial(128, j_0 * 16 + j_1)
+                            vk = T.axis.reduce(128, k_0 * 16 + k_1)
+                            T.reads(C[vi, vj], A[vi, vk], B[vj, vk])
+                            T.writes(C[vi, vj])
+                            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+        Declare and register the tensor intrinsic:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+                A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
+                B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
+                C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
+
+                with T.block("root"):
+                    vi = T.axis.S(16, 0)
+                    vj = T.axis.S(16, 0)
+                    vk = T.axis.R(16, 0)
+                    for i, j, k in T.grid(16, 16, 16):
+                        with T.block("update"):
+                            vii = T.axis.S(16, vi + i)
+                            vjj = T.axis.S(16, vj + j)
+                            vkk = T.axis.R(16, vk + k)
+                            C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, 
vkk]
+
+
+            @T.prim_func
+            def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+                A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
+                B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
+                C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
+
+                with T.block("root"):
+                    vi = T.axis.S(16, 0)
+                    vj = T.axis.S(16, 0)
+                    vk = T.axis.R(16, 0)
+                    T.reads(
+                        [
+                            C[vi : vi + 16, vj : vj + 16],
+                            A[vi : vi + 16, vk : vk + 16],
+                            B[vj : vj + 16, vk : vk + 16],
+                        ]
+                    )
+                    T.writes(C[vi : vi + 16, vj : vj + 16])
+                    T.evaluate(
+                        T.tvm_mma_sync(
+                            C.data,
+                            C.elem_offset // 256,
+                            A.data,
+                            A.elem_offset // 256,
+                            B.data,
+                            B.elem_offset // 256,
+                            C.data,
+                            C.elem_offset // 256,
+                            dtype="handle",
+                        )
+                    )
+
+            tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin)
+
+        Create the schedule and do tensorize:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_tensorize)
+            update = s.get_block("update")
+            _, _, _, i1, _, _ = s.get_loops(update)
+            s.tensorize(ii, "test_mma_intrin")
+            print(sch.mod["main"].script())
+
+        After applying tensorize, the IR becomes:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def after_tensoirze(
+                A: T.Buffer[(128, 128), "float32"],
+                B: T.Buffer[(128, 128), "float32"],
+                C: T.Buffer[(128, 128), "float32"],
+            ) -> None:
+                # body
+                # with T.block("root")
+                for i_0, j_0 in T.grid(8, 8):
+                    for i_1_init, j_1_init in T.grid(16, 16):
+                        with T.block("init"):
+                            vi = T.axis.spatial(128, i_0 * 16 + i_1_init)
+                            vj = T.axis.spatial(128, j_0 * 16 + j_1_init)
+                            T.reads()
+                            T.writes(C[vi, vj])
+                            C[vi, vj] = T.float32(0)
+                    for k_0 in T.serial(8):
+                        with T.block("blockized_update"):
+                            vio, vjo, vko = T.axis.remap("SSR", [i_0, j_0, 
k_0])
+                            T.reads(
+                                C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 
16 + 16],
+                                A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 
16 + 16],
+                                B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 
16 + 16],
+                            )
+                            T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : 
vjo * 16 + 16])
+                            A_1 = T.match_buffer(
+                                A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 
16 + 16],
+                                [16, 16],
+                                dtype="float32",
+                                offset_factor=1,
+                            )
+                            B_1 = T.match_buffer(
+                                B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 
16 + 16],
+                                [16, 16],
+                                dtype="float32",
+                                offset_factor=1,
+                            )
+                            C_1 = T.match_buffer(
+                                C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 
16 + 16],
+                                [16, 16],
+                                dtype="float32",
+                                offset_factor=1,
+                            )
+                            T.evaluate(
+                                T.tvm_mma_sync(
+                                    C_1.data,
+                                    C_1.elem_offset // 256,
+                                    A_1.data,
+                                    A_1.elem_offset // 256,
+                                    B_1.data,
+                                    B_1.elem_offset // 256,
+                                    C_1.data,
+                                    C_1.elem_offset // 256,
+                                    dtype="handle",
+                                )
+                            )
+
+        """
+        if isinstance(tensor_intrin, str):
+            tensor_intrin = String(tensor_intrin)

Review comment:
       oh i see. It's because our C++ registry are trying to decide between 
String and TensorIntrin - then it makes sense to me

##########
File path: src/tir/schedule/schedule.cc
##########
@@ -183,6 +183,20 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign")
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope")
     .set_body_method<Schedule>(&ScheduleNode::SetScope);
 /******** (FFI) Blockize & Tensorize ********/
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize")
+    .set_body_method<Schedule>(&ScheduleNode::Blockize);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize")
+    .set_body_typed([](Schedule self, LoopRV loop_rv, ObjectRef intrin) {
+      if (const auto* str = intrin.as<runtime::StringObj>()) {
+        return self->Tensorize(loop_rv, GetRef<String>(str));
+      }
+      if (const auto* p_intrin = intrin.as<TensorIntrinNode>()) {
+        return self->Tensorize(loop_rv, GetRef<TensorIntrin>(p_intrin));
+      }
+      LOG(FATAL) << "TypeError: Cannot handle type: " << intrin->GetTypeKey();
+      throw;

Review comment:
       I'm considering...is it reasonable to use only the `String` API and not 
expose the `TensorIntrin` one? The reason I'm asking this is that 
`TensorIntrin` isn't available for `TracedSchedule` anyway, which is the 
default go-to for python frontend users

##########
File path: src/tir/schedule/state.cc
##########
@@ -201,9 +201,8 @@ class BlockInfoCollector : private StmtVisitor {
     bool is_root_block = srefs_.empty();
     // Calculate `BlockInfo::scope`
     Array<StmtSRef> child_block_srefs = std::move(block_frames_.back());
-    BlockInfo& info =
-        self_->block_info.emplace(scope_root, 
BlockInfo(BlockScope(child_block_srefs)))
-            .first->second;
+    self_->block_info[scope_root] = BlockInfo(BlockScope(child_block_srefs));
+    BlockInfo& info = self_->block_info[scope_root];

Review comment:
       Why this change? I suppose they are logically equivalent and the 
original one saves one lookup?

##########
File path: src/tir/schedule/state.cc
##########
@@ -201,9 +201,8 @@ class BlockInfoCollector : private StmtVisitor {
     bool is_root_block = srefs_.empty();
     // Calculate `BlockInfo::scope`
     Array<StmtSRef> child_block_srefs = std::move(block_frames_.back());
-    BlockInfo& info =
-        self_->block_info.emplace(scope_root, 
BlockInfo(BlockScope(child_block_srefs)))
-            .first->second;
+    self_->block_info[scope_root] = BlockInfo(BlockScope(child_block_srefs));
+    BlockInfo& info = self_->block_info[scope_root];

Review comment:
       Oh I see - that makes sense to me! Then to save one look-up, let's do
   
   ```suggestion
       BlockInfo& info = self_->block_info[scope_root] = 
BlockInfo(BlockScope(child_block_srefs));
   ```

##########
File path: src/tir/schedule/analysis.h
##########
@@ -442,6 +442,79 @@ bool CanComputeAt(const ScheduleState& self, const 
StmtSRef& block_sref, const S
 bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
                          const StmtSRef& loop_sref, bool preserve_unit_loops);
 
+/******** Tensorization Related ********/
+
+using ExprComparator = ExprFunctor<bool(const PrimExpr& n, const PrimExpr& 
other)>;

Review comment:
       Shall we move this into `src/tir/schedule/ir_comparator.h` and the impl 
to a corresponding cc file? The reason I'm considering this is that it's 
relatively independent but takes huge chunk of code

##########
File path: src/tir/schedule/analysis.h
##########
@@ -442,6 +442,79 @@ bool CanComputeAt(const ScheduleState& self, const 
StmtSRef& block_sref, const S
 bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
                          const StmtSRef& loop_sref, bool preserve_unit_loops);
 
+/******** Tensorization Related ********/
+
+using ExprComparator = ExprFunctor<bool(const PrimExpr& n, const PrimExpr& 
other)>;
+using StmtComparator = StmtFunctor<bool(const Stmt& n, const Stmt& other)>;
+
+/*! \brief Deep comparison to check if two IR ASTs are equivalent */
+class TensorizeComparator : public ExprComparator, public StmtComparator {
+ public:
+  /*!
+   * \brief Constructor of TensorizeComparator
+   * \param assert_mode Whether to raise an error if the two IR ASTs do not 
match.
+   */
+  explicit TensorizeComparator(bool assert_mode = true) : 
assert_mode_(assert_mode) {}
+
+  // Map from rhs buffer to lhs buffer
+  std::unordered_map<Buffer, Buffer, ObjectHash, ObjectEqual> rhs_buffer_map_;
+  // Buffer indices mapping
+  std::unordered_map<Buffer, std::vector<PrimExpr>, ObjectPtrHash, 
ObjectPtrEqual> buffer_indices_;
+  std::vector<IterVar> extra_block_vars_;
+  // variable remap if any
+  std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> 
equal_map_;

Review comment:
       Let's move the fields to the end of the class, and document them properly

##########
File path: src/tir/schedule/analysis.h
##########
@@ -442,6 +442,79 @@ bool CanComputeAt(const ScheduleState& self, const 
StmtSRef& block_sref, const S
 bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
                          const StmtSRef& loop_sref, bool preserve_unit_loops);
 
+/******** Tensorization Related ********/
+
+using ExprComparator = ExprFunctor<bool(const PrimExpr& n, const PrimExpr& 
other)>;
+using StmtComparator = StmtFunctor<bool(const Stmt& n, const Stmt& other)>;
+
+/*! \brief Deep comparison to check if two IR ASTs are equivalent */
+class TensorizeComparator : public ExprComparator, public StmtComparator {
+ public:
+  /*!
+   * \brief Constructor of TensorizeComparator
+   * \param assert_mode Whether to raise an error if the two IR ASTs do not 
match.
+   */
+  explicit TensorizeComparator(bool assert_mode = true) : 
assert_mode_(assert_mode) {}
+
+  // Map from rhs buffer to lhs buffer
+  std::unordered_map<Buffer, Buffer, ObjectHash, ObjectEqual> rhs_buffer_map_;
+  // Buffer indices mapping
+  std::unordered_map<Buffer, std::vector<PrimExpr>, ObjectPtrHash, 
ObjectPtrEqual> buffer_indices_;
+  std::vector<IterVar> extra_block_vars_;
+  // variable remap if any
+  std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> 
equal_map_;
+
+  bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override;
+  bool VisitStmt(const Stmt& n, const Stmt& other) override;
+
+  bool VisitStmt_(const ForNode* op, const Stmt& other) override;
+  bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override;
+  bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override;
+  bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) override;
+  bool VisitStmt_(const BlockNode* op, const Stmt& other) override;

Review comment:
       Just curious, does it mean that we only allow 5 types of statements?

##########
File path: src/tir/schedule/transform.cc
##########
@@ -136,5 +136,52 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const 
StmtSRef& leaf_block_
   throw OnlyLeafError(self->mod, GetRef<Block>(leaf_block), 
GetRef<Block>(scope_block));
 }
 
+/******** IR Substitution ********/
+class IRSubstituteInScope : public StmtExprMutator {
+ public:
+  explicit IRSubstituteInScope(std::function<PrimExpr(const VarNode*)> fmap)
+      : fmap_(std::move(fmap)) {}
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = fmap_(op);
+    if (it.defined()) {
+      return it;
+    } else {
+      return GetRef<PrimExpr>(op);
+    }
+  }
+
+  Stmt VisitStmt_(const BlockRealizeNode* op) final {
+    auto fmutate = [&](const PrimExpr& e) { return this->VisitExpr(e); };
+    Array<PrimExpr> v = op->iter_values;
+    v.MutateByApply(fmutate);
+    PrimExpr pred = this->VisitExpr(op->predicate);
+    if (v.same_as(op->iter_values) && pred.same_as(op->predicate)) {
+      return GetRef<Stmt>(op);
+    } else {
+      auto n = CopyOnWrite(op);
+      n->iter_values = std::move(v);
+      n->predicate = std::move(analyzer.Simplify(pred));
+      return Stmt(n);
+    }
+  }
+
+ private:
+  const std::function<PrimExpr(const VarNode*)> fmap_;
+  arith::Analyzer analyzer;
+};
+
+Stmt SubstituteInScope(const Stmt& stmt, const Map<Var, PrimExpr>& subst_map) {

Review comment:
       I'm reading the implementation. Looks like what `SubstituteInScope` 
really does is that we only substitute block bindings and predicates, stops at 
the `BlockRealize` boundary and does nothing else. Is that accurate? If so, 
shall we rename this to `SubstituteBlockBinding`?

##########
File path: src/tir/schedule/analysis.h
##########
@@ -442,6 +442,79 @@ bool CanComputeAt(const ScheduleState& self, const 
StmtSRef& block_sref, const S
 bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
                          const StmtSRef& loop_sref, bool preserve_unit_loops);
 
+/******** Tensorization Related ********/
+
+using ExprComparator = ExprFunctor<bool(const PrimExpr& n, const PrimExpr& 
other)>;
+using StmtComparator = StmtFunctor<bool(const Stmt& n, const Stmt& other)>;
+
+/*! \brief Deep comparison to check if two IR ASTs are equivalent */
+class TensorizeComparator : public ExprComparator, public StmtComparator {
+ public:
+  /*!
+   * \brief Constructor of TensorizeComparator
+   * \param assert_mode Whether to raise an error if the two IR ASTs do not 
match.
+   */
+  explicit TensorizeComparator(bool assert_mode = true) : 
assert_mode_(assert_mode) {}
+
+  // Map from rhs buffer to lhs buffer
+  std::unordered_map<Buffer, Buffer, ObjectHash, ObjectEqual> rhs_buffer_map_;

Review comment:
       They are not :-(

##########
File path: src/tir/schedule/analysis/analysis.cc
##########
@@ -1376,5 +1376,333 @@ void CheckStorageScope(const ScheduleState& self, 
String storage_scope) {
   }
 }
 
+/******** Tensorize Comparator ********/
+
+bool TensorizeComparator::VisitStmt(const Stmt& n, const Stmt& other) {
+  if (n.same_as(other)) return true;
+  if (n->type_index() != other->type_index()) return false;
+  bool equal = StmtComparator::VisitStmt(n, other);
+  if (!equal && assert_mode_)
+    LOG(FATAL) << "Stmts are not matching between:\n" << n << "\nand\n" << 
other;
+  return equal;
+}
+
+bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) {
+  const auto* rhs = other.as<ForNode>();
+  if (!DefEqual(op->loop_var, rhs->loop_var)) return false;
+  if (!VisitExpr(op->min, rhs->min)) return false;
+  if (!VisitExpr(op->extent, rhs->extent)) return false;
+  if (!VisitStmt(op->body, rhs->body)) return false;
+  if (op->kind != rhs->kind) return false;
+  if (op->thread_binding.defined() ^ rhs->thread_binding.defined()) return 
false;

Review comment:
       ```suggestion
     if (op->thread_binding.defined() != rhs->thread_binding.defined()) return 
false;
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to