This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 08898e1 [TensorIR] Cross-Thread Reduction (#9360)
08898e1 is described below
commit 08898e18628752d02fdb9e10f8135e1e3b95fb34
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Nov 15 01:37:43 2021 +0800
[TensorIR] Cross-Thread Reduction (#9360)
* [TensorIR] Cross-Thread Reduction
* Code revision on analysis and misc
* Refactor TransformReductionBlock
* Refactor code organization
* Address comment
* Use `std::make_tuple`
Co-authored-by: Junru Shao <[email protected]>
---
include/tvm/tir/transform.h | 7 +
python/tvm/tir/transform/transform.py | 12 +
src/driver/driver_api.cc | 1 +
src/tir/schedule/analysis.h | 50 +-
src/tir/schedule/analysis/analysis.cc | 249 +++++--
src/tir/schedule/primitive/reduction.cc | 138 +---
src/tir/transforms/lower_cross_thread_reduction.cc | 645 ++++++++++++++++++
...t_tir_transform_lower_cross_thread_reduction.py | 737 +++++++++++++++++++++
8 files changed, 1662 insertions(+), 177 deletions(-)
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index e6b0af9..7922e97 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -358,6 +358,13 @@ TVM_DLL Pass PointerValueTypeRewrite();
TVM_DLL Pass HoistIfThenElse();
/*!
+ * \brief Lower cross-thread reduction from thread
+ * bindings to intrinsic function calls.
+ * \return The pass.
+ */
+TVM_DLL Pass LowerCrossThreadReduction();
+
+/*!
* \brief Lower block init stmt into IfThenElse stmts
* \return The pass.
*/
diff --git a/python/tvm/tir/transform/transform.py
b/python/tvm/tir/transform/transform.py
index 722810e..86f798c 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -577,6 +577,18 @@ def HoistIfThenElse(variant: Optional[str] = None):
return _ffi_api.HoistIfThenElse() # type: ignore
+def LowerCrossThreadReduction():
+ """Lower cross-thread reduction from thread bindings to
+ intrinsic function calls.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LowerCrossThreadReduction() # type: ignore
+
+
def LowerInitBlock():
"""Lower block init stmt into IfThenElse statements.
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index ad1f51b..f49409c 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -234,6 +234,7 @@ Array<tvm::transform::Pass> CreatePassList(bool
disable_loop_partition) {
pass_list.push_back(tir::transform::InjectPrefetch());
pass_list.push_back(tir::transform::TextureFlatten());
pass_list.push_back(tir::transform::StorageFlatten(64,
instrument_bound_checkers));
+ pass_list.push_back(tir::transform::LowerCrossThreadReduction());
pass_list.push_back(tir::transform::LowerInitBlock());
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 5a2f46c..42e0e00 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -19,12 +19,17 @@
#ifndef TVM_TIR_SCHEDULE_ANALYSIS_H_
#define TVM_TIR_SCHEDULE_ANALYSIS_H_
+#include <tvm/arith/analyzer.h>
#include <tvm/tir/schedule/state.h>
+#include <tuple>
#include <unordered_map>
#include <unordered_set>
+#include <utility>
#include <vector>
+#include "../../runtime/thread_storage_scope.h"
+
namespace tvm {
namespace tir {
@@ -323,6 +328,49 @@ struct ProducerConsumerSplit {
*/
Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int
n, bool is_write);
+/******** Reduction Block Related ********/
+
+/*!
+ * \brief Convert the `init` and `body` of the input block to BufferStores
+ * \param self The schedule state
+ * \param block The block to be analyzed
+ * \return The BufferStores of the `init` and `body` of the input block
+ * \throw ScheduleError If the `init` or `body` is not BufferStore, or they
don't write to the same
+ * buffer
+ */
+std::pair<BufferStore, BufferStore> GetBufferStoresFromReductionBlock(
+ const Optional<ScheduleState>& self, const Block& block);
+
+/*!
+ * \brief Check whether the input array of IterVars only contains
data-parallel and reduction block
+ * iters
+ * \param iters The input array of IterVars to be checked
+ * \return A boolean indicating whether the input array of IterVars only
contains data-parallel and
+ * reduction block iters
+ */
+bool ContainsOnlyDataParAndReductionBlockIter(const Array<IterVar>& iters);
+
+/*!
+ * \brief Check whether the block's reduction block iters are not used to
index the block's output
+ * buffers
+ * \param block The block to be checked
+ * \return A boolean indicating whether the block's reduction block iters are
not used to index the
+ * block's output buffer
+ */
+bool ReductionIterNotIndexOutputBuffer(const Block& block);
+
+/*!
+ * \brief Given a reduction identity and a reduction combiner, detect the
corresponding commutative
+ * reducer, and extract the combiner lhs and combiner rhs
+ * \param self The schedule state
+ * \param identity The reduction identity to be analyzed
+ * \param combiner The reduction combiner to be analyzed
+ * \return The corresponding CommReducer, the combiner lhs and the combiner rhs
+ * \throw ScheduleError If no corresponding commutative reducer can be matched
+ */
+std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs(
+ const Optional<ScheduleState>& self, const PrimExpr& identity, const
BufferStore& combiner);
+
/******** Commutative Reducer ********/
/*!
@@ -330,7 +378,7 @@ Buffer GetNthAccessBuffer(const ScheduleState& self, const
Block& block, int n,
* \return The list of the registered reducer-getter functions
* \sa ReducerRegistry
*/
-std::vector<TypedPackedFunc<CommReducer(DataType)>> GetReducerGetters();
+std::vector<runtime::TypedPackedFunc<CommReducer(DataType)>>
GetReducerGetters();
/*!
* \brief Given the input identity and the combiner BufferStore of a
reduction, extract the
diff --git a/src/tir/schedule/analysis/analysis.cc
b/src/tir/schedule/analysis/analysis.cc
index e3a535e..7e16bc9 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -153,15 +153,15 @@ Definition of a scope that is a stage pipeline:
/*!
* \brief Check the dominant property of a block:
* the block is the only writer of its output, dominating the reader of its
output buffers
- * \param self The schedule state
+ * \param scope The block-scope of the block to be checked
* \param block_sref The block whose dominant property is to be checked
* \return A boolean indicating if the block is a dominant block
*/
-bool IsDominantBlock(const BlockScope& self, const StmtSRef& block_sref) {
+bool IsDominantBlock(const BlockScope& scope, const StmtSRef& block_sref) {
// Check whether the input block is the only writer of its outputs
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash,
ObjectPtrEqual>& buffer_writers =
- self->buffer_writers;
+ scope->buffer_writers;
for (const BufferRegion& write_region : block->writes) {
ICHECK(buffer_writers.count(write_region->buffer))
<< "InternalError: buffer \"" << write_region->buffer->name
@@ -279,14 +279,8 @@ int CheckReductionBlockErrorCode(const ScheduleState&
self, const StmtSRef& bloc
}
// Cond 3. All block vars are either data parallel block vars or reduction
block vars. Meanwhile,
// we collect all the reduction block vars.
- std::unordered_set<const VarNode*> reduction_block_vars;
- reduction_block_vars.reserve(block->iter_vars.size());
- for (const IterVar& iter_var : block->iter_vars) {
- if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce)
{
- return 3;
- } else if (iter_var->iter_type == kCommReduce) {
- reduction_block_vars.insert(iter_var->var.get());
- }
+ if (!ContainsOnlyDataParAndReductionBlockIter(block->iter_vars)) {
+ return 3;
}
// Cond 4. Dominant: the block is the only writer of its output, dominating
the reader of its
// output buffers.
@@ -294,33 +288,7 @@ int CheckReductionBlockErrorCode(const ScheduleState&
self, const StmtSRef& bloc
return 4;
}
// Cond 5. The reduction block vars are not used to index the output buffers.
- std::unordered_set<const BufferNode*> buffer_written;
- buffer_written.reserve(block->writes.size());
- for (const BufferRegion& write_region : block->writes) {
- buffer_written.insert(write_region->buffer.get());
- }
- bool affected = false;
- PreOrderVisit(block->body, [&](const ObjectRef& obj) {
- if (affected) {
- return false;
- }
- if (const auto* store = obj.as<BufferStoreNode>()) {
- ICHECK(buffer_written.count(store->buffer.get()))
- << "ValueError: The buffer \"" << store->buffer
- << "\" is written in the block but is not in the block's signature";
- for (const PrimExpr& index : store->indices) {
- if (UsesVar(index, [&reduction_block_vars](const VarNode* var) {
- return reduction_block_vars.count(var);
- })) {
- affected = true;
- return false;
- }
- }
- return false;
- }
- return true;
- });
- return !affected ? 0 : 5;
+ return ReductionIterNotIndexOutputBuffer(GetRef<Block>(block)) ? 0 : 5;
}
bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
@@ -552,7 +520,9 @@ bool GetVarsTouchedByBlockIters(const BlockRealize&
block_realize,
} else {
has_block_vars_of_other_types = true;
}
-
+ if (set == nullptr) {
+ continue;
+ }
Array<Var> vars_in_binding = UndefinedVars(iter_value);
for (const Var& var : vars_in_binding) {
set->insert(var.get());
@@ -1128,6 +1098,207 @@ class PatternMatcher : public ExprVisitor {
std::unordered_map<const VarNode*, PrimExpr> filled_map_;
};
+/******** Reduction Block Related ********/
+
+class InitBodyNotBufferStoreError : public ScheduleError {
+ public:
+ explicit InitBodyNotBufferStoreError(IRModule mod, Block block, bool
init_is_bufferstore,
+ bool body_is_bufferstore)
+ : mod_(std::move(mod)),
+ block_(std::move(block)),
+ init_is_bufferstore_(init_is_bufferstore),
+ body_is_bufferstore_(body_is_bufferstore) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: The `init` and `body` of reduction block are
required to be both "
+ "BufferStore so that rfactor or cross-thread reduction can be
applied";
+ }
+
+ String DetailRenderTemplate() const final {
+ if (!init_is_bufferstore_ && !body_is_bufferstore_) {
+ return "The `init` and `body` of block {0} are required to be
BufferStore so that rfactor or "
+ "cross-thread reduction can be applied";
+ } else if (!init_is_bufferstore_) {
+ return "The `init` of block {0} is required to be BufferStore so that
rfactor or cross-thread"
+ " reduction can be applied";
+ } else {
+ ICHECK(!body_is_bufferstore_);
+ return "The `body` of block {0} is required to be BufferStore so that
rfactor or cross-thread"
+ " reduction can be applied";
+ }
+ }
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+
+ IRModule mod_;
+ Block block_;
+ bool init_is_bufferstore_;
+ bool body_is_bufferstore_;
+};
+
+class InitBodyNotSameBufferAccessError : public ScheduleError {
+ public:
+ explicit InitBodyNotSameBufferAccessError(IRModule mod, Block block)
+ : mod_(std::move(mod)), block_(std::move(block)) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: The `init` and `body` of the reduction block are
required to have the "
+ "same buffer access pattern";
+ }
+
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ const auto* init = block_->init.as<BufferStoreNode>();
+ const auto* update = block_->body.as<BufferStoreNode>();
+ os << "The `init` and `body` of the block {0} is required to have the same
buffer access "
+ "pattern. However, in block {0} the `init` writes to "
+ << init->buffer->name << init->indices << ", and the `body` writes to "
+ << update->buffer->name << update->indices;
+ return os.str();
+ }
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+
+ IRModule mod_;
+ Block block_;
+};
+
+std::pair<BufferStore, BufferStore> GetBufferStoresFromReductionBlock(
+ const Optional<ScheduleState>& self, const Block& block) {
+ static constexpr const char* error_str1 =
+ "ValueError: The `init` and `body` of the reduction block are required
to be both "
+ "BufferStore so that rfactor or cross-thread reduction can be applied.
However, a reduction "
+ "block that doesn't meet this requirement is ";
+ static constexpr const char* error_str2 =
+ "ValueError: The `init` and `body` of the reduction block are required
to have the same "
+ "buffer access pattern so that rfactor or cross-thread reduction can be
applied. However, a "
+ "reduction block that doesn't meet this requirement is ";
+
+ const auto* init = block->init.as<BufferStoreNode>();
+ const auto* body = block->body.as<BufferStoreNode>();
+ if (!(init && body)) {
+ if (self.defined()) {
+ throw InitBodyNotBufferStoreError(self.value()->mod, block, init !=
nullptr, body != nullptr);
+ } else {
+ LOG(FATAL) << error_str1 << block;
+ }
+ }
+ if (!init->buffer.same_as(body->buffer)) {
+ if (self.defined()) {
+ throw InitBodyNotSameBufferAccessError(self.value()->mod, block);
+ } else {
+ LOG(FATAL) << error_str2 << block;
+ }
+ }
+ int ndim = static_cast<int>(init->buffer->shape.size());
+ for (int i = 0; i < ndim; ++i) {
+ if (!ExprDeepEqual()(init->indices[i], body->indices[i])) {
+ if (self.defined()) {
+ throw InitBodyNotSameBufferAccessError(self.value()->mod, block);
+ } else {
+ LOG(FATAL) << error_str2 << block;
+ }
+ }
+ }
+ return std::make_pair(GetRef<BufferStore>(init), GetRef<BufferStore>(body));
+}
+
+bool ContainsOnlyDataParAndReductionBlockIter(const Array<IterVar>& iters) {
+ for (const IterVar& iter_var : iters) {
+ if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce)
{
+ return false;
+ }
+ }
+ return true;
+}
+
+bool ReductionIterNotIndexOutputBuffer(const Block& block) {
+ // Step 1. Collect the reduction block iters.
+ std::unordered_set<const VarNode*> reduction_block_iters;
+ reduction_block_iters.reserve(block->iter_vars.size());
+ for (const IterVar& iter_var : block->iter_vars) {
+ if (iter_var->iter_type == kCommReduce) {
+ reduction_block_iters.insert(iter_var->var.get());
+ }
+ }
+ // Step 2. Check if the reduction block iters are used to index the output
buffer.
+ std::unordered_set<const BufferNode*> buffer_written;
+ buffer_written.reserve(block->writes.size());
+ for (const BufferRegion& write_region : block->writes) {
+ buffer_written.insert(write_region->buffer.get());
+ }
+ auto f_uses_reduction_block_var = [&](const PrimExpr& expr) -> bool {
+ return UsesVar(expr, [&](const VarNode* var) { //
+ return reduction_block_iters.count(var);
+ });
+ };
+ bool affected = false;
+ PreOrderVisit(block->body, [&](const ObjectRef& obj) {
+ if (affected) {
+ return false;
+ }
+ const auto* store = obj.as<BufferStoreNode>();
+ if (!store) {
+ return true;
+ }
+ ICHECK(buffer_written.count(store->buffer.get()))
+ << "ValueError: The buffer \"" << store->buffer
+ << "\" is written in the block but is not in the block's signature";
+ for (const PrimExpr& index : store->indices) {
+ if (f_uses_reduction_block_var(index)) {
+ affected = true;
+ return false;
+ }
+ }
+ return false;
+ });
+ return !affected;
+}
+
+class NoMatchedReducerError : public ScheduleError {
+ public:
+ explicit NoMatchedReducerError(IRModule mod, PrimExpr identity, BufferStore
combiner)
+ : mod_(std::move(mod)), identity_(std::move(identity)),
combiner_(std::move(combiner)) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: No matched reducer for the identity and the
combiner of this reduction "
+ "block. So rfactor and cross-thread reduction cannot be applied.";
+ }
+
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ os << "No matched reducer for identity " << identity_ << " and combiner "
<< combiner_
+ << "In this case rfactor cannot be applied. You can check
tvm::tir::ReducerRegistry for "
+ "default reducers or registering new reducers.";
+ return os.str();
+ }
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+
+ IRModule mod_;
+ PrimExpr identity_;
+ BufferStore combiner_;
+};
+
+std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs(
+ const Optional<ScheduleState>& self, const PrimExpr& identity, const
BufferStore& combiner) {
+ CommReducer reducer{nullptr};
+ PrimExpr combiner_lhs{nullptr}, combiner_rhs{nullptr};
+ bool matched = FromIdentityCombiner(identity, combiner, &reducer,
&combiner_lhs, &combiner_rhs);
+ if (!matched) {
+ if (self.defined()) {
+ throw NoMatchedReducerError(self.value()->mod, identity, combiner);
+ } else {
+ LOG(FATAL) << "ValueError: No matched reducer for the identity and the
combiner of the "
+ "reduction block. So rfactor and cross-thread reduction
cannot be applied.";
+ }
+ }
+ return std::make_tuple(std::move(reducer), std::move(combiner_lhs),
std::move(combiner_rhs));
+}
+
/******** Commutative Reducer ********/
bool MatchReducer(const CommReducer& reducer, const PrimExpr& identity, const
PrimExpr& combiner,
diff --git a/src/tir/schedule/primitive/reduction.cc
b/src/tir/schedule/primitive/reduction.cc
index 0f85168..9c33076 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -370,69 +370,6 @@ class NotSerialLoopKindError : public ScheduleError {
For loop_;
};
-class InitBodyNotBufferStoreError : public ScheduleError {
- public:
- explicit InitBodyNotBufferStoreError(IRModule mod, Block block, bool
init_is_bufferstore,
- bool body_is_bufferstore)
- : mod_(std::move(mod)),
- block_(std::move(block)),
- init_is_bufferstore_(init_is_bufferstore),
- body_is_bufferstore_(body_is_bufferstore) {}
-
- String FastErrorString() const final {
- return "ScheduleError: The `init` and `body` of reduction block are
required to be both "
- "BufferStore";
- }
-
- String DetailRenderTemplate() const final {
- if (!init_is_bufferstore_ && !body_is_bufferstore_) {
- return "The `init` and `body` of block {0} are required to be
BufferStore so that rfactor "
- "can be applied";
- } else if (!init_is_bufferstore_) {
- return "The `init` of block {0} is required to be BufferStore so that
rfactor can be applied";
- } else {
- ICHECK(!body_is_bufferstore_);
- return "The `body` of block {0} is required to be BufferStore so that
rfactor can be applied";
- }
- }
-
- IRModule mod() const final { return mod_; }
- Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
-
- IRModule mod_;
- Block block_;
- bool init_is_bufferstore_;
- bool body_is_bufferstore_;
-};
-
-class InitBodyNotSameBufferAccessError : public ScheduleError {
- public:
- explicit InitBodyNotSameBufferAccessError(IRModule mod, Block block)
- : mod_(std::move(mod)), block_(std::move(block)) {}
-
- String FastErrorString() const final {
- return "ScheduleError: The `init` and `body` of the reduction block are
required to have the "
- "same buffer access pattern";
- }
-
- String DetailRenderTemplate() const final {
- std::ostringstream os;
- const auto* init = block_->init.as<BufferStoreNode>();
- const auto* update = block_->body.as<BufferStoreNode>();
- os << "The `init` and `body` of the block {0} is required to have the same
buffer access "
- "pattern. However, in block {0} the `init` writes to "
- << init->buffer->name << init->indices << ", and the `body` writes to "
- << update->buffer->name << update->indices;
- return os.str();
- }
-
- IRModule mod() const final { return mod_; }
- Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
-
- IRModule mod_;
- Block block_;
-};
-
class FactorAxisOutOfRangeError : public ScheduleError {
public:
explicit FactorAxisOutOfRangeError(IRModule mod, Buffer buffer, int
factor_axis)
@@ -473,32 +410,6 @@ class FactorAxisOutOfRangeError : public ScheduleError {
int factor_axis_;
};
-class NoMatchedReducerError : public ScheduleError {
- public:
- explicit NoMatchedReducerError(IRModule mod, PrimExpr identity, BufferStore
combiner)
- : mod_(std::move(mod)), identity_(std::move(identity)),
combiner_(std::move(combiner)) {}
-
- String FastErrorString() const final {
- return "ScheduleError: No matched reducer for the identity and the
combiner of this reduction "
- "block. So rfactor cannot be applied.";
- }
-
- String DetailRenderTemplate() const final {
- std::ostringstream os;
- os << "No matched reducer for identity " << identity_ << " and combiner "
<< combiner_
- << "In this case rfactor cannot be applied. You can check
tvm::tir::ReducerRegistry for "
- "default reducers or registering new reducers.";
- return os.str();
- }
-
- IRModule mod() const final { return mod_; }
- Array<ObjectRef> LocationsOfInterest() const final { return {}; }
-
- IRModule mod_;
- PrimExpr identity_;
- BufferStore combiner_;
-};
-
class LoopPropertyError : public ScheduleError {
public:
enum ErrorType {
@@ -592,53 +503,6 @@ class LoopPropertyError : public ScheduleError {
};
/*!
- * \brief Convert the `init` and `body` of the input block to BufferStores
- * \param self The schedule state
- * \param block The block to be analyzed
- * \return The BufferStores of the `init` and `body` of the input block
- * \throw ScheduleError If the `init` or `body` is not BufferStore, or they
don't write to the same
- * buffer
- */
-std::pair<BufferStore, BufferStore> GetBufferStoreNodes(const ScheduleState&
self,
- const Block& block) {
- const auto* init = block->init.as<BufferStoreNode>();
- const auto* body = block->body.as<BufferStoreNode>();
- if (!(init && body)) {
- throw InitBodyNotBufferStoreError(self->mod, block, init != nullptr, body
!= nullptr);
- }
- if (!init->buffer.same_as(body->buffer)) {
- throw InitBodyNotSameBufferAccessError(self->mod, block);
- }
- int ndim = static_cast<int>(init->buffer->shape.size());
- for (int i = 0; i < ndim; ++i) {
- if (!ExprDeepEqual()(init->indices[i], body->indices[i])) {
- throw InitBodyNotSameBufferAccessError(self->mod, block);
- }
- }
- return std::make_pair(GetRef<BufferStore>(init), GetRef<BufferStore>(body));
-}
-
-/*!
- * \brief Given a reduction identity and a reduction combiner, detect the
corresponding commutative
- * reducer, and extract the combiner lhs and combiner rhs
- * \param self The schedule state
- * \param identity The reduction identity to be analyzed
- * \param combiner The reduction combiner to be analyzed
- * \return The corresponding CommReducer, the combiner lhs and the combiner rhs
- * \throw ScheduleError If no corresponding commutative reducer can be matched
- */
-std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs(
- const ScheduleState& self, const PrimExpr& identity, const BufferStore&
combiner) {
- CommReducer reducer{nullptr};
- PrimExpr combiner_lhs{nullptr}, combiner_rhs{nullptr};
- bool matched = FromIdentityCombiner(identity, combiner, &reducer,
&combiner_lhs, &combiner_rhs);
- if (!matched) {
- throw NoMatchedReducerError(self->mod, identity, combiner);
- }
- return std::make_tuple(std::move(reducer), std::move(combiner_lhs),
std::move(combiner_rhs));
-}
-
-/*!
* \brief For each loop in the given array of loop, associate its loop var
with the loop itself
* using a mapping
* \param loops The loops to be analyzed
@@ -1177,7 +1041,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef&
rf_loop_sref, int factor_ax
BufferStore update;
CommReducer reducer;
PrimExpr combiner_lhs, combiner_rhs;
- std::tie(init, update) = GetBufferStoreNodes(self, block);
+ std::tie(init, update) = GetBufferStoresFromReductionBlock(self, block);
std::tie(reducer, combiner_lhs, combiner_rhs) =
GetReducerAndCombinerLhsRhs(self, init->value, update);
diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc
b/src/tir/transforms/lower_cross_thread_reduction.cc
new file mode 100644
index 0000000..630c00f
--- /dev/null
+++ b/src/tir/transforms/lower_cross_thread_reduction.cc
@@ -0,0 +1,645 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file lower_cross_thread_reduction.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include "../schedule/analysis.h"
+#include "./ir_utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Checks if a loop is bound to threadIdx.x/y/z
+ * \brief loop The loop to be checked
+ * \return True if the loop is bound to threadIdx.x/y/z
+ */
+bool IsBoundToThreadIdx(const ForNode* loop) {
+ if (!loop->thread_binding.defined()) {
+ return false;
+ }
+ runtime::ThreadScope scope =
+ runtime::ThreadScope::Create(loop->thread_binding.value()->thread_tag);
+ return scope.rank == 1 && scope.dim_index >= 0;
+}
+
+/*!
+ * \brief Check the dominant property of a block:
+ * the block is the only writer of its output, dominating the reader of its
output buffers
+ * \param scope_block The scope block of the block to be checked
+ * \param block The block whose dominant property is to be checked
+ * \return A boolean indicating if the block is a dominant block
+ */
+bool IsDominantBlock(const Block& scope_block, const Block& block) {
+ // Step 1. Count the number of writers for each buffer written by the scope
block.
+ std::unordered_map<const BufferNode*, int> buffer_writer_cnt;
+ PreOrderVisit(scope_block->body, [&buffer_writer_cnt](const ObjectRef& obj) {
+ if (const auto* block = obj.as<BlockNode>()) {
+ for (const BufferRegion& buffer_region : block->writes) {
+ ++buffer_writer_cnt[buffer_region->buffer.get()];
+ }
+ return false;
+ }
+ return true;
+ });
+ // Step 2. Check whether `block` is the only writer of its outputs.
+ for (const BufferRegion& buffer_region : block->writes) {
+ ICHECK(buffer_writer_cnt.count(buffer_region->buffer.get()));
+ if (buffer_writer_cnt[buffer_region->buffer.get()] != 1) {
+ return false;
+ }
+ }
+ return true;
+}
+
+/*!
+ * \brief Check whether the input block is a reduction block.
+ * \param realize The block to be checked
+ * \param loop_range_map The mapping from the loop variables outside the input
block to their ranges
+ * \param scope_block The scope block of the input block
+ * \param analyzer The analyzer
+ * \return A boolean indicating whether the input block is a reduction block.
+ * \note A similar check has been implemented in
"src/tir/schedule/analysis.h", but that check is
+ * based on `tir.Schedule`. Here we have no schedule information, and thus we
must implement the
+ * check again.
+ */
+bool IsReductionBlock(const BlockRealize& realize, const Map<Var, Range>&
loop_range_map,
+ const Block& scope_block, arith::Analyzer* analyzer) {
+ const auto* block = realize->block.as<BlockNode>();
+ // Cond 1. The block has the `init` statement.
+ if (!block->init.defined()) {
+ return false;
+ }
+ // Cond 2. All the block bindings are quasi-affine expressions.
+ if (!IsAffineBinding(realize, loop_range_map, analyzer)) {
+ return false;
+ }
+ // Cond 3. All block vars are either data parallel block vars or reduction
block vars. Meanwhile,
+ // we collect all the reduction block vars.
+ if (!ContainsOnlyDataParAndReductionBlockIter(block->iter_vars)) {
+ return false;
+ }
+ // Cond 4. Dominant: the block is the only writer of its output, dominating
the reader of its
+ // output buffers.
+ if (!IsDominantBlock(scope_block, GetRef<Block>(block))) {
+ return false;
+ }
+ // Cond 5. The reduction block vars are not used to index the output buffers.
+ return ReductionIterNotIndexOutputBuffer(GetRef<Block>(block));
+}
+
+/*!
+ * \brief Create an intermediate buffer with specified name and data type
+ * \param name The specified name
+ * \param dtype The specified data type
+ * \return The created buffer
+ */
+Buffer MakeScratchpad(String name, const DataType& dtype) {
+ return Buffer(/*ptr=*/Var(name, PointerType(PrimType(dtype), "local")),
+ /*dtype=*/dtype,
+ /*shape=*/{Integer(1)},
+ /*strides=*/{Integer(1)},
+ /*elem_offset=*/PrimExpr{nullptr},
+ /*name=*/std::move(name),
+ /*data_alignment=*/0,
+ /*offset_factor=*/0,
+ /*buffer_type=*/kDefault);
+}
+
+/*!
+ * \brief Remove the BufferRegions whose buffer is the input buffer
+ * \param buffer_regions The array of BufferRegions to be
+ * \param buffer_to_remove The specified buffer
+ * \return The mutated array of BufferRegions, no longer containing
BufferRegion of the input buffer
+ */
+Array<BufferRegion> RemoveBufferFromBufferRegions(const Array<BufferRegion>&
buffer_regions,
+ const Buffer&
buffer_to_remove) {
+ Array<BufferRegion> res;
+ res.reserve(buffer_regions.size());
+ for (const BufferRegion& buffer_region : buffer_regions) {
+ if (!buffer_region->buffer.same_as(buffer_to_remove)) {
+ res.push_back(buffer_region);
+ }
+ }
+ return res;
+}
+
+/*!
+ * \brief Substitute a given source buffer with a given target buffer in
statements or expressions
+ */
+class BufferReplacer : private StmtExprMutator {
+ public:
+ static Stmt Run(Buffer src_buffer, Buffer tgt_buffer, Stmt stmt) {
+ return BufferReplacer(src_buffer, tgt_buffer)(std::move(stmt));
+ }
+
+ private:
+ explicit BufferReplacer(Buffer src_buffer, Buffer tgt_buffer)
+ : src_buffer_(std::move(src_buffer)), tgt_buffer_(std::move(tgt_buffer))
{}
+
+ PrimExpr VisitExpr_(const BufferLoadNode* load) final {
+ return load->buffer.same_as(src_buffer_) ? BufferLoad(tgt_buffer_, {0})
+ : GetRef<BufferLoad>(load);
+ }
+
+ Stmt VisitStmt_(const BufferStoreNode* store) final {
+ if (store->buffer.same_as(src_buffer_)) {
+ PrimExpr value = StmtExprMutator::VisitExpr(store->value);
+ return BufferStore(tgt_buffer_, value, {0});
+ } else {
+ return StmtMutator::VisitStmt_(store);
+ }
+ }
+
+ Buffer src_buffer_;
+ Buffer tgt_buffer_;
+};
+
+/*!
+ * \brief Substitute a given source block with a given target block, or remove
the source block
+ * branch from the AST if the target block is undefined
+ */
+class InThreadReducerMaker : private StmtMutator {
+ public:
+ static Optional<Stmt> Make(const BlockRealizeNode* src_realize,
+ Optional<BlockRealize> tgt_realize, Stmt stmt) {
+ return InThreadReducerMaker(src_realize,
std::move(tgt_realize))(std::move(stmt));
+ }
+
+ private:
+ explicit InThreadReducerMaker(const BlockRealizeNode* src_realize,
+ Optional<BlockRealize> tgt_realize)
+ : src_realize_(src_realize), tgt_realize_(tgt_realize) {}
+ Stmt VisitStmt_(const BlockRealizeNode* realize) final {
+ if (realize == src_realize_) {
+ return tgt_realize_.defined() //
+ ? tgt_realize_.value()
+ : Stmt{nullptr};
+ }
+ return GetRef<BlockRealize>(realize);
+ }
+
+ Stmt VisitStmt_(const ForNode* loop) final {
+ if (Optional<For> opt_res =
Downcast<Optional<For>>(StmtMutator::VisitStmt_(loop))) {
+ For res = opt_res.value();
+ if (res->thread_binding.defined()) {
+ return res->body;
+ } else {
+ return res;
+ }
+ } else {
+ return Stmt{nullptr};
+ }
+ }
+
+ Stmt VisitStmt_(const SeqStmtNode* seq) final {
+ Array<Stmt> stmts;
+ stmts.reserve(seq->size());
+ for (const Stmt& stmt : seq->seq) {
+ if (Optional<Stmt> opt_res = VisitStmt(stmt)) {
+ stmts.push_back(opt_res.value());
+ }
+ }
+ return stmts.empty() ? Stmt{nullptr} : SeqStmt::Flatten(stmts);
+ }
+
+ const BlockRealizeNode* src_realize_;
+ Optional<BlockRealize> tgt_realize_;
+};
+
+/*!
+ * \brief Create the lowered allreduce block transformed from the input
reduction block
+ * \param reduction_block The input reduction block
+ * \param it_buffer The buffer to store in-thread reduction results
+ * \param ct_buffer The buffer to store cross-thread reduction results
+ * \param reducer The reduction function
+ * \param combiner_rhs The RHS of the combiner
+ * \param reduction_loops The reduction loops
+ */
+Stmt TransformReductionBlock(const BlockRealizeNode* realize, const
Optional<Buffer>& it_buffer,
+ const Buffer& ct_buffer, const CommReducer&
reducer,
+ const PrimExpr& combiner_rhs,
+ const std::vector<const ForNode*>&
reduction_loops) {
+ const BlockNode* block = realize->block.get();
+ Buffer wb_buffer = block->writes[0]->buffer;
+ Array<Range> wb_region = block->writes[0]->region;
+
+ BufferRegion ct_buffer_region(ct_buffer, {Range::FromMinExtent(0, 1)});
+ Optional<BufferRegion> it_buffer_region = NullOpt;
+ if (it_buffer.defined()) {
+ it_buffer_region = BufferRegion(it_buffer.value(),
{Range::FromMinExtent(0, 1)});
+ }
+ // In total, the block is transformed into at most 4 statements
+ // - Stmt 1: initialize the buffer for in-thread reduction
+ // - Stmt 2: do in-thread reduction
+ // - Stmt 3: do cross-thread reduction
+ // - Stmt 4: write cross-thread reduction result to the original buffer
+ Array<Stmt> stmts;
+ stmts.reserve(4);
+ // Stmt 1: initialize the buffer for in-thread reduction
+ if (it_buffer.defined()) {
+ BufferStore init = Downcast<BufferStore>(block->init);
+ stmts.push_back(BlockRealize(
+ /*iter_values=*/{},
+ /*predicate=*/const_true(),
+ /*block=*/
+ Block(/*iter_vars=*/{},
+ /*reads=*/{},
+ /*writes=*/{it_buffer_region.value()},
+ /*name_hint=*/block->name_hint + "_in_thread_init",
+ /*body=*/
+ BufferStore(/*buffer=*/it_buffer.value(),
+ /*value=*/init->value,
+ /*indices=*/{Integer(0)}))));
+ }
+ // Stmt 2: do in-thread reduction
+ {
+ Optional<BlockRealize> new_realize = NullOpt;
+ // If need to generate in-thread reduction,
+ // then replace `wb_buffer` with `it_buffer` accordingly in given
BlockRealize
+ // otherwise, directly remove given BlockRealize
+ if (it_buffer.defined()) {
+ ObjectPtr<BlockNode> new_block = make_object<BlockNode>(*block);
+ new_block->reads =
RemoveBufferFromBufferRegions(std::move(new_block->reads), wb_buffer);
+ new_block->reads.push_back(it_buffer_region.value());
+ new_block->writes = {it_buffer_region.value()};
+ new_block->name_hint = new_block->name_hint + "_in_thread";
+ new_block->body =
+ BufferReplacer::Run(wb_buffer, it_buffer.value(),
std::move(new_block->body));
+ new_block->init = NullOpt;
+ ObjectPtr<BlockRealizeNode> n = make_object<BlockRealizeNode>(*realize);
+ n->block = Block(new_block);
+ new_realize = BlockRealize(n);
+ }
+ For loop = GetRef<For>(reduction_loops[0]);
+ if (Optional<Stmt> stmt = InThreadReducerMaker::Make(realize, new_realize,
std::move(loop))) {
+ stmts.push_back(stmt.value());
+ }
+ }
+ // Stmt 3: do cross-thread reduction
+ {
+ // Step 3.1. Create the parameters to the intrinsic
+ Array<PrimExpr> parameters;
+ parameters.reserve(reduction_loops.size() + 4);
+ // 1-st argument: size
+ parameters.push_back(make_const(DataType::UInt(32), 1));
+ // 2-nd argument: source
+ if (it_buffer.defined()) {
+ parameters.push_back(BufferLoad(it_buffer.value(), {Integer(0)}));
+ } else {
+ parameters.push_back(combiner_rhs);
+ }
+ // 3-rd argument: predicate
+ parameters.push_back(const_true());
+ // 4-th argument: destination
+ parameters.push_back(ct_buffer->data);
+ // next arguments: all the reduction threads
+ for (const ForNode* reduction_loop : reduction_loops) {
+ if (reduction_loop->thread_binding.defined()) {
+ parameters.push_back(reduction_loop->loop_var);
+ }
+ }
+ // Step 3.2. Create the block and the block-realize.
+ Array<IterVar> iter_vars{nullptr};
+ Array<PrimExpr> bindings{nullptr};
+ Array<BufferRegion> reads{nullptr};
+ if (it_buffer.defined()) {
+ iter_vars = Array<IterVar>{};
+ bindings = Array<PrimExpr>{};
+ reads = {it_buffer_region.value()};
+ } else {
+ iter_vars = block->iter_vars;
+ bindings = realize->iter_values;
+ reads = {RemoveBufferFromBufferRegions(block->reads, wb_buffer)};
+ }
+ stmts.push_back(BlockRealize(
+ /*iter_values=*/std::move(bindings),
+ /*predicate=*/const_true(),
+ /*block=*/
+ Block(/*iter_vars=*/std::move(iter_vars),
+ /*reads=*/std::move(reads),
+ /*writes=*/{ct_buffer_region},
+ /*name_hint=*/block->name_hint + "_cross_thread",
+ /*body=*/
+ AttrStmt(/*node=*/reducer,
+ /*attr_key=*/tir::attr::reduce_scope,
+ /*value=*/make_zero(DataType::Handle()),
+ /*body=*/
+ Evaluate(Call(/*dtype=*/DataType::Handle(),
+
/*op=*/tir::builtin::tvm_thread_allreduce(),
+ /*args=*/std::move(parameters)))))));
+ }
+ // Stmt 4: write cross-thread reduction result to the original buffer
+ {
+ ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size());
+ int n_iter = static_cast<int>(block->iter_vars.size());
+ Array<IterVar> iter_vars;
+ Array<PrimExpr> bindings;
+ Map<Var, PrimExpr> var_map;
+ iter_vars.reserve(n_iter);
+ bindings.reserve(n_iter);
+ for (int i = 0; i < n_iter; ++i) {
+ const IterVar& iter_var = block->iter_vars[i];
+ const PrimExpr& binding = realize->iter_values[i];
+ if (iter_var->iter_type != kCommReduce) {
+ IterVar new_iter_var{nullptr};
+ {
+ ObjectPtr<IterVarNode> n = make_object<IterVarNode>(*iter_var.get());
+ ObjectPtr<VarNode> v = make_object<VarNode>(*iter_var->var.get());
+ n->var = Var(v);
+ new_iter_var = IterVar(n);
+ }
+ iter_vars.push_back(new_iter_var);
+ bindings.push_back(binding);
+ var_map.Set(iter_var->var, new_iter_var->var);
+ }
+ }
+ BufferStore update = Downcast<BufferStore>(block->body);
+ update = Downcast<BufferStore>(Substitute(std::move(update), var_map));
+ stmts.push_back(BlockRealize(
+ /*iter_values=*/std::move(bindings),
+ /*predicate=*/const_true(),
+ /*block=*/
+ Block(
+ /*iter_vars=*/std::move(iter_vars),
+ /*reads=*/{std::move(ct_buffer_region)},
+ /*writes=*/{BufferRegion(wb_buffer, Substitute(wb_region,
var_map))},
+ /*name_hint=*/block->name_hint + "_write_back",
+ /*body=*/
+ BufferStore(/*buffer=*/wb_buffer,
+ /*value=*/BufferLoad(ct_buffer, {Integer(0)}),
+ /*indices=*/update->indices))));
+ }
+ // Final step: Wrap all the above four statements with the reduction loops
bound to threadIdx
+ Stmt new_stmt = SeqStmt::Flatten(std::move(stmts));
+ for (auto rit = reduction_loops.rbegin(); rit != reduction_loops.rend();
++rit) {
+ const ForNode* loop = *rit;
+ if (loop->thread_binding.defined()) {
+ ObjectPtr<ForNode> n = make_object<ForNode>(*loop);
+ n->body = std::move(new_stmt);
+ new_stmt = For(n);
+ }
+ }
+ return new_stmt;
+}
+
+/*!
+ * \brief Detect cross-thread reduction pattern and then transform
+ */
+class CrossThreadReductionTransformer : public StmtMutator {
+ private:
+ // Check if the input block needs cross-thread reduction.
+ std::vector<const ForNode*> NeedCrossThreadReduction(const BlockRealizeNode*
realize) {
+ // Step 0. If the block is the root block, just return.
+ if (block_stack_.empty()) {
+ return {};
+ }
+
+ // Step 1. If the block is not a reduction block, cross-thread reduction
is not needed.
+ if (!IsReductionBlock(GetRef<BlockRealize>(realize), loop_range_map_,
+ GetRef<Block>(block_stack_.back()), &analyzer_)) {
+ return {};
+ }
+
+ // Step 2. Collect all the vars that appear in the bindings of reduction
block iters.
+ std::unordered_set<const VarNode*> reduction_vars;
+ GetVarsTouchedByBlockIters(GetRef<BlockRealize>(realize), nullptr,
&reduction_vars);
+
+ // Step 3. Collect the loops whose loop vars appear in the bindings of
reduction block iters.
+ // We call these loops "reduction-related".
+ // Step 4. See whether at least one reduction-related loop is bound to
thread axis in GPU - if
+ // so, cross-thread reduction is needed. If none of the reduction-related
loops is bound to
+ // thread axis, cross-thread reduction is not needed for the input block.
+ bool need = false;
+ std::vector<const ForNode*> reduction_loops;
+ for (const ForNode* loop : loop_stack_) {
+ if (reduction_vars.count(loop->loop_var.get())) {
+ // Step 3. Collect the loop.
+ reduction_loops.push_back(loop);
+ // Step 4. See whether the loop is bound to some thread axis.
+ if (loop->thread_binding.defined()) {
+ need = true;
+ }
+ }
+ }
+ return need ? reduction_loops : std::vector<const ForNode*>{};
+ }
+
+ // Given that the input block needs cross-thread reduction, check if
cross-thread reduction can
+ // be applied to the block (i.e., the block satisfies all necessary
conditions of cross-thread
+ // reduction).
+ std::tuple<int, CommReducer, PrimExpr> CheckCanApplyCrossThreadReduction(
+ const BlockNode* block, const std::vector<const ForNode*>&
reduction_loops) const {
+ // Condition 1. The block being applied cross-thread reduction should
write to single buffer.
+ CHECK_EQ(block->writes.size(), 1)
+ << "ValueError: Cross-thread reduction requires the block to only "
+ "write to single buffer. However, the block "
+ << block->name_hint << " writes to " << block->writes.size() << "
buffer(s).";
+
+ // Condition 2. All the reduction-related loops should be the deepest
among all statements
+ // outside the block (ignoring SeqStmt here).
+ int n_deepest_reduction_loops = 0;
+ for (auto rit = statement_stack_.rbegin() + 1; rit !=
statement_stack_.rend(); ++rit) {
+ const StmtNode* stmt = *rit;
+ if ((*rit)->IsInstance<SeqStmtNode>()) {
+ // Skip SeqStmt.
+ continue;
+ }
+ if (std::find(reduction_loops.begin(), reduction_loops.end(),
+ reinterpret_cast<const ForNode*>(stmt)) ==
reduction_loops.end()) {
+ break;
+ }
+ ++n_deepest_reduction_loops;
+ }
+ CHECK_EQ(n_deepest_reduction_loops, reduction_loops.size())
+ << "ValueError: Cross-thread reduction requires all the
reduction-related loops to be the "
+ "deepest among all statements outside the desired block. However,
block "
+ << block->name_hint
+ << " needs cross-thread reduction, while the reduction-related loops
outside of it are not "
+ "the deepest statements, which violates the condition.";
+
+ // Condition 3. All the reduction-related loops that are bound to thread
axes should only be
+ // bound to `threadIdx.x/y/z`.
+ int n_bound_reduction_loops = 0;
+ for (const ForNode* reduction_loop : reduction_loops) {
+ if (reduction_loop->thread_binding.defined()) {
+ ++n_bound_reduction_loops;
+ CHECK(IsBoundToThreadIdx(reduction_loop))
+ << "ValueError: Cross-thread reduction requires all the
reduction-related loops that "
+ "are bound to GPU thread axes to only be bound
`threadIdx.x/y/z`. However, loop "
+ << reduction_loop->loop_var->name_hint << " violates the
condition.";
+ }
+ }
+
+ // Condition 4. Get the `init` identity and the `update` combiner of the
reduction. They should
+ // both be BufferStores with the same buffer and indices;
+ // Extract the commutative reducer, combiner lhs and combiner rhs from the
reduction identity
+ // and the reduction combiner.
+ BufferStore init{nullptr};
+ BufferStore update{nullptr};
+ CommReducer reducer{nullptr};
+ PrimExpr combiner_lhs{nullptr};
+ PrimExpr combiner_rhs{nullptr};
+ std::tie(init, update) = GetBufferStoresFromReductionBlock(NullOpt,
GetRef<Block>(block));
+ std::tie(reducer, combiner_lhs, combiner_rhs) =
+ GetReducerAndCombinerLhsRhs(NullOpt, init->value, update);
+
+ // Condition 5. The block should be the last block under the first
reduction-related loop.
+ bool visit = false;
+ PreOrderVisit(GetRef<For>(reduction_loops[0]), [block, &visit](const
ObjectRef& obj) {
+ if (const auto* realize = obj.as<BlockRealizeNode>()) {
+ CHECK(!visit) << "ValueError: Cross-thread reduction cannot be applied
when the reduction "
+ "block isn't the last block under its first
reduction-related loop";
+ if (realize->block.get() == block) {
+ visit = true;
+ }
+ return false;
+ }
+ return true;
+ });
+ return std::make_tuple(n_bound_reduction_loops, reducer, combiner_rhs);
+ }
+
+ Stmt VisitStmt(const Stmt& stmt) final {
+ statement_stack_.push_back(stmt.get());
+ Stmt result = StmtMutator::VisitStmt(stmt);
+ statement_stack_.pop_back();
+ return result;
+ }
+
+ Stmt VisitStmt_(const ForNode* loop) final {
+ loop_stack_.push_back(loop);
+ loop_range_map_.Set(loop->loop_var, Range::FromMinExtent(loop->min,
loop->extent));
+ Stmt result = StmtMutator::VisitStmt_(loop);
+ loop_stack_.pop_back();
+ loop_range_map_.erase(loop->loop_var);
+
+ // Replace `result` with the pre-stored result if `loop` appears as a key
in `loop2new_stmt_`.
+ auto it = loop2new_stmt_.find(loop);
+ if (it != loop2new_stmt_.end()) {
+ return it->second;
+ } else {
+ return result;
+ }
+ }
+
+ Stmt VisitStmt_(const BlockNode* block) final {
+ Map<Var, Range> old_loop_range_map;
+
+ block_stack_.push_back(block);
+ std::swap(old_loop_range_map, loop_range_map_);
+ Block new_block = Downcast<Block>(StmtMutator::VisitStmt_(block));
+ block_stack_.pop_back();
+ std::swap(old_loop_range_map, loop_range_map_);
+
+ // Insert the new allocated buffers into the block's `alloc_buffers` field.
+ auto it = block2new_buffers_.find(block);
+ if (it != block2new_buffers_.end()) {
+ BlockNode* p_new_block = new_block.CopyOnWrite();
+ for (const Buffer& new_buffer : it->second) {
+ if (new_buffer.defined()) {
+ p_new_block->alloc_buffers.push_back(new_buffer);
+ }
+ }
+ }
+ return new_block;
+ }
+
+ Stmt VisitStmt_(const BlockRealizeNode* realize) final {
+ const BlockNode* block = realize->block.get();
+ // Step 1. Check whether cross-thread reduction is needed. If no, skip
this block.
+ std::vector<const ForNode*> reduction_loops =
NeedCrossThreadReduction(realize);
+ if (reduction_loops.empty()) {
+ return StmtMutator::VisitStmt_(realize);
+ }
+ ++reduction_id_;
+ // Step 2. Check whether cross-thread reduction can be applied. If no,
throw an exception on
+ // which condition the block violates.
+ int n_bound_reduction_loops = 0;
+ CommReducer reducer{nullptr};
+ PrimExpr combiner_rhs{nullptr};
+ std::tie(n_bound_reduction_loops, reducer, combiner_rhs) =
+ CheckCanApplyCrossThreadReduction(block, reduction_loops);
+ // Step 3. When not all the reduction-related loops are bound to thread
axes, in-thread
+ // reduction is needed in this cross-thread reduction.
+ bool need_in_thread_reduction =
+ n_bound_reduction_loops < static_cast<int>(reduction_loops.size());
+ // Step 4. Create intermediate buffers, storing them in `ct_buffer` and
+ // `it_buffer`. Let the scope block allocate these new buffers.
+ std::vector<Buffer>& new_buffers = block2new_buffers_[block_stack_.back()];
+ DataType dtype = block->writes[0]->buffer->dtype;
+ Buffer ct_buffer = MakeScratchpad("cross_thread_" +
std::to_string(reduction_id_), dtype);
+ new_buffers.push_back(ct_buffer);
+ Optional<Buffer> it_buffer = NullOpt;
+ if (need_in_thread_reduction) {
+ it_buffer = MakeScratchpad("in_thread_" + std::to_string(reduction_id_),
dtype);
+ new_buffers.push_back(it_buffer.value());
+ }
+ // Step 5. Transform.
+ loop2new_stmt_[reduction_loops[0]] = TransformReductionBlock(
+ realize, it_buffer, ct_buffer, reducer, combiner_rhs, reduction_loops);
+ // Step 6. Return an empty statement, because the transformation result
will be inserted when
+ // returning to the first reduction-related loop.
+ return Stmt{nullptr};
+ }
+
+ private:
+ int reduction_id_ = -1;
+ std::vector<const StmtNode*> statement_stack_;
+ std::vector<const ForNode*> loop_stack_;
+ std::vector<const BlockNode*> block_stack_;
+ std::unordered_map<const BlockNode*, std::vector<Buffer>> block2new_buffers_;
+ std::unordered_map<const ForNode*, Stmt> loop2new_stmt_;
+ Map<Var, Range> loop_range_map_;
+ arith::Analyzer analyzer_;
+};
+
+PrimFunc LowerCrossThreadReduction(PrimFunc f) {
+ // Only apply this pass to TIR that is not from TE schedules
+ if (!IsFromLegacyTESchedule(f)) {
+ PrimFuncNode* fptr = f.CopyOnWrite();
+ fptr->body = CrossThreadReductionTransformer()(f->body);
+ return f;
+ } else {
+ return f;
+ }
+}
+
+namespace transform {
+
+Pass LowerCrossThreadReduction() {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ return LowerCrossThreadReduction(std::move(f));
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.LowerCrossThreadReduction", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LowerCrossThreadReduction")
+ .set_body_typed(LowerCrossThreadReduction);
+
+} // namespace transform
+
+} // namespace tir
+} // namespace tvm
diff --git
a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
new file mode 100644
index 0000000..4fa3ab0
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
@@ -0,0 +1,737 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import sys
+
+import pytest
+import tvm
+from tvm import te
+from tvm.script import tir as T
+
+
+def _check(original, transformed):
+ mod = tvm.IRModule.from_expr(original)
+ mod = tvm.tir.transform.LowerCrossThreadReduction()(mod)
+ tvm.ir.assert_structural_equal(mod["main"], transformed, True)
+
+
+def _check_fail(original):
+ mod = tvm.IRModule.from_expr(original)
+ with pytest.raises(ValueError):
+ tvm.tir.transform.LowerCrossThreadReduction()(mod)
+
+
[email protected]_func
+def loop_split(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128], dtype="float32")
+ B = T.match_buffer(b, [128], dtype="float32")
+ for i, ko in T.grid(128, 4):
+ for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
+ with T.block("B"):
+ vi = T.axis.S(128, i)
+ vk = T.axis.R(128, ko * 32 + ki)
+ T.reads([B[vi], A[vi, vk]])
+ T.writes([B[vi]])
+ with T.init():
+ B[vi] = T.float32(0)
+ B[vi] = B[vi] + A[vi, vk]
+
+
[email protected]_func
+def lowered_loop_split(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128], dtype="float32")
+ B = T.match_buffer(b, [128], dtype="float32")
+ reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1],
scope="local")
+ normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1],
scope="local")
+ for i in T.serial(0, 128):
+ for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
+ with T.block("B_in_thread_init"):
+ T.reads([])
+ T.writes([normal_reduce_temp0[0]])
+ normal_reduce_temp0[0] = T.float32(0)
+ for ko in T.serial(0, 4):
+ with T.block("B_normal_reduction"):
+ vi = T.axis.S(128, i)
+ vk = T.axis.R(128, ko * 32 + ki)
+ T.reads([A[vi, vk], normal_reduce_temp0[0]])
+ T.writes([normal_reduce_temp0[0]])
+ normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk]
+ with T.block("B_cross_thread_reduction"):
+ T.reads([normal_reduce_temp0[0]])
+ T.writes([reduce_temp0[0]])
+ T.attr(
+ T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret(T.uint64(0), dtype="handle"),
+ )
+ T.evaluate(
+ T.tvm_thread_allreduce(
+ T.uint32(1),
+ normal_reduce_temp0[0],
+ True,
+ reduce_temp0.data,
+ ki,
+ dtype="handle",
+ )
+ )
+ with T.block("B_write_back"):
+ vi = T.axis.S(128, i)
+ T.reads([reduce_temp0[0]])
+ T.writes([B[vi]])
+ B[vi] = reduce_temp0[0]
+
+
[email protected]_func
+def no_normal_reduction(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128], dtype="float32")
+ B = T.match_buffer(b, [128], dtype="float32")
+ for i in T.serial(0, 128):
+ for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+ with T.block("B"):
+ vi, vk = T.axis.remap("SR", [i, k])
+ T.reads([B[vi], A[vi, vk]])
+ T.writes([B[vi]])
+ with T.init():
+ B[vi] = T.float32(0)
+ B[vi] = B[vi] + A[vi, vk]
+
+
[email protected]_func
+def lowered_no_normal_reduction(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128], dtype="float32")
+ B = T.match_buffer(b, [128], dtype="float32")
+ reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1],
scope="local")
+ for i in T.serial(0, 128):
+ for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+ with T.block("B_cross_thread_reduction"):
+ vi, vk = T.axis.remap("SR", [i, k])
+ T.reads([A[vi, vk]])
+ T.writes([reduce_temp0[0]])
+ T.attr(
+ T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret(T.uint64(0), dtype="handle"),
+ )
+ T.evaluate(
+ T.tvm_thread_allreduce(
+ T.uint32(1), A[vi, vk], True, reduce_temp0.data, k,
dtype="handle"
+ )
+ )
+ with T.block("B_write_back"):
+ vi = T.axis.spatial(128, i)
+ T.reads([reduce_temp0[0]])
+ T.writes([B[vi]])
+ B[vi] = reduce_temp0[0]
+
+
[email protected]_func
+def two_bound_loops(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128], dtype="float32")
+ B = T.match_buffer(b, [128], dtype="float32")
+ for i in T.serial(0, 128):
+ for ko in T.thread_binding(0, 4, thread="threadIdx.x"):
+ for ki in T.thread_binding(0, 32, thread="threadIdx.y"):
+ with T.block("B"):
+ vi = T.axis.spatial(128, i)
+ vk = T.axis.reduce(128, ko * 32 + ki)
+ T.reads([B[vi], A[vi, vk]])
+ T.writes([B[vi]])
+ with T.init():
+ B[vi] = T.float32(0)
+ B[vi] = B[vi] + A[vi, vk]
+
+
[email protected]_func
+def lowered_two_bound_loops(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128], dtype="float32")
+ B = T.match_buffer(b, [128], dtype="float32")
+ reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1],
scope="local")
+ for i in T.serial(0, 128):
+ for ko in T.thread_binding(0, 4, thread="threadIdx.x"):
+ for ki in T.thread_binding(0, 32, thread="threadIdx.y"):
+ with T.block("B_cross_thread_reduction"):
+ vi = T.axis.spatial(128, i)
+ vk = T.axis.reduce(128, ko * 32 + ki)
+ T.reads([A[vi, vk]])
+ T.writes([reduce_temp0[0]])
+ T.attr(
+ T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret(T.uint64(0), dtype="handle"),
+ )
+ T.evaluate(
+ T.tvm_thread_allreduce(
+ T.uint32(1), A[vi, vk], True, reduce_temp0.data,
ko, ki, dtype="handle"
+ )
+ )
+ with T.block("B_write_back"):
+ vi = T.axis.spatial(128, i)
+ T.reads([reduce_temp0[0]])
+ T.writes([B[vi]])
+ B[vi] = reduce_temp0[0]
+
+
[email protected]_func
+def multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [16, 16, 16], dtype="float32")
+ B = T.match_buffer(b, [16], dtype="float32")
+ B_rf_local = T.alloc_buffer([16, 16], dtype="float32", scope="local")
+ for i in T.thread_binding(0, 16, thread="blockIdx.x"):
+ for k0o in T.thread_binding(0, 4, thread="threadIdx.x"):
+ for k0i0, k1 in T.grid(4, 16):
+ with T.block("B_rf"):
+ vk0 = T.axis.spatial(16, k0o * 4 + k0i0)
+ vi, vk1 = T.axis.remap("SR", [i, k1])
+ T.reads([B_rf_local[vk0, vi], A[vi, vk0, vk1]])
+ T.writes([B_rf_local[vk0, vi]])
+ with T.init():
+ B_rf_local[vk0, vi] = T.float32(0)
+ B_rf_local[vk0, vi] = B_rf_local[vk0, vi] + A[vi, vk0, vk1]
+ for k0i1 in T.serial(0, 4):
+ with T.block("B"):
+ vk0 = T.axis.reduce(16, k0o * 4 + k0i1)
+ vi = T.axis.spatial(16, i)
+ T.reads([B[vi], B_rf_local[vk0, vi]])
+ T.writes([B[vi]])
+ with T.init():
+ B[vi] = T.float32(0)
+ B[vi] = B[vi] + B_rf_local[vk0, vi]
+
+
[email protected]_func
+def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) ->
None:
+ A = T.match_buffer(a, [16, 16, 16], dtype="float32")
+ B = T.match_buffer(b, [16], dtype="float32")
+ B_rf_local = T.alloc_buffer([16, 16], dtype="float32", scope="local")
+ reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1],
scope="local")
+ normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1],
scope="local")
+ for i in T.thread_binding(0, 16, thread="blockIdx.x"):
+ for k0o in T.thread_binding(0, 4, thread="threadIdx.x"):
+ with T.block("B_in_thread_init"):
+ T.reads([])
+ T.writes([normal_reduce_temp0[0]])
+ normal_reduce_temp0[0] = T.float32(0)
+ for k0i0, k1 in T.grid(4, 16):
+ with T.block("B_rf"):
+ vk0 = T.axis.spatial(16, k0o * 4 + k0i0)
+ vi, vk1 = T.axis.remap("SR", [i, k1])
+ T.reads([B_rf_local[vk0, vi], A[vi, vk0, vk1]])
+ T.writes([B_rf_local[vk0, vi]])
+ with T.init():
+ B_rf_local[vk0, vi] = T.float32(0)
+ B_rf_local[vk0, vi] = B_rf_local[vk0, vi] + A[vi, vk0, vk1]
+ for k0i1 in T.serial(0, 4):
+ with T.block("B_normal_reduction"):
+ vk0 = T.axis.reduce(16, k0o * 4 + k0i1)
+ vi = T.axis.spatial(16, i)
+ T.reads([B_rf_local[vk0, vi], normal_reduce_temp0[0]])
+ T.writes([normal_reduce_temp0[0]])
+ normal_reduce_temp0[0] = normal_reduce_temp0[0] +
B_rf_local[vk0, vi]
+ with T.block("B_cross_thread_reduction"):
+ T.reads([normal_reduce_temp0[0]])
+ T.writes([reduce_temp0[0]])
+ T.attr(
+ T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret(T.uint64(0), dtype="handle"),
+ )
+ T.evaluate(
+ T.tvm_thread_allreduce(
+ T.uint32(1),
+ normal_reduce_temp0[0],
+ True,
+ reduce_temp0.data,
+ k0o,
+ dtype="handle",
+ )
+ )
+ with T.block("B_write_back"):
+ vi = T.axis.spatial(16, i)
+ T.reads([reduce_temp0[0]])
+ T.writes([B[vi]])
+ B[vi] = reduce_temp0[0]
+
+
[email protected]_func
+def with_block_predicate(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128, 120], dtype="float32")
+ B = T.match_buffer(b, [128], dtype="float32")
+ for i, ko in T.grid(128, 4):
+ for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
+ with T.block("B"):
+ vi = T.axis.spatial(128, i)
+ vk = T.axis.reduce(120, ko * 32 + ki)
+ T.where(ko * 32 + ki < 120)
+ T.reads([B[vi], A[vi, vk]])
+ T.writes([B[vi]])
+ with T.init():
+ B[vi] = T.float32(0)
+ B[vi] = B[vi] + A[vi, vk]
+
+
[email protected]_func
+def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128, 120], dtype="float32")
+ B = T.match_buffer(b, [128], dtype="float32")
+ reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1],
scope="local")
+ normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1],
scope="local")
+ for i in T.serial(0, 128):
+ for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
+ with T.block("B_in_thread_init"):
+ T.reads([])
+ T.writes([normal_reduce_temp0[0]])
+ normal_reduce_temp0[0] = T.float32(0)
+ for ko in T.serial(0, 4):
+ with T.block("B_normal_reduction"):
+ vi = T.axis.spatial(128, i)
+ vk = T.axis.reduce(120, ko * 32 + ki)
+ T.where(ko * 32 + ki < 120)
+ T.reads([A[vi, vk], normal_reduce_temp0[0]])
+ T.writes([normal_reduce_temp0[0]])
+ normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk]
+ with T.block("B_cross_thread_reduction"):
+ T.reads([normal_reduce_temp0[0]])
+ T.writes([reduce_temp0[0]])
+ T.attr(
+ T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret(T.uint64(0), dtype="handle"),
+ )
+ T.evaluate(
+ T.tvm_thread_allreduce(
+ T.uint32(1),
+ normal_reduce_temp0[0],
+ True,
+ reduce_temp0.data,
+ ki,
+ dtype="handle",
+ )
+ )
+ with T.block("B_write_back"):
+ vi = T.axis.spatial(128, i)
+ T.reads([reduce_temp0[0]])
+ T.writes([B[vi]])
+ B[vi] = reduce_temp0[0]
+
+
[email protected]_func
+def reducer_max(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128], dtype="float32")
+ B = T.match_buffer(b, [128], dtype="float32")
+ for i in T.serial(0, 128):
+ for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+ with T.block("B"):
+ vi, vk = T.axis.remap("SR", [i, k])
+ T.reads([B[vi], A[vi, vk]])
+ T.writes([B[vi]])
+ with T.init():
+ B[vi] = T.min_value("float32")
+ B[vi] = T.max(B[vi], A[vi, vk])
+
+
[email protected]_func
+def lowered_reducer_max(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128], dtype="float32")
+ B = T.match_buffer(b, [128], dtype="float32")
+ reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1],
scope="local")
+ for i in T.serial(0, 128):
+ for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+ with T.block("B_cross_thread_reduction"):
+ vi, vk = T.axis.remap("SR", [i, k])
+ T.reads([A[vi, vk]])
+ T.writes([reduce_temp0[0]])
+ T.attr(
+ T.comm_reducer(lambda x, y: T.max(x, y),
[T.min_value("float32")]),
+ "reduce_scope",
+ T.reinterpret(T.uint64(0), dtype="handle"),
+ )
+ T.evaluate(
+ T.tvm_thread_allreduce(
+ T.uint32(1), A[vi, vk], True, reduce_temp0.data, k,
dtype="handle"
+ )
+ )
+ with T.block("B_write_back"):
+ vi = T.axis.spatial(128, i)
+ T.reads([reduce_temp0[0]])
+ T.writes([B[vi]])
+ B[vi] = reduce_temp0[0]
+
+
[email protected]_func
+def zero_rank_buffer(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128], dtype="float32")
+ B = T.match_buffer(b, [], dtype="float32")
+ for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+ with T.block("B"):
+ vk = T.axis.reduce(128, k)
+ T.reads([B[()], A[vk]])
+ T.writes([B[()]])
+ with T.init():
+ B[()] = T.float32(0)
+ B[()] = B[()] + A[vk]
+
+
[email protected]_func
+def lowered_zero_rank_buffer(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128], dtype="float32")
+ B = T.match_buffer(b, [], dtype="float32")
+ reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1],
scope="local")
+ for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+ with T.block("B_cross_thread_reduction"):
+ vk = T.axis.reduce(128, k)
+ T.reads([A[vk]])
+ T.writes([reduce_temp0[0]])
+ T.attr(
+ T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret(T.uint64(0), dtype="handle"),
+ )
+ T.evaluate(
+ T.tvm_thread_allreduce(
+ T.uint32(1), A[vk], True, reduce_temp0.data, k,
dtype="handle"
+ )
+ )
+ with T.block("B_write_back"):
+ T.reads([reduce_temp0[0]])
+ T.writes([B[()]])
+ B[()] = reduce_temp0[0]
+
+
[email protected]_func
+def multiple_bufferstore(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128], dtype="float32")
+ B = T.match_buffer(b, [128], dtype="float32")
+ C = T.alloc_buffer([], dtype="float32")
+ for i in T.serial(0, 128):
+ for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+ with T.block("B"):
+ vi, vk = T.axis.remap("SR", [i, k])
+ T.reads([A[vi, vk], B[vi], C[()]])
+ T.writes([B[vi], C[()]])
+ with T.init():
+ B[vi] = T.float32(0)
+ C[()] = A[vi, vk]
+ B[vi] = B[vi] + C[()]
+
+
[email protected]_func
+def reduction_loop_not_deepest(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128], dtype="float32")
+ B = T.match_buffer(b, [128], dtype="float32")
+ for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+ for i in T.serial(0, 128):
+ with T.block("B"):
+ vi, vk = T.axis.remap("SR", [i, k])
+ T.reads([B[vi], A[vi, vk]])
+ T.writes([B[vi]])
+ with T.init():
+ B[vi] = T.float32(0)
+ B[vi] = B[vi] + A[vi, vk]
+
+
[email protected]_func
+def reduction_loop_bound_to_blockidx(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128], dtype="float32")
+ B = T.match_buffer(b, [128], dtype="float32")
+ for i in T.serial(0, 128):
+ for k in T.thread_binding(0, 128, thread="blockIdx.x"):
+ with T.block("B"):
+ vi, vk = T.axis.remap("SR", [i, k])
+ T.reads([B[vi], A[vi, vk]])
+ T.writes([B[vi]])
+ with T.init():
+ B[vi] = T.float32(0)
+ B[vi] = B[vi] + A[vi, vk]
+
+
[email protected]_func
+def different_access_indices(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128, 128], dtype="float32")
+ B = T.match_buffer(b, [128, 128], dtype="float32")
+ for i, j in T.grid(128, 128):
+ for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+ with T.block("B"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ T.reads([B[vi, vj], A[vi, vj, vk]])
+ T.writes(
+ [
+ B[
+ T.min(vj, vi) : T.min(vj, vi) + (T.max(vj, vi) + 1
- T.min(vj, vi)),
+ T.min(vi, vj) : T.min(vi, vj) + (T.max(vi, vj) + 1
- T.min(vi, vj)),
+ ]
+ ]
+ )
+ with T.init():
+ B[vj, vi] = T.float32(0)
+ B[vi, vj] = B[vi, vj] + A[vi, vj, vk]
+
+
[email protected]_func
+def invalid_reducer(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128], dtype="float32")
+ B = T.match_buffer(b, [128], dtype="float32")
+ for i in T.serial(0, 128):
+ for k in T.thread_binding(0, 128, thread="threadIdx.x"):
+ with T.block("B"):
+ vi, vk = T.axis.remap("SR", [i, k])
+ T.reads([B[vi], A[vi, vk]])
+ T.writes([B[vi]])
+ with T.init():
+ B[vi] = T.float32(0)
+ B[vi] = B[vi] - A[vi, vk]
+
+
[email protected]_func
+def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None:
+ A = T.match_buffer(var_A, [256, 256], dtype="float32")
+ T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256],
dtype="float32")
+ T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32",
scope="shared")
+ T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32",
scope="shared")
+ for i0 in T.thread_binding(0, 256, thread="blockIdx.x"):
+ for ax0_0 in T.serial(0, 8):
+ for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
+ with T.block("T_softmax_maxelem"):
+ i0_1 = T.axis.spatial(256, i0)
+ k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
+ T.reads([T_softmax_maxelem_shared[i0_1], A[i0_1, k]])
+ T.writes([T_softmax_maxelem_shared[i0_1]])
+ with T.init():
+ T_softmax_maxelem_shared[i0_1] = T.min_value("float32")
+ T_softmax_maxelem_shared[i0_1] = T.max(
+ T_softmax_maxelem_shared[i0_1], A[i0_1, k]
+ )
+ for ax0_0 in T.serial(0, 8):
+ for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
+ with T.block("T_softmax_expsum"):
+ i0_2 = T.axis.spatial(256, i0)
+ k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
+ T.reads(
+ [
+ T_softmax_expsum_shared[i0_2],
+ A[i0_2, k],
+ T_softmax_maxelem_shared[i0_2],
+ ]
+ )
+ T.writes([T_softmax_expsum_shared[i0_2]])
+ with T.init():
+ T_softmax_expsum_shared[i0_2] = T.float32(0)
+ T_softmax_expsum_shared[i0_2] =
T_softmax_expsum_shared[i0_2] + T.exp(
+ A[i0_2, k] - T_softmax_maxelem_shared[i0_2],
dtype="float32"
+ )
+ for i1_0 in T.serial(0, 8):
+ for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
+ with T.block("T_softmax_norm"):
+ i0_3 = T.axis.spatial(256, i0)
+ i1 = T.axis.spatial(256, i1_0 * 32 + i1_1)
+ T.reads(
+ [
+ A[i0_3, i1],
+ T_softmax_maxelem_shared[i0_3],
+ T_softmax_expsum_shared[i0_3],
+ ]
+ )
+ T.writes([T_softmax_norm[i0_3, i1]])
+ T.block_attr({"axis": 1})
+ T_softmax_norm[i0_3, i1] = (
+ T.exp(
+ A[i0_3, i1] - T_softmax_maxelem_shared[i0_3],
+ dtype="float32",
+ )
+ / T_softmax_expsum_shared[i0_3]
+ )
+
+
[email protected]_func
+def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None:
+ A = T.match_buffer(var_A, [256, 256], dtype="float32")
+ T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256],
dtype="float32")
+ T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32",
scope="shared")
+ T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32",
scope="shared")
+ reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1],
scope="local")
+ normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1],
scope="local")
+ reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1],
scope="local")
+ normal_reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1],
scope="local")
+ for i0 in T.thread_binding(0, 256, thread="blockIdx.x"):
+ for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
+ with T.block("T_softmax_maxelem_normal_reduction_init"):
+ T.reads([])
+ T.writes([normal_reduce_temp0[0]])
+ normal_reduce_temp0[0] = T.min_value("float32")
+ for ax0_0 in T.serial(0, 8):
+ with T.block("T_softmax_maxelem_normal_reduction"):
+ i0_1 = T.axis.spatial(256, i0)
+ k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
+ T.reads([A[i0_1, k], normal_reduce_temp0[0]])
+ T.writes([normal_reduce_temp0[0]])
+ normal_reduce_temp0[0] = T.max(normal_reduce_temp0[0],
A[i0_1, k])
+ with T.block("T_softmax_maxelem_cross_thread_reduction"):
+ T.reads([normal_reduce_temp0[0]])
+ T.writes([reduce_temp0[0]])
+ T.attr(
+ T.comm_reducer(lambda x, y: T.max(x, y),
[T.min_value("float32")]),
+ "reduce_scope",
+ T.reinterpret(T.uint64(0), dtype="handle"),
+ )
+ T.evaluate(
+ T.tvm_thread_allreduce(
+ T.uint32(1),
+ normal_reduce_temp0[0],
+ True,
+ reduce_temp0.data,
+ ax0_1,
+ dtype="handle",
+ )
+ )
+ with T.block("T_softmax_maxelem_write_back"):
+ i0_2 = T.axis.spatial(256, i0)
+ T.reads([reduce_temp0[0]])
+ T.writes([T_softmax_maxelem_shared[i0_2]])
+ T_softmax_maxelem_shared[i0_2] = reduce_temp0[0]
+ for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
+ with T.block("T_softmax_expsum_normal_reduction_init"):
+ T.reads([])
+ T.writes([normal_reduce_temp1[0]])
+ normal_reduce_temp1[0] = T.float32(0)
+ for ax0_0 in T.serial(0, 8):
+ with T.block("T_softmax_expsum_normal_reduction"):
+ i0_3 = T.axis.spatial(256, i0)
+ k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
+ T.reads(
+ [
+ A[i0_3, k],
+ T_softmax_maxelem_shared[i0_3],
+ normal_reduce_temp1[0],
+ ]
+ )
+ T.writes([normal_reduce_temp1[0]])
+ normal_reduce_temp1[0] = normal_reduce_temp1[0] + T.exp(
+ A[i0_3, k] - T_softmax_maxelem_shared[i0_3],
dtype="float32"
+ )
+ with T.block("T_softmax_expsum_cross_thread_reduction"):
+ T.reads([normal_reduce_temp1[0]])
+ T.writes([reduce_temp1[0]])
+ T.attr(
+ T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret(T.uint64(0), dtype="handle"),
+ )
+ T.evaluate(
+ T.tvm_thread_allreduce(
+ T.uint32(1),
+ normal_reduce_temp1[0],
+ True,
+ reduce_temp1.data,
+ ax0_1,
+ dtype="handle",
+ )
+ )
+ with T.block("T_softmax_expsum_write_back"):
+ i0_4 = T.axis.spatial(256, i0)
+ T.reads([reduce_temp1[0]])
+ T.writes([T_softmax_expsum_shared[i0_4]])
+ T_softmax_expsum_shared[i0_4] = reduce_temp1[0]
+ for i1_0 in T.serial(0, 8):
+ for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
+ with T.block("T_softmax_norm"):
+ i0_5 = T.axis.spatial(256, i0)
+ i1 = T.axis.spatial(256, i1_0 * 32 + i1_1)
+ T.reads(
+ [
+ A[i0_5, i1],
+ T_softmax_maxelem_shared[i0_5],
+ T_softmax_expsum_shared[i0_5],
+ ]
+ )
+ T.writes([T_softmax_norm[i0_5, i1]])
+ T.block_attr({"axis": 1})
+ T_softmax_norm[i0_5, i1] = (
+ T.exp(
+ A[i0_5, i1] - T_softmax_maxelem_shared[i0_5],
+ dtype="float32",
+ )
+ / T_softmax_expsum_shared[i0_5]
+ )
+
+
+def test_loop_split():
+ _check(loop_split, lowered_loop_split)
+
+
+def test_no_normal_reduction():
+ _check(no_normal_reduction, lowered_no_normal_reduction)
+
+
+def test_two_bound_loops():
+ _check(two_bound_loops, lowered_two_bound_loops)
+
+
+def test_multiple_blocks_under_reduction_loop():
+ _check(multiple_blocks_under_reduction_loop,
lowered_multiple_blocks_under_reduction_loop)
+
+
+def test_with_block_predicate():
+ _check(with_block_predicate, lowered_with_block_predicate)
+
+
+def test_reducer_max():
+ _check(reducer_max, lowered_reducer_max)
+
+
+def test_zero_rank_buffer():
+ _check(zero_rank_buffer, lowered_zero_rank_buffer)
+
+
+def test_multiple_bufferstore():
+ _check_fail(multiple_bufferstore)
+
+
+def test_reduction_block_not_deepest():
+ _check_fail(reduction_loop_not_deepest)
+
+
+def test_reduction_loop_bound_to_blockidx():
+ _check_fail(reduction_loop_bound_to_blockidx)
+
+
+def test_different_access_indices():
+ _check_fail(different_access_indices)
+
+
+def test_invalid_reducer():
+ _check_fail(invalid_reducer)
+
+
+def test_softmax():
+ _check(softmax, lowered_softmax)
+
+
+def test_lower_te():
+ a = te.placeholder((32, 2, 2))
+ k1 = te.reduce_axis((0, 2), "k1")
+ k2 = te.reduce_axis((0, 2), "k2")
+ b = te.compute((32,), lambda i: te.sum(a[i, k1, k2], axis=[k1, k2]))
+ s = te.create_schedule(b.op)
+ s[b].bind(k1, te.thread_axis("threadIdx.x"))
+ s[b].bind(k2, te.thread_axis("threadIdx.y"))
+ orig_mod = tvm.driver.build_module.schedule_to_module(s, [a, b])
+ mod = tvm.tir.transform.LowerCrossThreadReduction()(orig_mod)
+ tvm.ir.assert_structural_equal(
+ mod, orig_mod
+ ) # LowerCrossThreadReduction should do nothing on TE
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main([__file__] + sys.argv[1:]))