This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dc626f33e3 [TVMScript] Unify `T.handle` and `T.Ptr` (#13969)
dc626f33e3 is described below

commit dc626f33e3489aa5f9e92f5ae5cbbcf2f71247ef
Author: Junru Shao <[email protected]>
AuthorDate: Sun Feb 12 17:13:32 2023 -0800

    [TVMScript] Unify `T.handle` and `T.Ptr` (#13969)
    
    While both represents a pointer type, `T.handle` was previously used to
    refer to tir variables whose `type_annotation` is `PrimType`, while
    `T.Ptr` instead specifically refers to `PointerType`. The divide is
    unnecessary if we extend `T.handle` slightly.
---
 include/tvm/script/ir_builder/tir/ir.h             |  5 +--
 python/tvm/script/ir_builder/tir/ir.py             | 13 +++---
 python/tvm/script/parser/tir/entry.py              |  5 ++-
 src/script/ir_builder/tir/ir.cc                    | 10 +++++
 src/script/printer/tir/ir.cc                       |  6 +--
 .../relay/aot/test_aot_create_executor_metadata.py |  2 +-
 tests/python/relay/aot/test_pass_aot_lower_main.py |  4 +-
 ...sform_plan_update_buffer_allocation_location.py |  2 +-
 .../unittest/test_tir_transform_storage_flatten.py |  2 +-
 .../unittest/test_tir_transform_storage_rewrite.py |  4 +-
 ...ransform_convert_pool_allocations_to_offsets.py | 36 ++++++++---------
 tests/python/unittest/test_tvmscript_parser_tir.py |  4 +-
 .../python/unittest/test_tvmscript_printer_tir.py  |  2 +-
 tests/python/unittest/test_tvmscript_roundtrip.py  | 46 +++++++++++-----------
 14 files changed, 77 insertions(+), 64 deletions(-)

diff --git a/include/tvm/script/ir_builder/tir/ir.h 
b/include/tvm/script/ir_builder/tir/ir.h
index 5cba879205..d5cc1de5c6 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -415,12 +415,12 @@ void Prefetch(Buffer buffer, Array<Range> bounds);
 void Evaluate(PrimExpr value);
 
 /*!
- * \brief The pointer declaration function.
+ * \brief Create a TIR var that represents a pointer
  * \param dtype The data type of the pointer.
  * \param storage_scope The storage scope of the pointer.
  * \return The pointer.
  */
-PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global");
+Var Handle(runtime::DataType dtype = runtime::DataType::Void(), String 
storage_scope = "global");
 
 #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType)                     
        \
   inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt) {                
        \
@@ -455,7 +455,6 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, 
DataType::Float);
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt);
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle());
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
 
 #undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST
diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index fdb27df2a9..25d16b56dc 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1358,20 +1358,23 @@ def boolean(expr: Optional[PrimExpr] = None) -> 
PrimExpr:
     return _ffi_api.Boolean(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
-def handle(expr: Optional[PrimExpr] = None) -> PrimExpr:
-    """Construct a new tir.Var with type handle or cast expression to type 
handle.
+def handle(dtype: str = "void", storage_scope: str = "global") -> Var:
+    """Create a TIR var that represents a pointer.
 
     Parameters
     ----------
-    expr: PrimExpr
-        The expression to be cast.
+    dtype: str
+        The data type of the pointer.
+
+    storage_scope: str
+        The storage scope of the pointer.
 
     Returns
     -------
     res : PrimExpr
         The new tir.Var with type handle or casted expression with type handle.
     """
-    return _ffi_api.Handle(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
+    return _ffi_api.Handle(dtype, storage_scope)  # type: ignore[attr-defined] 
# pylint: disable=no-member
 
 
 def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
diff --git a/python/tvm/script/parser/tir/entry.py 
b/python/tvm/script/parser/tir/entry.py
index bacf92c142..51743e6b50 100644
--- a/python/tvm/script/parser/tir/entry.py
+++ b/python/tvm/script/parser/tir/entry.py
@@ -79,7 +79,7 @@ class BufferProxy:
             axis_separators=axis_separators,
         )
 
-    @deprecated("T.Buffer(...)", "T.Buffer(...)")
+    @deprecated("T.Buffer[...]", "T.Buffer(...)")
     def __getitem__(self, keys) -> Buffer:
         if not isinstance(keys, tuple):
             return self(keys)
@@ -93,12 +93,13 @@ class PtrProxy:
     Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr().
     """
 
+    @deprecated("T.Ptr(...)", "T.handle(...)")
     def __call__(self, dtype, storage_scope="global"):
         if callable(dtype):
             dtype = dtype().dtype
         return ptr(dtype, storage_scope)  # pylint: disable=no-member # type: 
ignore
 
-    @deprecated("T.Ptr(...)", "T.Ptr(...)")
+    @deprecated("T.Ptr[...]", "T.handle(...)")
     def __getitem__(self, keys):
         if not isinstance(keys, tuple):
             return self(keys)
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index c586e81f1b..9ab19b2e28 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -545,6 +545,16 @@ PrimExpr Ptr(runtime::DataType dtype, String 
storage_scope) {
   return tvm::tir::Var("", tvm::PointerType(PrimType(dtype), storage_scope));
 }
 
+Var Handle(runtime::DataType dtype, String storage_scope) {
+  Type type_annotation{nullptr};
+  if (dtype.is_void() && storage_scope == "global") {
+    type_annotation = PrimType(runtime::DataType::Handle());
+  } else {
+    type_annotation = PointerType(PrimType(dtype), storage_scope);
+  }
+  return tvm::tir::Var("", type_annotation);
+}
+
 using tvm::script::ir_builder::details::Namer;
 
 TVM_STATIC_IR_FUNCTOR(Namer, vtable)
diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc
index ce10ff6816..78e50a5eb5 100644
--- a/src/script/printer/tir/ir.cc
+++ b/src/script/printer/tir/ir.cc
@@ -73,10 +73,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
         element_type = d->AsDoc<ExprDoc>(ty->element_type, 
ty_p->Attr("element_type"));
       }
       if (ty->storage_scope == "") {
-        return TIR(d, "Ptr")->Call({element_type});
+        return TIR(d, "handle")->Call({element_type});
       } else {
-        return TIR(d, "Ptr")->Call(
-            {element_type, LiteralDoc::Str(ty->storage_scope, 
ty_p->Attr("storage_scope"))});
+        return TIR(d, "handle")
+            ->Call({element_type, LiteralDoc::Str(ty->storage_scope, 
ty_p->Attr("storage_scope"))});
       }
     });
 
