vinx13 commented on a change in pull request #8384:
URL: https://github.com/apache/tvm/pull/8384#discussion_r663136293



##########
File path: src/arith/iter_affine_map.cc
##########
@@ -1385,5 +1385,147 @@ TVM_REGISTER_GLOBAL("arith.SubspaceDivide")
       return SubspaceDivide(bindings, root_iters, sub_iters, predicate, 
require_bijective, &ana);
     });
 
+class InverseAffineIterMapTransformer {
+ public:
+  explicit InverseAffineIterMapTransformer(Analyzer* analyzer) : 
analyzer_(analyzer) {}
+
+  Map<Var, PrimExpr> operator()(const Array<IterSumExpr>& iter_map,
+                                const Array<PrimExpr>& outputs) {
+    ICHECK(iter_map.size() == outputs.size());
+    std::vector<const IterMapExprNode*> post_dfs_order = 
ReverseTopologyOrder(iter_map);
+
+    // initialize back propagation accumulator
+    for (const IterMapExprNode* node : post_dfs_order) {
+      backprop_.Set(GetRef<IterMapExpr>(node), Integer(0));
+    }
+    for (size_t i = 0; i < iter_map.size(); i++) {
+      backprop_.Set(iter_map[i], outputs[i]);
+    }
+
+    // run back propagation
+    for (const IterMapExprNode* node : post_dfs_order) {
+      if (node->IsInstance<IterSumExprNode>()) {
+        Visit_(Downcast<IterSumExpr>(GetRef<IterMapExpr>(node)));
+      } else {
+        ICHECK(node->IsInstance<IterSplitExprNode>());
+        Visit_(Downcast<IterSplitExpr>(GetRef<IterMapExpr>(node)));
+      }
+    }
+    return std::move(inverse_);
+  }
+
+ private:
+  void Visit_(const IterSumExpr& iter_map_expr) {
+    PrimExpr input = backprop_.at(iter_map_expr) - iter_map_expr->base;
+
+    // Case 1: Propagate to the input node directly when the sum expression 
has only one components
+    if (iter_map_expr->args.size() == 1) {
+      const auto& source = iter_map_expr->args[0];
+      backprop_.Set(source, backprop_.at(source) + input);
+      return;
+    }
+
+    // Case 2: If the sum expression has multiple components, match the fuse 
pattern and then split
+    // the sum expression for each components.
+    // For example, consider the iterator i1[dom = (0, 16)], i2[dom = (0, 8)], 
fusing i1 and i2
+    // we will have i1_i2_fused[dom = (0, 64)]. During back propagation, we 
need to split the
+    // propagated value to get the corresponding components of i1 and i2, 
which are
+    // floordiv(i1_i2_fused, 8) and floormod(i1_i2_fused, 8), respectively.
+    Array<IterSplitExpr> splits = MatchFusePattern(iter_map_expr);
+    ICHECK(!splits.empty());
+
+    for (const IterSplitExpr& split : splits) {
+      backprop_.Set(split,
+                    backprop_.at(split) + floormod(floordiv(input, 
split->scale), split->extent));
+    }
+  }
+
+  std::vector<const IterMapExprNode*> ReverseTopologyOrder(const 
Array<IterSumExpr>& iter_map) {
+    std::vector<const IterMapExprNode*> post_dfs_order;
+    std::unordered_map<IterMapExpr, bool, ObjectPtrHash, ObjectPtrEqual> 
visited;
+
+    std::function<void(const IterMapExpr&)> fvisit = [&](const IterMapExpr& 
expr) {
+      if (visited[expr]) {
+        return;
+      }
+      visited[expr] = true;
+      if (const auto* sum_expr = expr.as<IterSumExprNode>()) {
+        for (const IterSplitExpr& child : sum_expr->args) {
+          fvisit(child);
+        }
+      } else {
+        const auto* split_expr = expr.as<IterSplitExprNode>();
+        ICHECK(split_expr);
+        if (const auto* source = 
split_expr->source->source.as<IterMapExprNode>()) {
+          fvisit(GetRef<IterMapExpr>(source));
+        }
+      }
+      post_dfs_order.push_back(expr.get());
+    };
+    for (const IterSumExpr& expr : iter_map) {
+      fvisit(expr);
+    }
+    std::reverse(post_dfs_order.begin(), post_dfs_order.end());
+    return post_dfs_order;
+  }
+
+  void Visit_(const IterSplitExpr& iter_map_expr) {
+    PrimExpr input = backprop_.at(iter_map_expr) * iter_map_expr->lower_factor;
+    const IterMark& source = iter_map_expr->source;
+    if (source->source.as<IterSumExprNode>()) {
+      IterSumExpr source_expr = Downcast<IterSumExpr>(source->source);
+      backprop_.Set(source_expr, backprop_.at(source_expr) + input);
+    } else {
+      Var source_var = Downcast<Var>(source->source);
+      if (inverse_.count(source_var)) {
+        inverse_.Set(source_var, inverse_.at(source_var) + input);
+      } else {
+        inverse_.Set(source_var, input);
+      }
+    }
+  }
+
+  Array<IterSplitExpr> MatchFusePattern(const IterSumExpr sum_expr) {
+    IntImm base_scale(nullptr);
+    size_t base_index = 0;
+    for (size_t i = 0; i < sum_expr->args.size(); ++i) {
+      if (const auto* op = sum_expr->args[i]->scale.as<IntImmNode>()) {
+        if (!base_scale.defined() || op->value < base_scale->value) {
+          base_scale = GetRef<IntImm>(op);
+          base_index = i;
+        }
+      }
+    }
+    ICHECK(base_scale.defined());
+    std::vector<IterSplitExpr> iters;
+    std::vector<bool> visited(sum_expr->args.size(), false);
+    PrimExpr expected_scale = base_scale;

Review comment:
       if the sum expression is produced by `DetectIterMap` this is true, can 
we can simply this whole function, I can add some assertions to check this 
assumption @tqchen 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to