This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 37053e1  [Tensorize] Support conds depend on outer loop vars inside 
tensorize scope (#7497)
37053e1 is described below

commit 37053e1708c6565c8a82c31c0ffc78e594bfe3b0
Author: lee <[email protected]>
AuthorDate: Thu Mar 4 01:30:31 2021 +0800

    [Tensorize] Support conds depend on outer loop vars inside tensorize scope 
(#7497)
    
    * [Tensorize] Support conds depend on outer loop vars inside tensorize scope
    
    * Reformat
---
 src/te/operation/op_utils.cc                       |  8 +++
 src/te/operation/op_utils.h                        | 10 +++-
 src/te/operation/tensorize.cc                      |  6 ++-
 .../python/unittest/test_te_schedule_tensorize.py  | 57 ++++++++++++++++++----
 4 files changed, 69 insertions(+), 12 deletions(-)

diff --git a/src/te/operation/op_utils.cc b/src/te/operation/op_utils.cc
index 32ffccb..b3897e1 100644
--- a/src/te/operation/op_utils.cc
+++ b/src/te/operation/op_utils.cc
@@ -243,6 +243,14 @@ Stmt Substitute(Stmt s, const std::unordered_map<IterVar, 
PrimExpr>& value_map)
   return tir::Substitute(s, init);
 }
 
+PrimExpr Substitute(PrimExpr s, const std::unordered_map<IterVar, PrimExpr>& 
value_map) {
+  std::unordered_map<const VarNode*, PrimExpr> init;
+  for (const auto& kv : value_map) {
+    init[kv.first->var.get()] = kv.second;
+  }
+  return tir::Substitute(s, init);
+}
+
 IterVarType ForKindToIterVarType(tir::ForKind kind) {
   switch (kind) {
     case ForKind::kSerial:
diff --git a/src/te/operation/op_utils.h b/src/te/operation/op_utils.h
index e6bf2ca..02f4a86 100644
--- a/src/te/operation/op_utils.h
+++ b/src/te/operation/op_utils.h
@@ -73,7 +73,7 @@ std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& 
predicates);
  */
 Stmt ReplaceTensor(Stmt stmt, const std::unordered_map<Tensor, Tensor>& 
replace);
 /*!
- * \brief Replace the tensor reference (especially in Call's) in stmt by the 
replace map.
+ * \brief Replace the tensor reference (especially in Call's) in primExpr by 
the replace map.
  * \param expr The expression to be processed.
  * \param replace The replacement rule.
  */
@@ -88,6 +88,14 @@ PrimExpr ReplaceTensor(PrimExpr expr, const 
std::unordered_map<Tensor, Tensor>&
 Stmt Substitute(Stmt stmt, const std::unordered_map<IterVar, PrimExpr>& 
value_map);
 
 /*!
+ * \brief Substitute the variables of primExpr by value map.
+ * \param expr the expression to be processed.
+ * \param value_map The value map.
+ * \return Substituted result.
+ */
+PrimExpr Substitute(PrimExpr expr, const std::unordered_map<IterVar, 
PrimExpr>& value_map);
+
+/*!
  * \brief Converts Halide ForKind to its corresponding IterVarType
  * \param kind The ForKind to be converted
  */
diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc
index bfd1ec5..ea71322 100644
--- a/src/te/operation/tensorize.cc
+++ b/src/te/operation/tensorize.cc
@@ -311,6 +311,7 @@ Array<PrimExpr> MatchTensorizeBody(const ComputeOpNode* 
self, const Stage& stage
 }
 
 void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage,
+                         const std::unordered_map<IterVar, PrimExpr>& 
value_map,
                          const std::unordered_map<IterVar, Range>& dom_map,
                          const std::unordered_map<IterVar, Range>& out_dom,
                          const std::unordered_map<Tensor, Array<Range> >& 
in_region,
@@ -327,7 +328,8 @@ void VerifyTensorizeBody(const ComputeOpNode* self, const 
Stage& stage,
 
   for (size_t i = 0; i < body.size(); ++i) {
     PrimExpr lhs = ana.Simplify(body[i]);
-    PrimExpr rhs = ana.Simplify(intrin_compute->body[i]);
+    // run substitution because the intrin body could depend on outer loop 
vars.
+    PrimExpr rhs = ana.Simplify(Substitute(intrin_compute->body[i], 
value_map));
     if (lhs.dtype() != rhs.dtype()) {
       LOG(FATAL) << "Failed to match the data type with TensorIntrin " << 
intrin->name
                  << "'s declaration "
@@ -349,7 +351,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& 
stage,
   ICHECK(intrin.defined());
   ComputeLoopNest n = ComputeLoopNest::Create(self, stage, dom_map, 
debug_keep_trivial_loop);
   VerifyTensorizeLoopNest(self, stage, n, tloc);
-  VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin);
+  VerifyTensorizeBody(self, stage, n.main_vmap, dom_map, out_dom, in_region, 
intrin);
   // Start bind data.
   Stmt nop = Evaluate(0);
   std::vector<Stmt> input_bind_nest, output_bind_nest;
diff --git a/tests/python/unittest/test_te_schedule_tensorize.py 
b/tests/python/unittest/test_te_schedule_tensorize.py
index 83a5d30..fdafdb7 100644
--- a/tests/python/unittest/test_te_schedule_tensorize.py
+++ b/tests/python/unittest/test_te_schedule_tensorize.py
@@ -18,14 +18,22 @@ import tvm
 from tvm import te
 
 
-def intrin_vadd(n):
+def intrin_vadd(xo, m, n):
     x = te.placeholder((n,), name="vx")
     y = te.placeholder((n,), name="vy")
-    z = te.compute(x.shape, lambda i: x[i] + y[i], name="z")
+    if m % n == 0:
+        body = lambda i: x[i] + y[i]
+    else:
+        body = lambda i: tvm.tir.Select(
+            xo * n + i < m, x[i] + y[i], tvm.tir.const(0, dtype=x.dtype)
+        )
+    z = te.compute(x.shape, body, name="z")
 
     def intrin_func(ins, outs):
         xx, yy = ins
         zz = outs[0]
+        # special handle needed to tackle tail loop part when m % n != 0
+        # here is tvm.min(n, m - xo * n)
         return tvm.tir.call_packed("vadd", xx, yy, zz)
 
     buffer_params = {"offset_factor": 16}
@@ -84,15 +92,17 @@ def intrin_gemv_no_reset(m, n):
 
 
 def test_tensorize_vadd():
-    m = 128
-    x = te.placeholder((m,), name="x")
-    y = te.placeholder((m,), name="y")
-    z = te.compute(x.shape, lambda i: x[i] + y[i], name="z")
+    def add(m):
+        x = te.placeholder((m,), name="x")
+        y = te.placeholder((m,), name="y")
+        z = te.compute(x.shape, lambda i: x[i] + y[i], name="z")
+        return x, y, z
 
-    def check(factor):
+    def check(m, factor):
+        x, y, z = add(m)
         s = te.create_schedule(z.op)
         xo, xi = s[z].split(z.op.axis[0], factor=factor)
-        vadd = intrin_vadd(factor)
+        vadd = intrin_vadd(xo, m, factor)
         s[z].tensorize(xi, vadd)
         s = s.normalize()
         dom_map = tvm.te.schedule.InferBound(s)
@@ -108,7 +118,36 @@ def test_tensorize_vadd():
         stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
         tvm.lower(s, [x, y, z])
 
-    check(16)
+    def check_cache_write(m, factor):
+        x, y, z = add(m)
+        s = te.create_schedule(z.op)
+        _, _ = s[z].split(z.op.axis[0], factor=factor)
+
+        z_global = s.cache_write(z, "global")
+        xo, xi = z_global.op.axis
+
+        vadd = intrin_vadd(xo, m, factor)
+        s[z_global].tensorize(xi, vadd)
+        s = s.normalize()
+        dom_map = tvm.te.schedule.InferBound(s)
+        finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
+        out_dom, in_dom = finfer(s[z_global], dom_map)
+        # outer loop var will be rebased, so min value is the new loop var and 
extent is 1
+        assert tvm.ir.structural_equal(out_dom[xo].extent, 1)
+        assert isinstance(out_dom[xo].min, tvm.tir.Var)
+        assert xo.var.name == out_dom[xo].min.name
+
+        fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
+        body = fmatch(s[z_global], out_dom, in_dom, vadd)[0]
+        ana = tvm.arith.Analyzer()
+        vars = tvm.runtime.convert({xo.var: out_dom[xo].min})
+        vadd_body = tvm.tir.stmt_functor.substitute(vadd.op.body[0], vars)
+        assert tvm.ir.structural_equal(ana.simplify(body), 
ana.simplify(vadd_body))
+        stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
+        tvm.lower(s, [x, y, z])
+
+    check(128, 16)
+    check_cache_write(129, 16)
 
 
 def test_tensorize_matmul():

Reply via email to