diff --git a/tests/python/relay/aot/test_aot_create_executor_metadata.py 
b/tests/python/relay/aot/test_aot_create_executor_metadata.py
index 1bc79fe2a6..804738a786 100644
--- a/tests/python/relay/aot/test_aot_create_executor_metadata.py
+++ b/tests/python/relay/aot/test_aot_create_executor_metadata.py
@@ -53,7 +53,7 @@ def test_create_executor_metadata_single_func():
     class Module:
         @T.prim_func
         def __tvm_main__(
-            a: T.handle, output: T.handle, workspace: T.Ptr(T.uint8), 
constants: T.Ptr(T.uint8)
+            a: T.handle, output: T.handle, workspace: T.handle("uint8"), 
constants: T.handle("uint8")
         ) -> None:
             # function attr dict
             T.func_attr({"global_symbol": "test_mod___tvm_main__", 
"runner_function": True, "target": T.target({"kind": "llvm", "tag": "", "keys": 
["cpu"]}), "input_vars": [a], "output_vars": [output], "devices": 
["test_device"]})
diff --git a/tests/python/relay/aot/test_pass_aot_lower_main.py 
b/tests/python/relay/aot/test_pass_aot_lower_main.py
index f2455e97a0..bc58812cd6 100644
--- a/tests/python/relay/aot/test_pass_aot_lower_main.py
+++ b/tests/python/relay/aot/test_pass_aot_lower_main.py
@@ -178,13 +178,13 @@ def @main(%a: Tensor[(5, 7), float32]) -> Tensor[(5, 7), 
float32] {
     def func(a: T.handle, output: T.handle) -> None:
         # function attr dict
         T.func_attr({"global_symbol": "test_mod___tvm_main__", 
"runner_function": True, "target": T.target({"kind":"llvm", "tag":"", 
"keys":["cpu"]}), "input_vars": [a], "output_vars": [output], "devices": []})
-        tmp_read = T.Ptr("uint8", "")
+        tmp_read = T.handle("uint8", "")
         # buffer definition
         tmp_read_1 = T.Buffer([T.uint64(140)], dtype="uint8", data=tmp_read)
         a_buffer = T.match_buffer(a, [5, 7], dtype="float32", align=16)
         output_buffer = T.match_buffer(output, [5, 7], dtype="float32", 
align=16)
         # body
-        tmp_write: T.Ptr(T.uint8) = output_buffer.data
+        tmp_write: T.handle("uint8") = output_buffer.data
         tmp_write_1 = T.Buffer([T.uint64(140)], dtype="uint8", data=tmp_write)
         for i in T.serial(140):
             tmp_write_1[i] = T.let(tmp_read, a_buffer.data, tmp_read_1[i])
diff --git 
a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
 
b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
index 05d71de5bc..758a395da6 100644
--- 
a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
+++ 
b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
@@ -424,7 +424,7 @@ def test_buffer_conditional_lowering():
     """
 
     @T.prim_func
-    def before(A: T.Ptr("float32")):
+    def before(A: T.handle("float32")):
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
         for i in range(1):
             A_1 = T.Buffer((1,), data=A)
diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py 
b/tests/python/unittest/test_tir_transform_storage_flatten.py
index 29623b498f..39009164e7 100644
--- a/tests/python/unittest/test_tir_transform_storage_flatten.py
+++ b/tests/python/unittest/test_tir_transform_storage_flatten.py
@@ -139,7 +139,7 @@ def test_flatten_let_buffer():
             T.func_attr({"from_legacy_te_schedule": True})
 
             # If a pointer defined using a LetStmt,
-            A_data: T.Ptr("int32") = T.call_extern("dummy_extern_function", 
dtype="handle")
+            A_data: T.handle("int32") = T.call_extern("dummy_extern_function", 
dtype="handle")
 
             # and a buffer is backed by that pointer,
             A = T.decl_buffer([1], dtype="float32", data=A_data)
diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py 
b/tests/python/unittest/test_tir_transform_storage_rewrite.py
index 4766022121..c46754fb17 100644
--- a/tests/python/unittest/test_tir_transform_storage_rewrite.py
+++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py
@@ -689,12 +689,12 @@ class TestLetBufferRewrite(BaseCompare):
     """
 
     def before() -> None:
-        A_data: T.Ptr("int32") = T.call_extern("dummy_func", dtype="handle")
+        A_data: T.handle("int32") = T.call_extern("dummy_func", dtype="handle")
         A = T.Buffer([8], "int32", data=A_data)
         A[0:8] = T.broadcast(42, 8)
 
     def expected() -> None:
-        A_data: T.Ptr("int32x8") = T.call_extern("dummy_func", dtype="handle")
+        A_data: T.handle("int32x8") = T.call_extern("dummy_func", 
dtype="handle")
         A = T.Buffer([1], "int32x8", data=A_data)
         A[0] = T.broadcast(42, 8)
 
diff --git 
a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
 
b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
index 5bbedd3492..58f37f0496 100644
--- 
a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
+++ 
b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
@@ -144,20 +144,20 @@ class LinearStructure:
 @tvm.script.ir_module
 class LinearStructurePlanned:
     @T.prim_func
-    def __tvm_main__(input: T.handle, fast_memory_0_var: T.Ptr("uint8"), 
slow_memory_1_var: T.Ptr("uint8"), output: T.handle) -> None:
+    def __tvm_main__(input: T.handle, fast_memory_0_var: T.handle("uint8"), 
slow_memory_1_var: T.handle("uint8"), output: T.handle) -> None:
         fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], 
dtype="uint8", strides=[1], elem_offset=0, align=16)
         slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, 
[1418528], dtype="uint8", strides=[1], elem_offset=0, align=16)
         # body
         T.attr("default", "device_id", 0)
         T.attr("default", "device_type", 1)
