This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 08898e1  [TensorIR] Cross-Thread Reduction (#9360)
08898e1 is described below

commit 08898e18628752d02fdb9e10f8135e1e3b95fb34
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Nov 15 01:37:43 2021 +0800

    [TensorIR] Cross-Thread Reduction (#9360)
    
    * [TensorIR] Cross-Thread Reduction
    
    * Code revision on analysis and misc
    
    * Refactor TransformReductionBlock
    
    * Refactor code organization
    
    * Address comment
    
    * Use `std::make_tuple`
    
    Co-authored-by: Junru Shao <[email protected]>
---
 include/tvm/tir/transform.h                        |   7 +
 python/tvm/tir/transform/transform.py              |  12 +
 src/driver/driver_api.cc                           |   1 +
 src/tir/schedule/analysis.h                        |  50 +-
 src/tir/schedule/analysis/analysis.cc              | 249 +++++--
 src/tir/schedule/primitive/reduction.cc            | 138 +---
 src/tir/transforms/lower_cross_thread_reduction.cc | 645 ++++++++++++++++++
 ...t_tir_transform_lower_cross_thread_reduction.py | 737 +++++++++++++++++++++
 8 files changed, 1662 insertions(+), 177 deletions(-)

diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index e6b0af9..7922e97 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -358,6 +358,13 @@ TVM_DLL Pass PointerValueTypeRewrite();
 TVM_DLL Pass HoistIfThenElse();
 
 /*!
+ * \brief Lower cross-thread reduction from thread
+ * bindings to intrinsic function calls.
+ * \return The pass.
+ */
+TVM_DLL Pass LowerCrossThreadReduction();
+
+/*!
  * \brief Lower block init stmt into IfThenElse stmts
  * \return The pass.
  */
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index 722810e..86f798c 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -577,6 +577,18 @@ def HoistIfThenElse(variant: Optional[str] = None):
         return _ffi_api.HoistIfThenElse()  # type: ignore
 
 
+def LowerCrossThreadReduction():
+    """Lower cross-thread reduction from thread bindings to
+    intrinsic function calls.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.LowerCrossThreadReduction()  # type: ignore
+
+
 def LowerInitBlock():
     """Lower block init stmt into IfThenElse statements.
 
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index ad1f51b..f49409c 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -234,6 +234,7 @@ Array<tvm::transform::Pass> CreatePassList(bool 
disable_loop_partition) {
   pass_list.push_back(tir::transform::InjectPrefetch());
   pass_list.push_back(tir::transform::TextureFlatten());
   pass_list.push_back(tir::transform::StorageFlatten(64, 
instrument_bound_checkers));
+  pass_list.push_back(tir::transform::LowerCrossThreadReduction());
   pass_list.push_back(tir::transform::LowerInitBlock());
   pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
   pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 5a2f46c..42e0e00 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -19,12 +19,17 @@
 #ifndef TVM_TIR_SCHEDULE_ANALYSIS_H_
 #define TVM_TIR_SCHEDULE_ANALYSIS_H_
 
+#include <tvm/arith/analyzer.h>
 #include <tvm/tir/schedule/state.h>
 
+#include <tuple>
 #include <unordered_map>
 #include <unordered_set>
+#include <utility>
 #include <vector>
 
+#include "../../runtime/thread_storage_scope.h"
+
 namespace tvm {
 namespace tir {
 
@@ -323,6 +328,49 @@ struct ProducerConsumerSplit {
  */
 Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int 
n, bool is_write);
 
+/******** Reduction Block Related ********/
+
+/*!
+ * \brief Convert the `init` and `body` of the input block to BufferStores
+ * \param self The schedule state
+ * \param block The block to be analyzed
+ * \return The BufferStores of the `init` and `body` of the input block
+ * \throw ScheduleError If the `init` or `body` is not BufferStore, or they 
don't write to the same
+ * buffer
+ */
+std::pair<BufferStore, BufferStore> GetBufferStoresFromReductionBlock(
+    const Optional<ScheduleState>& self, const Block& block);
+
+/*!
+ * \brief Check whether the input array of IterVars only contains 
data-parallel and reduction block
+ * iters
+ * \param iters The input array of IterVars to be checked
+ * \return A boolean indicating whether the input array of IterVars only 
contains data-parallel and
+ * reduction block iters
+ */
+bool ContainsOnlyDataParAndReductionBlockIter(const Array<IterVar>& iters);
+
+/*!
+ * \brief Check whether the block's reduction block iters are not used to 
index the block's output
+ * buffers
+ * \param block The block to be checked
+ * \return A boolean indicating whether the block's reduction block iters are 
not used to index the
+ * block's output buffer
+ */
+bool ReductionIterNotIndexOutputBuffer(const Block& block);
+
+/*!
+ * \brief Given a reduction identity and a reduction combiner, detect the 
corresponding commutative
+ * reducer, and extract the combiner lhs and combiner rhs
+ * \param self The schedule state
+ * \param identity The reduction identity to be analyzed
+ * \param combiner The reduction combiner to be analyzed
+ * \return The corresponding CommReducer, the combiner lhs and the combiner rhs
+ * \throw ScheduleError If no corresponding commutative reducer can be matched
+ */
+std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs(
+    const Optional<ScheduleState>& self, const PrimExpr& identity, const 
BufferStore& combiner);
+
 /******** Commutative Reducer ********/
 
 /*!
@@ -330,7 +378,7 @@ Buffer GetNthAccessBuffer(const ScheduleState& self, const 
Block& block, int n,
  * \return The list of the registered reducer-getter functions
  * \sa ReducerRegistry
  */
-std::vector<TypedPackedFunc<CommReducer(DataType)>> GetReducerGetters();
+std::vector<runtime::TypedPackedFunc<CommReducer(DataType)>> 
GetReducerGetters();
 
 /*!
  * \brief Given the input identity and the combiner BufferStore of a 
reduction, extract the
diff --git a/src/tir/schedule/analysis/analysis.cc 
b/src/tir/schedule/analysis/analysis.cc
index e3a535e..7e16bc9 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -153,15 +153,15 @@ Definition of a scope that is a stage pipeline:
 /*!
  * \brief Check the dominant property of a block:
  * the block is the only writer of its output, dominating the reader of its 
output buffers
- * \param self The schedule state
+ * \param scope The block-scope of the block to be checked
  * \param block_sref The block whose dominant property is to be checked
  * \return A boolean indicating if the block is a dominant block
  */
-bool IsDominantBlock(const BlockScope& self, const StmtSRef& block_sref) {
+bool IsDominantBlock(const BlockScope& scope, const StmtSRef& block_sref) {
   // Check whether the input block is the only writer of its outputs
   const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
   const std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, 
ObjectPtrEqual>& buffer_writers =
-      self->buffer_writers;
+      scope->buffer_writers;
   for (const BufferRegion& write_region : block->writes) {
     ICHECK(buffer_writers.count(write_region->buffer))
         << "InternalError: buffer \"" << write_region->buffer->name
@@ -279,14 +279,8 @@ int CheckReductionBlockErrorCode(const ScheduleState& 
self, const StmtSRef& bloc
   }
   // Cond 3. All block vars are either data parallel block vars or reduction 
block vars. Meanwhile,
   // we collect all the reduction block vars.
-  std::unordered_set<const VarNode*> reduction_block_vars;
-  reduction_block_vars.reserve(block->iter_vars.size());
-  for (const IterVar& iter_var : block->iter_vars) {
-    if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) 
{
-      return 3;
-    } else if (iter_var->iter_type == kCommReduce) {
-      reduction_block_vars.insert(iter_var->var.get());
-    }
+  if (!ContainsOnlyDataParAndReductionBlockIter(block->iter_vars)) {
+    return 3;
   }
   // Cond 4. Dominant: the block is the only writer of its output, dominating 
the reader of its
   // output buffers.
@@ -294,33 +288,7 @@ int CheckReductionBlockErrorCode(const ScheduleState& 
self, const StmtSRef& bloc
     return 4;
   }
   // Cond 5. The reduction block vars are not used to index the output buffers.
-  std::unordered_set<const BufferNode*> buffer_written;
-  buffer_written.reserve(block->writes.size());
-  for (const BufferRegion& write_region : block->writes) {
-    buffer_written.insert(write_region->buffer.get());
-  }
-  bool affected = false;
-  PreOrderVisit(block->body, [&](const ObjectRef& obj) {
-    if (affected) {
-      return false;
-    }
-    if (const auto* store = obj.as<BufferStoreNode>()) {
-      ICHECK(buffer_written.count(store->buffer.get()))
-          << "ValueError: The buffer \"" << store->buffer
-          << "\" is written in the block but is not in the block's signature";
-      for (const PrimExpr& index : store->indices) {
-        if (UsesVar(index, [&reduction_block_vars](const VarNode* var) {
-              return reduction_block_vars.count(var);
-            })) {
-          affected = true;
-          return false;
-        }
-      }
-      return false;
-    }
-    return true;
-  });
-  return !affected ? 0 : 5;
+  return ReductionIterNotIndexOutputBuffer(GetRef<Block>(block)) ? 0 : 5;
 }
 
 bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
@@ -552,7 +520,9 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& 
block_realize,
     } else {
       has_block_vars_of_other_types = true;
     }
-
+    if (set == nullptr) {
+      continue;
+    }
     Array<Var> vars_in_binding = UndefinedVars(iter_value);
     for (const Var& var : vars_in_binding) {
       set->insert(var.get());
@@ -1128,6 +1098,207 @@ class PatternMatcher : public ExprVisitor {
   std::unordered_map<const VarNode*, PrimExpr> filled_map_;
 };
 
+/******** Reduction Block Related ********/
+
+class InitBodyNotBufferStoreError : public ScheduleError {
+ public:
+  explicit InitBodyNotBufferStoreError(IRModule mod, Block block, bool 
init_is_bufferstore,
+                                       bool body_is_bufferstore)
+      : mod_(std::move(mod)),
+        block_(std::move(block)),
+        init_is_bufferstore_(init_is_bufferstore),
+        body_is_bufferstore_(body_is_bufferstore) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The `init` and `body` of reduction block are 
required to be both "
+           "BufferStore so that rfactor or cross-thread reduction can be 
applied";
+  }
+
+  String DetailRenderTemplate() const final {
+    if (!init_is_bufferstore_ && !body_is_bufferstore_) {
+      return "The `init` and `body` of block {0} are required to be 
BufferStore so that rfactor or "
+             "cross-thread reduction can be applied";
+    } else if (!init_is_bufferstore_) {
+      return "The `init` of block {0} is required to be BufferStore so that 
rfactor or cross-thread"
+             " reduction can be applied";
+    } else {
+      ICHECK(!body_is_bufferstore_);
+      return "The `body` of block {0} is required to be BufferStore so that 
rfactor or cross-thread"
+             " reduction can be applied";
+    }
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+
+  IRModule mod_;
+  Block block_;
+  bool init_is_bufferstore_;
+  bool body_is_bufferstore_;
+};
+
+class InitBodyNotSameBufferAccessError : public ScheduleError {
+ public:
+  explicit InitBodyNotSameBufferAccessError(IRModule mod, Block block)
+      : mod_(std::move(mod)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The `init` and `body` of the reduction block are 
required to have the "
+           "same buffer access pattern";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    const auto* init = block_->init.as<BufferStoreNode>();
+    const auto* update = block_->body.as<BufferStoreNode>();
+    os << "The `init` and `body` of the block {0} is required to have the same 
buffer access "
+          "pattern. However, in block {0} the `init` writes to "
+       << init->buffer->name << init->indices << ", and the `body` writes to "
+       << update->buffer->name << update->indices;
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+
+  IRModule mod_;
+  Block block_;
+};
+
+std::pair<BufferStore, BufferStore> GetBufferStoresFromReductionBlock(
+    const Optional<ScheduleState>& self, const Block& block) {
+  static constexpr const char* error_str1 =
+      "ValueError: The `init` and `body` of the reduction block are required 
to be both "
+      "BufferStore so that rfactor or cross-thread reduction can be applied. 
However, a reduction "
+      "block that doesn't meet this requirement is ";
+  static constexpr const char* error_str2 =
+      "ValueError: The `init` and `body` of the reduction block are required 
to have the same "
+      "buffer access pattern so that rfactor or cross-thread reduction can be 
applied. However, a "
+      "reduction block that doesn't meet this requirement is ";
+
+  const auto* init = block->init.as<BufferStoreNode>();
+  const auto* body = block->body.as<BufferStoreNode>();
+  if (!(init && body)) {
+    if (self.defined()) {
+      throw InitBodyNotBufferStoreError(self.value()->mod, block, init != 
nullptr, body != nullptr);
+    } else {
+      LOG(FATAL) << error_str1 << block;
+    }
+  }
+  if (!init->buffer.same_as(body->buffer)) {
+    if (self.defined()) {
+      throw InitBodyNotSameBufferAccessError(self.value()->mod, block);
+    } else {
+      LOG(FATAL) << error_str2 << block;
+    }
+  }
+  int ndim = static_cast<int>(init->buffer->shape.size());
+  for (int i = 0; i < ndim; ++i) {
+    if (!ExprDeepEqual()(init->indices[i], body->indices[i])) {
+      if (self.defined()) {
+        throw InitBodyNotSameBufferAccessError(self.value()->mod, block);
+      } else {
+        LOG(FATAL) << error_str2 << block;
+      }
+    }
+  }
+  return std::make_pair(GetRef<BufferStore>(init), GetRef<BufferStore>(body));
+}
+
+bool ContainsOnlyDataParAndReductionBlockIter(const Array<IterVar>& iters) {
+  for (const IterVar& iter_var : iters) {
+    if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) 
{
+      return false;
+    }
+  }
+  return true;
+}
+
+bool ReductionIterNotIndexOutputBuffer(const Block& block) {
+  // Step 1. Collect the reduction block iters.
+  std::unordered_set<const VarNode*> reduction_block_iters;
+  reduction_block_iters.reserve(block->iter_vars.size());
+  for (const IterVar& iter_var : block->iter_vars) {
+    if (iter_var->iter_type == kCommReduce) {
+      reduction_block_iters.insert(iter_var->var.get());
+    }
+  }
+  // Step 2. Check if the reduction block iters are used to index the output 
buffer.
+  std::unordered_set<const BufferNode*> buffer_written;
+  buffer_written.reserve(block->writes.size());
+  for (const BufferRegion& write_region : block->writes) {
+    buffer_written.insert(write_region->buffer.get());
+  }
+  auto f_uses_reduction_block_var = [&](const PrimExpr& expr) -> bool {
+    return UsesVar(expr, [&](const VarNode* var) {  //
+      return reduction_block_iters.count(var);
+    });
+  };
+  bool affected = false;
+  PreOrderVisit(block->body, [&](const ObjectRef& obj) {
+    if (affected) {
+      return false;
+    }
+    const auto* store = obj.as<BufferStoreNode>();
+    if (!store) {
+      return true;
+    }
+    ICHECK(buffer_written.count(store->buffer.get()))
+        << "ValueError: The buffer \"" << store->buffer
+        << "\" is written in the block but is not in the block's signature";
+    for (const PrimExpr& index : store->indices) {
+      if (f_uses_reduction_block_var(index)) {
+        affected = true;
+        return false;
+      }
+    }
+    return false;
+  });
+  return !affected;
+}
+
+class NoMatchedReducerError : public ScheduleError {
+ public:
+  explicit NoMatchedReducerError(IRModule mod, PrimExpr identity, BufferStore 
combiner)
+      : mod_(std::move(mod)), identity_(std::move(identity)), 
combiner_(std::move(combiner)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: No matched reducer for the identity and the 
combiner of this reduction "
+           "block. So rfactor and cross-thread reduction cannot be applied.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "No matched reducer for identity " << identity_ << " and combiner " 
<< combiner_
+       << "In this case rfactor cannot be applied. You can check 
tvm::tir::ReducerRegistry for "
+          "default reducers or registering new reducers.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+
+  IRModule mod_;
+  PrimExpr identity_;
+  BufferStore combiner_;
+};
+
+std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs(
+    const Optional<ScheduleState>& self, const PrimExpr& identity, const 
BufferStore& combiner) {
+  CommReducer reducer{nullptr};
+  PrimExpr combiner_lhs{nullptr}, combiner_rhs{nullptr};
+  bool matched = FromIdentityCombiner(identity, combiner, &reducer, 
&combiner_lhs, &combiner_rhs);
+  if (!matched) {
+    if (self.defined()) {
+      throw NoMatchedReducerError(self.value()->mod, identity, combiner);
+    } else {
+      LOG(FATAL) << "ValueError: No matched reducer for the identity and the 
combiner of the "
+                    "reduction block. So rfactor and cross-thread reduction 
cannot be applied.";
+    }
+  }
+  return std::make_tuple(std::move(reducer), std::move(combiner_lhs), 
std::move(combiner_rhs));
+}
+
 /******** Commutative Reducer ********/
 
 bool MatchReducer(const CommReducer& reducer, const PrimExpr& identity, const 
PrimExpr& combiner,
diff --git a/src/tir/schedule/primitive/reduction.cc 
b/src/tir/schedule/primitive/reduction.cc
index 0f85168..9c33076 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -370,69 +370,6 @@ class NotSerialLoopKindError : public ScheduleError {
   For loop_;
 };
 
-class InitBodyNotBufferStoreError : public ScheduleError {
- public:
-  explicit InitBodyNotBufferStoreError(IRModule mod, Block block, bool 
init_is_bufferstore,
-                                       bool body_is_bufferstore)
-      : mod_(std::move(mod)),
-        block_(std::move(block)),
-        init_is_bufferstore_(init_is_bufferstore),
-        body_is_bufferstore_(body_is_bufferstore) {}
-
-  String FastErrorString() const final {
-    return "ScheduleError: The `init` and `body` of reduction block are 
required to be both "
-           "BufferStore";
-  }
-
-  String DetailRenderTemplate() const final {
-    if (!init_is_bufferstore_ && !body_is_bufferstore_) {
-      return "The `init` and `body` of block {0} are required to be 
BufferStore so that rfactor "
-             "can be applied";
-    } else if (!init_is_bufferstore_) {
-      return "The `init` of block {0} is required to be BufferStore so that 
rfactor can be applied";
-    } else {
-      ICHECK(!body_is_bufferstore_);
-      return "The `body` of block {0} is required to be BufferStore so that 
rfactor can be applied";
-    }
-  }
-
-  IRModule mod() const final { return mod_; }
-  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
-
-  IRModule mod_;
-  Block block_;
-  bool init_is_bufferstore_;
-  bool body_is_bufferstore_;
-};
-
-class InitBodyNotSameBufferAccessError : public ScheduleError {
- public:
-  explicit InitBodyNotSameBufferAccessError(IRModule mod, Block block)
-      : mod_(std::move(mod)), block_(std::move(block)) {}
-
-  String FastErrorString() const final {
-    return "ScheduleError: The `init` and `body` of the reduction block are 
required to have the "
-           "same buffer access pattern";
-  }
-
-  String DetailRenderTemplate() const final {
-    std::ostringstream os;
-    const auto* init = block_->init.as<BufferStoreNode>();
-    const auto* update = block_->body.as<BufferStoreNode>();
-    os << "The `init` and `body` of the block {0} is required to have the same 
buffer access "
-          "pattern. However, in block {0} the `init` writes to "
-       << init->buffer->name << init->indices << ", and the `body` writes to "
-       << update->buffer->name << update->indices;
-    return os.str();
-  }
-
-  IRModule mod() const final { return mod_; }
-  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
-
-  IRModule mod_;
-  Block block_;
-};
-
 class FactorAxisOutOfRangeError : public ScheduleError {
  public:
   explicit FactorAxisOutOfRangeError(IRModule mod, Buffer buffer, int 
factor_axis)
@@ -473,32 +410,6 @@ class FactorAxisOutOfRangeError : public ScheduleError {
   int factor_axis_;
 };
 
-class NoMatchedReducerError : public ScheduleError {
- public:
-  explicit NoMatchedReducerError(IRModule mod, PrimExpr identity, BufferStore 
combiner)
-      : mod_(std::move(mod)), identity_(std::move(identity)), 
combiner_(std::move(combiner)) {}
-
-  String FastErrorString() const final {
-    return "ScheduleError: No matched reducer for the identity and the 
combiner of this reduction "
-           "block. So rfactor cannot be applied.";
-  }
-
-  String DetailRenderTemplate() const final {
-    std::ostringstream os;
-    os << "No matched reducer for identity " << identity_ << " and combiner " 
<< combiner_
-       << "In this case rfactor cannot be applied. You can check 
tvm::tir::ReducerRegistry for "
-          "default reducers or registering new reducers.";
-    return os.str();
-  }
-
-  IRModule mod() const final { return mod_; }
-  Array<ObjectRef> LocationsOfInterest() const final { return {}; }
-
-  IRModule mod_;
-  PrimExpr identity_;
-  BufferStore combiner_;
-};
-
 class LoopPropertyError : public ScheduleError {
  public:
   enum ErrorType {
@@ -592,53 +503,6 @@ class LoopPropertyError : public ScheduleError {
 };
 
 /*!
- * \brief Convert the `init` and `body` of the input block to BufferStores
- * \param self The schedule state
- * \param block The block to be analyzed
- * \return The BufferStores of the `init` and `body` of the input block
- * \throw ScheduleError If the `init` or `body` is not BufferStore, or they 
don't write to the same
- * buffer
- */
-std::pair<BufferStore, BufferStore> GetBufferStoreNodes(const ScheduleState& 
self,
-                                                        const Block& block) {
-  const auto* init = block->init.as<BufferStoreNode>();
-  const auto* body = block->body.as<BufferStoreNode>();
-  if (!(init && body)) {
-    throw InitBodyNotBufferStoreError(self->mod, block, init != nullptr, body 
!= nullptr);
-  }
-  if (!init->buffer.same_as(body->buffer)) {
-    throw InitBodyNotSameBufferAccessError(self->mod, block);
-  }
-  int ndim = static_cast<int>(init->buffer->shape.size());
-  for (int i = 0; i < ndim; ++i) {
-    if (!ExprDeepEqual()(init->indices[i], body->indices[i])) {
-      throw InitBodyNotSameBufferAccessError(self->mod, block);
-    }
-  }
-  return std::make_pair(GetRef<BufferStore>(init), GetRef<BufferStore>(body));
-}
-
-/*!
- * \brief Given a reduction identity and a reduction combiner, detect the 
corresponding commutative
- * reducer, and extract the combiner lhs and combiner rhs
- * \param self The schedule state
- * \param identity The reduction identity to be analyzed
- * \param combiner The reduction combiner to be analyzed
- * \return The corresponding CommReducer, the combiner lhs and the combiner rhs
- * \throw ScheduleError If no corresponding commutative reducer can be matched
- */
-std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs(
-    const ScheduleState& self, const PrimExpr& identity, const BufferStore& 
combiner) {
-  CommReducer reducer{nullptr};
-  PrimExpr combiner_lhs{nullptr}, combiner_rhs{nullptr};
-  bool matched = FromIdentityCombiner(identity, combiner, &reducer, 
&combiner_lhs, &combiner_rhs);
-  if (!matched) {
-    throw NoMatchedReducerError(self->mod, identity, combiner);
-  }
-  return std::make_tuple(std::move(reducer), std::move(combiner_lhs), 
std::move(combiner_rhs));
-}
-
-/*!
  * \brief For each loop in the given array of loop, associate its loop var 
with the loop itself
  * using a mapping
  * \param loops The loops to be analyzed
@@ -1177,7 +1041,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& 
rf_loop_sref, int factor_ax
   BufferStore update;
   CommReducer reducer;
   PrimExpr combiner_lhs, combiner_rhs;
-  std::tie(init, update) = GetBufferStoreNodes(self, block);
+  std::tie(init, update) = GetBufferStoresFromReductionBlock(self, block);
   std::tie(reducer, combiner_lhs, combiner_rhs) =
       GetReducerAndCombinerLhsRhs(self, init->value, update);
 
diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc 
b/src/tir/transforms/lower_cross_thread_reduction.cc
new file mode 100644
index 0000000..630c00f
--- /dev/null
+++ b/src/tir/transforms/lower_cross_thread_reduction.cc
@@ -0,0 +1,645 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file lower_cross_thread_reduction.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include "../schedule/analysis.h"
+#include "./ir_utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Checks if a loop is bound to threadIdx.x/y/z
+ * \brief loop The loop to be checked
+ * \return True if the loop is bound to threadIdx.x/y/z
+ */
+bool IsBoundToThreadIdx(const ForNode* loop) {
+  if (!loop->thread_binding.defined()) {
+    return false;
+  }
+  runtime::ThreadScope scope =
+      runtime::ThreadScope::Create(loop->thread_binding.value()->thread_tag);
+  return scope.rank == 1 && scope.dim_index >= 0;
+}
+
+/*!
+ * \brief Check the dominant property of a block:
+ * the block is the only writer of its output, dominating the reader of its 
output buffers
+ * \param scope_block The scope block of the block to be checked
+ * \param block The block whose dominant property is to be checked
+ * \return A boolean indicating if the block is a dominant block
+ */
+bool IsDominantBlock(const Block& scope_block, const Block& block) {
+  // Step 1. Count the number of writers for each buffer written by the scope 
block.
+  std::unordered_map<const BufferNode*, int> buffer_writer_cnt;
+  PreOrderVisit(scope_block->body, [&buffer_writer_cnt](const ObjectRef& obj) {
+    if (const auto* block = obj.as<BlockNode>()) {
+      for (const BufferRegion& buffer_region : block->writes) {
+        ++buffer_writer_cnt[buffer_region->buffer.get()];
+      }
+      return false;
+    }
+    return true;
+  });
+  // Step 2. Check whether `block` is the only writer of its outputs.
+  for (const BufferRegion& buffer_region : block->writes) {
+    ICHECK(buffer_writer_cnt.count(buffer_region->buffer.get()));
+    if (buffer_writer_cnt[buffer_region->buffer.get()] != 1) {
+      return false;
+    }
+  }
+  return true;
+}
+
+/*!
+ * \brief Check whether the input block is a reduction block.
+ * \param realize The block to be checked
+ * \param loop_range_map The mapping from the loop variables outside the input 
block to their ranges
+ * \param scope_block The scope block of the input block
+ * \param analyzer The analyzer
+ * \return A boolean indicating whether the input block is a reduction block.
+ * \note A similar check has been implemented in 
"src/tir/schedule/analysis.h", but that check is
+ * based on `tir.Schedule`. Here we have no schedule information, and thus we 
must implement the
+ * check again.
+ */
+bool IsReductionBlock(const BlockRealize& realize, const Map<Var, Range>& 
loop_range_map,
+                      const Block& scope_block, arith::Analyzer* analyzer) {
+  const auto* block = realize->block.as<BlockNode>();
+  // Cond 1. The block has the `init` statement.
+  if (!block->init.defined()) {
+    return false;
+  }
+  // Cond 2. All the block bindings are quasi-affine expressions.
+  if (!IsAffineBinding(realize, loop_range_map, analyzer)) {
+    return false;
+  }
+  // Cond 3. All block vars are either data parallel block vars or reduction 
block vars. Meanwhile,
+  // we collect all the reduction block vars.
+  if (!ContainsOnlyDataParAndReductionBlockIter(block->iter_vars)) {
+    return false;
+  }
+  // Cond 4. Dominant: the block is the only writer of its output, dominating 
the reader of its
+  // output buffers.
+  if (!IsDominantBlock(scope_block, GetRef<Block>(block))) {
+    return false;
+  }
+  // Cond 5. The reduction block vars are not used to index the output buffers.
+  return ReductionIterNotIndexOutputBuffer(GetRef<Block>(block));
+}
+
+/*!
+ * \brief Create an intermediate buffer with specified name and data type
+ * \param name The specified name
+ * \param dtype The specified data type
+ * \return The created buffer
+ */
+Buffer MakeScratchpad(String name, const DataType& dtype) {
+  return Buffer(/*ptr=*/Var(name, PointerType(PrimType(dtype), "local")),
+                /*dtype=*/dtype,
+                /*shape=*/{Integer(1)},
+                /*strides=*/{Integer(1)},
+                /*elem_offset=*/PrimExpr{nullptr},
+                /*name=*/std::move(name),
+                /*data_alignment=*/0,
+                /*offset_factor=*/0,
+                /*buffer_type=*/kDefault);
+}
+
+/*!
+ * \brief Remove the BufferRegions whose buffer is the input buffer
+ * \param buffer_regions The array of BufferRegions to be
+ * \param buffer_to_remove The specified buffer
+ * \return The mutated array of BufferRegions, no longer containing 
BufferRegion of the input buffer
+ */
+Array<BufferRegion> RemoveBufferFromBufferRegions(const Array<BufferRegion>& 
buffer_regions,
+                                                  const Buffer& 
buffer_to_remove) {
+  Array<BufferRegion> res;
+  res.reserve(buffer_regions.size());
+  for (const BufferRegion& buffer_region : buffer_regions) {
+    if (!buffer_region->buffer.same_as(buffer_to_remove)) {
+      res.push_back(buffer_region);
+    }
+  }
+  return res;
+}
+
+/*!
+ * \brief Substitute a given source buffer with a given target buffer in 
statements or expressions
+ */
+class BufferReplacer : private StmtExprMutator {
+ public:
+  static Stmt Run(Buffer src_buffer, Buffer tgt_buffer, Stmt stmt) {
+    return BufferReplacer(src_buffer, tgt_buffer)(std::move(stmt));
+  }
+
+ private:
+  explicit BufferReplacer(Buffer src_buffer, Buffer tgt_buffer)
+      : src_buffer_(std::move(src_buffer)), tgt_buffer_(std::move(tgt_buffer)) 
{}
+
+  PrimExpr VisitExpr_(const BufferLoadNode* load) final {
+    return load->buffer.same_as(src_buffer_) ? BufferLoad(tgt_buffer_, {0})
+                                             : GetRef<BufferLoad>(load);
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* store) final {
+    if (store->buffer.same_as(src_buffer_)) {
+      PrimExpr value = StmtExprMutator::VisitExpr(store->value);
+      return BufferStore(tgt_buffer_, value, {0});
+    } else {
+      return StmtMutator::VisitStmt_(store);
+    }
+  }
+
+  Buffer src_buffer_;
+  Buffer tgt_buffer_;
+};
+
+/*!
+ * \brief Substitute a given source block with a given target block, or remove 
the source block
+ * branch from the AST if the target block is undefined
+ */
+class InThreadReducerMaker : private StmtMutator {
+ public:
+  static Optional<Stmt> Make(const BlockRealizeNode* src_realize,
+                             Optional<BlockRealize> tgt_realize, Stmt stmt) {
+    return InThreadReducerMaker(src_realize, 
std::move(tgt_realize))(std::move(stmt));
+  }
+
+ private:
+  explicit InThreadReducerMaker(const BlockRealizeNode* src_realize,
+                                Optional<BlockRealize> tgt_realize)
+      : src_realize_(src_realize), tgt_realize_(tgt_realize) {}
+  Stmt VisitStmt_(const BlockRealizeNode* realize) final {
+    if (realize == src_realize_) {
+      return tgt_realize_.defined()  //
+                 ? tgt_realize_.value()
+                 : Stmt{nullptr};
+    }
+    return GetRef<BlockRealize>(realize);
+  }
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    if (Optional<For> opt_res = 
Downcast<Optional<For>>(StmtMutator::VisitStmt_(loop))) {
+      For res = opt_res.value();
+      if (res->thread_binding.defined()) {
+        return res->body;
+      } else {
+        return res;
+      }
+    } else {
+      return Stmt{nullptr};
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> stmts;
+    stmts.reserve(seq->size());
+    for (const Stmt& stmt : seq->seq) {
+      if (Optional<Stmt> opt_res = VisitStmt(stmt)) {
+        stmts.push_back(opt_res.value());
+      }
+    }
+    return stmts.empty() ? Stmt{nullptr} : SeqStmt::Flatten(stmts);
+  }
+
+  const BlockRealizeNode* src_realize_;
+  Optional<BlockRealize> tgt_realize_;
+};
+
+/*!
+ * \brief Create the lowered allreduce block transformed from the input 
reduction block
+ * \param reduction_block The input reduction block
+ * \param it_buffer The buffer to store in-thread reduction results
+ * \param ct_buffer The buffer to store cross-thread reduction results
+ * \param reducer The reduction function
+ * \param combiner_rhs The RHS of the combiner
+ * \param reduction_loops The reduction loops
+ */
+Stmt TransformReductionBlock(const BlockRealizeNode* realize, const 
Optional<Buffer>& it_buffer,
+                             const Buffer& ct_buffer, const CommReducer& 
reducer,
+                             const PrimExpr& combiner_rhs,
+                             const std::vector<const ForNode*>& 
reduction_loops) {
+  const BlockNode* block = realize->block.get();
+  Buffer wb_buffer = block->writes[0]->buffer;
+  Array<Range> wb_region = block->writes[0]->region;
+
+  BufferRegion ct_buffer_region(ct_buffer, {Range::FromMinExtent(0, 1)});
+  Optional<BufferRegion> it_buffer_region = NullOpt;
+  if (it_buffer.defined()) {
+    it_buffer_region = BufferRegion(it_buffer.value(), 
{Range::FromMinExtent(0, 1)});
+  }
+  // In total, the block is transformed into at most 4 statements
+  // - Stmt 1: initialize the buffer for in-thread reduction
+  // - Stmt 2: do in-thread reduction
+  // - Stmt 3: do cross-thread reduction
+  // - Stmt 4: write cross-thread reduction result to the original buffer
+  Array<Stmt> stmts;
+  stmts.reserve(4);
+  // Stmt 1: initialize the buffer for in-thread reduction
+  if (it_buffer.defined()) {
+    BufferStore init = Downcast<BufferStore>(block->init);
+    stmts.push_back(BlockRealize(
+        /*iter_values=*/{},
+        /*predicate=*/const_true(),
+        /*block=*/
+        Block(/*iter_vars=*/{},
+              /*reads=*/{},
+              /*writes=*/{it_buffer_region.value()},
+              /*name_hint=*/block->name_hint + "_in_thread_init",
+              /*body=*/
+              BufferStore(/*buffer=*/it_buffer.value(),
+                          /*value=*/init->value,
+                          /*indices=*/{Integer(0)}))));
+  }
+  // Stmt 2: do in-thread reduction
+  {
+    Optional<BlockRealize> new_realize = NullOpt;
+    // If need to generate in-thread reduction,
+    // then replace `wb_buffer` with `it_buffer` accordingly in given 
BlockRealize
+    // otherwise, directly remove given BlockRealize
+    if (it_buffer.defined()) {
+      ObjectPtr<BlockNode> new_block = make_object<BlockNode>(*block);
+      new_block->reads = 
RemoveBufferFromBufferRegions(std::move(new_block->reads), wb_buffer);
+      new_block->reads.push_back(it_buffer_region.value());
+      new_block->writes = {it_buffer_region.value()};
+      new_block->name_hint = new_block->name_hint + "_in_thread";
+      new_block->body =
+          BufferReplacer::Run(wb_buffer, it_buffer.value(), 
std::move(new_block->body));
+      new_block->init = NullOpt;
+      ObjectPtr<BlockRealizeNode> n = make_object<BlockRealizeNode>(*realize);
+      n->block = Block(new_block);
+      new_realize = BlockRealize(n);
+    }
+    For loop = GetRef<For>(reduction_loops[0]);
+    if (Optional<Stmt> stmt = InThreadReducerMaker::Make(realize, new_realize, 
std::move(loop))) {
+      stmts.push_back(stmt.value());
+    }
+  }
+  // Stmt 3: do cross-thread reduction
+  {
+    // Step 3.1. Create the parameters to the intrinsic
+    Array<PrimExpr> parameters;
+    parameters.reserve(reduction_loops.size() + 4);
+    // 1-st argument: size
+    parameters.push_back(make_const(DataType::UInt(32), 1));
+    // 2-nd argument: source
+    if (it_buffer.defined()) {
+      parameters.push_back(BufferLoad(it_buffer.value(), {Integer(0)}));
+    } else {
+      parameters.push_back(combiner_rhs);
+    }
+    // 3-rd argument: predicate
+    parameters.push_back(const_true());
+    // 4-th argument: destination
+    parameters.push_back(ct_buffer->data);
+    // next arguments: all the reduction threads
+    for (const ForNode* reduction_loop : reduction_loops) {
+      if (reduction_loop->thread_binding.defined()) {
+        parameters.push_back(reduction_loop->loop_var);
+      }
+    }
+    // Step 3.2. Create the block and the block-realize.
+    Array<IterVar> iter_vars{nullptr};
+    Array<PrimExpr> bindings{nullptr};
+    Array<BufferRegion> reads{nullptr};
+    if (it_buffer.defined()) {
+      iter_vars = Array<IterVar>{};
+      bindings = Array<PrimExpr>{};
+      reads = {it_buffer_region.value()};
+    } else {
+      iter_vars = block->iter_vars;
+      bindings = realize->iter_values;
+      reads = {RemoveBufferFromBufferRegions(block->reads, wb_buffer)};
+    }
+    stmts.push_back(BlockRealize(
+        /*iter_values=*/std::move(bindings),
+        /*predicate=*/const_true(),
+        /*block=*/
+        Block(/*iter_vars=*/std::move(iter_vars),
+              /*reads=*/std::move(reads),
+              /*writes=*/{ct_buffer_region},
+              /*name_hint=*/block->name_hint + "_cross_thread",
+              /*body=*/
+              AttrStmt(/*node=*/reducer,
+                       /*attr_key=*/tir::attr::reduce_scope,
+                       /*value=*/make_zero(DataType::Handle()),
+                       /*body=*/
+                       Evaluate(Call(/*dtype=*/DataType::Handle(),
+                                     
/*op=*/tir::builtin::tvm_thread_allreduce(),
+                                     /*args=*/std::move(parameters)))))));
+  }
+  // Stmt 4: write cross-thread reduction result to the original buffer
+  {
+    ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size());
+    int n_iter = static_cast<int>(block->iter_vars.size());
+    Array<IterVar> iter_vars;
+    Array<PrimExpr> bindings;
+    Map<Var, PrimExpr> var_map;
+    iter_vars.reserve(n_iter);
+    bindings.reserve(n_iter);
+    for (int i = 0; i < n_iter; ++i) {
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != kCommReduce) {
+        IterVar new_iter_var{nullptr};
+        {
+          ObjectPtr<IterVarNode> n = make_object<IterVarNode>(*iter_var.get());
+          ObjectPtr<VarNode> v = make_object<VarNode>(*iter_var->var.get());
+          n->var = Var(v);
+          new_iter_var = IterVar(n);
+        }
+        iter_vars.push_back(new_iter_var);
+        bindings.push_back(binding);
+        var_map.Set(iter_var->var, new_iter_var->var);
+      }
+    }
+    BufferStore update = Downcast<BufferStore>(block->body);
+    update = Downcast<BufferStore>(Substitute(std::move(update), var_map));
+    stmts.push_back(BlockRealize(
+        /*iter_values=*/std::move(bindings),
+        /*predicate=*/const_true(),
+        /*block=*/
+        Block(
+            /*iter_vars=*/std::move(iter_vars),
+            /*reads=*/{std::move(ct_buffer_region)},
+            /*writes=*/{BufferRegion(wb_buffer, Substitute(wb_region, 
var_map))},
+            /*name_hint=*/block->name_hint + "_write_back",
+            /*body=*/
+            BufferStore(/*buffer=*/wb_buffer,
+                        /*value=*/BufferLoad(ct_buffer, {Integer(0)}),
+                        /*indices=*/update->indices))));
+  }
+  // Final step: Wrap all the above four statements with the reduction loops 
bound to threadIdx
+  Stmt new_stmt = SeqStmt::Flatten(std::move(stmts));
+  for (auto rit = reduction_loops.rbegin(); rit != reduction_loops.rend(); 
++rit) {
+    const ForNode* loop = *rit;
+    if (loop->thread_binding.defined()) {
+      ObjectPtr<ForNode> n = make_object<ForNode>(*loop);
+      n->body = std::move(new_stmt);
+      new_stmt = For(n);
+    }
+  }
+  return new_stmt;
+}
+
+/*!
+ * \brief Detect cross-thread reduction pattern and then transform
+ */
+class CrossThreadReductionTransformer : public StmtMutator {
+ private:
+  // Check if the input block needs cross-thread reduction.
+  std::vector<const ForNode*> NeedCrossThreadReduction(const BlockRealizeNode* 
realize) {
+    // Step 0. If the block is the root block, just return.
+    if (block_stack_.empty()) {
+      return {};
+    }
+
+    // Step 1. If the block is not a reduction block, cross-thread reduction 
is not needed.
+    if (!IsReductionBlock(GetRef<BlockRealize>(realize), loop_range_map_,
+                          GetRef<Block>(block_stack_.back()), &analyzer_)) {
+      return {};
+    }
+
+    // Step 2. Collect all the vars that appear in the bindings of reduction 
block iters.
+    std::unordered_set<const VarNode*> reduction_vars;
+    GetVarsTouchedByBlockIters(GetRef<BlockRealize>(realize), nullptr, 
&reduction_vars);
+
+    // Step 3. Collect the loops whose loop vars appear in the bindings of 
reduction block iters.
+    // We call these loops "reduction-related".
+    // Step 4. See whether at least one reduction-related loop is bound to 
thread axis in GPU - if
+    // so, cross-thread reduction is needed. If none of the reduction-related 
loops is bound to
+    // thread axis, cross-thread reduction is not needed for the input block.
+    bool need = false;
+    std::vector<const ForNode*> reduction_loops;
+    for (const ForNode* loop : loop_stack_) {
+      if (reduction_vars.count(loop->loop_var.get())) {
+        // Step 3. Collect the loop.
+        reduction_loops.push_back(loop);
+        // Step 4. See whether the loop is bound to some thread axis.
+        if (loop->thread_binding.defined()) {
+          need = true;
+        }
+      }
+    }
+    return need ? reduction_loops : std::vector<const ForNode*>{};
+  }
+
+  // Given that the input block needs cross-thread reduction, check if 
cross-thread reduction can
+  // be applied to the block (i.e., the block satisfies all necessary 
conditions of cross-thread
+  // reduction).
+  std::tuple<int, CommReducer, PrimExpr> CheckCanApplyCrossThreadReduction(
+      const BlockNode* block, const std::vector<const ForNode*>& 
reduction_loops) const {
+    // Condition 1. The block being applied cross-thread reduction should 
write to single buffer.
+    CHECK_EQ(block->writes.size(), 1)
+        << "ValueError: Cross-thread reduction requires the block to only "
+           "write to single buffer. However, the block "
+        << block->name_hint << " writes to " << block->writes.size() << " 
buffer(s).";
+
+    // Condition 2. All the reduction-related loops should be the deepest 
among all statements
+    // outside the block (ignoring SeqStmt here).
+    int n_deepest_reduction_loops = 0;
+    for (auto rit = statement_stack_.rbegin() + 1; rit != 
statement_stack_.rend(); ++rit) {
+      const StmtNode* stmt = *rit;
+      if ((*rit)->IsInstance<SeqStmtNode>()) {
+        // Skip SeqStmt.
+        continue;
+      }
+      if (std::find(reduction_loops.begin(), reduction_loops.end(),
+                    reinterpret_cast<const ForNode*>(stmt)) == 
reduction_loops.end()) {
+        break;
+      }
+      ++n_deepest_reduction_loops;
+    }
+    CHECK_EQ(n_deepest_reduction_loops, reduction_loops.size())
+        << "ValueError: Cross-thread reduction requires all the 
reduction-related loops to be the "
+           "deepest among all statements outside the desired block. However, 
block "
+        << block->name_hint
+        << " needs cross-thread reduction, while the reduction-related loops 
outside of it are not "
+           "the deepest statements, which violates the condition.";
+
+    // Condition 3. All the reduction-related loops that are bound to thread 
axes should only be
+    // bound to `threadIdx.x/y/z`.
+    int n_bound_reduction_loops = 0;
+    for (const ForNode* reduction_loop : reduction_loops) {
+      if (reduction_loop->thread_binding.defined()) {
+        ++n_bound_reduction_loops;
+        CHECK(IsBoundToThreadIdx(reduction_loop))
+            << "ValueError: Cross-thread reduction requires all the 
reduction-related loops that "
+               "are bound to GPU thread axes to only be bound 
`threadIdx.x/y/z`. However, loop "
+            << reduction_loop->loop_var->name_hint << " violates the 
condition.";
+      }
+    }
+
+    // Condition 4. Get the `init` identity and the `update` combiner of the 
reduction. They should
+    // both be BufferStores with the same buffer and indices;
+    // Extract the commutative reducer, combiner lhs and combiner rhs from the 
reduction identity
+    // and the reduction combiner.
+    BufferStore init{nullptr};
+    BufferStore update{nullptr};
+    CommReducer reducer{nullptr};
+    PrimExpr combiner_lhs{nullptr};
+    PrimExpr combiner_rhs{nullptr};
+    std::tie(init, update) = GetBufferStoresFromReductionBlock(NullOpt, 
GetRef<Block>(block));
+    std::tie(reducer, combiner_lhs, combiner_rhs) =
+        GetReducerAndCombinerLhsRhs(NullOpt, init->value, update);
+
+    // Condition 5. The block should be the last block under the first 
reduction-related loop.
+    bool visit = false;
+    PreOrderVisit(GetRef<For>(reduction_loops[0]), [block, &visit](const 
ObjectRef& obj) {
+      if (const auto* realize = obj.as<BlockRealizeNode>()) {
+        CHECK(!visit) << "ValueError: Cross-thread reduction cannot be applied 
when the reduction "
+                         "block isn't the last block under its first 
reduction-related loop";
+        if (realize->block.get() == block) {
+          visit = true;
+        }
+        return false;
+      }
+      return true;
+    });
+    return std::make_tuple(n_bound_reduction_loops, reducer, combiner_rhs);
+  }
+
+  Stmt VisitStmt(const Stmt& stmt) final {
+    statement_stack_.push_back(stmt.get());
+    Stmt result = StmtMutator::VisitStmt(stmt);
+    statement_stack_.pop_back();
+    return result;
+  }
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    loop_stack_.push_back(loop);
+    loop_range_map_.Set(loop->loop_var, Range::FromMinExtent(loop->min, 
loop->extent));
+    Stmt result = StmtMutator::VisitStmt_(loop);
+    loop_stack_.pop_back();
+    loop_range_map_.erase(loop->loop_var);
+
+    // Replace `result` with the pre-stored result if `loop` appears as a key 
in `loop2new_stmt_`.
+    auto it = loop2new_stmt_.find(loop);
+    if (it != loop2new_stmt_.end()) {
+      return it->second;
+    } else {
+      return result;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    Map<Var, Range> old_loop_range_map;
+
+    block_stack_.push_back(block);
+    std::swap(old_loop_range_map, loop_range_map_);
+    Block new_block = Downcast<Block>(StmtMutator::VisitStmt_(block));
+    block_stack_.pop_back();
+    std::swap(old_loop_range_map, loop_range_map_);
+
+    // Insert the new allocated buffers into the block's `alloc_buffers` field.
+    auto it = block2new_buffers_.find(block);
+    if (it != block2new_buffers_.end()) {
+      BlockNode* p_new_block = new_block.CopyOnWrite();
+      for (const Buffer& new_buffer : it->second) {
+        if (new_buffer.defined()) {
+          p_new_block->alloc_buffers.push_back(new_buffer);
+        }
+      }
+    }
+    return new_block;
+  }
+
+  Stmt VisitStmt_(const BlockRealizeNode* realize) final {
+    const BlockNode* block = realize->block.get();
+    // Step 1. Check whether cross-thread reduction is needed. If no, skip 
this block.
+    std::vector<const ForNode*> reduction_loops = 
NeedCrossThreadReduction(realize);
+    if (reduction_loops.empty()) {
+      return StmtMutator::VisitStmt_(realize);
+    }
+    ++reduction_id_;
+    // Step 2. Check whether cross-thread reduction can be applied. If no, 
throw an exception on
+    // which condition the block violates.
+    int n_bound_reduction_loops = 0;
+    CommReducer reducer{nullptr};
+    PrimExpr combiner_rhs{nullptr};
+    std::tie(n_bound_reduction_loops, reducer, combiner_rhs) =
+        CheckCanApplyCrossThreadReduction(block, reduction_loops);
+    // Step 3. When not all the reduction-related loops are bound to thread 
axes, in-thread
+    // reduction is needed in this cross-thread reduction.
+    bool need_in_thread_reduction =
+        n_bound_reduction_loops < static_cast<int>(reduction_loops.size());
+    // Step 4. Create intermediate buffers, storing them in `ct_buffer` and
+    // `it_buffer`. Let the scope block allocate these new buffers.
+    std::vector<Buffer>& new_buffers = block2new_buffers_[block_stack_.back()];
+    DataType dtype = block->writes[0]->buffer->dtype;
+    Buffer ct_buffer = MakeScratchpad("cross_thread_" + 
std::to_string(reduction_id_), dtype);
+    new_buffers.push_back(ct_buffer);
+    Optional<Buffer> it_buffer = NullOpt;
+    if (need_in_thread_reduction) {
+      it_buffer = MakeScratchpad("in_thread_" + std::to_string(reduction_id_), 
dtype);
+      new_buffers.push_back(it_buffer.value());
+    }
+    // Step 5. Transform.
+    loop2new_stmt_[reduction_loops[0]] = TransformReductionBlock(
+        realize, it_buffer, ct_buffer, reducer, combiner_rhs, reduction_loops);
+    // Step 6. Return an empty statement, because the transformation result 
will be inserted when
+    // returning to the first reduction-related loop.
+    return Stmt{nullptr};
+  }
+
+ private:
+  int reduction_id_ = -1;
+  std::vector<const StmtNode*> statement_stack_;
+  std::vector<const ForNode*> loop_stack_;
+  std::vector<const BlockNode*> block_stack_;
+  std::unordered_map<const BlockNode*, std::vector<Buffer>> block2new_buffers_;
+  std::unordered_map<const ForNode*, Stmt> loop2new_stmt_;
+  Map<Var, Range> loop_range_map_;
+  arith::Analyzer analyzer_;
+};
+
+PrimFunc LowerCrossThreadReduction(PrimFunc f) {
+  // Only apply this pass to TIR that is not from TE schedules
+  if (!IsFromLegacyTESchedule(f)) {
+    PrimFuncNode* fptr = f.CopyOnWrite();
+    fptr->body = CrossThreadReductionTransformer()(f->body);
+    return f;
+  } else {
+    return f;
+  }
+}
+
+namespace transform {
+
+Pass LowerCrossThreadReduction() {
+  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    return LowerCrossThreadReduction(std::move(f));
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.LowerCrossThreadReduction", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LowerCrossThreadReduction")
+    .set_body_typed(LowerCrossThreadReduction);
+
+}  // namespace transform
+
+}  // namespace tir
+}  // namespace tvm
diff --git 
a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py 
b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
new file mode 100644
index 0000000..4fa3ab0
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
@@ -0,0 +1,737 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import sys
+
+import pytest
+import tvm
+from tvm import te
+from tvm.script import tir as T
+
+
+def _check(original, transformed):
+    mod = tvm.IRModule.from_expr(original)
+    mod = tvm.tir.transform.LowerCrossThreadReduction()(mod)
+    tvm.ir.assert_structural_equal(mod["main"], transformed, True)
+
+
+def _check_fail(original):
+    mod = tvm.IRModule.from_expr(original)
+    with pytest.raises(ValueError):
+        tvm.tir.transform.LowerCrossThreadReduction()(mod)
+
+
[email protected]_func
+def loop_split(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128], dtype="float32")
+    B = T.match_buffer(b, [128], dtype="float32")
+    for i, ko in T.grid(128, 4):
+        for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
+            with T.block("B"):
+                vi = T.axis.S(128, i)
+                vk = T.axis.R(128, ko * 32 + ki)
+                T.reads([B[vi], A[vi, vk]])
+                T.writes([B[vi]])
+                with T.init():
+                    B[vi] = T.float32(0)
+                B[vi] = B[vi] + A[vi, vk]
+
+
[email protected]_func
+def lowered_loop_split(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128], dtype="float32")
+    B = T.match_buffer(b, [128], dtype="float32")
+    reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], 
scope="local")
+    normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], 
scope="local")
+    for i in T.serial(0, 128):
+        for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
+            with T.block("B_in_thread_init"):
+                T.reads([])
+                T.writes([normal_reduce_temp0[0]])
+                normal_reduce_temp0[0] = T.float32(0)
+            for ko in T.serial(0, 4):
+                with T.block("B_normal_reduction"):
+                    vi = T.axis.S(128, i)
+                    vk = T.axis.R(128, ko * 32 + ki)
+                    T.reads([A[vi, vk], normal_reduce_temp0[0]])
+                    T.writes([normal_reduce_temp0[0]])
+                    normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk]
+            with T.block("B_cross_thread_reduction"):
+                T.reads([normal_reduce_temp0[0]])
+                T.writes([reduce_temp0[0]])
+                T.attr(
+                    T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+                    "reduce_scope",
+                    T.reinterpret(T.uint64(0), dtype="handle"),
+                )
+                T.evaluate(
+                    T.tvm_thread_allreduce(
+                        T.uint32(1),
+                        normal_reduce_temp0[0],
+                        True,
+                        reduce_temp0.data,
+                        ki,
+                        dtype="handle",
+                    )
+                )
+            with T.block("B_write_back"):
+                vi = T.axis.S(128, i)
+                T.reads([reduce_temp0[0]])
+                T.writes([B[vi]])
+                B[vi] = reduce_temp0[0]
+
+
[email protected]_func
+def no_normal_reduction(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128], dtype="float32")
+    B = T.match_buffer(b, [128], dtype="float32")
+    for i in T.serial(0, 128):
+        for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+            with T.block("B"):
+                vi, vk = T.axis.remap("SR", [i, k])
+                T.reads([B[vi], A[vi, vk]])
+                T.writes([B[vi]])
+                with T.init():
+                    B[vi] = T.float32(0)
+                B[vi] = B[vi] + A[vi, vk]
+
+
[email protected]_func
+def lowered_no_normal_reduction(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128], dtype="float32")
+    B = T.match_buffer(b, [128], dtype="float32")
+    reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], 
scope="local")
+    for i in T.serial(0, 128):
+        for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+            with T.block("B_cross_thread_reduction"):
+                vi, vk = T.axis.remap("SR", [i, k])
+                T.reads([A[vi, vk]])
+                T.writes([reduce_temp0[0]])
+                T.attr(
+                    T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+                    "reduce_scope",
+                    T.reinterpret(T.uint64(0), dtype="handle"),
+                )
+                T.evaluate(
+                    T.tvm_thread_allreduce(
+                        T.uint32(1), A[vi, vk], True, reduce_temp0.data, k, 
dtype="handle"
+                    )
+                )
+            with T.block("B_write_back"):
+                vi = T.axis.spatial(128, i)
+                T.reads([reduce_temp0[0]])
+                T.writes([B[vi]])
+                B[vi] = reduce_temp0[0]
+
+
[email protected]_func
+def two_bound_loops(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128], dtype="float32")
+    B = T.match_buffer(b, [128], dtype="float32")
+    for i in T.serial(0, 128):
+        for ko in T.thread_binding(0, 4, thread="threadIdx.x"):
+            for ki in T.thread_binding(0, 32, thread="threadIdx.y"):
+                with T.block("B"):
+                    vi = T.axis.spatial(128, i)
+                    vk = T.axis.reduce(128, ko * 32 + ki)
+                    T.reads([B[vi], A[vi, vk]])
+                    T.writes([B[vi]])
+                    with T.init():
+                        B[vi] = T.float32(0)
+                    B[vi] = B[vi] + A[vi, vk]
+
+
[email protected]_func
+def lowered_two_bound_loops(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128], dtype="float32")
+    B = T.match_buffer(b, [128], dtype="float32")
+    reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], 
scope="local")
+    for i in T.serial(0, 128):
+        for ko in T.thread_binding(0, 4, thread="threadIdx.x"):
+            for ki in T.thread_binding(0, 32, thread="threadIdx.y"):
+                with T.block("B_cross_thread_reduction"):
+                    vi = T.axis.spatial(128, i)
+                    vk = T.axis.reduce(128, ko * 32 + ki)
+                    T.reads([A[vi, vk]])
+                    T.writes([reduce_temp0[0]])
+                    T.attr(
+                        T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+                        "reduce_scope",
+                        T.reinterpret(T.uint64(0), dtype="handle"),
+                    )
+                    T.evaluate(
+                        T.tvm_thread_allreduce(
+                            T.uint32(1), A[vi, vk], True, reduce_temp0.data, 
ko, ki, dtype="handle"
+                        )
+                    )
+                with T.block("B_write_back"):
+                    vi = T.axis.spatial(128, i)
+                    T.reads([reduce_temp0[0]])
+                    T.writes([B[vi]])
+                    B[vi] = reduce_temp0[0]
+
+
[email protected]_func
+def multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [16, 16, 16], dtype="float32")
+    B = T.match_buffer(b, [16], dtype="float32")
+    B_rf_local = T.alloc_buffer([16, 16], dtype="float32", scope="local")
+    for i in T.thread_binding(0, 16, thread="blockIdx.x"):
+        for k0o in T.thread_binding(0, 4, thread="threadIdx.x"):
+            for k0i0, k1 in T.grid(4, 16):
+                with T.block("B_rf"):
+                    vk0 = T.axis.spatial(16, k0o * 4 + k0i0)
+                    vi, vk1 = T.axis.remap("SR", [i, k1])
+                    T.reads([B_rf_local[vk0, vi], A[vi, vk0, vk1]])
+                    T.writes([B_rf_local[vk0, vi]])
+                    with T.init():
+                        B_rf_local[vk0, vi] = T.float32(0)
+                    B_rf_local[vk0, vi] = B_rf_local[vk0, vi] + A[vi, vk0, vk1]
+            for k0i1 in T.serial(0, 4):
+                with T.block("B"):
+                    vk0 = T.axis.reduce(16, k0o * 4 + k0i1)
+                    vi = T.axis.spatial(16, i)
+                    T.reads([B[vi], B_rf_local[vk0, vi]])
+                    T.writes([B[vi]])
+                    with T.init():
+                        B[vi] = T.float32(0)
+                    B[vi] = B[vi] + B_rf_local[vk0, vi]
+
+
[email protected]_func
+def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> 
None:
+    A = T.match_buffer(a, [16, 16, 16], dtype="float32")
+    B = T.match_buffer(b, [16], dtype="float32")
+    B_rf_local = T.alloc_buffer([16, 16], dtype="float32", scope="local")
+    reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], 
scope="local")
+    normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], 
scope="local")
+    for i in T.thread_binding(0, 16, thread="blockIdx.x"):
+        for k0o in T.thread_binding(0, 4, thread="threadIdx.x"):
+            with T.block("B_in_thread_init"):
+                T.reads([])
+                T.writes([normal_reduce_temp0[0]])
+                normal_reduce_temp0[0] = T.float32(0)
+            for k0i0, k1 in T.grid(4, 16):
+                with T.block("B_rf"):
+                    vk0 = T.axis.spatial(16, k0o * 4 + k0i0)
+                    vi, vk1 = T.axis.remap("SR", [i, k1])
+                    T.reads([B_rf_local[vk0, vi], A[vi, vk0, vk1]])
+                    T.writes([B_rf_local[vk0, vi]])
+                    with T.init():
+                        B_rf_local[vk0, vi] = T.float32(0)
+                    B_rf_local[vk0, vi] = B_rf_local[vk0, vi] + A[vi, vk0, vk1]
+            for k0i1 in T.serial(0, 4):
+                with T.block("B_normal_reduction"):
+                    vk0 = T.axis.reduce(16, k0o * 4 + k0i1)
+                    vi = T.axis.spatial(16, i)
+                    T.reads([B_rf_local[vk0, vi], normal_reduce_temp0[0]])
+                    T.writes([normal_reduce_temp0[0]])
+                    normal_reduce_temp0[0] = normal_reduce_temp0[0] + 
B_rf_local[vk0, vi]
+            with T.block("B_cross_thread_reduction"):
+                T.reads([normal_reduce_temp0[0]])
+                T.writes([reduce_temp0[0]])
+                T.attr(
+                    T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+                    "reduce_scope",
+                    T.reinterpret(T.uint64(0), dtype="handle"),
+                )
+                T.evaluate(
+                    T.tvm_thread_allreduce(
+                        T.uint32(1),
+                        normal_reduce_temp0[0],
+                        True,
+                        reduce_temp0.data,
+                        k0o,
+                        dtype="handle",
+                    )
+                )
+            with T.block("B_write_back"):
+                vi = T.axis.spatial(16, i)
+                T.reads([reduce_temp0[0]])
+                T.writes([B[vi]])
+                B[vi] = reduce_temp0[0]
+
+
[email protected]_func
+def with_block_predicate(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128, 120], dtype="float32")
+    B = T.match_buffer(b, [128], dtype="float32")
+    for i, ko in T.grid(128, 4):
+        for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
+            with T.block("B"):
+                vi = T.axis.spatial(128, i)
+                vk = T.axis.reduce(120, ko * 32 + ki)
+                T.where(ko * 32 + ki < 120)
+                T.reads([B[vi], A[vi, vk]])
+                T.writes([B[vi]])
+                with T.init():
+                    B[vi] = T.float32(0)
+                B[vi] = B[vi] + A[vi, vk]
+
+
[email protected]_func
+def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128, 120], dtype="float32")
+    B = T.match_buffer(b, [128], dtype="float32")
+    reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], 
scope="local")
+    normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], 
scope="local")
+    for i in T.serial(0, 128):
+        for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
+            with T.block("B_in_thread_init"):
+                T.reads([])
+                T.writes([normal_reduce_temp0[0]])
+                normal_reduce_temp0[0] = T.float32(0)
+            for ko in T.serial(0, 4):
+                with T.block("B_normal_reduction"):
+                    vi = T.axis.spatial(128, i)
+                    vk = T.axis.reduce(120, ko * 32 + ki)
+                    T.where(ko * 32 + ki < 120)
+                    T.reads([A[vi, vk], normal_reduce_temp0[0]])
+                    T.writes([normal_reduce_temp0[0]])
+                    normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk]
+            with T.block("B_cross_thread_reduction"):
+                T.reads([normal_reduce_temp0[0]])
+                T.writes([reduce_temp0[0]])
+                T.attr(
+                    T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+                    "reduce_scope",
+                    T.reinterpret(T.uint64(0), dtype="handle"),
+                )
+                T.evaluate(
+                    T.tvm_thread_allreduce(
+                        T.uint32(1),
+                        normal_reduce_temp0[0],
+                        True,
+                        reduce_temp0.data,
+                        ki,
+                        dtype="handle",
+                    )
+                )
+            with T.block("B_write_back"):
+                vi = T.axis.spatial(128, i)
+                T.reads([reduce_temp0[0]])
+                T.writes([B[vi]])
+                B[vi] = reduce_temp0[0]
+
+
[email protected]_func
+def reducer_max(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128], dtype="float32")
+    B = T.match_buffer(b, [128], dtype="float32")
+    for i in T.serial(0, 128):
+        for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+            with T.block("B"):
+                vi, vk = T.axis.remap("SR", [i, k])
+                T.reads([B[vi], A[vi, vk]])
+                T.writes([B[vi]])
+                with T.init():
+                    B[vi] = T.min_value("float32")
+                B[vi] = T.max(B[vi], A[vi, vk])
+
+
[email protected]_func
+def lowered_reducer_max(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128], dtype="float32")
+    B = T.match_buffer(b, [128], dtype="float32")
+    reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], 
scope="local")
+    for i in T.serial(0, 128):
+        for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+            with T.block("B_cross_thread_reduction"):
+                vi, vk = T.axis.remap("SR", [i, k])
+                T.reads([A[vi, vk]])
+                T.writes([reduce_temp0[0]])
+                T.attr(
+                    T.comm_reducer(lambda x, y: T.max(x, y), 
[T.min_value("float32")]),
+                    "reduce_scope",
+                    T.reinterpret(T.uint64(0), dtype="handle"),
+                )
+                T.evaluate(
+                    T.tvm_thread_allreduce(
+                        T.uint32(1), A[vi, vk], True, reduce_temp0.data, k, 
dtype="handle"
+                    )
+                )
+            with T.block("B_write_back"):
+                vi = T.axis.spatial(128, i)
+                T.reads([reduce_temp0[0]])
+                T.writes([B[vi]])
+                B[vi] = reduce_temp0[0]
+
+
[email protected]_func
+def zero_rank_buffer(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128], dtype="float32")
+    B = T.match_buffer(b, [], dtype="float32")
+    for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+        with T.block("B"):
+            vk = T.axis.reduce(128, k)
+            T.reads([B[()], A[vk]])
+            T.writes([B[()]])
+            with T.init():
+                B[()] = T.float32(0)
+            B[()] = B[()] + A[vk]
+
+
[email protected]_func
+def lowered_zero_rank_buffer(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128], dtype="float32")
+    B = T.match_buffer(b, [], dtype="float32")
+    reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], 
scope="local")
+    for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+        with T.block("B_cross_thread_reduction"):
+            vk = T.axis.reduce(128, k)
+            T.reads([A[vk]])
+            T.writes([reduce_temp0[0]])
+            T.attr(
+                T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+                "reduce_scope",
+                T.reinterpret(T.uint64(0), dtype="handle"),
+            )
+            T.evaluate(
+                T.tvm_thread_allreduce(
+                    T.uint32(1), A[vk], True, reduce_temp0.data, k, 
dtype="handle"
+                )
+            )
+        with T.block("B_write_back"):
+            T.reads([reduce_temp0[0]])
+            T.writes([B[()]])
+            B[()] = reduce_temp0[0]
+
+
[email protected]_func
+def multiple_bufferstore(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128], dtype="float32")
+    B = T.match_buffer(b, [128], dtype="float32")
+    C = T.alloc_buffer([], dtype="float32")
+    for i in T.serial(0, 128):
+        for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+            with T.block("B"):
+                vi, vk = T.axis.remap("SR", [i, k])
+                T.reads([A[vi, vk], B[vi], C[()]])
+                T.writes([B[vi], C[()]])
+                with T.init():
+                    B[vi] = T.float32(0)
+                C[()] = A[vi, vk]
+                B[vi] = B[vi] + C[()]
+
+
[email protected]_func
+def reduction_loop_not_deepest(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128], dtype="float32")
+    B = T.match_buffer(b, [128], dtype="float32")
+    for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+        for i in T.serial(0, 128):
+            with T.block("B"):
+                vi, vk = T.axis.remap("SR", [i, k])
+                T.reads([B[vi], A[vi, vk]])
+                T.writes([B[vi]])
+                with T.init():
+                    B[vi] = T.float32(0)
+                B[vi] = B[vi] + A[vi, vk]
+
+
[email protected]_func
+def reduction_loop_bound_to_blockidx(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128], dtype="float32")
+    B = T.match_buffer(b, [128], dtype="float32")
+    for i in T.serial(0, 128):
+        for k in T.thread_binding(0, 128, thread="blockIdx.x"):
+            with T.block("B"):
+                vi, vk = T.axis.remap("SR", [i, k])
+                T.reads([B[vi], A[vi, vk]])
+                T.writes([B[vi]])
+                with T.init():
+                    B[vi] = T.float32(0)
+                B[vi] = B[vi] + A[vi, vk]
+
+
[email protected]_func
+def different_access_indices(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128, 128], dtype="float32")
+    B = T.match_buffer(b, [128, 128], dtype="float32")
+    for i, j in T.grid(128, 128):
+        for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+            with T.block("B"):
+                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                T.reads([B[vi, vj], A[vi, vj, vk]])
+                T.writes(
+                    [
+                        B[
+                            T.min(vj, vi) : T.min(vj, vi) + (T.max(vj, vi) + 1 
- T.min(vj, vi)),
+                            T.min(vi, vj) : T.min(vi, vj) + (T.max(vi, vj) + 1 
- T.min(vi, vj)),
+                        ]
+                    ]
+                )
+                with T.init():
+                    B[vj, vi] = T.float32(0)
+                B[vi, vj] = B[vi, vj] + A[vi, vj, vk]
+
+
[email protected]_func
+def invalid_reducer(a: T.handle, b: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128], dtype="float32")
+    B = T.match_buffer(b, [128], dtype="float32")
+    for i in T.serial(0, 128):
+        for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+            with T.block("B"):
+                vi, vk = T.axis.remap("SR", [i, k])
+                T.reads([B[vi], A[vi, vk]])
+                T.writes([B[vi]])
+                with T.init():
+                    B[vi] = T.float32(0)
+                B[vi] = B[vi] - A[vi, vk]
+
+
[email protected]_func
+def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None:
+    A = T.match_buffer(var_A, [256, 256], dtype="float32")
+    T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], 
dtype="float32")
+    T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", 
scope="shared")
+    T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", 
scope="shared")
+    for i0 in T.thread_binding(0, 256, thread="blockIdx.x"):
+        for ax0_0 in T.serial(0, 8):
+            for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
+                with T.block("T_softmax_maxelem"):
+                    i0_1 = T.axis.spatial(256, i0)
+                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
+                    T.reads([T_softmax_maxelem_shared[i0_1], A[i0_1, k]])
+                    T.writes([T_softmax_maxelem_shared[i0_1]])
+                    with T.init():
+                        T_softmax_maxelem_shared[i0_1] = T.min_value("float32")
+                    T_softmax_maxelem_shared[i0_1] = T.max(
+                        T_softmax_maxelem_shared[i0_1], A[i0_1, k]
+                    )
+        for ax0_0 in T.serial(0, 8):
+            for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
+                with T.block("T_softmax_expsum"):
+                    i0_2 = T.axis.spatial(256, i0)
+                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
+                    T.reads(
+                        [
+                            T_softmax_expsum_shared[i0_2],
+                            A[i0_2, k],
+                            T_softmax_maxelem_shared[i0_2],
+                        ]
+                    )
+                    T.writes([T_softmax_expsum_shared[i0_2]])
+                    with T.init():
+                        T_softmax_expsum_shared[i0_2] = T.float32(0)
+                    T_softmax_expsum_shared[i0_2] = 
T_softmax_expsum_shared[i0_2] + T.exp(
+                        A[i0_2, k] - T_softmax_maxelem_shared[i0_2], 
dtype="float32"
+                    )
+        for i1_0 in T.serial(0, 8):
+            for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
+                with T.block("T_softmax_norm"):
+                    i0_3 = T.axis.spatial(256, i0)
+                    i1 = T.axis.spatial(256, i1_0 * 32 + i1_1)
+                    T.reads(
+                        [
+                            A[i0_3, i1],
+                            T_softmax_maxelem_shared[i0_3],
+                            T_softmax_expsum_shared[i0_3],
+                        ]
+                    )
+                    T.writes([T_softmax_norm[i0_3, i1]])
+                    T.block_attr({"axis": 1})
+                    T_softmax_norm[i0_3, i1] = (
+                        T.exp(
+                            A[i0_3, i1] - T_softmax_maxelem_shared[i0_3],
+                            dtype="float32",
+                        )
+                        / T_softmax_expsum_shared[i0_3]
+                    )
+
+
[email protected]_func
+def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None:
+    A = T.match_buffer(var_A, [256, 256], dtype="float32")
+    T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], 
dtype="float32")
+    T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", 
scope="shared")
+    T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", 
scope="shared")
+    reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], 
scope="local")
+    normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], 
scope="local")
+    reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1], 
scope="local")
+    normal_reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1], 
scope="local")
+    for i0 in T.thread_binding(0, 256, thread="blockIdx.x"):
+        for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
+            with T.block("T_softmax_maxelem_normal_reduction_init"):
+                T.reads([])
+                T.writes([normal_reduce_temp0[0]])
+                normal_reduce_temp0[0] = T.min_value("float32")
+            for ax0_0 in T.serial(0, 8):
+                with T.block("T_softmax_maxelem_normal_reduction"):
+                    i0_1 = T.axis.spatial(256, i0)
+                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
+                    T.reads([A[i0_1, k], normal_reduce_temp0[0]])
+                    T.writes([normal_reduce_temp0[0]])
+                    normal_reduce_temp0[0] = T.max(normal_reduce_temp0[0], 
A[i0_1, k])
+            with T.block("T_softmax_maxelem_cross_thread_reduction"):
+                T.reads([normal_reduce_temp0[0]])
+                T.writes([reduce_temp0[0]])
+                T.attr(
+                    T.comm_reducer(lambda x, y: T.max(x, y), 
[T.min_value("float32")]),
+                    "reduce_scope",
+                    T.reinterpret(T.uint64(0), dtype="handle"),
+                )
+                T.evaluate(
+                    T.tvm_thread_allreduce(
+                        T.uint32(1),
+                        normal_reduce_temp0[0],
+                        True,
+                        reduce_temp0.data,
+                        ax0_1,
+                        dtype="handle",
+                    )
+                )
+            with T.block("T_softmax_maxelem_write_back"):
+                i0_2 = T.axis.spatial(256, i0)
+                T.reads([reduce_temp0[0]])
+                T.writes([T_softmax_maxelem_shared[i0_2]])
+                T_softmax_maxelem_shared[i0_2] = reduce_temp0[0]
+        for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
+            with T.block("T_softmax_expsum_normal_reduction_init"):
+                T.reads([])
+                T.writes([normal_reduce_temp1[0]])
+                normal_reduce_temp1[0] = T.float32(0)
+            for ax0_0 in T.serial(0, 8):
+                with T.block("T_softmax_expsum_normal_reduction"):
+                    i0_3 = T.axis.spatial(256, i0)
+                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
+                    T.reads(
+                        [
+                            A[i0_3, k],
+                            T_softmax_maxelem_shared[i0_3],
+                            normal_reduce_temp1[0],
+                        ]
+                    )
+                    T.writes([normal_reduce_temp1[0]])
+                    normal_reduce_temp1[0] = normal_reduce_temp1[0] + T.exp(
+                        A[i0_3, k] - T_softmax_maxelem_shared[i0_3], 
dtype="float32"
+                    )
+            with T.block("T_softmax_expsum_cross_thread_reduction"):
+                T.reads([normal_reduce_temp1[0]])
+                T.writes([reduce_temp1[0]])
+                T.attr(
+                    T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]),
+                    "reduce_scope",
+                    T.reinterpret(T.uint64(0), dtype="handle"),
+                )
+                T.evaluate(
+                    T.tvm_thread_allreduce(
+                        T.uint32(1),
+                        normal_reduce_temp1[0],
+                        True,
+                        reduce_temp1.data,
+                        ax0_1,
+                        dtype="handle",
+                    )
+                )
+            with T.block("T_softmax_expsum_write_back"):
+                i0_4 = T.axis.spatial(256, i0)
+                T.reads([reduce_temp1[0]])
+                T.writes([T_softmax_expsum_shared[i0_4]])
+                T_softmax_expsum_shared[i0_4] = reduce_temp1[0]
+        for i1_0 in T.serial(0, 8):
+            for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
+                with T.block("T_softmax_norm"):
+                    i0_5 = T.axis.spatial(256, i0)
+                    i1 = T.axis.spatial(256, i1_0 * 32 + i1_1)
+                    T.reads(
+                        [
+                            A[i0_5, i1],
+                            T_softmax_maxelem_shared[i0_5],
+                            T_softmax_expsum_shared[i0_5],
+                        ]
+                    )
+                    T.writes([T_softmax_norm[i0_5, i1]])
+                    T.block_attr({"axis": 1})
+                    T_softmax_norm[i0_5, i1] = (
+                        T.exp(
+                            A[i0_5, i1] - T_softmax_maxelem_shared[i0_5],
+                            dtype="float32",
+                        )
+                        / T_softmax_expsum_shared[i0_5]
+                    )
+
+
+def test_loop_split():
+    _check(loop_split, lowered_loop_split)
+
+
+def test_no_normal_reduction():
+    _check(no_normal_reduction, lowered_no_normal_reduction)
+
+
+def test_two_bound_loops():
+    _check(two_bound_loops, lowered_two_bound_loops)
+
+
+def test_multiple_blocks_under_reduction_loop():
+    _check(multiple_blocks_under_reduction_loop, 
lowered_multiple_blocks_under_reduction_loop)
+
+
+def test_with_block_predicate():
+    _check(with_block_predicate, lowered_with_block_predicate)
+
+
+def test_reducer_max():
+    _check(reducer_max, lowered_reducer_max)
+
+
+def test_zero_rank_buffer():
+    _check(zero_rank_buffer, lowered_zero_rank_buffer)
+
+
+def test_multiple_bufferstore():
+    _check_fail(multiple_bufferstore)
+
+
+def test_reduction_block_not_deepest():
+    _check_fail(reduction_loop_not_deepest)
+
+
+def test_reduction_loop_bound_to_blockidx():
+    _check_fail(reduction_loop_bound_to_blockidx)
+
+
+def test_different_access_indices():
+    _check_fail(different_access_indices)
+
+
+def test_invalid_reducer():
+    _check_fail(invalid_reducer)
+
+
+def test_softmax():
+    _check(softmax, lowered_softmax)
+
+
+def test_lower_te():
+    a = te.placeholder((32, 2, 2))
+    k1 = te.reduce_axis((0, 2), "k1")
+    k2 = te.reduce_axis((0, 2), "k2")
+    b = te.compute((32,), lambda i: te.sum(a[i, k1, k2], axis=[k1, k2]))
+    s = te.create_schedule(b.op)
+    s[b].bind(k1, te.thread_axis("threadIdx.x"))
+    s[b].bind(k2, te.thread_axis("threadIdx.y"))
+    orig_mod = tvm.driver.build_module.schedule_to_module(s, [a, b])
+    mod = tvm.tir.transform.LowerCrossThreadReduction()(orig_mod)
+    tvm.ir.assert_structural_equal(
+        mod, orig_mod
+    )  # LowerCrossThreadReduction should do nothing on TE
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main([__file__] + sys.argv[1:]))

Reply via email to