This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 37053e1 [Tensorize] Support conds depend on outer loop vars inside
tensorize scope (#7497)
37053e1 is described below
commit 37053e1708c6565c8a82c31c0ffc78e594bfe3b0
Author: lee <[email protected]>
AuthorDate: Thu Mar 4 01:30:31 2021 +0800
[Tensorize] Support conds depend on outer loop vars inside tensorize scope
(#7497)
* [Tensorize] Support conds depend on outer loop vars inside tensorize scope
* Reformat
---
src/te/operation/op_utils.cc | 8 +++
src/te/operation/op_utils.h | 10 +++-
src/te/operation/tensorize.cc | 6 ++-
.../python/unittest/test_te_schedule_tensorize.py | 57 ++++++++++++++++++----
4 files changed, 69 insertions(+), 12 deletions(-)
diff --git a/src/te/operation/op_utils.cc b/src/te/operation/op_utils.cc
index 32ffccb..b3897e1 100644
--- a/src/te/operation/op_utils.cc
+++ b/src/te/operation/op_utils.cc
@@ -243,6 +243,14 @@ Stmt Substitute(Stmt s, const std::unordered_map<IterVar,
PrimExpr>& value_map)
return tir::Substitute(s, init);
}
+PrimExpr Substitute(PrimExpr s, const std::unordered_map<IterVar, PrimExpr>&
value_map) {
+ std::unordered_map<const VarNode*, PrimExpr> init;
+ for (const auto& kv : value_map) {
+ init[kv.first->var.get()] = kv.second;
+ }
+ return tir::Substitute(s, init);
+}
+
IterVarType ForKindToIterVarType(tir::ForKind kind) {
switch (kind) {
case ForKind::kSerial:
diff --git a/src/te/operation/op_utils.h b/src/te/operation/op_utils.h
index e6bf2ca..02f4a86 100644
--- a/src/te/operation/op_utils.h
+++ b/src/te/operation/op_utils.h
@@ -73,7 +73,7 @@ std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>&
predicates);
*/
Stmt ReplaceTensor(Stmt stmt, const std::unordered_map<Tensor, Tensor>&
replace);
/*!
- * \brief Replace the tensor reference (especially in Call's) in stmt by the
replace map.
+ * \brief Replace the tensor reference (especially in Call's) in primExpr by
the replace map.
* \param expr The expression to be processed.
* \param replace The replacement rule.
*/
@@ -88,6 +88,14 @@ PrimExpr ReplaceTensor(PrimExpr expr, const
std::unordered_map<Tensor, Tensor>&
Stmt Substitute(Stmt stmt, const std::unordered_map<IterVar, PrimExpr>&
value_map);
/*!
+ * \brief Substitute the variables of primExpr by value map.
+ * \param expr the expression to be processed.
+ * \param value_map The value map.
+ * \return Substituted result.
+ */
+PrimExpr Substitute(PrimExpr expr, const std::unordered_map<IterVar,
PrimExpr>& value_map);
+
+/*!
* \brief Converts Halide ForKind to its corresponding IterVarType
* \param kind The ForKind to be converted
*/
diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc
index bfd1ec5..ea71322 100644
--- a/src/te/operation/tensorize.cc
+++ b/src/te/operation/tensorize.cc
@@ -311,6 +311,7 @@ Array<PrimExpr> MatchTensorizeBody(const ComputeOpNode*
self, const Stage& stage
}
void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage,
+ const std::unordered_map<IterVar, PrimExpr>&
value_map,
const std::unordered_map<IterVar, Range>& dom_map,
const std::unordered_map<IterVar, Range>& out_dom,
const std::unordered_map<Tensor, Array<Range> >&
in_region,
@@ -327,7 +328,8 @@ void VerifyTensorizeBody(const ComputeOpNode* self, const
Stage& stage,
for (size_t i = 0; i < body.size(); ++i) {
PrimExpr lhs = ana.Simplify(body[i]);
- PrimExpr rhs = ana.Simplify(intrin_compute->body[i]);
+ // run substitution because the intrin body could depend on outer loop
vars.
+ PrimExpr rhs = ana.Simplify(Substitute(intrin_compute->body[i],
value_map));
if (lhs.dtype() != rhs.dtype()) {
LOG(FATAL) << "Failed to match the data type with TensorIntrin " <<
intrin->name
<< "'s declaration "
@@ -349,7 +351,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage&
stage,
ICHECK(intrin.defined());
ComputeLoopNest n = ComputeLoopNest::Create(self, stage, dom_map,
debug_keep_trivial_loop);
VerifyTensorizeLoopNest(self, stage, n, tloc);
- VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin);
+ VerifyTensorizeBody(self, stage, n.main_vmap, dom_map, out_dom, in_region,
intrin);
// Start bind data.
Stmt nop = Evaluate(0);
std::vector<Stmt> input_bind_nest, output_bind_nest;
diff --git a/tests/python/unittest/test_te_schedule_tensorize.py
b/tests/python/unittest/test_te_schedule_tensorize.py
index 83a5d30..fdafdb7 100644
--- a/tests/python/unittest/test_te_schedule_tensorize.py
+++ b/tests/python/unittest/test_te_schedule_tensorize.py
@@ -18,14 +18,22 @@ import tvm
from tvm import te
-def intrin_vadd(n):
+def intrin_vadd(xo, m, n):
x = te.placeholder((n,), name="vx")
y = te.placeholder((n,), name="vy")
- z = te.compute(x.shape, lambda i: x[i] + y[i], name="z")
+ if m % n == 0:
+ body = lambda i: x[i] + y[i]
+ else:
+ body = lambda i: tvm.tir.Select(
+ xo * n + i < m, x[i] + y[i], tvm.tir.const(0, dtype=x.dtype)
+ )
+ z = te.compute(x.shape, body, name="z")
def intrin_func(ins, outs):
xx, yy = ins
zz = outs[0]
+ # special handle needed to tackle tail loop part when m % n != 0
+ # here is tvm.min(n, m - xo * n)
return tvm.tir.call_packed("vadd", xx, yy, zz)
buffer_params = {"offset_factor": 16}
@@ -84,15 +92,17 @@ def intrin_gemv_no_reset(m, n):
def test_tensorize_vadd():
- m = 128
- x = te.placeholder((m,), name="x")
- y = te.placeholder((m,), name="y")
- z = te.compute(x.shape, lambda i: x[i] + y[i], name="z")
+ def add(m):
+ x = te.placeholder((m,), name="x")
+ y = te.placeholder((m,), name="y")
+ z = te.compute(x.shape, lambda i: x[i] + y[i], name="z")
+ return x, y, z
- def check(factor):
+ def check(m, factor):
+ x, y, z = add(m)
s = te.create_schedule(z.op)
xo, xi = s[z].split(z.op.axis[0], factor=factor)
- vadd = intrin_vadd(factor)
+ vadd = intrin_vadd(xo, m, factor)
s[z].tensorize(xi, vadd)
s = s.normalize()
dom_map = tvm.te.schedule.InferBound(s)
@@ -108,7 +118,36 @@ def test_tensorize_vadd():
stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [x, y, z])
- check(16)
+ def check_cache_write(m, factor):
+ x, y, z = add(m)
+ s = te.create_schedule(z.op)
+ _, _ = s[z].split(z.op.axis[0], factor=factor)
+
+ z_global = s.cache_write(z, "global")
+ xo, xi = z_global.op.axis
+
+ vadd = intrin_vadd(xo, m, factor)
+ s[z_global].tensorize(xi, vadd)
+ s = s.normalize()
+ dom_map = tvm.te.schedule.InferBound(s)
+ finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
+ out_dom, in_dom = finfer(s[z_global], dom_map)
+ # outer loop var will be rebased, so min value is the new loop var and
extent is 1
+ assert tvm.ir.structural_equal(out_dom[xo].extent, 1)
+ assert isinstance(out_dom[xo].min, tvm.tir.Var)
+ assert xo.var.name == out_dom[xo].min.name
+
+ fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
+ body = fmatch(s[z_global], out_dom, in_dom, vadd)[0]
+ ana = tvm.arith.Analyzer()
+ vars = tvm.runtime.convert({xo.var: out_dom[xo].min})
+ vadd_body = tvm.tir.stmt_functor.substitute(vadd.op.body[0], vars)
+ assert tvm.ir.structural_equal(ana.simplify(body),
ana.simplify(vadd_body))
+ stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
+ tvm.lower(s, [x, y, z])
+
+ check(128, 16)
+ check_cache_write(129, 16)
def test_tensorize_matmul():