-        sid_9_let: T.Ptr("int8") = 
T.address_of(slow_memory_1_buffer_var[1117472], dtype="handle")
-        sid_8_let: T.Ptr("int8") = T.address_of(slow_memory_1_buffer_var[0], 
dtype="handle")
+        sid_9_let: T.handle("int8") = 
T.address_of(slow_memory_1_buffer_var[1117472], dtype="handle")
+        sid_8_let: T.handle("int8") = 
T.address_of(slow_memory_1_buffer_var[0], dtype="handle")
         T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, 
T.lookup_param("p0", dtype="handle"), sid_9_let, fast_memory_0_buffer_var.data, 
slow_memory_1_buffer_var.data, dtype="int32"))
         
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast",
 sid_9_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", 
dtype="handle"), sid_8_let, fast_memory_0_buffer_var.data, 
slow_memory_1_buffer_var.data, dtype="int32"))
         T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", 
sid_8_let, output, fast_memory_0_buffer_var.data, 
slow_memory_1_buffer_var.data, dtype="int32"))
 
     @T.prim_func
-    def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, 
T_cast_6: T.handle, fast_memory_6_var: T.Ptr("uint8"), slow_memory_7_var: 
T.Ptr("uint8")) -> None:
+    def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, 
T_cast_6: T.handle, fast_memory_6_var: T.handle("uint8"), slow_memory_7_var: 
T.handle("uint8")) -> None:
         placeholder_29 = T.match_buffer(placeholder_28, [802816], 
dtype="uint8")
         T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16")
         fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], 
dtype="uint8", strides=[1], elem_offset=0, align=16)
@@ -174,7 +174,7 @@ class LinearStructurePlanned:
                 T_cast_7[ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3] = 
T.cast(tensor_2_let[ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3], "int16")
 
     @T.prim_func
-    def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, 
placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: 
T.Ptr("uint8"), slow_memory_3_var: T.Ptr("uint8")) -> None:
+    def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, 
placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: 
T.handle("uint8"), slow_memory_3_var: T.handle("uint8")) -> None:
         placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8")
         placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16")
         T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16")
@@ -185,7 +185,7 @@ class LinearStructurePlanned:
             T_subtract_1[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1] = 
T.cast(placeholder_4[ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1], "int16") 
- placeholder_5[0]
 
     @T.prim_func
-    def 
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62:
 T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: 
T.handle, fast_memory_4_var: T.Ptr("uint8"), slow_memory_5_var: T.Ptr("uint8")) 
-> None:
+    def 
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62:
 T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: 
T.handle, fast_memory_4_var: T.handle("uint8"), slow_memory_5_var: 
T.handle("uint8")) -> None:
         placeholder_65 = T.match_buffer(placeholder_62, [150528], 
dtype="int16")
         placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16")
         placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32")
@@ -380,7 +380,7 @@ class ResnetStructure:
 @tvm.script.ir_module
 class ResnetStructurePlanned:
     @T.prim_func
-    def 
tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder:
 T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: 
T.Ptr("uint8")) -> None:
+    def 
tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder:
 T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: 
T.handle("uint8")) -> None:
         placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8")
         placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32")
         T_cast_1 = T.match_buffer(T_cast, [215], dtype="int16")
@@ -390,7 +390,7 @@ class ResnetStructurePlanned:
             T_cast_1[ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + 
ax3_inner] = 
T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(placeholder_2[ax0_ax1_fused 
* 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner], "int32") - 94, 1843157232, 31, 
1, dtype="int32") + placeholder_3[ax3_outer * 16 + ax3_inner], 255), 0), 
"uint8"), "int16")
 
     @T.prim_func
-    def 
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22:
 T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: 
T.handle, T_cast_6: T.handle, global_workspace_5_var: T.Ptr("uint8")) -> None:
+    def 
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22:
 T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: 
T.handle, T_cast_6: T.handle, global_workspace_5_var: T.handle("uint8")) -> 
None:
         placeholder_29 = T.match_buffer(placeholder_22, [360000], 
dtype="int16")
         placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16")
         placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32")
