This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 370ec6a2c9 [TE][CreatePrimFunc] Fix loop carried dependency case with
nested block levels (#17474)
370ec6a2c9 is described below
commit 370ec6a2c93a444e5b8eb09877567995dfa866fd
Author: wrongtest <[email protected]>
AuthorDate: Thu Nov 14 12:51:27 2024 +0800
[TE][CreatePrimFunc] Fix loop carried dependency case with nested block
levels (#17474)
Nested level primfunc block generation for axis dependencies
Co-authored-by: wrongtest <[email protected]>
---
python/tvm/script/ir_builder/tir/ir.py | 5 +
src/te/operation/create_primfunc.cc | 453 +++++++++++++++++++++--------
tests/python/te/test_te_create_primfunc.py | 103 ++++++-
3 files changed, 437 insertions(+), 124 deletions(-)
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index f7face272d..59548634fc 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -522,6 +522,11 @@ def _as_range(dom: Union[ir.Range, List[PrimExpr]]) ->
ir.Range:
if isinstance(dom, ir.Range):
return dom
if isinstance(dom, (list, tuple)):
+ from tvm.arith import Analyzer # pylint:
disable=import-outside-toplevel
+
+ extent = Analyzer().simplify(dom[1] - dom[0])
+ if isinstance(extent, tir.IntImm):
+ return ir.Range.from_min_extent(dom[0], extent)
return ir.Range(dom[0], dom[1])
if hasattr(dom, "dtype"):
return ir.Range(IntImm(dom.dtype, 0), dom)
diff --git a/src/te/operation/create_primfunc.cc
b/src/te/operation/create_primfunc.cc
index 31815fc710..2709bd2f94 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -22,6 +22,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/ir/name_supply.h>
#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
#include <tvm/tir/data_type_rewriter.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
@@ -33,6 +34,7 @@
#include <utility>
#include <vector>
+#include "../../support/array.h"
#include "../../tir/ir/functor_common.h"
#include "../../tir/transforms/ir_utils.h"
#include "../schedule/graph.h"
@@ -180,30 +182,97 @@ class LayoutFreePlaceholdersNormalizer : public
StmtMutator {
"workload"};
};
-BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
- const Array<te::Tensor>& tensors,
Array<PrimExpr> bindings,
- PrimExpr expr_body, CreateFuncInfo* info,
- arith::Analyzer* analyzer) {
- // Step 1. Push_back data_par axis and reduce_axis into block_vars.
- Array<IterVar> iter_vars;
- std::unordered_map<const VarNode*, Var> var_map;
- iter_vars.reserve(compute_op->axis.size() + compute_op->reduce_axis.size());
- auto f_push_block_vars = [&iter_vars, &var_map, &analyzer](const
Array<IterVar>& iters) {
- for (IterVar iter_var : iters) {
- // Create new var
- Var new_var("v_" + iter_var->var->name_hint, iter_var->var->dtype);
- var_map[iter_var->var.get()] = new_var;
-
- PrimExpr dom_min = analyzer->Simplify(iter_var->dom->min);
- PrimExpr dom_extent = analyzer->Simplify(iter_var->dom->extent);
- iter_vars.push_back(IterVar(Range::FromMinExtent(dom_min, dom_extent),
new_var,
- iter_var->iter_type, iter_var->thread_tag,
iter_var->span));
+/**!
+ * \brief The iter levels specify nested structure wrt iteration domain
dependencies.
+ * (1) Each iter should reside in exactly one level.
+ * (2) The domain of low level iter should be either free or ony depend on
iters in high level.
+ **/
+using NestedIterLevels = std::vector<std::vector<IterVar>>;
+
+NestedIterLevels GenerateNestedIterLevels(const Array<IterVar>& axes,
arith::Analyzer* analyzer) {
+ int global_max_depth = 0;
+ std::unordered_map<Var, int> depth;
+ std::unordered_map<Var, IterVar> var2iter;
+ for (const auto& axis : axes) {
+ var2iter[axis->var] = axis;
+ }
+
+ std::function<int(const IterVar&)> traverse = [&](const IterVar& axis) ->
int {
+ auto depth_it = depth.find(axis->var);
+ if (depth_it != depth.end()) { // cache
+ return depth_it->second;
+ }
+ std::vector<Var> dep_vars;
+ for (const Var& v : UndefinedVars(analyzer->Simplify(axis->dom->min))) {
+ dep_vars.push_back(v);
}
+ for (const Var& v : UndefinedVars(analyzer->Simplify(axis->dom->extent))) {
+ dep_vars.push_back(v);
+ }
+ int cur_depth = 0;
+ for (const Var& v : dep_vars) {
+ auto it = var2iter.find(v);
+ if (it == var2iter.end()) {
+ // not axis var dependency, maybe a symbolic shape var or others.
+ continue;
+ }
+ int depth = traverse(it->second);
+ cur_depth = std::max(cur_depth, depth + 1);
+ }
+ depth.emplace_hint(depth_it, axis->var, cur_depth);
+ global_max_depth = std::max(global_max_depth, cur_depth);
+ return cur_depth;
};
- f_push_block_vars(compute_op->axis);
- f_push_block_vars(compute_op->reduce_axis);
- // Step 2.
+ for (const auto& axis : axes) {
+ traverse(axis);
+ }
+ NestedIterLevels levels;
+ levels.resize(global_max_depth + 1);
+ for (const auto& axis : axes) {
+ const Var& var = axis->var;
+ levels[depth[var]].push_back(axis);
+ }
+ return levels;
+}
+
+/*!
+ * \brief Generate output buffers from compute op's output tensors, and bind
to context func info.
+ * \param compute_op The target compute op.
+ * \param info Generation context info.
+ * \returns The output buffer objects, ordered by compute op's outputs.
+ **/
+Array<Buffer> GenerateOutputBuffers(const te::ComputeOp& compute_op,
CreateFuncInfo* info) {
+ // Step 1. Collect output tensors in TE operation.
+ Array<te::Tensor> tensors;
+ if (compute_op->body[0]->IsInstance<ReduceNode>()) {
+ auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) ->
bool {
+ StructuralEqual eq;
+ return eq(a->combiner, b->combiner) && //
+ eq(a->source, b->source) && //
+ eq(a->axis, b->axis) && //
+ eq(a->condition, b->condition) && //
+ eq(a->init, b->init);
+ };
+ PrimExpr expr_body = compute_op->body[0];
+ tensors.push_back(compute_op.output(0));
+ const tir::ReduceNode* reduce = expr_body.as<tir::ReduceNode>();
+ // specially handle reduction inline for multiplre reductions.
+ for (size_t k = 1; k < compute_op->body.size(); ++k) {
+ const tir::ReduceNode* reduce_ =
compute_op->body[k].as<tir::ReduceNode>();
+ ICHECK(reduce_);
+ ICHECK(f_reducer_equal(reduce_, reduce))
+ << "The Reduce inputs of ComputeOp should have the same attribute
except value_index, "
+ << "but the first argument has body " << GetRef<PrimExpr>(reduce_)
<< ", while the " << k
+ << "-th argument has body " << GetRef<PrimExpr>(reduce);
+ tensors.push_back(compute_op.output(k));
+ }
+ } else {
+ for (size_t k = 0; k < compute_op->body.size(); ++k) {
+ tensors.push_back(compute_op.output(k));
+ }
+ }
+ // Step 2. Prepare buffers for compute outputs
// - Declare buffers
// - Update `op2buffers`
// - Add the non-argument tensors to `alloc_buffer` of the root block
@@ -212,32 +281,94 @@ BlockRealize GenerateBlockFromTensors(const
te::ComputeOp& compute_op,
Buffer buffer = decl_buffer(tensor->shape, tensor->dtype,
tensor->GetNameHint(), "global");
info->tensor2buffers[tensor] = buffer;
buffers.push_back(buffer);
-
if (!info->IsArg(tensor)) {
info->root_alloc.push_back(info->tensor2buffers[tensor]);
}
}
+ return buffers;
+}
- // Step 3. Calculate indices for BufferStore
- Array<PrimExpr> indices;
- indices.reserve(compute_op->axis.size());
- for (const IterVar& iter_var : compute_op->axis) {
- auto it = var_map.find(iter_var->var.get());
- ICHECK(it != var_map.end());
- indices.push_back(it->second);
+/*!
+ * \brief Generate block annotation dict from compute op attrs.
+ * \param compute_op The target compute op.
+ * \param info Generation context info.
+ * \returns The block annotation dict.
+ **/
+Map<String, ObjectRef> GenerateBlockAnnotations(const te::ComputeOp&
compute_op,
+ CreateFuncInfo* info) {
+ Map<String, ObjectRef> annotations;
+ auto mutate_attr = [&info](const ObjectRef& value) -> ObjectRef {
+ if (auto tensor_value = value.as<te::Tensor>()) {
+ return info->tensor2buffers.at(tensor_value.value());
+ } else {
+ return value;
+ }
+ };
+ for (const auto& pair : compute_op->attrs) {
+ const String& key = pair.first;
+ const ObjectRef& value = pair.second;
+ // TensorIR will not allow Tensor data structure
+ if (value->IsInstance<ArrayNode>()) {
+ const auto array_value = Downcast<Array<ObjectRef>>(value);
+ annotations.Set(key, array_value.Map(mutate_attr));
+ } else {
+ annotations.Set(key, mutate_attr(value));
+ }
}
+ // Set script_parsing_detect_access
+ annotations.Set(tir::attr::script_parsing_detect_access,
IntImm(DataType::Int(32), 3));
+ return annotations;
+}
- // Step 4. Create block body.
+/*!
+ * \brief Generate init stmt for reduction.
+ * \param indices Target store indices for the block.
+ * \param buffers Target store buffers for the block.
+ * \param reduce Reduce description node.
+ * \param var_map Var re-mapping for TE compute axes.
+ * \param info Generation context info.
+ * \returns Init stmt.
+ **/
+Stmt GenerateInitStmt(const Array<PrimExpr>& indices, const Array<Buffer>&
buffers,
+ const ReduceNode* reduce, const Map<Var, PrimExpr>&
var_map,
+ CreateFuncInfo* info) {
// helper to transform the expr and remap iters to the block domain
auto f_transform_and_remap = [&](const PrimExpr& e) {
return Substitute(info->transformer(e), var_map);
};
- String block_name{nullptr};
Optional<Stmt> init = NullOpt;
Stmt body;
+ int n_buffers = buffers.size();
+ Array<Stmt> init_stmts;
+ init_stmts.reserve(n_buffers);
+ for (int i = 0; i < n_buffers; ++i) {
+ const Buffer& buffer = buffers[i];
+ PrimExpr identity =
f_transform_and_remap(reduce->combiner->identity_element[i]);
+ init_stmts.push_back(BufferStore(buffer, identity, indices));
+ }
+ return SeqStmt::Flatten(init_stmts);
+}
+
+/*!
+ * \brief Generate body execution stmt.
+ * \param indices Target store indices for the block.
+ * \param buffers Target store buffers for the block.
+ * \param var_map Var re-mapping for TE compute axes.
+ * \param expr_body Target computation expression.
+ * \param info Generation context info.
+ * \param analyzer Arithmetic analyzer in context.
+ * \returns Init stmt.
+ **/
+Stmt GenerateBodyStmt(const Array<PrimExpr>& indices, const Array<Buffer>&
buffers,
+ const Map<Var, PrimExpr>& var_map, PrimExpr expr_body,
CreateFuncInfo* info,
+ arith::Analyzer* analyzer) {
+ // helper to transform the expr and remap iters to the block domain
+ auto f_transform_and_remap = [&](const PrimExpr& e) {
+ return Substitute(info->transformer(e), var_map);
+ };
+ Stmt body;
if (const auto* reduce = expr_body.as<ReduceNode>()) {
// Case 1. Reduce compute
- block_name = info->FreshName(compute_op->name);
int n_buffers = buffers.size();
Array<PrimExpr> lhs;
@@ -258,10 +389,8 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp&
compute_op,
Array<Var> temp_vars;
Array<Stmt> body_stmts;
- Array<Stmt> init_stmts;
temp_vars.reserve(n_buffers);
body_stmts.reserve(n_buffers);
- init_stmts.reserve(n_buffers);
// - When there is only one buffer, we directly create a BufferStore which
stores "combiner(lhs,
// rhs)" into the target buffer position.
@@ -270,8 +399,6 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp&
compute_op,
// then store the value of the variables into the target buffer
positions.
for (int i = 0; i < n_buffers; ++i) {
const Buffer& buffer = buffers[i];
- PrimExpr identity =
f_transform_and_remap(reduce->combiner->identity_element[i]);
- init_stmts.push_back(BufferStore(buffer, identity, indices));
PrimExpr value{nullptr};
if (n_buffers > 1) {
temp_vars.push_back(Var("v_" + buffer->name,
PrimType(lhs[i].dtype())));
@@ -282,8 +409,6 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp&
compute_op,
}
body_stmts.push_back(BufferStore(buffer, value, indices));
}
-
- init = SeqStmt::Flatten(init_stmts);
body = SeqStmt::Flatten(body_stmts);
if (n_buffers > 1) {
// When there are multiple buffers, we wrap the body with LetStmts.
@@ -294,116 +419,198 @@ BlockRealize GenerateBlockFromTensors(const
te::ComputeOp& compute_op,
}
} else {
// Case 2. Data parallel compute
- ICHECK_EQ(tensors.size(), 1);
- block_name = info->FreshName(tensors[0]->GetNameHint());
+ ICHECK_EQ(buffers.size(), 1);
const PrimExpr& compute_body = f_transform_and_remap(expr_body);
- body = BufferStore(info->tensor2buffers[tensors[0]],
analyzer->Simplify(compute_body), indices);
+ body = BufferStore(buffers[0], analyzer->Simplify(compute_body), indices);
}
+ return std::move(body);
+}
- // Step 5. Add script_parsing_detect_access attr for auto complete the whole
IR.
- Map<String, ObjectRef> annotations;
- auto mutate_attr = [&info](const ObjectRef& value) -> ObjectRef {
- if (auto tensor_value = value.as<te::Tensor>()) {
- return info->tensor2buffers.at(tensor_value.value());
- } else {
- return value;
+/*! \brief Record loops, block vars and binding in the single level scope. */
+struct NestedScopeInfo {
+ // loop var and range in the scope.
+ std::vector<std::pair<Var, Range>> loop_vars;
+ // block iters for current level's block.
+ Array<IterVar> block_iters;
+ // block bindings for current level's block.
+ Array<PrimExpr> bindings;
+ // store indices for current level's block.
+ Array<PrimExpr> store_indices;
+ // mapping from original TE compute axes to new block vars.
+ Map<Var, PrimExpr> axes_remap;
+
+ // helper to add new block var
+ void AddBlockIter(const Optional<IterVar>& origin_axis, const IterVar& iter,
+ const PrimExpr& value) {
+ block_iters.push_back(iter);
+ bindings.push_back(value);
+ if (origin_axis.defined()) {
+ if (iter->iter_type != IterVarType::kCommReduce) {
+ store_indices.push_back(iter->var);
+ }
+ axes_remap.Set(origin_axis.value()->var, iter->var);
}
- };
+ }
- for (const auto& pair : compute_op->attrs) {
- const String& key = pair.first;
- const ObjectRef& value = pair.second;
- // TensorIR will not allow Tensor data structure
- if (value->IsInstance<ArrayNode>()) {
- const auto array_value = Downcast<Array<ObjectRef>>(value);
- annotations.Set(key, array_value.Map(mutate_attr));
- } else {
- annotations.Set(key, mutate_attr(value));
+ // helper to renew leaf block var defs to ensure SSA.
+ void Renew(const Array<IterVar>& origin_axes) {
+ block_iters.MutateByApply([](const IterVar& itervar) {
+ auto n = make_object<IterVarNode>(*itervar.get());
+ n->var = n->var.copy_with_suffix("");
+ return IterVar(n);
+ });
+ for (size_t i = 0; i < origin_axes.size(); ++i) {
+ Var block_var = block_iters[i]->var;
+ if (origin_axes[i]->iter_type != IterVarType::kCommReduce) {
+ store_indices.Set(i, block_var);
+ }
+ axes_remap.Set(origin_axes[i]->var, block_var);
}
}
- // Set script_parsing_detect_access
- annotations.Set(tir::attr::script_parsing_detect_access,
IntImm(DataType::Int(32), 3));
- if (iter_vars.empty()) {
- IterVar iter(Range::FromMinExtent(0, 1), Var("vi", DataType::Int(32)),
IterVarType::kDataPar);
- PrimExpr binding(0);
- iter_vars.push_back(iter);
- bindings.push_back(binding);
- }
-
- // Step 6. Create Block and BlockRealize.
- return BlockRealize(/*iter_values=*/std::move(bindings),
- /*predicate=*/Bool(true),
- /*block=*/
- Block(/*iter_vars=*/std::move(iter_vars),
- /*reads=*/{},
- /*writes=*/{},
- /*name_hint=*/block_name,
- /*body=*/std::move(body),
- /*init=*/std::move(init),
- /*alloc_buffers=*/{},
- /*match_buffers=*/{},
- /*annotations=*/std::move(annotations)));
-}
+};
Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo*
info,
arith::Analyzer* analyzer) {
- // Step 1. Creating loop vars for block bindings.
+ // Step 1. Collect all iter axes in original TE compute op
Array<IterVar> axes = compute_op->axis;
axes.insert(axes.end(), compute_op->reduce_axis.begin(),
compute_op->reduce_axis.end());
- Array<PrimExpr> bindings = axes.Map([&](IterVar iter_var) -> PrimExpr {
- int bits = std::max(iter_var->dom->min.dtype().bits(),
iter_var->dom->extent.dtype().bits());
- return Var(iter_var->var->name_hint, runtime::DataType::Int(bits));
- });
+ // Step 2. Prepare nested iteration scopes.
+ // For each axis, we generate loop and the first block binding at the level
it belongs to.
+ // In lower levels, we just create new block var and bind it to the previous
level block var.
+ auto axes_levels = GenerateNestedIterLevels(axes, analyzer);
+ ICHECK(!axes_levels.empty());
+ std::vector<NestedScopeInfo> scopes;
+ scopes.reserve(axes_levels.size());
+ std::unordered_set<Var> defined_axes;
+ for (size_t i = 0; i < axes_levels.size(); ++i) {
+ NestedScopeInfo cur_scope;
+ for (size_t j = 0; j < axes.size(); ++j) {
+ const IterVar& axis = axes[j];
+ DataType index_type =
+ DataType::Int(std::max(axis->dom->min.dtype().bits(),
axis->dom->extent.dtype().bits()));
+ bool first_times_define =
+ std::find(axes_levels[i].begin(), axes_levels[i].end(), axis) !=
axes_levels[i].end();
+ if (first_times_define) {
+ Var loop_var = Var(axis->var->name_hint, index_type);
+ Var block_var("v_" + axis->var->name_hint, index_type);
+ PrimExpr min = axis->dom->min;
+ PrimExpr extent = axis->dom->extent;
+ if (i > 0) {
+ const auto& scope_repl = scopes[i - 1].axes_remap;
+ min = Substitute(min, scope_repl);
+ extent = Substitute(extent, scope_repl);
+ }
+ Range dom = Range::FromMinExtent(analyzer->Simplify(min),
analyzer->Simplify(extent));
+ IterVar new_block_iter(dom, block_var, axis->iter_type,
axis->thread_tag, axis->span);
+ cur_scope.loop_vars.emplace_back(loop_var, dom);
+ cur_scope.AddBlockIter(axis, new_block_iter, loop_var);
+ defined_axes.insert(axis->var);
+ } else if (defined_axes.count(axis->var)) {
+ ICHECK_GT(i, 0);
+ ICHECK(scopes[i - 1].axes_remap.count(axis->var));
+ PrimExpr prev_binding = scopes[i - 1].axes_remap.at(axis->var);
+ Var block_var("v_" + axis->var->name_hint, index_type);
+ Range dom = Range::FromMinExtent(prev_binding, make_const(index_type,
1));
+ IterVar new_block_iter(dom, block_var, axis->iter_type,
axis->thread_tag, axis->span);
+ cur_scope.AddBlockIter(axis, new_block_iter, prev_binding);
+ }
+ }
+ if (i == axes_levels.size() - 1 && cur_scope.block_iters.empty()) {
+ // for the leaf scope, we ensure at least one block var exists
+ IterVar dummy(Range::FromMinExtent(0, 1), Var("vi", DataType::Int(32)),
+ IterVarType::kDataPar);
+ cur_scope.AddBlockIter(NullOpt, dummy, 0);
+ }
+ scopes.push_back(cur_scope);
+ }
- // Step 2. Generate block bodies.
- Array<Stmt> seq_stmt;
- if (compute_op->body[0]->IsInstance<ReduceNode>()) {
- auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) ->
bool {
- StructuralEqual eq;
- return eq(a->combiner, b->combiner) && //
- eq(a->source, b->source) && //
- eq(a->axis, b->axis) && //
- eq(a->condition, b->condition) && //
- eq(a->init, b->init);
- };
+ // Step 3. Generate output buffers for each output tensor
+ Array<Buffer> buffers = GenerateOutputBuffers(compute_op, info);
+ // Step 4. Generate leaf block stmts.
+ Array<Stmt> seq_stmt;
+ auto leaf = scopes.back();
+ Map<String, ObjectRef> annotations = GenerateBlockAnnotations(compute_op,
info);
+ const ReduceNode* reduce = compute_op->body[0].as<ReduceNode>();
+ if (reduce) {
PrimExpr expr_body = compute_op->body[0];
- Array<te::Tensor> tensors = {compute_op.output(0)};
- const tir::ReduceNode* reduce = expr_body.as<tir::ReduceNode>();
- // specially handle reduction inline for multiplre reductions.
- for (size_t k = 1; k < compute_op->body.size(); ++k) {
- const tir::ReduceNode* reduce_ =
compute_op->body[k].as<tir::ReduceNode>();
- ICHECK(reduce_);
- ICHECK(f_reducer_equal(reduce_, reduce))
- << "The Reduce inputs of ComputeOp should have the same attribute
except value_index, "
- << "but the first argument has body " << GetRef<PrimExpr>(reduce_)
<< ", while the " << k
- << "-th argument has body " << GetRef<PrimExpr>(reduce);
- tensors.push_back(compute_op.output(k));
- }
+ Stmt init = GenerateInitStmt(leaf.store_indices, buffers, reduce,
leaf.axes_remap, info);
+ Stmt body =
+ GenerateBodyStmt(leaf.store_indices, buffers, leaf.axes_remap,
expr_body, info, analyzer);
+ seq_stmt.push_back(BlockRealize(/*iter_values=*/leaf.bindings,
+ /*predicate=*/Bool(true),
+ /*block=*/
+ Block(/*iter_vars=*/leaf.block_iters,
+ /*reads=*/{},
+ /*writes=*/{},
+
/*name_hint=*/info->FreshName(compute_op->name),
+ /*body=*/body,
+ /*init=*/init,
+ /*alloc_buffers=*/{},
+ /*match_buffers=*/{},
+ /*annotations=*/annotations)));
- seq_stmt.push_back(GenerateBlockFromTensors(compute_op, tensors, bindings,
std::move(expr_body),
- info, analyzer));
} else {
for (int i = 0; i < compute_op->num_outputs(); ++i) {
- const te::Tensor& tensor = compute_op.output(i);
+ if (i > 0) {
+ // Renew block var defs to ensure SSA
+ leaf.Renew(axes);
+ }
PrimExpr expr_body = compute_op->body[i];
- seq_stmt.push_back(GenerateBlockFromTensors(compute_op, {tensor},
bindings,
- std::move(expr_body), info,
analyzer));
+ Stmt body = GenerateBodyStmt(leaf.store_indices, {buffers[i]},
leaf.axes_remap, expr_body,
+ info, analyzer);
+ seq_stmt.push_back(BlockRealize(/*iter_values=*/leaf.bindings,
+ /*predicate=*/Bool(true),
+ /*block=*/
+ Block(/*iter_vars=*/leaf.block_iters,
+ /*reads=*/{},
+ /*writes=*/{},
+
/*name_hint=*/info->FreshName(buffers[i]->name),
+ /*body=*/body,
+ /*init=*/NullOpt,
+ /*alloc_buffers=*/{},
+ /*match_buffers=*/{},
+ /*annotations=*/annotations)));
}
}
-
Stmt body = SeqStmt::Flatten(seq_stmt);
- // Step 3. Generate loop nesting.
- for (size_t i = axes.size(); i > 0; --i) {
- const IterVar& axis = axes[i - 1];
- PrimExpr dom_min = analyzer->Simplify(axis->dom->min);
- PrimExpr dom_extent = analyzer->Simplify(axis->dom->extent);
- const Var& loop_var = Downcast<Var>(bindings[i - 1]);
- body = For(loop_var, dom_min, dom_extent, ForKind::kSerial, body);
- }
+ // Step 4. Generate nested parent scopes.
+ for (size_t i = scopes.size(); i > 0; --i) {
+ const auto& cur = scopes[i - 1];
+ if (i < scopes.size()) {
+ auto block_name = info->FreshName(compute_op->name + "_l" +
std::to_string(i));
+ const auto& block_iters = cur.block_iters;
+
+ Optional<Stmt> init{NullOpt};
+ if (reduce && std::any_of(block_iters.begin(), block_iters.end(),
[](const IterVar& iter) {
+ return iter->iter_type == IterVarType::kCommReduce;
+ })) {
+ // if the reduce axis defined in non-leaf scopes, the nested block is
also
+ // a reduction block, thus we should also insert init stmt in the
parent level.
+ init = GenerateInitStmt(cur.store_indices, buffers, reduce,
cur.axes_remap, info);
+ }
+ // wrap nested block
+ body = BlockRealize(/*iter_values=*/cur.bindings,
+ /*predicate=*/Bool(true),
+ /*block=*/
+ Block(/*iter_vars=*/block_iters,
+ /*reads=*/{},
+ /*writes=*/{},
+ /*name_hint=*/block_name,
+ /*body=*/body,
+ /*init=*/init,
+ /*alloc_buffers=*/{},
+ /*match_buffers=*/{},
+ /*annotations=*/annotations));
+ }
+ for (size_t j = cur.loop_vars.size(); j > 0; --j) {
+ const auto& [loop_var, dom] = cur.loop_vars[j - 1];
+ body = For(loop_var, dom->min, dom->extent, ForKind::kSerial, body);
+ }
+ }
return body;
}
diff --git a/tests/python/te/test_te_create_primfunc.py
b/tests/python/te/test_te_create_primfunc.py
index 1a7e03188a..0fb64e8d0f 100644
--- a/tests/python/te/test_te_create_primfunc.py
+++ b/tests/python/te/test_te_create_primfunc.py
@@ -45,8 +45,12 @@ def test_unique_name_reduction_block():
assert isinstance(s.get_sref(s.get_block("sum_1")), tir.schedule.StmtSRef)
-def _check_workload(te_workload, tir_workload, index_dtype_override=None):
+def _check_workload(te_workload, tir_workload, index_dtype_override=None,
do_simplify=False):
func = te.create_prim_func(te_workload(), index_dtype_override)
+ if do_simplify:
+ simplify = tir.transform.Simplify()
+ func = simplify(tvm.IRModule.from_expr(func))["main"]
+ tir_workload = simplify(tvm.IRModule.from_expr(tir_workload))["main"]
tvm.ir.assert_structural_equal(func, tir_workload)
# make sure that we can create schedule from the func
s = tir.Schedule(func, debug_mask="all")
@@ -887,5 +891,102 @@ def test_loop_aware_reducer_combiner():
_check_workload(te_workload, tir_workload)
+def test_adaptive_pooling_window():
+ @T.prim_func
+ def tir_workload(
+ x: T.Buffer((1, 1024, 16, 40), "float32"),
+ adaptive_pool_avg: T.Buffer((1, 1024, 12, 30), "float32"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"})
+ # fmt: off
+ adaptive_pool_sum = T.alloc_buffer((1, 1024, 12, 30))
+ for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30):
+ with T.block("adaptive_pool_sum_1"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1,
ax2, ax3])
+ T.reads(x[v_ax0, v_ax1, v_ax2 * 16 // 12:v_ax2 * 16 // 12 +
((v_ax2 % 3 * 4 + 16) // 12 + 1), v_ax3 * 40 // 30:v_ax3 * 40 // 30 + ((v_ax3 %
3 * 10 + 40) // 30 + 1)])
+ T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
+ for rv0, rv1 in T.grid(T.Select((v_ax2 * 4 + 4) % 12 == 0,
(v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12,
T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 +
40) // 30 + 1) - v_ax3 * 40 // 30):
+ with T.block("adaptive_pool_sum"):
+ v_ax0_1 = T.axis.spatial((v_ax0, v_ax0 + 1), v_ax0)
+ v_ax1_1 = T.axis.spatial((v_ax1, v_ax1 + 1), v_ax1)
+ v_ax2_1 = T.axis.spatial((v_ax2, v_ax2 + 1), v_ax2)
+ v_ax3_1 = T.axis.spatial((v_ax3, v_ax3 + 1), v_ax3)
+ v_rv0, v_rv1 = T.axis.remap("RR", [rv0, rv1])
+ T.reads(x[v_ax0_1, v_ax1_1, v_ax2_1 * 16 // 12 +
v_rv0, v_ax3_1 * 40 // 30 + v_rv1])
+ T.writes(adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1,
v_ax3_1])
+ with T.init():
+ adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1,
v_ax3_1] = T.float32(0.0)
+ adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1]
= adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] + x[v_ax0_1, v_ax1_1,
v_ax2_1 * 16 // 12 + v_rv0, v_ax3_1 * 40 // 30 + v_rv1]
+ for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30):
+ with T.block("adaptive_pool_avg"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1,
ax2, ax3])
+ T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.block_attr({"schedule_rule":
"meta_schedule.adaptive_pool_avg"})
+ adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] =
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32",
T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16)
// 12 + 1) - v_ax2 * 16 // 12) * T.Cast("float32", T.Select((v_ax3 * 10 + 10) %
30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 //
30))
+ # fmt: on
+
+ def te_workload():
+ x = te.placeholder([1, 1024, 16, 40], "float32", "x")
+ y = topi.nn.adaptive_pool(x, [12, 30], pool_type="avg")
+ f = te.create_prim_func([x, y])
+ return [x, y]
+
+ _check_workload(te_workload, tir_workload)
+
+
+def test_nested_reduce_domain_dependency():
+ @T.prim_func
+ def tir_workload(
+ x: T.Buffer((8, 8, 8, 8, 8), "float32"), compute: T.Buffer((8, 8, 8),
"float32")
+ ):
+ T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"})
+ for i0, i1, i2 in T.grid(8, 8, 8):
+ with T.block("compute_2"):
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(x[v_i0, v_i1, v_i2, 0:v_i1, 0 : v_i1 - 1])
+ T.writes(compute[v_i0, v_i1, v_i2])
+ for rv in range(v_i1):
+ with T.block("compute_1"):
+ v_i0_1 = T.axis.spatial((v_i0, v_i0 + 1), v_i0)
+ v_i1_1 = T.axis.spatial((v_i1, v_i1 + 1), v_i1)
+ v_i2_1 = T.axis.spatial((v_i2, v_i2 + 1), v_i2)
+ v_rv = T.axis.reduce(v_i1, rv)
+ T.reads(x[v_i0_1, v_i1_1, v_i2_1, v_rv, 0:v_rv])
+ T.writes(compute[v_i0_1, v_i1_1, v_i2_1])
+ with T.init():
+ compute[v_i0_1, v_i1_1, v_i2_1] = T.float32(0.0)
+ for rv_1 in range(v_rv):
+ with T.block("compute"):
+ v_i0_2 = T.axis.spatial((v_i0_1, v_i0_1 + 1),
v_i0_1)
+ v_i1_2 = T.axis.spatial((v_i1_1, v_i1_1 + 1),
v_i1_1)
+ v_i2_2 = T.axis.spatial((v_i2_1, v_i2_1 + 1),
v_i2_1)
+ v_rv_1 = T.axis.reduce((v_rv, v_rv + 1), v_rv)
+ v_rv_2 = T.axis.reduce(v_rv, rv_1)
+ T.reads(x[v_i0_2, v_i1_2, v_i2_2, v_rv_1,
v_rv_2])
+ T.writes(compute[v_i0_2, v_i1_2, v_i2_2])
+ with T.init():
+ compute[v_i0_2, v_i1_2, v_i2_2] =
T.float32(0.0)
+ compute[v_i0_2, v_i1_2, v_i2_2] = (
+ compute[v_i0_2, v_i1_2, v_i2_2]
+ + x[v_i0_2, v_i1_2, v_i2_2, v_rv_1, v_rv_2]
+ )
+
+ def te_workload():
+ x = te.placeholder([8, 8, 8, 8, 8], "float32", "x")
+
+ def fcompute(*axes):
+ r1 = te.reduce_axis(tvm.ir.Range.from_min_extent(0, axes[1]))
+ r2 = te.reduce_axis(tvm.ir.Range.from_min_extent(0, r1))
+ all_axes = [*axes, r1, r2]
+ return te.sum(x(*all_axes), [r1, r2])
+
+ y = te.compute([8, 8, 8], fcompute)
+ f = te.create_prim_func([x, y])
+ return [x, y]
+
+ _check_workload(te_workload, tir_workload)
+
+
if __name__ == "__main__":
tvm.testing.main()