This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e32d47e  [Arith] Inverse affine map (#8384)
e32d47e is described below

commit e32d47e9f5fca5772e027a9595620436285e295d
Author: Wuwei Lin <[email protected]>
AuthorDate: Sat Jul 3 20:59:16 2021 -0400

    [Arith] Inverse affine map (#8384)
    
    * [Arith] Inverse affine map
    
    * [Arith] Inverse affine map
    
    * Update iter_affine_map.h
    
    * Update iter_affine_map.h
    
    * Update iter_affine_map.py
    
    * Topology order visit
    
    * doc
    
    * fix
    
    * address comments
    
    * lint
    
    * remove print
---
 include/tvm/arith/iter_affine_map.h                |  21 +++
 python/tvm/arith/__init__.py                       |   7 +-
 python/tvm/arith/iter_affine_map.py                |  27 ++++
 src/arith/iter_affine_map.cc                       | 142 +++++++++++++++++++++
 .../python/unittest/test_arith_iter_affine_map.py  |  61 +++++++++
 5 files changed, 257 insertions(+), 1 deletion(-)

diff --git a/include/tvm/arith/iter_affine_map.h 
b/include/tvm/arith/iter_affine_map.h
index 641d0e0..d671339 100644
--- a/include/tvm/arith/iter_affine_map.h
+++ b/include/tvm/arith/iter_affine_map.h
@@ -284,6 +284,27 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& 
indices, const Map<Var,
                                  arith::Analyzer* analyzer);
 
 /*!
+ * \brief Apply the inverse of the affine transformation to the outputs.
+ *
+ * Similar to the back-propagation, starting from the outputs, it visits the 
DAG of the expressions
+ * in reverse topology order and applies the inverse of the affine 
transformation until it reaches
+ * the input. The affine iter map is required to be bijective.
+ *
+ * For example, iter_map = [l0 // 16, l0 % 16], outputs = [output_0, output_1],
+ * the affine transformation specified by `iter_map` will be applied to 
`outputs` and the result
+ * will be {l0: ((output_0*16) + output_1)}.
+ *
+ * \sa DetectIterMap
+ *
+ * \param iter_map The bijective affine iter map.
+ * \param outputs The outputs of the affine transformation.
+ *
+ * \return The map from the input to the transformed result.
+ */
+Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
+                                        const Array<PrimExpr> outputs);
+
+/*!
  * \brief Detect if bindings can be written as
  * [a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n]
  *
diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py
index d1e4431..f5a0478 100644
--- a/python/tvm/arith/__init__.py
+++ b/python/tvm/arith/__init__.py
@@ -22,4 +22,9 @@ from .bound import deduce_bound
 from .pattern import detect_linear_equation, detect_clip_bound
 from .int_solver import solve_linear_equations, solve_linear_inequalities
 from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr
-from .iter_affine_map import detect_iter_map, normalize_iter_map_to_expr, 
subspace_divide
+from .iter_affine_map import (
+    detect_iter_map,
+    normalize_iter_map_to_expr,
+    subspace_divide,
+    inverse_affine_iter_map,
+)
diff --git a/python/tvm/arith/iter_affine_map.py 
b/python/tvm/arith/iter_affine_map.py
index bfd5dfa..85513ec 100644
--- a/python/tvm/arith/iter_affine_map.py
+++ b/python/tvm/arith/iter_affine_map.py
@@ -173,3 +173,30 @@ def subspace_divide(bindings, input_iters, sub_iters, 
predicate=True, require_bi
         Empty array if no match can be found.
     """
     return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, 
predicate, require_bijective)
+
+
+def inverse_affine_iter_map(iter_map, outputs):
+    """Apply the inverse of the affine transformation to the outputs.
+    Similar to the back-propagation, starting from the outputs, it visits the 
DAG of the expressions
+    in reverse topology order and applies the inverse of the affine 
transformation until it reaches
+    the input. The affine iter map is required to be bijective.
+
+    For example, iter_map = [l0 // 16, l0 % 16], outputs = [output_0, 
output_1],
+    the affine transformation specified by `iter_map` will be applied to 
`outputs` and the result
+    will be {l0: ((output_0*16) + output_1)}.
+
+    See also :any:`detect_iter_map`.
+
+    Parameters
+    ----------
+    iter_map : List[IterSumExpr]
+        The bijective affine iter map.
+    outputs : List[PrimExpr]
+        The outputs of the affine transformation.
+
+    Returns
+    -------
+    results : Map[Var, PrimExpr]
+        The map from the input to the transformed result.
+    """
+    return _ffi_api.InverseAffineIterMap(iter_map, outputs)
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index c1daae9..e885195 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -1385,5 +1385,147 @@ TVM_REGISTER_GLOBAL("arith.SubspaceDivide")
       return SubspaceDivide(bindings, root_iters, sub_iters, predicate, 
require_bijective, &ana);
     });
 
+class InverseAffineIterMapTransformer {
+ public:
+  explicit InverseAffineIterMapTransformer(Analyzer* analyzer) : 
analyzer_(analyzer) {}
+
+  Map<Var, PrimExpr> operator()(const Array<IterSumExpr>& iter_map,
+                                const Array<PrimExpr>& outputs) {
+    ICHECK(iter_map.size() == outputs.size());
+    std::vector<const IterMapExprNode*> post_dfs_order = 
ReverseTopologyOrder(iter_map);
+
+    // initialize back propagation accumulator
+    for (const IterMapExprNode* node : post_dfs_order) {
+      backprop_.Set(GetRef<IterMapExpr>(node), Integer(0));
+    }
+    for (size_t i = 0; i < iter_map.size(); i++) {
+      backprop_.Set(iter_map[i], outputs[i]);
+    }
+
+    // run back propagation
+    for (const IterMapExprNode* node : post_dfs_order) {
+      if (node->IsInstance<IterSumExprNode>()) {
+        Visit_(Downcast<IterSumExpr>(GetRef<IterMapExpr>(node)));
+      } else {
+        ICHECK(node->IsInstance<IterSplitExprNode>());
+        Visit_(Downcast<IterSplitExpr>(GetRef<IterMapExpr>(node)));
+      }
+    }
+    return std::move(inverse_);
+  }
+
+ private:
+  void Visit_(const IterSumExpr& iter_map_expr) {
+    PrimExpr input = backprop_.at(iter_map_expr) - iter_map_expr->base;
+
+    // Case 1: Propagate to the input node directly when the sum expression 
has only one components
+    if (iter_map_expr->args.size() == 1) {
+      const auto& source = iter_map_expr->args[0];
+      backprop_.Set(source, backprop_.at(source) + input);
+      return;
+    }
+
+    // Case 2: If the sum expression has multiple components, match the fuse 
pattern and then split
+    // the sum expression for each components.
+    // For example, consider the iterator i1[dom = (0, 16)], i2[dom = (0, 8)], 
fusing i1 and i2
+    // we will have i1_i2_fused[dom = (0, 64)]. During back propagation, we 
need to split the
+    // propagated value to get the corresponding components of i1 and i2, 
which are
+    // floordiv(i1_i2_fused, 8) and floormod(i1_i2_fused, 8), respectively.
+    Array<IterSplitExpr> splits = MatchFusePattern(iter_map_expr);
+    ICHECK(!splits.empty());
+
+    for (const IterSplitExpr& split : splits) {
+      backprop_.Set(split,
+                    backprop_.at(split) + floormod(floordiv(input, 
split->scale), split->extent));
+    }
+  }
+
+  std::vector<const IterMapExprNode*> ReverseTopologyOrder(const 
Array<IterSumExpr>& iter_map) {
+    std::vector<const IterMapExprNode*> post_dfs_order;
+    std::unordered_map<IterMapExpr, bool, ObjectPtrHash, ObjectPtrEqual> 
visited;
+
+    std::function<void(const IterMapExpr&)> fvisit = [&](const IterMapExpr& 
expr) {
+      if (visited[expr]) {
+        return;
+      }
+      visited[expr] = true;
+      if (const auto* sum_expr = expr.as<IterSumExprNode>()) {
+        for (const IterSplitExpr& child : sum_expr->args) {
+          fvisit(child);
+        }
+      } else {
+        const auto* split_expr = expr.as<IterSplitExprNode>();
+        ICHECK(split_expr);
+        if (const auto* source = 
split_expr->source->source.as<IterMapExprNode>()) {
+          fvisit(GetRef<IterMapExpr>(source));
+        }
+      }
+      post_dfs_order.push_back(expr.get());
+    };
+    for (const IterSumExpr& expr : iter_map) {
+      fvisit(expr);
+    }
+    std::reverse(post_dfs_order.begin(), post_dfs_order.end());
+    return post_dfs_order;
+  }
+
+  void Visit_(const IterSplitExpr& iter_map_expr) {
+    PrimExpr input = backprop_.at(iter_map_expr) * iter_map_expr->lower_factor;
+    const IterMark& source = iter_map_expr->source;
+    if (source->source.as<IterSumExprNode>()) {
+      IterSumExpr source_expr = Downcast<IterSumExpr>(source->source);
+      backprop_.Set(source_expr, backprop_.at(source_expr) + input);
+    } else {
+      Var source_var = Downcast<Var>(source->source);
+      if (inverse_.count(source_var)) {
+        inverse_.Set(source_var, inverse_.at(source_var) + input);
+      } else {
+        inverse_.Set(source_var, input);
+      }
+    }
+  }
+
+  Array<IterSplitExpr> MatchFusePattern(const IterSumExpr sum_expr) {
+    IntImm base_scale(nullptr);
+    size_t base_index = 0;
+    for (size_t i = 0; i < sum_expr->args.size(); ++i) {
+      if (const auto* op = sum_expr->args[i]->scale.as<IntImmNode>()) {
+        if (!base_scale.defined() || op->value < base_scale->value) {
+          base_scale = GetRef<IntImm>(op);
+          base_index = i;
+        }
+      }
+    }
+    ICHECK(base_scale.defined());
+    std::vector<IterSplitExpr> iters;
+    std::vector<bool> visited(sum_expr->args.size(), false);
+    PrimExpr expected_scale = base_scale;
+    for (size_t i = 0; i < sum_expr->args.size(); i++) {
+      size_t j = i == 0 ? base_index : 0;
+      for (; j < sum_expr->args.size(); ++j) {
+        if (!visited[j] && analyzer_->CanProveEqual(sum_expr->args[j]->scale, 
expected_scale))
+          break;
+      }
+      ICHECK(j != sum_expr->args.size());
+      visited[j] = true;
+      iters.push_back(sum_expr->args[j]);
+      expected_scale *= sum_expr->args[j]->extent;
+    }
+    return iters;
+  }
+
+  Analyzer* analyzer_;
+  Map<IterMapExpr, PrimExpr> backprop_;  // the accumulator of backpropgation
+  Map<Var, PrimExpr> inverse_;           // the result of inverse 
transformation
+};
+
+Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
+                                        const Array<PrimExpr> outputs) {
+  Analyzer analyzer;
+  return InverseAffineIterMapTransformer(&analyzer)(iter_map, outputs);
+}
+
+TVM_REGISTER_GLOBAL("arith.InverseAffineIterMap").set_body_typed(InverseAffineIterMap);
+
 }  // namespace arith
 }  // namespace tvm
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py 
b/tests/python/unittest/test_arith_iter_affine_map.py
index 7bfdfc6..b34acb9 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -643,6 +643,66 @@ def test_normalize_iter_map_to_expr():
     
tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[1]), 
flm(x[0], 5))
 
 
+def test_inverse_affine_iter_map():
+    analyzer = tvm.arith.Analyzer()
+    l0 = create_iter("l0", 64)
+    l1 = create_iter("l1", 64)
+    l2 = create_iter("l2", 64)
+
+    # simple case
+    l0_0, l0_1 = isplit(l0, 16)
+    l1_0, l1_1 = isplit(l1, 4)
+    l0_1_l1_1_fused = ifuse([l0_1, l1_1])
+
+    iter_map = tvm.arith.detect_iter_map([l0_1_l1_1_fused[0], l0_0[0], 
l1_0[0]], var_dom([l0, l1]))
+    outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in 
range(len(iter_map))]
+    res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
+    assert len(res) == 2
+    l0_inverse = floormod(floordiv(outputs[0], 4), 16) + outputs[1] * 16
+    l1_inverse = floormod(outputs[0], 4) + outputs[2] * 4
+    assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
+    assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0
+
+    # compound case
+    l0_0, l0_1 = isplit(l0, 16)
+    l1_0, l1_1 = isplit(l1, 4)
+    l2_1, l2_2 = isplit(l2, 4)
+    l2_0, l2_1 = isplit(l2_1, 4)
+
+    l0_1_l2_1_l1_1_l2_0_fused = ifuse([l0_1, l2_1, l1_1, l2_0])
+
+    iter_map = tvm.arith.detect_iter_map(
+        [l0_1_l2_1_l1_1_l2_0_fused[0], l0_0[0], l2_2[0], l1_0[0]], 
var_dom([l0, l1, l2])
+    )
+    outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in 
range(len(iter_map))]
+    res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
+    assert len(res) == 3
+    l0_inverse = floormod(floordiv(outputs[0], 64), 16) + outputs[1] * 16
+    l1_inverse = floormod(floordiv(outputs[0], 4), 4) + outputs[3] * 4
+    l2_inverse = (
+        floormod(outputs[0], 4) * 16 + floormod(floordiv(outputs[0], 16), 4) * 
4 + outputs[2]
+    )
+
+    assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
+    assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0
+    assert analyzer.simplify(res[l2[0]] - l2_inverse) == 0
+
+    # diamond-shape DAG
+    l0_0, l0_1 = isplit(l0, 16)
+    l1 = ifuse([l0_1, l0_0])
+    l1_0, l1_1 = isplit(l1, 8)
+    l2 = ifuse([l1_1, l1_0])
+
+    iter_map = tvm.arith.detect_iter_map([l2[0]], var_dom([l0]))
+    outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in 
range(len(iter_map))]
+    res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
+    assert len(res) == 1
+    l1_inverse = floormod(outputs[0], 8) * 8 + floormod(floordiv(outputs[0], 
8), 8)
+    l0_inverse = floormod(l1_inverse, 4) * 16 + floormod(floordiv(l1_inverse, 
4), 16)
+
+    assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
+
+
 if __name__ == "__main__":
     test_split()
     test_trivial()
@@ -652,3 +712,4 @@ if __name__ == "__main__":
     test_normalize_iter_map_to_expr()
     test_subspace_division()
     test_complex()
+    test_inverse_affine_iter_map()

Reply via email to