@@ -414,7 +414,7 @@ class ResnetStructurePlanned:
                             T_cast_7[ax0_ax1_fused_ax2_fused_3 * 256 + 
ax3_outer_2 * 64 + ax3_inner_4] = 
T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_3_let[ax3_inner_4]
 + placeholder_26[ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8, 
dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, 
dtype="int32") + placeholder_28[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 
64 + ax3_inner_4], 255), 0), "uint8")
 
     @T.prim_func
-    def 
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16:
 T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, 
global_workspace_4_var: T.Ptr("uint8")) -> None:
+    def 
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16:
 T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, 
global_workspace_4_var: T.handle("uint8")) -> None:
         placeholder_19 = T.match_buffer(placeholder_16, [360000], 
dtype="int16")
         placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16")
         placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32")
@@ -437,7 +437,7 @@ class ResnetStructurePlanned:
                             T_add_1[ax0_ax1_fused_ax2_fused_2 * 256 + 
ax3_outer_1 * 64 + ax3_inner_3] = 
T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_2_let[ax3_inner_3]
 + placeholder_21[ax3_outer_1 * 64 + ax3_inner_3], 1711626602, 31, -8, 
dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, 
dtype="int32") + 136
 
     @T.prim_func
-    def 
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4:
 T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: 
T.handle, global_workspace_2_var: T.Ptr("uint8")) -> None:
+    def 
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4:
 T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: 
T.handle, global_workspace_2_var: T.handle("uint8")) -> None:
         placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16")
         placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16")
         placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32")
@@ -459,7 +459,7 @@ class ResnetStructurePlanned:
                         T_cast_3[ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1] = 
T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_let[ax3_inner_1] + 
placeholder_9[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), 
"uint8"), "int16")
 
     @T.prim_func
-    def 
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10:
 T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: 
T.handle, global_workspace_3_var: T.Ptr("uint8")) -> None:
+    def 
tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10:
 T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: 
T.handle, global_workspace_3_var: T.handle("uint8")) -> None:
         placeholder_13 = T.match_buffer(placeholder_10, [360000], 
dtype="int16")
         placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16")
         placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32")
@@ -481,15 +481,15 @@ class ResnetStructurePlanned:
                         T_cast_5[ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2] 
= T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_1_let[ax3_inner_2] 
+ placeholder_15[ax3_inner_2], 1608879842, 31, -7, dtype="int32"), 255), 0), 
"uint8"), "int16")
 
     @T.prim_func
-    def __tvm_main__(input: T.handle, global_workspace_0_var: T.Ptr("uint8"), 
output: T.handle) -> None:
+    def __tvm_main__(input: T.handle, global_workspace_0_var: 
T.handle("uint8"), output: T.handle) -> None:
         global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, 
[7920256], dtype="uint8", strides=[1], elem_offset=0, align=16)
         # body
         T.attr("default", "device_id", 0)
         T.attr("default", "device_type", 1)
-        sid_2_let: T.Ptr("int8") = 
T.address_of(global_workspace_0_buffer_var[5760000], dtype="handle")
-        sid_6_let: T.Ptr("int8") = 
T.address_of(global_workspace_0_buffer_var[0], dtype="handle")
-        sid_7_let: T.Ptr("int8") = 
T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle")
-        sid_8_let: T.Ptr("int8") = 
T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle")
+        sid_2_let: T.handle("int8") = 
T.address_of(global_workspace_0_buffer_var[5760000], dtype="handle")
+        sid_6_let: T.handle("int8") = 
T.address_of(global_workspace_0_buffer_var[0], dtype="handle")
+        sid_7_let: T.handle("int8") = 
T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle")
+        sid_8_let: T.handle("int8") = 
T.address_of(global_workspace_0_buffer_var[6480000], dtype="handle")
         
T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast",
 input, T.lookup_param("p0", dtype="handle"), sid_2_let, 
global_workspace_0_buffer_var.data, dtype="int32"))
         
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast",
 sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", 
dtype="handle"), sid_8_let, global_workspace_0_buffer_var.data, dtype="int32"))
         
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1",
 sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", 
dtype="handle"), sid_7_let, global_workspace_0_buffer_var.data, dtype="int32"))
@@ -557,7 +557,7 @@ class TensorIntrinStructure:
 @tvm.script.ir_module
 class TensorIntrinStructurePlanned:
     @T.prim_func
