This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 370ec6a2c9 [TE][CreatePrimFunc] Fix loop carried dependency case with 
nested block levels (#17474)
370ec6a2c9 is described below

commit 370ec6a2c93a444e5b8eb09877567995dfa866fd
Author: wrongtest <[email protected]>
AuthorDate: Thu Nov 14 12:51:27 2024 +0800

    [TE][CreatePrimFunc] Fix loop carried dependency case with nested block 
levels (#17474)
    
    Nested level primfunc block generation for axis dependencies
    
    Co-authored-by: wrongtest <[email protected]>
---
 python/tvm/script/ir_builder/tir/ir.py     |   5 +
 src/te/operation/create_primfunc.cc        | 453 +++++++++++++++++++++--------
 tests/python/te/test_te_create_primfunc.py | 103 ++++++-
 3 files changed, 437 insertions(+), 124 deletions(-)

diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index f7face272d..59548634fc 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -522,6 +522,11 @@ def _as_range(dom: Union[ir.Range, List[PrimExpr]]) -> 
ir.Range:
     if isinstance(dom, ir.Range):
         return dom
     if isinstance(dom, (list, tuple)):
+        from tvm.arith import Analyzer  # pylint: 
disable=import-outside-toplevel
+
+        extent = Analyzer().simplify(dom[1] - dom[0])
+        if isinstance(extent, tir.IntImm):
+            return ir.Range.from_min_extent(dom[0], extent)
         return ir.Range(dom[0], dom[1])
     if hasattr(dom, "dtype"):
         return ir.Range(IntImm(dom.dtype, 0), dom)
diff --git a/src/te/operation/create_primfunc.cc 
b/src/te/operation/create_primfunc.cc
index 31815fc710..2709bd2f94 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -22,6 +22,7 @@
 #include <tvm/arith/analyzer.h>
 #include <tvm/ir/name_supply.h>
 #include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
 #include <tvm/tir/data_type_rewriter.h>
 #include <tvm/tir/function.h>
 #include <tvm/tir/stmt_functor.h>
@@ -33,6 +34,7 @@
 #include <utility>
 #include <vector>
 
+#include "../../support/array.h"
 #include "../../tir/ir/functor_common.h"
 #include "../../tir/transforms/ir_utils.h"
 #include "../schedule/graph.h"
@@ -180,30 +182,97 @@ class LayoutFreePlaceholdersNormalizer : public 
StmtMutator {
                                    "workload"};
 };
 
-BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
-                                      const Array<te::Tensor>& tensors, 
Array<PrimExpr> bindings,
-                                      PrimExpr expr_body, CreateFuncInfo* info,
-                                      arith::Analyzer* analyzer) {
-  // Step 1. Push_back data_par axis and reduce_axis into block_vars.
-  Array<IterVar> iter_vars;
-  std::unordered_map<const VarNode*, Var> var_map;
-  iter_vars.reserve(compute_op->axis.size() + compute_op->reduce_axis.size());
-  auto f_push_block_vars = [&iter_vars, &var_map, &analyzer](const 
Array<IterVar>& iters) {
-    for (IterVar iter_var : iters) {
-      // Create new var
-      Var new_var("v_" + iter_var->var->name_hint, iter_var->var->dtype);
-      var_map[iter_var->var.get()] = new_var;
-
-      PrimExpr dom_min = analyzer->Simplify(iter_var->dom->min);
-      PrimExpr dom_extent = analyzer->Simplify(iter_var->dom->extent);
-      iter_vars.push_back(IterVar(Range::FromMinExtent(dom_min, dom_extent), 
new_var,
-                                  iter_var->iter_type, iter_var->thread_tag, 
iter_var->span));
+/**!
+ * \brief The iter levels specify nested structure wrt iteration domain 
dependencies.
+ * (1) Each iter should reside in exactly one level.
+ * (2) The domain of low level iter should be either free or ony depend on 
iters in high level.
+ **/
+using NestedIterLevels = std::vector<std::vector<IterVar>>;
+
+NestedIterLevels GenerateNestedIterLevels(const Array<IterVar>& axes, 
arith::Analyzer* analyzer) {
+  int global_max_depth = 0;
+  std::unordered_map<Var, int> depth;
+  std::unordered_map<Var, IterVar> var2iter;
+  for (const auto& axis : axes) {
+    var2iter[axis->var] = axis;
+  }
+
+  std::function<int(const IterVar&)> traverse = [&](const IterVar& axis) -> 
int {
+    auto depth_it = depth.find(axis->var);
+    if (depth_it != depth.end()) {  // cache
+      return depth_it->second;
+    }
+    std::vector<Var> dep_vars;
+    for (const Var& v : UndefinedVars(analyzer->Simplify(axis->dom->min))) {
+      dep_vars.push_back(v);
     }
+    for (const Var& v : UndefinedVars(analyzer->Simplify(axis->dom->extent))) {
+      dep_vars.push_back(v);
+    }
+    int cur_depth = 0;
+    for (const Var& v : dep_vars) {
+      auto it = var2iter.find(v);
+      if (it == var2iter.end()) {
+        // not axis var dependency, maybe a symbolic shape var or others.
+        continue;
+      }
+      int depth = traverse(it->second);
+      cur_depth = std::max(cur_depth, depth + 1);
+    }
+    depth.emplace_hint(depth_it, axis->var, cur_depth);
+    global_max_depth = std::max(global_max_depth, cur_depth);
+    return cur_depth;
   };
-  f_push_block_vars(compute_op->axis);
-  f_push_block_vars(compute_op->reduce_axis);
 
-  // Step 2.
+  for (const auto& axis : axes) {
+    traverse(axis);
+  }
+  NestedIterLevels levels;
+  levels.resize(global_max_depth + 1);
+  for (const auto& axis : axes) {
+    const Var& var = axis->var;
+    levels[depth[var]].push_back(axis);
+  }
+  return levels;
+}
+
+/*!
+ * \brief Generate output buffers from compute op's output tensors, and bind 
to context func info.
+ * \param compute_op The target compute op.
+ * \param info Generation context info.
+ * \returns The output buffer objects, ordered by compute op's outputs.
+ **/
+Array<Buffer> GenerateOutputBuffers(const te::ComputeOp& compute_op, 
CreateFuncInfo* info) {
+  // Step 1. Collect output tensors in TE operation.
+  Array<te::Tensor> tensors;
+  if (compute_op->body[0]->IsInstance<ReduceNode>()) {
+    auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> 
bool {
+      StructuralEqual eq;
+      return eq(a->combiner, b->combiner) &&    //
+             eq(a->source, b->source) &&        //
+             eq(a->axis, b->axis) &&            //
+             eq(a->condition, b->condition) &&  //
+             eq(a->init, b->init);
+    };
+    PrimExpr expr_body = compute_op->body[0];
+    tensors.push_back(compute_op.output(0));
+    const tir::ReduceNode* reduce = expr_body.as<tir::ReduceNode>();
+    // specially handle reduction inline for multiplre reductions.
+    for (size_t k = 1; k < compute_op->body.size(); ++k) {
+      const tir::ReduceNode* reduce_ = 
compute_op->body[k].as<tir::ReduceNode>();
+      ICHECK(reduce_);
+      ICHECK(f_reducer_equal(reduce_, reduce))
+          << "The Reduce inputs of ComputeOp should have the same attribute 
except value_index, "
+          << "but the first argument has body " << GetRef<PrimExpr>(reduce_) 
<< ", while the " << k
+          << "-th argument has body " << GetRef<PrimExpr>(reduce);
+      tensors.push_back(compute_op.output(k));
+    }
+  } else {
+    for (size_t k = 0; k < compute_op->body.size(); ++k) {
+      tensors.push_back(compute_op.output(k));
+    }
+  }
+  // Step 2. Prepare buffers for compute outputs
   //  - Declare buffers
   //  - Update `op2buffers`
   //  - Add the non-argument tensors to `alloc_buffer` of the root block
@@ -212,32 +281,94 @@ BlockRealize GenerateBlockFromTensors(const 
te::ComputeOp& compute_op,
     Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, 
tensor->GetNameHint(), "global");
     info->tensor2buffers[tensor] = buffer;
     buffers.push_back(buffer);
-
     if (!info->IsArg(tensor)) {
       info->root_alloc.push_back(info->tensor2buffers[tensor]);
     }
   }
+  return buffers;
+}
 
-  // Step 3. Calculate indices for BufferStore
-  Array<PrimExpr> indices;
-  indices.reserve(compute_op->axis.size());
-  for (const IterVar& iter_var : compute_op->axis) {
-    auto it = var_map.find(iter_var->var.get());
-    ICHECK(it != var_map.end());
-    indices.push_back(it->second);
+/*!
+ * \brief Generate block annotation dict from compute op attrs.
+ * \param compute_op The target compute op.
+ * \param info Generation context info.
+ * \returns The block annotation dict.
+ **/
+Map<String, ObjectRef> GenerateBlockAnnotations(const te::ComputeOp& 
compute_op,
+                                                CreateFuncInfo* info) {
+  Map<String, ObjectRef> annotations;
+  auto mutate_attr = [&info](const ObjectRef& value) -> ObjectRef {
+    if (auto tensor_value = value.as<te::Tensor>()) {
+      return info->tensor2buffers.at(tensor_value.value());
+    } else {
+      return value;
+    }
+  };
+  for (const auto& pair : compute_op->attrs) {
+    const String& key = pair.first;
+    const ObjectRef& value = pair.second;
+    // TensorIR will not allow Tensor data structure
+    if (value->IsInstance<ArrayNode>()) {
+      const auto array_value = Downcast<Array<ObjectRef>>(value);
+      annotations.Set(key, array_value.Map(mutate_attr));
+    } else {
+      annotations.Set(key, mutate_attr(value));
+    }
   }
+  // Set script_parsing_detect_access
+  annotations.Set(tir::attr::script_parsing_detect_access, 
IntImm(DataType::Int(32), 3));
+  return annotations;
+}
 
-  // Step 4. Create block body.
+/*!
+ * \brief Generate init stmt for reduction.
+ * \param indices Target store indices for the block.
+ * \param buffers Target store buffers for the block.
+ * \param reduce Reduce description node.
+ * \param var_map Var re-mapping for TE compute axes.
+ * \param info Generation context info.
+ * \returns Init stmt.
+ **/
+Stmt GenerateInitStmt(const Array<PrimExpr>& indices, const Array<Buffer>& 
buffers,
+                      const ReduceNode* reduce, const Map<Var, PrimExpr>& 
var_map,
+                      CreateFuncInfo* info) {
   // helper to transform the expr and remap iters to the block domain
   auto f_transform_and_remap = [&](const PrimExpr& e) {
     return Substitute(info->transformer(e), var_map);
   };
-  String block_name{nullptr};
   Optional<Stmt> init = NullOpt;
   Stmt body;
+  int n_buffers = buffers.size();
+  Array<Stmt> init_stmts;
+  init_stmts.reserve(n_buffers);
+  for (int i = 0; i < n_buffers; ++i) {
+    const Buffer& buffer = buffers[i];
+    PrimExpr identity = 
f_transform_and_remap(reduce->combiner->identity_element[i]);
+    init_stmts.push_back(BufferStore(buffer, identity, indices));
+  }
+  return SeqStmt::Flatten(init_stmts);
+}
+
+/*!
+ * \brief Generate body execution stmt.
+ * \param indices Target store indices for the block.
+ * \param buffers Target store buffers for the block.
+ * \param var_map Var re-mapping for TE compute axes.
+ * \param expr_body Target computation expression.
+ * \param info Generation context info.
+ * \param analyzer Arithmetic analyzer in context.
+ * \returns Init stmt.
+ **/
+Stmt GenerateBodyStmt(const Array<PrimExpr>& indices, const Array<Buffer>& 
buffers,
+                      const Map<Var, PrimExpr>& var_map, PrimExpr expr_body, 
CreateFuncInfo* info,
+                      arith::Analyzer* analyzer) {
+  // helper to transform the expr and remap iters to the block domain
+  auto f_transform_and_remap = [&](const PrimExpr& e) {
+    return Substitute(info->transformer(e), var_map);
+  };
+  Stmt body;
   if (const auto* reduce = expr_body.as<ReduceNode>()) {
     // Case 1. Reduce compute
-    block_name = info->FreshName(compute_op->name);
     int n_buffers = buffers.size();
 
     Array<PrimExpr> lhs;
@@ -258,10 +389,8 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& 
compute_op,
 
     Array<Var> temp_vars;
     Array<Stmt> body_stmts;
-    Array<Stmt> init_stmts;
     temp_vars.reserve(n_buffers);
     body_stmts.reserve(n_buffers);
-    init_stmts.reserve(n_buffers);
 
     // - When there is only one buffer, we directly create a BufferStore which 
stores "combiner(lhs,
     //   rhs)" into the target buffer position.
@@ -270,8 +399,6 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& 
compute_op,
     //   then store the value of the variables into the target buffer 
positions.
     for (int i = 0; i < n_buffers; ++i) {
       const Buffer& buffer = buffers[i];
-      PrimExpr identity = 
f_transform_and_remap(reduce->combiner->identity_element[i]);
-      init_stmts.push_back(BufferStore(buffer, identity, indices));
       PrimExpr value{nullptr};
       if (n_buffers > 1) {
         temp_vars.push_back(Var("v_" + buffer->name, 
PrimType(lhs[i].dtype())));
@@ -282,8 +409,6 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& 
compute_op,
       }
       body_stmts.push_back(BufferStore(buffer, value, indices));
     }
-
-    init = SeqStmt::Flatten(init_stmts);
     body = SeqStmt::Flatten(body_stmts);
     if (n_buffers > 1) {
       // When there are multiple buffers, we wrap the body with LetStmts.
@@ -294,116 +419,198 @@ BlockRealize GenerateBlockFromTensors(const 
te::ComputeOp& compute_op,
     }
   } else {
     // Case 2. Data parallel compute
-    ICHECK_EQ(tensors.size(), 1);
-    block_name = info->FreshName(tensors[0]->GetNameHint());
+    ICHECK_EQ(buffers.size(), 1);
     const PrimExpr& compute_body = f_transform_and_remap(expr_body);
-    body = BufferStore(info->tensor2buffers[tensors[0]], 
analyzer->Simplify(compute_body), indices);
+    body = BufferStore(buffers[0], analyzer->Simplify(compute_body), indices);
   }
+  return std::move(body);
+}
 
-  // Step 5. Add script_parsing_detect_access attr for auto complete the whole 
IR.
-  Map<String, ObjectRef> annotations;
-  auto mutate_attr = [&info](const ObjectRef& value) -> ObjectRef {
-    if (auto tensor_value = value.as<te::Tensor>()) {
-      return info->tensor2buffers.at(tensor_value.value());
-    } else {
-      return value;
+/*! \brief Record loops, block vars and binding in the single level scope. */
+struct NestedScopeInfo {
+  // loop var and range in the scope.
+  std::vector<std::pair<Var, Range>> loop_vars;
+  // block iters for current level's block.
+  Array<IterVar> block_iters;
+  // block bindings for current level's block.
+  Array<PrimExpr> bindings;
+  // store indices for current level's block.
+  Array<PrimExpr> store_indices;
+  // mapping from original TE compute axes to new block vars.
+  Map<Var, PrimExpr> axes_remap;
+
+  // helper to add new block var
+  void AddBlockIter(const Optional<IterVar>& origin_axis, const IterVar& iter,
+                    const PrimExpr& value) {
+    block_iters.push_back(iter);
+    bindings.push_back(value);
+    if (origin_axis.defined()) {
+      if (iter->iter_type != IterVarType::kCommReduce) {
+        store_indices.push_back(iter->var);
+      }
+      axes_remap.Set(origin_axis.value()->var, iter->var);
     }
-  };
+  }
 
-  for (const auto& pair : compute_op->attrs) {
-    const String& key = pair.first;
-    const ObjectRef& value = pair.second;
-    // TensorIR will not allow Tensor data structure
-    if (value->IsInstance<ArrayNode>()) {
-      const auto array_value = Downcast<Array<ObjectRef>>(value);
-      annotations.Set(key, array_value.Map(mutate_attr));
-    } else {
-      annotations.Set(key, mutate_attr(value));
+  // helper to renew leaf block var defs to ensure SSA.
+  void Renew(const Array<IterVar>& origin_axes) {
+    block_iters.MutateByApply([](const IterVar& itervar) {
+      auto n = make_object<IterVarNode>(*itervar.get());
+      n->var = n->var.copy_with_suffix("");
+      return IterVar(n);
+    });
+    for (size_t i = 0; i < origin_axes.size(); ++i) {
+      Var block_var = block_iters[i]->var;
+      if (origin_axes[i]->iter_type != IterVarType::kCommReduce) {
+        store_indices.Set(i, block_var);
+      }
+      axes_remap.Set(origin_axes[i]->var, block_var);
     }
   }
-  // Set script_parsing_detect_access
-  annotations.Set(tir::attr::script_parsing_detect_access, 
IntImm(DataType::Int(32), 3));
-  if (iter_vars.empty()) {
-    IterVar iter(Range::FromMinExtent(0, 1), Var("vi", DataType::Int(32)), 
IterVarType::kDataPar);
-    PrimExpr binding(0);
-    iter_vars.push_back(iter);
-    bindings.push_back(binding);
-  }
-
-  // Step 6. Create Block and BlockRealize.
-  return BlockRealize(/*iter_values=*/std::move(bindings),
-                      /*predicate=*/Bool(true),
-                      /*block=*/
-                      Block(/*iter_vars=*/std::move(iter_vars),
-                            /*reads=*/{},
-                            /*writes=*/{},
-                            /*name_hint=*/block_name,
-                            /*body=*/std::move(body),
-                            /*init=*/std::move(init),
-                            /*alloc_buffers=*/{},
-                            /*match_buffers=*/{},
-                            /*annotations=*/std::move(annotations)));
-}
+};
 
 Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* 
info,
                              arith::Analyzer* analyzer) {
-  // Step 1. Creating loop vars for block bindings.
+  // Step 1. Collect all iter axes in original TE compute op
   Array<IterVar> axes = compute_op->axis;
   axes.insert(axes.end(), compute_op->reduce_axis.begin(), 
compute_op->reduce_axis.end());
 
-  Array<PrimExpr> bindings = axes.Map([&](IterVar iter_var) -> PrimExpr {
-    int bits = std::max(iter_var->dom->min.dtype().bits(), 
iter_var->dom->extent.dtype().bits());
-    return Var(iter_var->var->name_hint, runtime::DataType::Int(bits));
-  });
+  // Step 2. Prepare nested iteration scopes.
+  // For each axis, we generate loop and the first block binding at the level 
it belongs to.
+  // In lower levels, we just create new block var and bind it to the previous 
level block var.
+  auto axes_levels = GenerateNestedIterLevels(axes, analyzer);
+  ICHECK(!axes_levels.empty());
+  std::vector<NestedScopeInfo> scopes;
+  scopes.reserve(axes_levels.size());
+  std::unordered_set<Var> defined_axes;
+  for (size_t i = 0; i < axes_levels.size(); ++i) {
+    NestedScopeInfo cur_scope;
+    for (size_t j = 0; j < axes.size(); ++j) {
+      const IterVar& axis = axes[j];
+      DataType index_type =
+          DataType::Int(std::max(axis->dom->min.dtype().bits(), 
axis->dom->extent.dtype().bits()));
+      bool first_times_define =
+          std::find(axes_levels[i].begin(), axes_levels[i].end(), axis) != 
axes_levels[i].end();
+      if (first_times_define) {
+        Var loop_var = Var(axis->var->name_hint, index_type);
+        Var block_var("v_" + axis->var->name_hint, index_type);
+        PrimExpr min = axis->dom->min;
+        PrimExpr extent = axis->dom->extent;
+        if (i > 0) {
+          const auto& scope_repl = scopes[i - 1].axes_remap;
+          min = Substitute(min, scope_repl);
+          extent = Substitute(extent, scope_repl);
+        }
+        Range dom = Range::FromMinExtent(analyzer->Simplify(min), 
analyzer->Simplify(extent));
+        IterVar new_block_iter(dom, block_var, axis->iter_type, 
axis->thread_tag, axis->span);
+        cur_scope.loop_vars.emplace_back(loop_var, dom);
+        cur_scope.AddBlockIter(axis, new_block_iter, loop_var);
+        defined_axes.insert(axis->var);
+      } else if (defined_axes.count(axis->var)) {
+        ICHECK_GT(i, 0);
+        ICHECK(scopes[i - 1].axes_remap.count(axis->var));
+        PrimExpr prev_binding = scopes[i - 1].axes_remap.at(axis->var);
+        Var block_var("v_" + axis->var->name_hint, index_type);
+        Range dom = Range::FromMinExtent(prev_binding, make_const(index_type, 
1));
+        IterVar new_block_iter(dom, block_var, axis->iter_type, 
axis->thread_tag, axis->span);
+        cur_scope.AddBlockIter(axis, new_block_iter, prev_binding);
+      }
+    }
+    if (i == axes_levels.size() - 1 && cur_scope.block_iters.empty()) {
+      // for the leaf scope, we ensure at least one block var exists
+      IterVar dummy(Range::FromMinExtent(0, 1), Var("vi", DataType::Int(32)),
+                    IterVarType::kDataPar);
+      cur_scope.AddBlockIter(NullOpt, dummy, 0);
+    }
+    scopes.push_back(cur_scope);
+  }
 
-  // Step 2. Generate block bodies.
-  Array<Stmt> seq_stmt;
-  if (compute_op->body[0]->IsInstance<ReduceNode>()) {
-    auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> 
bool {
-      StructuralEqual eq;
-      return eq(a->combiner, b->combiner) &&    //
-             eq(a->source, b->source) &&        //
-             eq(a->axis, b->axis) &&            //
-             eq(a->condition, b->condition) &&  //
-             eq(a->init, b->init);
-    };
+  // Step 3. Generate output buffers for each output tensor
+  Array<Buffer> buffers = GenerateOutputBuffers(compute_op, info);
 
+  // Step 4. Generate leaf block stmts.
+  Array<Stmt> seq_stmt;
+  auto leaf = scopes.back();
+  Map<String, ObjectRef> annotations = GenerateBlockAnnotations(compute_op, 
info);
+  const ReduceNode* reduce = compute_op->body[0].as<ReduceNode>();
+  if (reduce) {
     PrimExpr expr_body = compute_op->body[0];
-    Array<te::Tensor> tensors = {compute_op.output(0)};
-    const tir::ReduceNode* reduce = expr_body.as<tir::ReduceNode>();
-    // specially handle reduction inline for multiplre reductions.
-    for (size_t k = 1; k < compute_op->body.size(); ++k) {
-      const tir::ReduceNode* reduce_ = 
compute_op->body[k].as<tir::ReduceNode>();
-      ICHECK(reduce_);
-      ICHECK(f_reducer_equal(reduce_, reduce))
-          << "The Reduce inputs of ComputeOp should have the same attribute 
except value_index, "
-          << "but the first argument has body " << GetRef<PrimExpr>(reduce_) 
<< ", while the " << k
-          << "-th argument has body " << GetRef<PrimExpr>(reduce);
-      tensors.push_back(compute_op.output(k));
-    }
+    Stmt init = GenerateInitStmt(leaf.store_indices, buffers, reduce, 
leaf.axes_remap, info);
+    Stmt body =
+        GenerateBodyStmt(leaf.store_indices, buffers, leaf.axes_remap, 
expr_body, info, analyzer);
+    seq_stmt.push_back(BlockRealize(/*iter_values=*/leaf.bindings,
+                                    /*predicate=*/Bool(true),
+                                    /*block=*/
+                                    Block(/*iter_vars=*/leaf.block_iters,
+                                          /*reads=*/{},
+                                          /*writes=*/{},
+                                          
/*name_hint=*/info->FreshName(compute_op->name),
+                                          /*body=*/body,
+                                          /*init=*/init,
+                                          /*alloc_buffers=*/{},
+                                          /*match_buffers=*/{},
+                                          /*annotations=*/annotations)));
 
-    seq_stmt.push_back(GenerateBlockFromTensors(compute_op, tensors, bindings, 
std::move(expr_body),
-                                                info, analyzer));
   } else {
     for (int i = 0; i < compute_op->num_outputs(); ++i) {
-      const te::Tensor& tensor = compute_op.output(i);
+      if (i > 0) {
+        // Renew block var defs to ensure SSA
+        leaf.Renew(axes);
+      }
       PrimExpr expr_body = compute_op->body[i];
-      seq_stmt.push_back(GenerateBlockFromTensors(compute_op, {tensor}, 
bindings,
-                                                  std::move(expr_body), info, 
analyzer));
+      Stmt body = GenerateBodyStmt(leaf.store_indices, {buffers[i]}, 
leaf.axes_remap, expr_body,
+                                   info, analyzer);
+      seq_stmt.push_back(BlockRealize(/*iter_values=*/leaf.bindings,
+                                      /*predicate=*/Bool(true),
+                                      /*block=*/
+                                      Block(/*iter_vars=*/leaf.block_iters,
+                                            /*reads=*/{},
+                                            /*writes=*/{},
+                                            
/*name_hint=*/info->FreshName(buffers[i]->name),
+                                            /*body=*/body,
+                                            /*init=*/NullOpt,
+                                            /*alloc_buffers=*/{},
+                                            /*match_buffers=*/{},
+                                            /*annotations=*/annotations)));
     }
   }
-
   Stmt body = SeqStmt::Flatten(seq_stmt);
 
-  // Step 3. Generate loop nesting.
-  for (size_t i = axes.size(); i > 0; --i) {
-    const IterVar& axis = axes[i - 1];
-    PrimExpr dom_min = analyzer->Simplify(axis->dom->min);
-    PrimExpr dom_extent = analyzer->Simplify(axis->dom->extent);
-    const Var& loop_var = Downcast<Var>(bindings[i - 1]);
-    body = For(loop_var, dom_min, dom_extent, ForKind::kSerial, body);
-  }
+  // Step 4. Generate nested parent scopes.
+  for (size_t i = scopes.size(); i > 0; --i) {
+    const auto& cur = scopes[i - 1];
+    if (i < scopes.size()) {
+      auto block_name = info->FreshName(compute_op->name + "_l" + 
std::to_string(i));
+      const auto& block_iters = cur.block_iters;
+
+      Optional<Stmt> init{NullOpt};
+      if (reduce && std::any_of(block_iters.begin(), block_iters.end(), 
[](const IterVar& iter) {
+            return iter->iter_type == IterVarType::kCommReduce;
+          })) {
+        // if the reduce axis defined in non-leaf scopes, the nested block is 
also
+        // a reduction block, thus we should also insert init stmt in the 
parent level.
+        init = GenerateInitStmt(cur.store_indices, buffers, reduce, 
cur.axes_remap, info);
+      }
 
+      // wrap nested block
+      body = BlockRealize(/*iter_values=*/cur.bindings,
+                          /*predicate=*/Bool(true),
+                          /*block=*/
+                          Block(/*iter_vars=*/block_iters,
+                                /*reads=*/{},
+                                /*writes=*/{},
+                                /*name_hint=*/block_name,
+                                /*body=*/body,
+                                /*init=*/init,
+                                /*alloc_buffers=*/{},
+                                /*match_buffers=*/{},
+                                /*annotations=*/annotations));
+    }
+    for (size_t j = cur.loop_vars.size(); j > 0; --j) {
+      const auto& [loop_var, dom] = cur.loop_vars[j - 1];
+      body = For(loop_var, dom->min, dom->extent, ForKind::kSerial, body);
+    }
+  }
   return body;
 }
 
diff --git a/tests/python/te/test_te_create_primfunc.py 
b/tests/python/te/test_te_create_primfunc.py
index 1a7e03188a..0fb64e8d0f 100644
--- a/tests/python/te/test_te_create_primfunc.py
+++ b/tests/python/te/test_te_create_primfunc.py
@@ -45,8 +45,12 @@ def test_unique_name_reduction_block():
     assert isinstance(s.get_sref(s.get_block("sum_1")), tir.schedule.StmtSRef)
 
 
-def _check_workload(te_workload, tir_workload, index_dtype_override=None):
+def _check_workload(te_workload, tir_workload, index_dtype_override=None, 
do_simplify=False):
     func = te.create_prim_func(te_workload(), index_dtype_override)
+    if do_simplify:
+        simplify = tir.transform.Simplify()
+        func = simplify(tvm.IRModule.from_expr(func))["main"]
+        tir_workload = simplify(tvm.IRModule.from_expr(tir_workload))["main"]
     tvm.ir.assert_structural_equal(func, tir_workload)
     # make sure that we can create schedule from the func
     s = tir.Schedule(func, debug_mask="all")
@@ -887,5 +891,102 @@ def test_loop_aware_reducer_combiner():
     _check_workload(te_workload, tir_workload)
 
 
+def test_adaptive_pooling_window():
+    @T.prim_func
+    def tir_workload(
+        x: T.Buffer((1, 1024, 16, 40), "float32"),
+        adaptive_pool_avg: T.Buffer((1, 1024, 12, 30), "float32"),
+    ):
+        T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"})
+        # fmt: off
+        adaptive_pool_sum = T.alloc_buffer((1, 1024, 12, 30))
+        for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30):
+            with T.block("adaptive_pool_sum_1"):
+                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, 
ax2, ax3])
+                T.reads(x[v_ax0, v_ax1, v_ax2 * 16 // 12:v_ax2 * 16 // 12 + 
((v_ax2 % 3 * 4 + 16) // 12 + 1), v_ax3 * 40 // 30:v_ax3 * 40 // 30 + ((v_ax3 % 
3 * 10 + 40) // 30 + 1)])
+                T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
+                for rv0, rv1 in T.grid(T.Select((v_ax2 * 4 + 4) % 12 == 0, 
(v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12, 
T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 
40) // 30 + 1) - v_ax3 * 40 // 30):
+                    with T.block("adaptive_pool_sum"):
+                        v_ax0_1 = T.axis.spatial((v_ax0, v_ax0 + 1), v_ax0)
+                        v_ax1_1 = T.axis.spatial((v_ax1, v_ax1 + 1), v_ax1)
+                        v_ax2_1 = T.axis.spatial((v_ax2, v_ax2 + 1), v_ax2)
+                        v_ax3_1 = T.axis.spatial((v_ax3, v_ax3 + 1), v_ax3)
+                        v_rv0, v_rv1 = T.axis.remap("RR", [rv0, rv1])
+                        T.reads(x[v_ax0_1, v_ax1_1, v_ax2_1 * 16 // 12 + 
v_rv0, v_ax3_1 * 40 // 30 + v_rv1])
+                        T.writes(adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, 
v_ax3_1])
+                        with T.init():
+                            adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, 
v_ax3_1] = T.float32(0.0)
+                        adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] 
= adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] + x[v_ax0_1, v_ax1_1, 
v_ax2_1 * 16 // 12 + v_rv0, v_ax3_1 * 40 // 30 + v_rv1]
+        for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30):
+            with T.block("adaptive_pool_avg"):
+                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, 
ax2, ax3])
+                T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
+                T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
+                T.block_attr({"schedule_rule": 
"meta_schedule.adaptive_pool_avg"})
+                adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = 
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32", 
T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) 
// 12 + 1) - v_ax2 * 16 // 12) * T.Cast("float32", T.Select((v_ax3 * 10 + 10) % 
30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 
30))
+        # fmt: on
+
+    def te_workload():
+        x = te.placeholder([1, 1024, 16, 40], "float32", "x")
+        y = topi.nn.adaptive_pool(x, [12, 30], pool_type="avg")
+        f = te.create_prim_func([x, y])
+        return [x, y]
+
+    _check_workload(te_workload, tir_workload)
+
+
+def test_nested_reduce_domain_dependency():
+    @T.prim_func
+    def tir_workload(
+        x: T.Buffer((8, 8, 8, 8, 8), "float32"), compute: T.Buffer((8, 8, 8), 
"float32")
+    ):
+        T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"})
+        for i0, i1, i2 in T.grid(8, 8, 8):
+            with T.block("compute_2"):
+                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+                T.reads(x[v_i0, v_i1, v_i2, 0:v_i1, 0 : v_i1 - 1])
+                T.writes(compute[v_i0, v_i1, v_i2])
+                for rv in range(v_i1):
+                    with T.block("compute_1"):
+                        v_i0_1 = T.axis.spatial((v_i0, v_i0 + 1), v_i0)
+                        v_i1_1 = T.axis.spatial((v_i1, v_i1 + 1), v_i1)
+                        v_i2_1 = T.axis.spatial((v_i2, v_i2 + 1), v_i2)
+                        v_rv = T.axis.reduce(v_i1, rv)
+                        T.reads(x[v_i0_1, v_i1_1, v_i2_1, v_rv, 0:v_rv])
+                        T.writes(compute[v_i0_1, v_i1_1, v_i2_1])
+                        with T.init():
+                            compute[v_i0_1, v_i1_1, v_i2_1] = T.float32(0.0)
+                        for rv_1 in range(v_rv):
+                            with T.block("compute"):
+                                v_i0_2 = T.axis.spatial((v_i0_1, v_i0_1 + 1), 
v_i0_1)
+                                v_i1_2 = T.axis.spatial((v_i1_1, v_i1_1 + 1), 
v_i1_1)
+                                v_i2_2 = T.axis.spatial((v_i2_1, v_i2_1 + 1), 
v_i2_1)
+                                v_rv_1 = T.axis.reduce((v_rv, v_rv + 1), v_rv)
+                                v_rv_2 = T.axis.reduce(v_rv, rv_1)
+                                T.reads(x[v_i0_2, v_i1_2, v_i2_2, v_rv_1, 
v_rv_2])
+                                T.writes(compute[v_i0_2, v_i1_2, v_i2_2])
+                                with T.init():
+                                    compute[v_i0_2, v_i1_2, v_i2_2] = 
T.float32(0.0)
+                                compute[v_i0_2, v_i1_2, v_i2_2] = (
+                                    compute[v_i0_2, v_i1_2, v_i2_2]
+                                    + x[v_i0_2, v_i1_2, v_i2_2, v_rv_1, v_rv_2]
+                                )
+
+    def te_workload():
+        x = te.placeholder([8, 8, 8, 8, 8], "float32", "x")
+
+        def fcompute(*axes):
+            r1 = te.reduce_axis(tvm.ir.Range.from_min_extent(0, axes[1]))
+            r2 = te.reduce_axis(tvm.ir.Range.from_min_extent(0, r1))
+            all_axes = [*axes, r1, r2]
+            return te.sum(x(*all_axes), [r1, r2])
+
+        y = te.compute([8, 8, 8], fcompute)
+        f = te.create_prim_func([x, y])
+        return [x, y]
+
+    _check_workload(te_workload, tir_workload)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to