This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dc626f33e3 [TVMScript] Unify `T.handle` and `T.Ptr` (#13969)
dc626f33e3 is described below
commit dc626f33e3489aa5f9e92f5ae5cbbcf2f71247ef
Author: Junru Shao <[email protected]>
AuthorDate: Sun Feb 12 17:13:32 2023 -0800
[TVMScript] Unify `T.handle` and `T.Ptr` (#13969)
While both represents a pointer type, `T.handle` was previously used to
refer to tir variables whose `type_annotation` is `PrimType`, while
`T.Ptr` instead specifically refers to `PointerType`. The divide is
unnecessary if we extend `T.handle` slightly.
---
include/tvm/script/ir_builder/tir/ir.h | 5 +--
python/tvm/script/ir_builder/tir/ir.py | 13 +++---
python/tvm/script/parser/tir/entry.py | 5 ++-
src/script/ir_builder/tir/ir.cc | 10 +++++
src/script/printer/tir/ir.cc | 6 +--
.../relay/aot/test_aot_create_executor_metadata.py | 2 +-
tests/python/relay/aot/test_pass_aot_lower_main.py | 4 +-
...sform_plan_update_buffer_allocation_location.py | 2 +-
.../unittest/test_tir_transform_storage_flatten.py | 2 +-
.../unittest/test_tir_transform_storage_rewrite.py | 4 +-
...ransform_convert_pool_allocations_to_offsets.py | 36 ++++++++---------
tests/python/unittest/test_tvmscript_parser_tir.py | 4 +-
.../python/unittest/test_tvmscript_printer_tir.py | 2 +-
tests/python/unittest/test_tvmscript_roundtrip.py | 46 +++++++++++-----------
14 files changed, 77 insertions(+), 64 deletions(-)
diff --git a/include/tvm/script/ir_builder/tir/ir.h
b/include/tvm/script/ir_builder/tir/ir.h
index 5cba879205..d5cc1de5c6 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -415,12 +415,12 @@ void Prefetch(Buffer buffer, Array<Range> bounds);
void Evaluate(PrimExpr value);
/*!
- * \brief The pointer declaration function.
+ * \brief Create a TIR var that represents a pointer
* \param dtype The data type of the pointer.
* \param storage_scope The storage scope of the pointer.
* \return The pointer.
*/
-PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global");
+Var Handle(runtime::DataType dtype = runtime::DataType::Void(), String
storage_scope = "global");
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType)
\
inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt) {
\
@@ -455,7 +455,6 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float,
DataType::Float);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
#undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index fdb27df2a9..25d16b56dc 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1358,20 +1358,23 @@ def boolean(expr: Optional[PrimExpr] = None) ->
PrimExpr:
return _ffi_api.Boolean(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
-def handle(expr: Optional[PrimExpr] = None) -> PrimExpr:
- """Construct a new tir.Var with type handle or cast expression to type
handle.
+def handle(dtype: str = "void", storage_scope: str = "global") -> Var:
+ """Create a TIR var that represents a pointer.
Parameters
----------
- expr: PrimExpr
- The expression to be cast.
+ dtype: str
+ The data type of the pointer.
+
+ storage_scope: str
+ The storage scope of the pointer.
Returns
-------
res : PrimExpr
The new tir.Var with type handle or casted expression with type handle.
"""
- return _ffi_api.Handle(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
+ return _ffi_api.Handle(dtype, storage_scope) # type: ignore[attr-defined]
# pylint: disable=no-member
def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
diff --git a/python/tvm/script/parser/tir/entry.py
b/python/tvm/script/parser/tir/entry.py
index bacf92c142..51743e6b50 100644
--- a/python/tvm/script/parser/tir/entry.py
+++ b/python/tvm/script/parser/tir/entry.py
@@ -79,7 +79,7 @@ class BufferProxy:
axis_separators=axis_separators,
)
- @deprecated("T.Buffer(...)", "T.Buffer(...)")
+ @deprecated("T.Buffer[...]", "T.Buffer(...)")
def __getitem__(self, keys) -> Buffer:
if not isinstance(keys, tuple):
return self(keys)
@@ -93,12 +93,13 @@ class PtrProxy:
Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr().
"""
+ @deprecated("T.Ptr(...)", "T.handle(...)")
def __call__(self, dtype, storage_scope="global"):
if callable(dtype):
dtype = dtype().dtype
return ptr(dtype, storage_scope) # pylint: disable=no-member # type:
ignore
- @deprecated("T.Ptr(...)", "T.Ptr(...)")
+ @deprecated("T.Ptr[...]", "T.handle(...)")
def __getitem__(self, keys):
if not isinstance(keys, tuple):
return self(keys)
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index c586e81f1b..9ab19b2e28 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -545,6 +545,16 @@ PrimExpr Ptr(runtime::DataType dtype, String
storage_scope) {
return tvm::tir::Var("", tvm::PointerType(PrimType(dtype), storage_scope));
}
+Var Handle(runtime::DataType dtype, String storage_scope) {
+ Type type_annotation{nullptr};
+ if (dtype.is_void() && storage_scope == "global") {
+ type_annotation = PrimType(runtime::DataType::Handle());
+ } else {
+ type_annotation = PointerType(PrimType(dtype), storage_scope);
+ }
+ return tvm::tir::Var("", type_annotation);
+}
+
using tvm::script::ir_builder::details::Namer;
TVM_STATIC_IR_FUNCTOR(Namer, vtable)
diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc
index ce10ff6816..78e50a5eb5 100644
--- a/src/script/printer/tir/ir.cc
+++ b/src/script/printer/tir/ir.cc
@@ -73,10 +73,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
element_type = d->AsDoc<ExprDoc>(ty->element_type,
ty_p->Attr("element_type"));
}
if (ty->storage_scope == "") {
- return TIR(d, "Ptr")->Call({element_type});
+ return TIR(d, "handle")->Call({element_type});
} else {
- return TIR(d, "Ptr")->Call(
- {element_type, LiteralDoc::Str(ty->storage_scope,
ty_p->Attr("storage_scope"))});
+ return TIR(d, "handle")
+ ->Call({element_type, LiteralDoc::Str(ty->storage_scope,
ty_p->Attr("storage_scope"))});
}
});
diff --git a/tests/python/relay/aot/test_aot_create_executor_metadata.py
b/tests/python/relay/aot/test_aot_create_executor_metadata.py
index 1bc79fe2a6..804738a786 100644
--- a/tests/python/relay/aot/test_aot_create_executor_metadata.py
+++ b/tests/python/relay/aot/test_aot_create_executor_metadata.py
@@ -53,7 +53,7 @@ def test_create_executor_metadata_single_func():
class Module:
@T.prim_func
def __tvm_main__(
- a: T.handle, output: T.handle, workspace: T.Ptr(T.uint8),
constants: T.Ptr(T.uint8)
+ a: T.handle, output: T.handle, workspace: T.handle("uint8"),
constants: T.handle("uint8")
) -> None:
# function attr dict
T.func_attr({"global_symbol": "test_mod___tvm_main__",
"runner_function": True, "target": T.target({"kind": "llvm", "tag": "", "keys":
["cpu"]}), "input_vars": [a], "output_vars": [output], "devices":
["test_device"]})
diff --git a/tests/python/relay/aot/test_pass_aot_lower_main.py
b/tests/python/relay/aot/test_pass_aot_lower_main.py
index f2455e97a0..bc58812cd6 100644
--- a/tests/python/relay/aot/test_pass_aot_lower_main.py
+++ b/tests/python/relay/aot/test_pass_aot_lower_main.py
@@ -178,13 +178,13 @@ def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7),
float32] {
def func(a: T.handle, output: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "test_mod___tvm_main__",
"runner_function": True, "target": T.target({"kind":"llvm", "tag":"",
"keys":["cpu"]}), "input_vars": [a], "output_vars": [output], "devices": []})
- tmp_read = T.Ptr("uint8", "")
+ tmp_read = T.handle("uint8", "")
# buffer definition
tmp_read_1 = T.Buffer([T.uint64(140)], dtype="uint8", data=tmp_read)
a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16)
output_buffer = T.match_buffer(output, [5, 7], dtype="float32",
align=16)
# body
- tmp_write: T.Ptr(T.uint8) = output_buffer.data
+ tmp_write: T.handle("uint8") = output_buffer.data
tmp_write_1 = T.Buffer([T.uint64(140)], dtype="uint8", data=tmp_write)
for i in T.serial(140):
tmp_write_1[i] = T.let(tmp_read, a_buffer.data, tmp_read_1[i])
diff --git
a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
index 05d71de5bc..758a395da6 100644
---
a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
+++
b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
@@ -424,7 +424,7 @@ def test_buffer_conditional_lowering():
"""
@T.prim_func
- def before(A: T.Ptr("float32")):
+ def before(A: T.handle("float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i in range(1):
A_1 = T.Buffer((1,), data=A)
diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py
b/tests/python/unittest/test_tir_transform_storage_flatten.py
index 29623b498f..39009164e7 100644
--- a/tests/python/unittest/test_tir_transform_storage_flatten.py
+++ b/tests/python/unittest/test_tir_transform_storage_flatten.py
@@ -139,7 +139,7 @@ def test_flatten_let_buffer():
T.func_attr({"from_legacy_te_schedule": True})
# If a pointer defined using a LetStmt,
- A_data: T.Ptr("int32") = T.call_extern("dummy_extern_function",
dtype="handle")
+ A_data: T.handle("int32") = T.call_extern("dummy_extern_function",
dtype="handle")
# and a buffer is backed by that pointer,
A = T.decl_buffer([1], dtype="float32", data=A_data)
diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py
b/tests/python/unittest/test_tir_transform_storage_rewrite.py
index 4766022121..c46754fb17 100644
--- a/tests/python/unittest/test_tir_transform_storage_rewrite.py
+++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py
@@ -689,12 +689,12 @@ class TestLetBufferRewrite(BaseCompare):
"""
def before() -> None:
- A_data: T.Ptr("int32") = T.call_extern("dummy_func", dtype="handle")
+ A_data: T.handle("int32") = T.call_extern("dummy_func", dtype="handle")
A = T.Buffer([8], "int32", data=A_data)
A[0:8] = T.broadcast(42, 8)
def expected() -> None:
- A_data: T.Ptr("int32x8") = T.call_extern("dummy_func", dtype="handle")
+ A_data: T.handle("int32x8") = T.call_extern("dummy_func",
dtype="handle")
A = T.Buffer([1], "int32x8", data=A_data)
A[0] = T.broadcast(42, 8)
diff --git
a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
index 5bbedd3492..58f37f0496 100644
---
a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
+++
b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
@@ -144,20 +144,20 @@ class LinearStructure:
@tvm.script.ir_module
class LinearStructurePlanned:
@T.prim_func
- def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr("uint8"),
slow_memory_1_var: T.Ptr("uint8"), output: T.handle) -> None:
+ def __tvm_main__(input: T.handle, fast_memory_0_var: T.handle("uint8"),
slow_memory_1_var: T.handle("uint8"), output: T.handle) -> None:
fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704],
dtype="uint8", strides=[1], elem_offset=0, align=16)
slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var,
[1418528], dtype="uint8", strides=[1], elem_offset=0, align=16)
# body
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
- sid_9_let: T.Ptr("int8") =
T.address_of(slow_memory_1_buffer_var[1117472], dtype="handle")
- sid_8_let: T.Ptr("int8") = T.address_of(slow_memory_1_buffer_var[0],
dtype="handle")
+ sid_9_let: T.handle("int8") =
T.address_of(slow_memory_1_buffer_var[1117472], dtype="handle")
+ sid_8_let: T.handle("int8") =
T.address_of(slow_memory_1_buffer_var[0], dtype="handle")
T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input,
T.lookup_param("p0", dtype="handle"), sid_9_let, fast_memory_0_buffer_var.data,
slow_memory_1_buffer_var.data, dtype="int32"))
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast",
sid_9_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2",
dtype="handle"), sid_8_let, fast_memory_0_buffer_var.data,
slow_memory_1_buffer_var.data, dtype="int32"))
T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast",
sid_8_let, output, fast_memory_0_buffer_var.data,
slow_memory_1_buffer_var.data, dtype="int32"))
@T.prim_func
- def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle,
T_cast_6: T.handle, fast_memory_6_var: T.Ptr("uint8"), slow_memory_7_var:
T.Ptr("uint8")) -> None:
+ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle,
T_cast_6: T.handle, fast_memory_6_var: T.handle("uint8"), slow_memory_7_var:
T.handle("uint8")) -> None:
placeholder_29 = T.match_buffer(placeholder_28, [802816],
dtype="uint8")
T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16")
fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704],
dtype="uint8", strides=[1], elem_offset=0, align=16)
@@ -174,7 +174,7 @@ class LinearStructurePlanned:
T_cast_7[ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3] =
T.cast(tensor_2_let[ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3], "int16")
@T.prim_func
- def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle,
placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var:
T.Ptr("uint8"), slow_memory_3_var: T.Ptr("uint8")) -> None:
+ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle,
placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var:
T.handle("uint8"), slow_memory_3_var: T.handle("uint8")) -> None:
placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8")
placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16")
T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16")
@@ -185,7 +185,7 @@ class LinearStructurePlanned:
T_subtract_1[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1] =
T.cast(placeholder_4[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1], "int16")
- placeholder_5[0]
@T.prim_func
- def
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62:
T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20:
T.handle, fast_memory_4_var: T.Ptr("uint8"), slow_memory_5_var: T.Ptr("uint8"))
-> None:
+ def
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62:
T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20:
T.handle, fast_memory_4_var: T.handle("uint8"), slow_memory_5_var:
T.handle("uint8")) -> None:
placeholder_65 = T.match_buffer(placeholder_62, [150528],
dtype="int16")
placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16")
placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32")
@@ -380,7 +380,7 @@ class ResnetStructure:
@tvm.script.ir_module
class ResnetStructurePlanned:
@T.prim_func
- def
tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder:
T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var:
T.Ptr("uint8")) -> None:
+ def
tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder:
T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var:
T.handle("uint8")) -> None:
placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8")
placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32")
T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16")
@@ -390,7 +390,7 @@ class ResnetStructurePlanned:
T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 +
ax3_inner] =
T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused
* 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31,
1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0),
"uint8"), "int16")
@T.prim_func
- def
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22:
T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25:
T.handle, T_cast_6: T.handle, global_workspace_5_var: T.Ptr("uint8")) -> None:
+ def
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22:
T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25:
T.handle, T_cast_6: T.handle, global_workspace_5_var: T.handle("uint8")) ->
None:
placeholder_29 = T.match_buffer(placeholder_22, [360000],
dtype="int16")
placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16")
placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32")
@@ -414,7 +414,7 @@ class ResnetStructurePlanned:
T_cast_7[ax0_ax1_fused_ax2_fused_3 * 256 +
ax3_outer_2 * 64 + ax3_inner_4] =
T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_3_let[ax3_inner_4]
+ placeholder_26[ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8,
dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1,
dtype="int32") + placeholder_28[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 *
64 + ax3_inner_4], 255), 0), "uint8")
@T.prim_func
- def
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16:
T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle,
global_workspace_4_var: T.Ptr("uint8")) -> None:
+ def
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16:
T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle,
global_workspace_4_var: T.handle("uint8")) -> None:
placeholder_19 = T.match_buffer(placeholder_16, [360000],
dtype="int16")
placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16")
placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32")
@@ -437,7 +437,7 @@ class ResnetStructurePlanned:
T_add_1[ax0_ax1_fused_ax2_fused_2 * 256 +
ax3_outer_1 * 64 + ax3_inner_3] =
T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_2_let[ax3_inner_3]
+ placeholder_21[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8,
dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2,
dtype="int32") + 136
@T.prim_func
- def
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4:
T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2:
T.handle, global_workspace_2_var: T.Ptr("uint8")) -> None:
+ def
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4:
T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2:
T.handle, global_workspace_2_var: T.handle("uint8")) -> None:
placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16")
placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16")
placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32")
@@ -459,7 +459,7 @@ class ResnetStructurePlanned:
T_cast_3[ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1] =
T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_let[ax3_inner_1] +
placeholder_9[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0),
"uint8"), "int16")
@T.prim_func
- def
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10:
T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4:
T.handle, global_workspace_3_var: T.Ptr("uint8")) -> None:
+ def
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10:
T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4:
T.handle, global_workspace_3_var: T.handle("uint8")) -> None:
placeholder_13 = T.match_buffer(placeholder_10, [360000],
dtype="int16")
placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16")
placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32")
@@ -481,15 +481,15 @@ class ResnetStructurePlanned:
T_cast_5[ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2]
= T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_1_let[ax3_inner_2]
+ placeholder_15[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0),
"uint8"), "int16")
@T.prim_func
- def __tvm_main__(input: T.handle, global_workspace_0_var: T.Ptr("uint8"),
output: T.handle) -> None:
+ def __tvm_main__(input: T.handle, global_workspace_0_var:
T.handle("uint8"), output: T.handle) -> None:
global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var,
[7920256], dtype="uint8", strides=[1], elem_offset=0, align=16)
# body
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
- sid_2_let: T.Ptr("int8") =
T.address_of(global_workspace_0_buffer_var[5760000], dtype="handle")
- sid_6_let: T.Ptr("int8") =
T.address_of(global_workspace_0_buffer_var[0], dtype="handle")
- sid_7_let: T.Ptr("int8") =
T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle")
- sid_8_let: T.Ptr("int8") =
T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle")
+ sid_2_let: T.handle("int8") =
T.address_of(global_workspace_0_buffer_var[5760000], dtype="handle")
+ sid_6_let: T.handle("int8") =
T.address_of(global_workspace_0_buffer_var[0], dtype="handle")
+ sid_7_let: T.handle("int8") =
T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle")
+ sid_8_let: T.handle("int8") =
T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle")
T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast",
input, T.lookup_param("p0", dtype="handle"), sid_2_let,
global_workspace_0_buffer_var.data, dtype="int32"))
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast",
sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4",
dtype="handle"), sid_8_let, global_workspace_0_buffer_var.data, dtype="int32"))
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1",
sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6",
dtype="handle"), sid_7_let, global_workspace_0_buffer_var.data, dtype="int32"))
@@ -557,7 +557,7 @@ class TensorIntrinStructure:
@tvm.script.ir_module
class TensorIntrinStructurePlanned:
@T.prim_func
- def tensor_intrin_primfunc(global_workspace_1_var: T.Ptr("uint8")) -> None:
+ def tensor_intrin_primfunc(global_workspace_1_var: T.handle("uint8")) ->
None:
global_workspace_1_buffer_var = T.match_buffer(
global_workspace_1_var, [40], dtype="uint8", strides=[1],
elem_offset=0, align=16
)
@@ -576,7 +576,7 @@ class TensorIntrinStructurePlanned:
@T.prim_func
def __tvm_main__(
- input: T.handle, global_workspace_1_var: T.Ptr("uint8"), output:
T.handle
+ input: T.handle, global_workspace_1_var: T.handle("uint8"), output:
T.handle
) -> None:
global_workspace_1_buffer_var = T.match_buffer(
global_workspace_1_var, [40], dtype="uint8", strides=[1],
elem_offset=0, align=16
diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py
b/tests/python/unittest/test_tvmscript_parser_tir.py
index e96ae4da8c..20be6d1498 100644
--- a/tests/python/unittest/test_tvmscript_parser_tir.py
+++ b/tests/python/unittest/test_tvmscript_parser_tir.py
@@ -40,7 +40,7 @@ def test_tir_buffer_proxy():
def test_tir_ptr_proxy():
- ptr_0 = T.Ptr("int32", "global")
+ ptr_0 = T.handle("int32", "global")
assert (
isinstance(ptr_0, tir.Var)
and ptr_0.dtype == "handle"
@@ -49,7 +49,7 @@ def test_tir_ptr_proxy():
and ptr_0.type_annotation.storage_scope == "global"
)
- ptr_1 = T.Ptr("float32", "shared")
+ ptr_1 = T.handle("float32", "shared")
assert (
isinstance(ptr_1, tir.Var)
and ptr_1.dtype == "handle"
diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py
b/tests/python/unittest/test_tvmscript_printer_tir.py
index 6f96b3a3dd..a04544152e 100644
--- a/tests/python/unittest/test_tvmscript_printer_tir.py
+++ b/tests/python/unittest/test_tvmscript_printer_tir.py
@@ -674,7 +674,7 @@ def test_prim_type():
def test_pointer_type():
obj = ir.PointerType(ir.PrimType("int32"), "global")
- _assert_print(obj, 'T.Ptr("int32", "global")')
+ _assert_print(obj, 'T.handle("int32", "global")')
def test_tuple_type():
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py
b/tests/python/unittest/test_tvmscript_roundtrip.py
index 1ec8f49b4b..db21223366 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -204,30 +204,30 @@ def opt_gemm_mod_host():
arg2: T.handle = T.tvm_struct_get(args, 2, 12, dtype="handle")
arg2_code: T.int32 = buf_type_ids[2]
- A_data: T.Ptr("int32") = T.tvm_struct_get(arg0, 0, 1,
dtype="handle")
+ A_data: T.handle("int32") = T.tvm_struct_get(arg0, 0, 1,
dtype="handle")
T.attr(A_data, "storage_alignment", 128)
A = T.Buffer([1024 * 1024], dtype="int32", data=A_data)
- buf0_shape_data: T.Ptr("int32") = T.tvm_struct_get(arg0, 0, 2,
dtype="handle")
+ buf0_shape_data: T.handle("int32") = T.tvm_struct_get(arg0, 0, 2,
dtype="handle")
buf0_shape = T.Buffer([2], dtype="int32", data=buf0_shape_data)
- buf0_strides_data: T.Ptr("int32") = T.tvm_struct_get(arg0, 0, 3,
dtype="handle")
+ buf0_strides_data: T.handle("int32") = T.tvm_struct_get(arg0, 0,
3, dtype="handle")
buf0_strides = T.Buffer([2], dtype="int32", data=buf0_strides_data)
dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32")
- B_data: T.Ptr("int32") = T.tvm_struct_get(arg1, 0, 1,
dtype="handle")
+ B_data: T.handle("int32") = T.tvm_struct_get(arg1, 0, 1,
dtype="handle")
T.attr(B_data, "storage_alignment", 128)
B = T.Buffer([1024 * 1024], dtype="int32", data=B_data)
- buf1_shape_data: T.Ptr("int32") = T.tvm_struct_get(arg1, 0, 2,
dtype="handle")
+ buf1_shape_data: T.handle("int32") = T.tvm_struct_get(arg1, 0, 2,
dtype="handle")
buf1_shape = T.Buffer([2], dtype="int32", data=buf1_shape_data)
- buf1_strides_data: T.Ptr("int32") = T.tvm_struct_get(arg1, 0, 3,
dtype="handle")
+ buf1_strides_data: T.handle("int32") = T.tvm_struct_get(arg1, 0,
3, dtype="handle")
buf1_strides = T.Buffer([2], dtype="int32", data=buf1_strides_data)
- C_data: T.Ptr("int32") = T.tvm_struct_get(arg2, 0, 1,
dtype="handle")
+ C_data: T.handle("int32") = T.tvm_struct_get(arg2, 0, 1,
dtype="handle")
T.attr(C_data, "storage_alignment", 128)
C = T.Buffer([1024 * 1024], dtype="int32", data=C_data)
- buf2_shape_data: T.Ptr("int32") = T.tvm_struct_get(arg2, 0, 2,
dtype="handle")
+ buf2_shape_data: T.handle("int32") = T.tvm_struct_get(arg2, 0, 2,
dtype="handle")
buf2_shape = T.Buffer([2], dtype="int32", data=buf2_shape_data)
- buf2_strides_data: T.Ptr("int32") = T.tvm_struct_get(arg2, 0, 3,
dtype="handle")
+ buf2_strides_data: T.handle("int32") = T.tvm_struct_get(arg2, 0,
3, dtype="handle")
buf2_strides = T.Buffer([2], dtype="int32", data=buf2_strides_data)
assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code ==
7)) or (
@@ -2238,7 +2238,7 @@ def opt_conv_tensorcore_mod_host():
}
)
# body
- stack_tcode_data: T.Ptr("int32") = T.tvm_stack_alloca("arg_tcode", 10,
dtype="handle")
+ stack_tcode_data: T.handle("int32") = T.tvm_stack_alloca("arg_tcode",
10, dtype="handle")
stack_tcode = T.Buffer([9], "int32", data=stack_tcode_data)
stack_value: T.handle = T.tvm_stack_alloca("arg_value", 10,
dtype="handle")
assert num_args == 3, "default_function: num_args should be 3"
@@ -2251,25 +2251,25 @@ def opt_conv_tensorcore_mod_host():
A: T.handle = T.tvm_struct_get(arg0, 0, 1, dtype="handle")
T.attr(A, "storage_alignment", 128)
- arg0_shape_data: T.Ptr("int64") = T.tvm_struct_get(arg0, 0, 2,
dtype="handle")
+ arg0_shape_data: T.handle("int64") = T.tvm_struct_get(arg0, 0, 2,
dtype="handle")
arg0_shape = T.Buffer([6], "int64", data=arg0_shape_data)
- arg0_strides_data: T.Ptr("int64") = T.tvm_struct_get(arg0, 0, 3,
dtype="handle")
+ arg0_strides_data: T.handle("int64") = T.tvm_struct_get(arg0, 0, 3,
dtype="handle")
arg0_strides = T.Buffer([6], "int64", data=arg0_strides_data)
dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32")
W: T.handle = T.tvm_struct_get(arg1, 0, 1, dtype="handle")
T.attr(W, "storage_alignment", 128)
- arg1_shape_data: T.Ptr("int64") = T.tvm_struct_get(arg1, 0, 2,
dtype="handle")
+ arg1_shape_data: T.handle("int64") = T.tvm_struct_get(arg1, 0, 2,
dtype="handle")
arg1_shape = T.Buffer([6], "int64", data=arg1_shape_data)
- arg1_strides_data: T.Ptr("int64") = T.tvm_struct_get(arg1, 0, 3,
dtype="handle")
+ arg1_strides_data: T.handle("int64") = T.tvm_struct_get(arg1, 0, 3,
dtype="handle")
arg1_strides = T.Buffer([6], "int64", data=arg1_strides_data)
Conv: T.handle = T.tvm_struct_get(arg2, 0, 1, dtype="handle")
T.attr(Conv, "storage_alignment", 128)
- arg2_shape_data: T.Ptr("int64") = T.tvm_struct_get(arg2, 0, 2,
dtype="handle")
+ arg2_shape_data: T.handle("int64") = T.tvm_struct_get(arg2, 0, 2,
dtype="handle")
arg2_shape = T.Buffer([6], "int64", data=arg2_shape_data)
- arg2_strides_data: T.Ptr("int64") = T.tvm_struct_get(arg2, 0, 3,
dtype="handle")
+ arg2_strides_data: T.handle("int64") = T.tvm_struct_get(arg2, 0, 3,
dtype="handle")
arg2_strides = T.Buffer([6], "int64", data=arg2_strides_data)
assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7))
or (
@@ -3145,7 +3145,7 @@ def func_nested_root_block():
def func_T_ptr_let_statement():
@T.prim_func
def func_T_ptr_let_statement(
- args: T.handle, arg_type_ids_handle: T.Ptr("int32"), num_args: T.int32
+ args: T.handle, arg_type_ids_handle: T.handle("int32"), num_args:
T.int32
) -> None:
# The T.Ptr declaration in the parameter list should parse
# correctly, and should be usable as the data pointer in a buffer.
@@ -3157,14 +3157,14 @@ def func_T_ptr_let_statement():
# Functions that return a "handle" can be assigned to a T.Ptr
# variable. A variable annotated with T.Ptr still has dtype of
# T.handle, but has type annotation as a pointer type.
- A_data: T.Ptr("float32") = T.tvm_struct_get(arg0, 0, 1, dtype="handle")
+ A_data: T.handle("float32") = T.tvm_struct_get(arg0, 0, 1,
dtype="handle")
# The buffer declaration has a data pointer defined earlier in
# this function. It should only be defined after the data pointer
# has been defined, and should not be hoisted into the header of
# the function as other buffer_decl statements can be.
A = T.Buffer([1024], dtype="float32", data=A_data)
- B_data: T.Ptr("float32") = T.tvm_struct_get(arg1, 0, 1, dtype="handle")
+ B_data: T.handle("float32") = T.tvm_struct_get(arg1, 0, 1,
dtype="handle")
B = T.Buffer([1024], dtype="float32", data=B_data)
B[0] = A[0]
@@ -3266,13 +3266,13 @@ def string_annotation_escaping():
def pointer_type():
@T.prim_func
- def func_with_ptr_type_annotations(x: T.Ptr("int32"), y: T.Ptr("int32",
"shared")):
+ def func_with_ptr_type_annotations(x: T.handle("int32"), y:
T.handle("int32", "shared")):
xx_data = T.allocate([16], "int32", "global")
xx = T.Buffer(shape=[16], dtype="int32", scope="global", data=xx_data)
yy_data = T.allocate([16], "int32", "shared")
yy = T.Buffer(shape=[16], dtype="int32", scope="shared", data=yy_data)
- a: T.Ptr("int32") = T.address_of(xx[0], dtype="handle")
- b: T.Ptr("int32", "shared") = T.address_of(yy[0], dtype="handle")
+ a: T.handle("int32") = T.address_of(xx[0], dtype="handle")
+ b: T.handle("int32", "shared") = T.address_of(yy[0], dtype="handle")
T.evaluate(T.call_extern("copy", a, b, dtype=""))
return func_with_ptr_type_annotations
@@ -3324,7 +3324,7 @@ def let_expression():
def void_ptr():
@T.prim_func
- def func(out_ret_value: T.Ptr("void")):
+ def func(out_ret_value: T.handle("void")):
T.evaluate(out_ret_value)
return func