This is an automated email from the ASF dual-hosted git repository.
wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4a94a94dfc fix T.Ptr[T.void] for packed api roundtrip (#12118)
4a94a94dfc is described below
commit 4a94a94dfc9ba6e265e0847228272562989072c7
Author: wrongtest <[email protected]>
AuthorDate: Fri Jul 22 12:40:01 2022 +0800
fix T.Ptr[T.void] for packed api roundtrip (#12118)
---
python/tvm/_ffi/base.py | 2 +-
python/tvm/script/tir/__init__.py | 2 +-
python/tvm/script/tir/ty.py | 8 ++++++++
src/printer/tvmscript_printer.cc | 7 ++++++-
tests/python/unittest/test_tvmscript_roundtrip.py | 9 +++++++++
5 files changed, 25 insertions(+), 3 deletions(-)
diff --git a/python/tvm/_ffi/base.py b/python/tvm/_ffi/base.py
index e4e1fb1bb8..744e4c93e1 100644
--- a/python/tvm/_ffi/base.py
+++ b/python/tvm/_ffi/base.py
@@ -255,7 +255,7 @@ def c2pyerror(err_msg):
message = []
for line in arr:
if trace_mode:
- if line.startswith(" "):
+ if line.startswith(" ") and len(stack_trace) > 0:
stack_trace[-1] += "\n" + line
elif line.startswith(" "):
stack_trace.append(line)
diff --git a/python/tvm/script/tir/__init__.py
b/python/tvm/script/tir/__init__.py
index de40459131..2655f5bb33 100644
--- a/python/tvm/script/tir/__init__.py
+++ b/python/tvm/script/tir/__init__.py
@@ -17,7 +17,7 @@
"""TVMScript for TIR"""
# Type system
-from .ty import uint8, int8, int16, int32, int64, float16, float32, float64
+from .ty import uint8, int8, int16, int32, int64, float16, float32, float64,
void
from .ty import boolean, handle, Ptr, Tuple, Buffer
from .prim_func import prim_func
diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py
index 878f029e55..a64485b215 100644
--- a/python/tvm/script/tir/ty.py
+++ b/python/tvm/script/tir/ty.py
@@ -69,6 +69,13 @@ class ConcreteType(TypeGeneric): # pylint:
disable=too-few-public-methods, abst
return self.type
+class VoidType(ConcreteType): # pylint: disable=too-few-public-methods,
abstract-method
+ """TVM script typing class for void type"""
+
+ def __init__(self):
+ super().__init__("")
+
+
class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method
"""TVM script typing class generator for PtrType
@@ -202,6 +209,7 @@ float32 = ConcreteType("float32")
float64 = ConcreteType("float64")
boolean = ConcreteType("bool")
handle = ConcreteType("handle")
+void = VoidType()
Ptr = GenericPtrType()
Tuple = GenericTupleType()
# we don't have 'buffer' type on the cpp side
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index 725e105c01..aaebc7409f 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -1236,7 +1236,12 @@ Doc TVMScriptPrinter::VisitStmt_(const WhileNode* op) {
Doc TVMScriptPrinter::VisitType_(const PrimTypeNode* node) {
Doc doc;
- doc << tir_prefix_ << "." << runtime::DLDataType2String(node->dtype);
+ doc << tir_prefix_ << ".";
+ if (node->dtype.is_void()) {
+ doc << "void";
+ } else {
+ doc << runtime::DLDataType2String(node->dtype);
+ }
return doc;
}
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py
b/tests/python/unittest/test_tvmscript_roundtrip.py
index 306f60f1b1..8e0561bb19 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3297,6 +3297,14 @@ def let_expression():
return func
+def void_ptr():
+ @T.prim_func
+ def func(out_ret_value: T.Ptr[T.void]):
+ T.evaluate(out_ret_value)
+
+ return func
+
+
ir_generator = tvm.testing.parameter(
opt_gemm_normalize,
opt_gemm_lower,
@@ -3335,6 +3343,7 @@ ir_generator = tvm.testing.parameter(
buffer_axis_separator,
buffer_ramp_access_as_slice_index,
let_expression,
+ void_ptr,
)