This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e59d1efc68 [Fix][TVMScript]TVMScript BinOP printing refactor (#14200)
e59d1efc68 is described below
commit e59d1efc68a0e1b6f1936ab09f5635386b26616e
Author: Yaxing Cai <[email protected]>
AuthorDate: Tue Mar 7 18:38:37 2023 -0800
[Fix][TVMScript]TVMScript BinOP printing refactor (#14200)
This PR fixes the output for `T.Div(int, int)`. It will print `T.Div(int,
int)`, instead of `int / int`, to avoid the integer division ambiguity in
parser.
And this PR refactors the logic of binary operators printing in TVMScript.
The updated TVMScript printer will print the binary operator to avoid constant
folding when parsing back.
---
python/tvm/script/parser/tir/operation.py | 28 ++++++------
src/script/printer/tir/expr.cc | 47 ++++++++++++-------
tests/python/unittest/test_inject_ptx_ldg32.py | 2 +-
...schedule_feature_extractor_per_store_feature.py | 2 +-
.../test_meta_schedule_schedule_rule_mlt_tc.py | 2 +-
.../unittest/test_meta_schedule_space_cuda.py | 6 +--
.../test_tir_transform_inject_virtual_thread.py | 8 ++--
.../python/unittest/test_tvmscript_printer_tir.py | 52 +++++++++++++++++++---
8 files changed, 100 insertions(+), 47 deletions(-)
diff --git a/python/tvm/script/parser/tir/operation.py
b/python/tvm/script/parser/tir/operation.py
index f0c04f47cd..3e120339a6 100644
--- a/python/tvm/script/parser/tir/operation.py
+++ b/python/tvm/script/parser/tir/operation.py
@@ -46,17 +46,17 @@ def _register_expr_op(ty: Type): # pylint:
disable=invalid-name
for i in [0, 1]:
# Case 1. binop
- r(doc.Add, i, tir.Add)
- r(doc.Sub, i, tir.Sub)
- r(doc.Mult, i, tir.Mul)
- r(doc.Div, i, tir.Div)
- r(doc.FloorDiv, i, tir.FloorDiv)
- r(doc.Mod, i, tir.FloorMod)
- r(doc.LShift, i, lambda a, b: a << b)
- r(doc.RShift, i, lambda a, b: a >> b)
- r(doc.BitOr, i, lambda a, b: a | b)
- r(doc.BitXor, i, lambda a, b: a ^ b)
- r(doc.BitAnd, i, lambda a, b: a & b)
+ # doc.Add <-- is overloaded
+ # doc.Sub <-- is overloaded
+ # doc.Mult <-- is overloaded
+ # doc.Div <-- is overloaded
+ # doc.FloorDiv <-- is overloaded
+ # doc.Mod <-- is overloaded
+ # doc.LShift <-- is overloaded
+ # doc.RShift <-- is overloaded
+ # doc.BitOr <-- is overloaded
+ # doc.BitXor <-- is overloaded
+ # doc.BitAnd <-- is overloaded
# doc.MatMult <-- not implemented
# doc.Pow <-- not implemented
# Case 2. cmpop
@@ -75,10 +75,10 @@ def _register_expr_op(ty: Type): # pylint:
disable=invalid-name
r(doc.Or, i, _or)
for i in [0]:
# Case 4. unaryop
- r(doc.Invert, i, lambda a: ~a)
+ # doc.Invert <-- is overloaded
r(doc.Not, i, tir.Not)
- r(doc.UAdd, i, lambda a: +a)
- r(doc.USub, i, lambda a: -a)
+ # doc.UAdd <-- is overloaded
+ # doc.USub <-- is overloaded
_register_expr_op(tir.PrimExpr)
diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc
index 655f69c32d..dda2c73b8e 100644
--- a/src/script/printer/tir/expr.cc
+++ b/src/script/printer/tir/expr.cc
@@ -302,32 +302,47 @@ bool IsNumber(const ExprDoc& e) {
return false;
}
-#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, OpString, OpKind)
\
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::Div>("", [](tir::Div node, ObjectPath p, IRDocsifier d)
-> Doc {
+ ExprDoc a = d->AsDoc<ExprDoc>(node->a, p->Attr("a"));
+ ExprDoc b = d->AsDoc<ExprDoc>(node->b, p->Attr("b"));
+ PrimExpr ret = tvm::div(node->a, node->b);
+ if (!ret->IsInstance<tir::DivNode>()) {
+ return TIR(d, "Div")->Call({a, b});
+ }
+ if ((node->a->dtype.is_int() || node->a->dtype.is_uint()) &&
+ (node->b->dtype.is_int() || node->b->dtype.is_uint())) {
+ return TIR(d, "Div")->Call({a, b});
+ }
+ return OperationDoc(OperationDocNode::Kind::kDiv, {a, b});
+ });
+
+#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, NodeObj, NodeFunc,
OpString, OpKind) \
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
\
.set_dispatch<tir::NodeType>("",
\
[](tir::NodeType node, ObjectPath p,
IRDocsifier d) -> Doc { \
ExprDoc a = d->AsDoc<ExprDoc>(node->a,
p->Attr("a")); \
ExprDoc b = d->AsDoc<ExprDoc>(node->b,
p->Attr("b")); \
- if (IsNumber(a) && IsNumber(b)) {
\
+ PrimExpr ret = tvm::NodeFunc(node->a,
node->b); \
+ if (!ret->IsInstance<tir::NodeObj>()) {
\
return TIR(d, OpString)->Call({a, b});
\
}
\
return
OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \
});
-TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Add, "Add", kAdd);
-TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Sub, "Sub", kSub);
-TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Mul, "Mul", kMult);
-TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Div, "Div", kDiv);
-TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorDiv, "FloorDiv", kFloorDiv);
-TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorMod, "FloorMod", kMod);
-TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LT, "LT", kLt);
-TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LE, "LE", kLtE);
-TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(EQ, "EQ", kEq);
-TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NE, "NE", kNotEq);
-TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GT, "GT", kGt);
-TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GE, "GE", kGtE);
-TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(And, "And", kAnd);
-TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Or, "Or", kOr);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Add, AddNode, add, "Add", kAdd);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Sub, SubNode, sub, "Sub", kSub);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Mul, MulNode, mul, "Mul", kMult);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorDiv, FloorDivNode, floordiv,
"FloorDiv", kFloorDiv);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorMod, FloorModNode, floormod,
"FloorMod", kMod);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LT, LTNode, less, "LT", kLt);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LE, LENode, less_equal, "LE", kLtE);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(EQ, EQNode, equal, "EQ", kEq);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NE, NENode, not_equal, "NE", kNotEq);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GT, GTNode, greater, "GT", kGt);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GE, GENode, greater_equal, "GE",
kGtE);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(And, AndNode, logical_and, "And",
kAnd);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Or, OrNode, logical_or, "Or", kOr);
TVM_SCRIPT_PRINTER_DEF_BINARY(Mod, "truncmod");
TVM_SCRIPT_PRINTER_DEF_BINARY(Min, "min");
diff --git a/tests/python/unittest/test_inject_ptx_ldg32.py
b/tests/python/unittest/test_inject_ptx_ldg32.py
index 81c6e89ad9..8e8547c572 100644
--- a/tests/python/unittest/test_inject_ptx_ldg32.py
+++ b/tests/python/unittest/test_inject_ptx_ldg32.py
@@ -32,7 +32,7 @@ def vector_add(A: T.Buffer((16), "float32"), B:
T.Buffer((32), "float32")) -> No
with T.block():
T.reads(A[0:16])
T.writes(A_local[0:32])
- A_local[tx] = T.if_then_else(tx % 2 == 0, A[tx / 2], T.float32(0),
dtype="float32")
+ A_local[tx] = T.if_then_else(tx % 2 == 0, A[tx // 2], T.float32(0),
dtype="float32")
B[tx] = A_local[tx] + 1.0
diff --git
a/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py
b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py
index 88947962d6..c62ac788d7 100644
---
a/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py
+++
b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py
@@ -70,7 +70,7 @@ class LayoutTransform:
ax4 = T.axis.spatial(512, i0_i1_i2_i3_i4_fused % 512)
T.reads(placeholder[0, (ax4 * 49 + ax2 * 7 + ax3) % 25088 //
1568, (ax2 * 7 + ax3) % 49 // 7, ax3 % 7, (ax4 * 49 + ax2 * 7 + ax3) % 1568 //
49], placeholder_1[(ax4 * 49 + ax2 * 7 + ax3) % 25088])
T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4])
- T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(ax0 <
1 and ax1 * 512 + ax4 < 512 and ax2 < 7 and ax3 < 7, T.Select(T.float32(0) <
T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 %
25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 //
7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7,
placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49
// 32, ((ax1 * 512 + ax4) * 49 + ax [...]
+ T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(ax0 <
1 and ax1 * 512 + ax4 < 512 and ax2 < 7 and ax3 < 7, T.Select(T.float32(0) <
T.if_then_else(T.LT(0, 1) and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088
% 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49
// 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7,
placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49
// 32, ((ax1 * 512 + ax4) * 49 [...]
# fmt: on
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
index 1cab2554e8..97ee53f4e4 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
@@ -851,7 +851,7 @@ def test_padded_matmul_relu():
C_reindex_shared[v0, v1, v2, v3, v4_i,
v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]
for ax0_ax1_ax3_ax4_ax5_fused in range(512):
with T.block("C_reindex_shared"):
- v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2
+ 0)
+ v0 = T.axis.spatial(4, T.Add(ax0_0_0_ax1_0_0_fused
// 2, 0))
v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 *
4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_ax3_ax4_ax5_fused % 512 // 256)
v2 = T.axis.spatial(2, ax2)
v3 = T.axis.spatial(1, 0)
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py
b/tests/python/unittest/test_meta_schedule_space_cuda.py
index bc674064d1..ef662ed5b1 100644
--- a/tests/python/unittest/test_meta_schedule_space_cuda.py
+++ b/tests/python/unittest/test_meta_schedule_space_cuda.py
@@ -315,7 +315,7 @@ def test_cuda_cap():
with T.block("PadInput_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(18,
i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused // 64 * 4 + i6_0 +
ax0_ax1_ax2_ax3_ax4_ax5_fused % 48 // 16)
- v2 = T.axis.spatial(18,
i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8 * 2 + i7_0 + 0)
+ v2 = T.axis.spatial(18,
T.Add(i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8 * 2 + i7_0, 0))
v3 = T.axis.spatial(4,
i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 8 // 4 * 2 +
ax0_ax1_ax2_ax3_ax4_ax5_fused % 16 // 8)
v4 = T.axis.spatial(4, i8_0 * 2 +
ax0_ax1_ax2_ax3_ax4_ax5_fused % 8 // 4)
v5 = T.axis.spatial(32, i9_0 * 4 +
ax0_ax1_ax2_ax3_ax4_ax5_fused % 4)
@@ -493,9 +493,9 @@ def test_cuda_dil():
for ax0_ax1_ax2_ax3_fused in T.serial(217):
with T.block("PadInput_shared"):
v0 = T.axis.spatial(1, 0)
- v1 = T.axis.spatial(230,
i0_0_i1_0_i2_0_i3_0_fused // 2 * 2 + i4_0 * 2 + 0)
+ v1 = T.axis.spatial(230,
T.Add(i0_0_i1_0_i2_0_i3_0_fused // 2 * 2 + i4_0 * 2, 0))
v2 = T.axis.spatial(230, i5_0 * 2 +
ax0_ax1_ax2_ax3_fused % 217)
- v3 = T.axis.spatial(3, i6_0 + 0)
+ v3 = T.axis.spatial(3, T.Add(i6_0, 0))
T.reads(inputs[v0, v1 - 3, v2 - 3, v3])
T.writes(PadInput_shared[v0, v1, v2, v3])
T.block_attr({"meta_schedule.cooperative_fetch":2})
diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
index d327149384..beb20fd43b 100644
--- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
+++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
@@ -182,10 +182,10 @@ def test_vthread_vectorized():
def expected_func():
B_data = T.allocate([4], "int32x4", "shared")
B = T.Buffer([4], "int32x4", data=B_data, scope="shared")
- B[T.Mul(0, 4) / 4] = T.broadcast(0, 4)
- B[T.Mul(1, 4) / 4] = T.broadcast(1, 4)
- B[T.Mul(2, 4) / 4] = T.broadcast(2, 4)
- B[T.Mul(3, 4) / 4] = T.broadcast(3, 4)
+ B[T.Div(T.Mul(0, 4), 4)] = T.broadcast(0, 4)
+ B[T.Div(T.Mul(1, 4), 4)] = T.broadcast(1, 4)
+ B[T.Div(T.Mul(2, 4), 4)] = T.broadcast(2, 4)
+ B[T.Div(T.Mul(3, 4), 4)] = T.broadcast(3, 4)
before_mod = tvm.IRModule.from_expr(before_func)
intermediate_mod = tvm.tir.transform.InjectVirtualThread()(before_mod)
diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py
b/tests/python/unittest/test_tvmscript_printer_tir.py
index e74f69dcae..87ec98e9a2 100644
--- a/tests/python/unittest/test_tvmscript_printer_tir.py
+++ b/tests/python/unittest/test_tvmscript_printer_tir.py
@@ -501,13 +501,12 @@ T.Cast("float64", a)
def test_binary_arith():
- a = tir.Var("a", "float32")
- b = tir.Var("b", "float32")
+ a = tir.Var("a", "int32")
+ b = tir.Var("b", "int32")
for op, sign in [
(tir.Add, "+"),
(tir.Sub, "-"),
(tir.Mul, "*"),
- (tir.Div, "/"),
(tir.Mod, "truncmod"),
(tir.FloorDiv, "//"),
(tir.FloorMod, "%"),
@@ -521,21 +520,60 @@ def test_binary_arith():
obj = op(a, b)
if sign.isalpha():
expected = """
-a = T.float32()
-b = T.float32()
+a = T.int32()
+b = T.int32()
T.{}(a, b)""".format(
sign
)
else:
expected = """
-a = T.float32()
-b = T.float32()
+a = T.int32()
+b = T.int32()
a {} b""".format(
sign
)
_assert_print(obj, expected)
+def test_binary_arith_const():
+ a = tir.IntImm("int64", 3)
+ b = tir.IntImm("int64", 4)
+ for op, name in [
+ (tir.Add, "Add"),
+ (tir.Sub, "Sub"),
+ (tir.Mul, "Mul"),
+ (tir.Div, "Div"),
+ (tir.Mod, "truncmod"),
+ (tir.FloorDiv, "FloorDiv"),
+ (tir.FloorMod, "FloorMod"),
+ (tir.LT, "LT"),
+ (tir.LE, "LE"),
+ (tir.EQ, "EQ"),
+ (tir.NE, "NE"),
+ (tir.GT, "GT"),
+ (tir.GE, "GE"),
+ ]:
+ obj = op(a, b)
+ expected = """
+T.{}({}, {})""".format(
+ name, str(a), str(b)
+ )
+ _assert_print(obj, expected)
+
+
+def test_int_div():
+ a = tir.Var("a", "int32")
+ b = tir.Var("b", "int32")
+ _assert_print(
+ tir.Div(a, b),
+ """
+a = T.int32()
+b = T.int32()
+T.Div(a, b)
+""",
+ )
+
+
def test_logical():
a = tir.Var("a", "bool")
b = tir.Var("b", "bool")