-    def tensor_intrin_primfunc(global_workspace_1_var: T.Ptr("uint8")) -> None:
+    def tensor_intrin_primfunc(global_workspace_1_var: T.handle("uint8")) -> 
None:
         global_workspace_1_buffer_var = T.match_buffer(
             global_workspace_1_var, [40], dtype="uint8", strides=[1], 
elem_offset=0, align=16
         )
@@ -576,7 +576,7 @@ class TensorIntrinStructurePlanned:
 
     @T.prim_func
     def __tvm_main__(
-        input: T.handle, global_workspace_1_var: T.Ptr("uint8"), output: 
T.handle
+        input: T.handle, global_workspace_1_var: T.handle("uint8"), output: 
T.handle
     ) -> None:
         global_workspace_1_buffer_var = T.match_buffer(
             global_workspace_1_var, [40], dtype="uint8", strides=[1], 
elem_offset=0, align=16
diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py 
b/tests/python/unittest/test_tvmscript_parser_tir.py
index e96ae4da8c..20be6d1498 100644
--- a/tests/python/unittest/test_tvmscript_parser_tir.py
+++ b/tests/python/unittest/test_tvmscript_parser_tir.py
@@ -40,7 +40,7 @@ def test_tir_buffer_proxy():
 
 
 def test_tir_ptr_proxy():
-    ptr_0 = T.Ptr("int32", "global")
+    ptr_0 = T.handle("int32", "global")
     assert (
         isinstance(ptr_0, tir.Var)
         and ptr_0.dtype == "handle"
@@ -49,7 +49,7 @@ def test_tir_ptr_proxy():
         and ptr_0.type_annotation.storage_scope == "global"
     )
 
-    ptr_1 = T.Ptr("float32", "shared")
+    ptr_1 = T.handle("float32", "shared")
     assert (
         isinstance(ptr_1, tir.Var)
         and ptr_1.dtype == "handle"
diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py 
b/tests/python/unittest/test_tvmscript_printer_tir.py
index 6f96b3a3dd..a04544152e 100644
--- a/tests/python/unittest/test_tvmscript_printer_tir.py
+++ b/tests/python/unittest/test_tvmscript_printer_tir.py
@@ -674,7 +674,7 @@ def test_prim_type():
 
 def test_pointer_type():
     obj = ir.PointerType(ir.PrimType("int32"), "global")
-    _assert_print(obj, 'T.Ptr("int32", "global")')
+    _assert_print(obj, 'T.handle("int32", "global")')
 
 
 def test_tuple_type():
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py 
b/tests/python/unittest/test_tvmscript_roundtrip.py
index 1ec8f49b4b..db21223366 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -204,30 +204,30 @@ def opt_gemm_mod_host():
             arg2: T.handle = T.tvm_struct_get(args, 2, 12, dtype="handle")
             arg2_code: T.int32 = buf_type_ids[2]
 
-            A_data: T.Ptr("int32") = T.tvm_struct_get(arg0, 0, 1, 
dtype="handle")
+            A_data: T.handle("int32") = T.tvm_struct_get(arg0, 0, 1, 
dtype="handle")
             T.attr(A_data, "storage_alignment", 128)
             A = T.Buffer([1024 * 1024], dtype="int32", data=A_data)
-            buf0_shape_data: T.Ptr("int32") = T.tvm_struct_get(arg0, 0, 2, 
dtype="handle")
+            buf0_shape_data: T.handle("int32") = T.tvm_struct_get(arg0, 0, 2, 
dtype="handle")
             buf0_shape = T.Buffer([2], dtype="int32", data=buf0_shape_data)
-            buf0_strides_data: T.Ptr("int32") = T.tvm_struct_get(arg0, 0, 3, 
dtype="handle")
+            buf0_strides_data: T.handle("int32") = T.tvm_struct_get(arg0, 0, 
3, dtype="handle")
             buf0_strides = T.Buffer([2], dtype="int32", data=buf0_strides_data)
 
             dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32")
 
-            B_data: T.Ptr("int32") = T.tvm_struct_get(arg1, 0, 1, 
dtype="handle")
+            B_data: T.handle("int32") = T.tvm_struct_get(arg1, 0, 1, 
dtype="handle")
             T.attr(B_data, "storage_alignment", 128)
             B = T.Buffer([1024 * 1024], dtype="int32", data=B_data)
-            buf1_shape_data: T.Ptr("int32") = T.tvm_struct_get(arg1, 0, 2, 
dtype="handle")
+            buf1_shape_data: T.handle("int32") = T.tvm_struct_get(arg1, 0, 2, 
dtype="handle")
             buf1_shape = T.Buffer([2], dtype="int32", data=buf1_shape_data)
-            buf1_strides_data: T.Ptr("int32") = T.tvm_struct_get(arg1, 0, 3, 
dtype="handle")
+            buf1_strides_data: T.handle("int32") = T.tvm_struct_get(arg1, 0, 
3, dtype="handle")
             buf1_strides = T.Buffer([2], dtype="int32", data=buf1_strides_data)
 
-            C_data: T.Ptr("int32") = T.tvm_struct_get(arg2, 0, 1, 
dtype="handle")
+            C_data: T.handle("int32") = T.tvm_struct_get(arg2, 0, 1, 
dtype="handle")
             T.attr(C_data, "storage_alignment", 128)
             C = T.Buffer([1024 * 1024], dtype="int32", data=C_data)
-            buf2_shape_data: T.Ptr("int32") = T.tvm_struct_get(arg2, 0, 2, 
dtype="handle")
+            buf2_shape_data: T.handle("int32") = T.tvm_struct_get(arg2, 0, 2, 
dtype="handle")
             buf2_shape = T.Buffer([2], dtype="int32", data=buf2_shape_data)
-            buf2_strides_data: T.Ptr("int32") = T.tvm_struct_get(arg2, 0, 3, 
dtype="handle")
+            buf2_strides_data: T.handle("int32") = T.tvm_struct_get(arg2, 0, 
3, dtype="handle")
             buf2_strides = T.Buffer([2], dtype="int32", data=buf2_strides_data)
 
             assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 
7)) or (
@@ -2238,7 +2238,7 @@ def opt_conv_tensorcore_mod_host():
             }
         )
         # body
