This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 589d919c59 [Unity] Fix TVMScript Issues in Testcases (#15920)
589d919c59 is described below
commit 589d919c599b9a1a0d7d89457182ea0e8bf3e5ae
Author: Siyuan Feng <[email protected]>
AuthorDate: Thu Oct 12 20:50:51 2023 +0800
[Unity] Fix TVMScript Issues in Testcases (#15920)
* [Unity] Fix TVMScript Issues in Testcases
Due to frequent sync with upstream, some of the testcases are broken,
because of the changes in the TVMScript. This PR is to fix the broken
---
python/tvm/script/parser/tir/parser.py | 2 +-
src/script/printer/relax/function.cc | 2 +-
src/script/printer/tir/function.cc | 2 +-
.../test_tir_transform_force_narrow_index_to_i32.py | 16 ++++++++--------
tests/python/unittest/test_tvmscript_printer_tir.py | 2 +-
5 files changed, 12 insertions(+), 12 deletions(-)
diff --git a/python/tvm/script/parser/tir/parser.py
b/python/tvm/script/parser/tir/parser.py
index 33b42b3436..89673d291b 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -454,7 +454,7 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
T.evaluate(res)
elif isinstance(res, (int, bool)):
T.evaluate(tvm.tir.const(res))
- elif isinstance(res, tvm.relay.Call) and not res.args:
+ elif isinstance(res, (tvm.relay.Call, tvm.relax.Call)) and not res.args:
# Using GlobalVar.__call__ with no arguments is ambiguous, as
# each IR has a different function Call representation. If
# this occurs, convert to the TIR representation.
diff --git a/src/script/printer/relax/function.cc
b/src/script/printer/relax/function.cc
index 5fb54c793d..458eb3766d 100644
--- a/src/script/printer/relax/function.cc
+++ b/src/script/printer/relax/function.cc
@@ -50,7 +50,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
if (Optional<String> name = GetBindingName(d)) {
func_name = std::move(IdDoc(name.value()));
} else {
- func_name = std::move(d->Define(n, f(), FindFunctionName(d,
n).value_or("main")));
+ func_name = std::move(IdDoc(FindFunctionName(d, n).value_or("main")));
}
(*f)->AddDispatchToken(d, "relax");
(*f)->is_func = true;
diff --git a/src/script/printer/tir/function.cc
b/src/script/printer/tir/function.cc
index 95f30b843c..0108a50be8 100644
--- a/src/script/printer/tir/function.cc
+++ b/src/script/printer/tir/function.cc
@@ -68,7 +68,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::PrimFunc>("", [](tir::PrimFunc func, ObjectPath p,
IRDocsifier d) -> Doc {
With<TIRFrame> f(d, func);
(*f)->AddDispatchToken(d, "tir");
- IdDoc func_name = d->Define(func, f(), FindFunctionName(d,
func).value_or("main"));
+ IdDoc func_name = IdDoc(FindFunctionName(d, func).value_or("main"));
d->SetCommonPrefix(func, [](const ObjectRef& obj) {
return obj->IsInstance<tir::VarNode>() ||
obj->IsInstance<tir::BufferNode>();
});
diff --git
a/tests/python/unittest/test_tir_transform_force_narrow_index_to_i32.py
b/tests/python/unittest/test_tir_transform_force_narrow_index_to_i32.py
index 8a2a286671..c1b81853de 100644
--- a/tests/python/unittest/test_tir_transform_force_narrow_index_to_i32.py
+++ b/tests/python/unittest/test_tir_transform_force_narrow_index_to_i32.py
@@ -22,7 +22,7 @@ import tvm.testing
def test_thread_axis1():
- @T.prim_func
+ @T.prim_func(private=True)
def before(A: T.Buffer((T.int64(64),), "float32"), B:
T.Buffer((T.int64(64),), "float32")):
blockIdx_x = T.env_thread("blockIdx.x")
T.launch_thread(blockIdx_x, T.int64(2))
@@ -32,7 +32,7 @@ def test_thread_axis1():
T.Cast("int64", blockIdx_x) * T.int64(32) + T.Cast("int64",
threadIdx_x)
] + T.float32(1)
- @T.prim_func
+ @T.prim_func(private=True)
def expected(A: T.Buffer((64,), "float32"), B: T.Buffer((64,), "float32")):
blockIdx_x = T.env_thread("blockIdx.x")
T.launch_thread(blockIdx_x, 2)
@@ -161,7 +161,7 @@ def test_thread_axis2():
def test_block():
- @T.prim_func
+ @T.prim_func(private=True)
def before(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")):
for i in T.serial(0, T.int64(16)):
for j in T.serial(0, T.int64(8)):
@@ -169,7 +169,7 @@ def test_block():
vi = T.axis.spatial(T.int64(128), i * T.int64(8) + j)
B[vi] = A[vi] + T.float32(1)
- @T.prim_func
+ @T.prim_func(private=True)
def expected(A: T.Buffer((128,), "float32"), B: T.Buffer((128,),
"float32")):
for i in T.serial(0, T.int32(16)):
for j in T.serial(0, T.int32(8)):
@@ -183,7 +183,7 @@ def test_block():
def test_i16_buffer():
- @T.prim_func
+ @T.prim_func(private=True)
def before(A: T.Buffer((128,), "int16"), B: T.Buffer((128,), "int16")):
for i in T.serial(0, T.int64(16)):
for j in T.serial(0, T.int64(16)):
@@ -191,7 +191,7 @@ def test_i16_buffer():
vi = T.axis.spatial(T.int64(128), i * 8 + j)
B[vi] = A[vi] + T.int16(1)
- @T.prim_func
+ @T.prim_func(private=True)
def expected(A: T.Buffer((128,), "int16"), B: T.Buffer((128,), "int16")):
for i in T.serial(0, 16):
for j in T.serial(0, 16):
@@ -205,7 +205,7 @@ def test_i16_buffer():
def test_fail_on_buffer_map():
- @T.prim_func
+ @T.prim_func(private=True)
def func(A: T.Buffer((128,), "int64"), B: T.Buffer((128,), "int64")):
for i in T.serial(0, 16):
for j in T.serial(0, 8):
@@ -219,7 +219,7 @@ def test_fail_on_buffer_map():
def test_fail_on_buffer_map():
- @T.prim_func
+ @T.prim_func(private=True)
def func(A: T.Buffer((128,), "int32"), B: T.Buffer((128,), "int32")):
C = T.alloc_buffer((128,), "int64")
for i in T.serial(0, 16):
diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py
b/tests/python/unittest/test_tvmscript_printer_tir.py
index 70d56e6903..76281f5950 100644
--- a/tests/python/unittest/test_tvmscript_printer_tir.py
+++ b/tests/python/unittest/test_tvmscript_printer_tir.py
@@ -704,7 +704,7 @@ def test_range():
_assert_print(
obj,
"""
-T.Range(0, 10)
+I.Range(0, 10)
""",
)