-        stack_tcode_data: T.Ptr("int32") = T.tvm_stack_alloca("arg_tcode", 10, 
dtype="handle")
+        stack_tcode_data: T.handle("int32") = T.tvm_stack_alloca("arg_tcode", 
10, dtype="handle")
         stack_tcode = T.Buffer([9], "int32", data=stack_tcode_data)
         stack_value: T.handle = T.tvm_stack_alloca("arg_value", 10, 
dtype="handle")
         assert num_args == 3, "default_function: num_args should be 3"
@@ -2251,25 +2251,25 @@ def opt_conv_tensorcore_mod_host():
 
         A: T.handle = T.tvm_struct_get(arg0, 0, 1, dtype="handle")
         T.attr(A, "storage_alignment", 128)
-        arg0_shape_data: T.Ptr("int64") = T.tvm_struct_get(arg0, 0, 2, 
dtype="handle")
+        arg0_shape_data: T.handle("int64") = T.tvm_struct_get(arg0, 0, 2, 
dtype="handle")
         arg0_shape = T.Buffer([6], "int64", data=arg0_shape_data)
-        arg0_strides_data: T.Ptr("int64") = T.tvm_struct_get(arg0, 0, 3, 
dtype="handle")
+        arg0_strides_data: T.handle("int64") = T.tvm_struct_get(arg0, 0, 3, 
dtype="handle")
         arg0_strides = T.Buffer([6], "int64", data=arg0_strides_data)
 
         dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32")
 
         W: T.handle = T.tvm_struct_get(arg1, 0, 1, dtype="handle")
         T.attr(W, "storage_alignment", 128)
-        arg1_shape_data: T.Ptr("int64") = T.tvm_struct_get(arg1, 0, 2, 
dtype="handle")
+        arg1_shape_data: T.handle("int64") = T.tvm_struct_get(arg1, 0, 2, 
dtype="handle")
         arg1_shape = T.Buffer([6], "int64", data=arg1_shape_data)
-        arg1_strides_data: T.Ptr("int64") = T.tvm_struct_get(arg1, 0, 3, 
dtype="handle")
+        arg1_strides_data: T.handle("int64") = T.tvm_struct_get(arg1, 0, 3, 
dtype="handle")
         arg1_strides = T.Buffer([6], "int64", data=arg1_strides_data)
 
         Conv: T.handle = T.tvm_struct_get(arg2, 0, 1, dtype="handle")
         T.attr(Conv, "storage_alignment", 128)
-        arg2_shape_data: T.Ptr("int64") = T.tvm_struct_get(arg2, 0, 2, 
dtype="handle")
+        arg2_shape_data: T.handle("int64") = T.tvm_struct_get(arg2, 0, 2, 
dtype="handle")
         arg2_shape = T.Buffer([6], "int64", data=arg2_shape_data)
-        arg2_strides_data: T.Ptr("int64") = T.tvm_struct_get(arg2, 0, 3, 
dtype="handle")
+        arg2_strides_data: T.handle("int64") = T.tvm_struct_get(arg2, 0, 3, 
dtype="handle")
         arg2_strides = T.Buffer([6], "int64", data=arg2_strides_data)
 
         assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) 
or (
@@ -3145,7 +3145,7 @@ def func_nested_root_block():
 def func_T_ptr_let_statement():
     @T.prim_func
     def func_T_ptr_let_statement(
-        args: T.handle, arg_type_ids_handle: T.Ptr("int32"), num_args: T.int32
+        args: T.handle, arg_type_ids_handle: T.handle("int32"), num_args: 
T.int32
     ) -> None:
         # The T.Ptr declaration in the parameter list should parse
         # correctly, and should be usable as the data pointer in a buffer.
@@ -3157,14 +3157,14 @@ def func_T_ptr_let_statement():
         # Functions that return a "handle" can be assigned to a T.Ptr
         # variable.  A variable annotated with T.Ptr still has dtype of
         # T.handle, but has type annotation as a pointer type.
-        A_data: T.Ptr("float32") = T.tvm_struct_get(arg0, 0, 1, dtype="handle")
+        A_data: T.handle("float32") = T.tvm_struct_get(arg0, 0, 1, 
dtype="handle")
 
         # The buffer declaration has a data pointer defined earlier in
         # this function.  It should only be defined after the data pointer
         # has been defined, and should not be hoisted into the header of
         # the function as other buffer_decl statements can be.
         A = T.Buffer([1024], dtype="float32", data=A_data)
-        B_data: T.Ptr("float32") = T.tvm_struct_get(arg1, 0, 1, dtype="handle")
+        B_data: T.handle("float32") = T.tvm_struct_get(arg1, 0, 1, 
dtype="handle")
         B = T.Buffer([1024], dtype="float32", data=B_data)
 
         B[0] = A[0]
@@ -3266,13 +3266,13 @@ def string_annotation_escaping():
 
 def pointer_type():
     @T.prim_func
-    def func_with_ptr_type_annotations(x: T.Ptr("int32"), y: T.Ptr("int32", 
"shared")):
+    def func_with_ptr_type_annotations(x: T.handle("int32"), y: 
T.handle("int32", "shared")):
         xx_data = T.allocate([16], "int32", "global")
         xx = T.Buffer(shape=[16], dtype="int32", scope="global", data=xx_data)
         yy_data = T.allocate([16], "int32", "shared")
         yy = T.Buffer(shape=[16], dtype="int32", scope="shared", data=yy_data)
-        a: T.Ptr("int32") = T.address_of(xx[0], dtype="handle")
-        b: T.Ptr("int32", "shared") = T.address_of(yy[0], dtype="handle")
+        a: T.handle("int32") = T.address_of(xx[0], dtype="handle")
+        b: T.handle("int32", "shared") = T.address_of(yy[0], dtype="handle")
         T.evaluate(T.call_extern("copy", a, b, dtype=""))
 
     return func_with_ptr_type_annotations
@@ -3324,7 +3324,7 @@ def let_expression():
 
 def void_ptr():
     @T.prim_func
-    def func(out_ret_value: T.Ptr("void")):
+    def func(out_ret_value: T.handle("void")):
         T.evaluate(out_ret_value)
 
     return func

Reply via email to