This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 82cf9f72d6 [TVMScript] Simplify TIR Var Definition (#13970)
82cf9f72d6 is described below
commit 82cf9f72d68903f1f36921cd1e7ae4435eced5d3
Author: Junru Shao <[email protected]>
AuthorDate: Sun Feb 12 18:32:41 2023 -0800
[TVMScript] Simplify TIR Var Definition (#13970)
This PR introduces a small tweak to TVMScript printer that simplifies
variable definition in TIR.
Originally, defining a TIR var uses `T.var(dtype)`, e.g.
```python
a = T.var("int32")
```
This PR encourages to shorten the definition to:
```python
a = T.int32()
```
There is no breaking change in this PR, which means the legacy behavior
still works without any problem.
---
python/tvm/script/ir_builder/tir/ir.py | 1 +
python/tvm/script/parser/tir/parser.py | 2 +-
python/tvm/tir/tensor_intrin/cuda.py | 40 ++++++------
python/tvm/utils/roofline/cuda.py | 2 +-
python/tvm/utils/roofline/x86.py | 2 +-
src/script/printer/tir/expr.cc | 25 +++++--
.../test_ethosu/test_copy_compute_reordering.py | 76 +++++++++++-----------
.../contrib/test_ethosu/test_merge_constants.py | 40 ++++++------
tests/python/integration/test_lower.py | 12 ++--
.../unittest/test_aot_legalize_packed_call.py | 16 ++---
tests/python/unittest/test_arith_domain_touched.py | 4 +-
.../test_meta_schedule_postproc_verify_gpu_code.py | 12 ++--
.../unittest/test_meta_schedule_trace_apply.py | 40 ++++++------
tests/python/unittest/test_te_create_primfunc.py | 16 ++---
tests/python/unittest/test_tir_analysis_oob.py | 2 +-
tests/python/unittest/test_tir_intrin.py | 10 +--
.../python/unittest/test_tir_lower_match_buffer.py | 26 ++++----
tests/python/unittest/test_tir_renew_defs.py | 6 +-
tests/python/unittest/test_tir_schedule_rfactor.py | 2 +-
.../python/unittest/test_tir_schedule_tensorize.py | 24 +++----
tests/python/unittest/test_tir_specialize.py | 18 ++---
.../test_tir_transform_common_subexpr_elim.py | 4 +-
.../test_tir_transform_hoist_expression.py | 4 +-
.../python/unittest/test_tvmscript_error_report.py | 4 +-
.../unittest/test_tvmscript_ir_builder_tir.py | 32 +++++----
.../python/unittest/test_tvmscript_printer_tir.py | 52 +++++++--------
tests/python/unittest/test_tvmscript_roundtrip.py | 22 +++----
.../python/unittest/test_tvmscript_syntax_sugar.py | 6 +-
28 files changed, 257 insertions(+), 243 deletions(-)
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index 25d16b56dc..2c5a848e4a 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1393,6 +1393,7 @@ def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
return _ffi_api.Void(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
+@deprecated("T.var", "T.{dtype}")
def var(dtype: str, name: str = "") -> Var:
"""Construct a new tir.Var.
diff --git a/python/tvm/script/parser/tir/parser.py
b/python/tvm/script/parser/tir/parser.py
index 0e74114ba2..fbef1a9691 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -143,7 +143,7 @@ def bind_assign_value(self: Parser, node: doc.expr,
var_name: str, value: Any) -
IRBuilder.name(var_name, value)
return value
elif isinstance(value, PrimExpr):
- var = T.var(value.dtype)
+ var = Var("", value.dtype)
IRBuilder.name(var_name, var)
frame = T.let(var, value)
frame.add_callback(partial(frame.__exit__, None, None, None))
diff --git a/python/tvm/tir/tensor_intrin/cuda.py
b/python/tvm/tir/tensor_intrin/cuda.py
index 0703811ea7..6483b99454 100644
--- a/python/tvm/tir/tensor_intrin/cuda.py
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -146,8 +146,8 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed,
shared_scope="shared"):
@T.prim_func
def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
- s0 = T.var("int32")
- s1 = T.var("int32")
+ s0 = T.int32()
+ s1 = T.int32()
shared = T.match_buffer(
shared_handle,
shmem_shape,
@@ -385,8 +385,8 @@ def get_mma_store_intrin(dtype, local_size, scope="global"):
@T.prim_func
def mma_store_impl(a: T.handle, c: T.handle) -> None:
- s0 = T.var("int32")
- s1 = T.var("int32")
+ s0 = T.int32()
+ s1 = T.int32()
C_warp = T.match_buffer(
a, [WARP_SIZE, local_size], dtype=dtype, scope="warp",
offset_factor=1
@@ -530,10 +530,10 @@ def get_wmma_load_intrin(
@T.prim_func
def wmma_load_impl(a: T.handle, c: T.handle) -> None:
- s1 = T.var("int32")
- s0 = T.var("int32")
- d1 = T.var("int32")
- d0 = T.var("int32")
+ s1 = T.int32()
+ s0 = T.int32()
+ d1 = T.int32()
+ d0 = T.int32()
A = T.match_buffer(
a,
(m_dim, n_dim),
@@ -593,8 +593,8 @@ def get_wmma_fill_intrin(
@T.prim_func
def wmma_fill_impl(c: T.handle) -> None:
- d1 = T.var("int32")
- d0 = T.var("int32")
+ d1 = T.int32()
+ d0 = T.int32()
C = T.match_buffer(
c,
(m_dim, n_dim),
@@ -643,10 +643,10 @@ def get_wmma_store_intrin(
@T.prim_func
def wmma_store_impl(a: T.handle, c: T.handle) -> None:
- s1 = T.var("int32")
- s0 = T.var("int32")
- d1 = T.var("int32")
- d0 = T.var("int32")
+ s1 = T.int32()
+ s0 = T.int32()
+ d1 = T.int32()
+ d0 = T.int32()
A = T.match_buffer(
a,
(m_dim, n_dim),
@@ -726,12 +726,12 @@ def get_wmma_sync_intrin(
@T.prim_func
def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
- a1 = T.var("int32")
- a0 = T.var("int32")
- b1 = T.var("int32")
- b0 = T.var("int32")
- c1 = T.var("int32")
- c0 = T.var("int32")
+ a1 = T.int32()
+ a0 = T.int32()
+ b1 = T.int32()
+ b0 = T.int32()
+ c1 = T.int32()
+ c0 = T.int32()
A = T.match_buffer(
a,
diff --git a/python/tvm/utils/roofline/cuda.py
b/python/tvm/utils/roofline/cuda.py
index 5d80c80880..b83a902b7f 100644
--- a/python/tvm/utils/roofline/cuda.py
+++ b/python/tvm/utils/roofline/cuda.py
@@ -299,7 +299,7 @@ def estimate_peak_flops(
@T.prim_func
def peak_bandwidth_tir(a: T.handle, b: T.handle, blocks: T.int32, warp_size:
T.int32) -> None:
# pylint: disable=invalid-name, missing-function-docstring
- N = T.var("int32")
+ N = T.int32()
A = T.match_buffer(a, [blocks, N, 4, warp_size], "float32")
B = T.match_buffer(b, [blocks, 4, warp_size], "float32")
for i in T.thread_binding(blocks, "blockIdx.x"):
diff --git a/python/tvm/utils/roofline/x86.py b/python/tvm/utils/roofline/x86.py
index 37a666d252..5d2dd27e52 100644
--- a/python/tvm/utils/roofline/x86.py
+++ b/python/tvm/utils/roofline/x86.py
@@ -216,7 +216,7 @@ def estimate_peak_fma_flops(
@T.prim_func
def peak_bandwidth_tir(a: T.handle, b: T.handle, threads: T.int32, vec_width:
T.int32) -> None:
# pylint: disable=invalid-name, missing-function-docstring
- N = T.var("int32")
+ N = T.int32()
A = T.match_buffer(a, [threads, N, 4, vec_width], "float32")
B = T.match_buffer(b, [threads, 4, vec_width], "float32")
# Parallelism is necessary to hit all cores/nodes
diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc
index a5d5d492ff..d860eeb2a7 100644
--- a/src/script/printer/tir/expr.cc
+++ b/src/script/printer/tir/expr.cc
@@ -29,14 +29,29 @@ Doc PrintVar(const tir::Var& var, const ObjectPath& var_p,
const IRDocsifier& d)
if (Optional<Frame> opt_f = FindLowestVarDef(var, d)) {
ExprDoc lhs = DefineVar(var, opt_f.value(), d);
Type type = var->type_annotation;
+ ObjectPath type_p = var_p->Attr("type_annotation");
+ ExprDoc rhs{nullptr};
if (const auto* ptr_type = type.as<PointerTypeNode>()) {
- ICHECK(ptr_type->element_type->IsInstance<PrimTypeNode>());
- ExprDoc rhs = d->AsDoc<ExprDoc>(type, var_p->Attr("type_annotation"));
- opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
+ const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>();
+ ICHECK(prim_type);
+ ExprDoc element_type =
+ LiteralDoc::DataType(prim_type->dtype,
type_p->Attr("element_type")->Attr("dtype"));
+ rhs = TIR(d, "handle");
+ rhs->source_paths.push_back(var_p->Attr("dtype"));
+ if (ptr_type->storage_scope == "") {
+ rhs = rhs->Call({element_type});
+ } else {
+ rhs = rhs->Call({element_type,
+ LiteralDoc::Str(ptr_type->storage_scope, //
+ type_p->Attr("storage_scope"))});
+ }
} else {
- ExprDoc rhs = TIR(d, "var")->Call({LiteralDoc::DataType(var->dtype,
var_p->Attr("dtype"))});
- opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
+ rhs = TIR(d, DType2Str(var->dtype));
+ rhs->source_paths.push_back(var_p->Attr("dtype"));
+ rhs = rhs->Call({});
}
+ rhs->source_paths.push_back(type_p);
+ opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
} else {
LOG(WARNING) << "Didn't find variable definition for: " <<
var->name_hint;
}
diff --git a/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py
b/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py
index 99bd273115..1a00e01b60 100644
--- a/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py
+++ b/tests/python/contrib/test_ethosu/test_copy_compute_reordering.py
@@ -476,16 +476,16 @@ def test_reordering_based_on_cycles():
def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded:
T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"),
placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_encoded_3:
T.Buffer(112, "uint8"), ethosu_write: T.Buffer(43672, "int8")) -> None:
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main",
"from_legacy_te_schedule": True})
- ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
- ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
- ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
- ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.var("int32")
- nn = T.var("int32")
- nn_1 = T.var("int32")
- nn_2 = T.var("int32")
- nn_3 = T.var("int32")
- nn_4 = T.var("int32")
- nn_5 = T.var("int32")
+ ax0_ax1_fused_ax2_fused_ax3_fused = T.int32()
+ ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32()
+ ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32()
+ ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.int32()
+ nn = T.int32()
+ nn_1 = T.int32()
+ nn_2 = T.int32()
+ nn_3 = T.int32()
+ nn_4 = T.int32()
+ nn_5 = T.int32()
# body
placeholder_d_global = T.decl_buffer([208], "uint8")
placeholder_d_global_1 = T.decl_buffer([112], "uint8")
@@ -524,16 +524,16 @@ def test_reordering_based_on_cycles():
def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded:
T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"),
placeholder_encoded_2: T.Buffer(96, "uint8"), placeholder_encoded_3:
T.Buffer(112, "uint8"), ethosu_write: T.Buffer(43672, "int8")) -> None:
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main",
"from_legacy_te_schedule": True})
- ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
- ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
- ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
- ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.var("int32")
- nn = T.var("int32")
- nn_1 = T.var("int32")
- nn_2 = T.var("int32")
- nn_3 = T.var("int32")
- nn_4 = T.var("int32")
- nn_5 = T.var("int32")
+ ax0_ax1_fused_ax2_fused_ax3_fused = T.int32()
+ ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32()
+ ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32()
+ ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.int32()
+ nn = T.int32()
+ nn_1 = T.int32()
+ nn_2 = T.int32()
+ nn_3 = T.int32()
+ nn_4 = T.int32()
+ nn_5 = T.int32()
# body
placeholder_d_global = T.decl_buffer([208], "uint8")
placeholder_d_global_1 = T.decl_buffer([112], "uint8")
@@ -579,15 +579,15 @@ def test_reordering_based_on_cycles_luts_present():
def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded:
T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"),
placeholder_1: T.Buffer(256, "int8"), placeholder_encoded_2: T.Buffer(96,
"uint8"), placeholder_2: T.Buffer(256, "int8"), placeholder_3: T.Buffer(256,
"int8"), ethosu_write: T.Buffer(46200, "int8")) -> None:
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main",
"from_legacy_te_schedule": True})
- ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
- ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
- ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
- nn = T.var("int32")
- nn_1 = T.var("int32")
- nn_2 = T.var("int32")
- nn_3 = T.var("int32")
- nn_4 = T.var("int32")
- nn_5 = T.var("int32")
+ ax0_ax1_fused_ax2_fused_ax3_fused = T.int32()
+ ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32()
+ ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32()
+ nn = T.int32()
+ nn_1 = T.int32()
+ nn_2 = T.int32()
+ nn_3 = T.int32()
+ nn_4 = T.int32()
+ nn_5 = T.int32()
# body
placeholder_d_d_global = T.decl_buffer([208], "uint8")
placeholder_d_d_global_1 = T.decl_buffer([112], "uint8")
@@ -629,15 +629,15 @@ def test_reordering_based_on_cycles_luts_present():
def main(placeholder: T.Buffer(97156, "int8"), placeholder_encoded:
T.Buffer(208, "uint8"), placeholder_encoded_1: T.Buffer(112, "uint8"),
placeholder_1: T.Buffer(256, "int8"), placeholder_encoded_2: T.Buffer(96,
"uint8"), placeholder_2: T.Buffer(256, "int8"), placeholder_3: T.Buffer(256,
"int8"), ethosu_write: T.Buffer(46200, "int8")) -> None:
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main",
"from_legacy_te_schedule": True})
- ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
- ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
- ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
- nn = T.var("int32")
- nn_1 = T.var("int32")
- nn_2 = T.var("int32")
- nn_3 = T.var("int32")
- nn_4 = T.var("int32")
- nn_5 = T.var("int32")
+ ax0_ax1_fused_ax2_fused_ax3_fused = T.int32()
+ ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.int32()
+ ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.int32()
+ nn = T.int32()
+ nn_1 = T.int32()
+ nn_2 = T.int32()
+ nn_3 = T.int32()
+ nn_4 = T.int32()
+ nn_5 = T.int32()
# body
placeholder_d_d_global = T.decl_buffer([208], "uint8")
placeholder_d_d_global_1 = T.decl_buffer([112], "uint8")
diff --git a/tests/python/contrib/test_ethosu/test_merge_constants.py
b/tests/python/contrib/test_ethosu/test_merge_constants.py
index 909f9fe673..624bef00c7 100644
--- a/tests/python/contrib/test_ethosu/test_merge_constants.py
+++ b/tests/python/contrib/test_ethosu/test_merge_constants.py
@@ -650,18 +650,18 @@ def test_cycle_count():
def main(buffer2: T.Buffer((128,), "uint8"), buffer3: T.Buffer((32,),
"uint8"), buffer4: T.Buffer((112,), "uint8"), buffer5: T.Buffer((32,),
"uint8"), buffer6: T.Buffer((112,), "uint8"), buffer7: T.Buffer((32,),
"uint8"), buffer8: T.Buffer((112,), "uint8"), buffer9: T.Buffer((32,),
"uint8")) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol":
"main", "tir.noalias": True})
- v1a = T.var("int32")
- v1b = T.var("int32")
- v1c = T.var("int32")
- v2a = T.var("int32")
- v2b = T.var("int32")
- v2c = T.var("int32")
- v3a = T.var("int32")
- v3b = T.var("int32")
- v3c = T.var("int32")
- v4a = T.var("int32")
- v4b = T.var("int32")
- v4c = T.var("int32")
+ v1a = T.int32()
+ v1b = T.int32()
+ v1c = T.int32()
+ v2a = T.int32()
+ v2b = T.int32()
+ v2c = T.int32()
+ v3a = T.int32()
+ v3b = T.int32()
+ v3c = T.int32()
+ v4a = T.int32()
+ v4b = T.int32()
+ v4c = T.int32()
buffer1 = T.Buffer([8192], "int8")
buffer10 = T.Buffer([2048], "int8")
# body
@@ -713,14 +713,14 @@ def test_cycle_count():
def main(buffer2: T.Buffer((160,), "uint8"), buffer4: T.Buffer((144,),
"uint8"), buffer6: T.Buffer((144,), "uint8"), buffer8: T.Buffer((144,),
"uint8")) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol":
"main", "tir.noalias": True})
- v1a = T.var("int32")
- v1c = T.var("int32")
- v2a = T.var("int32")
- v2c = T.var("int32")
- v3a = T.var("int32")
- v3c = T.var("int32")
- v4a = T.var("int32")
- v4c = T.var("int32")
+ v1a = T.int32()
+ v1c = T.int32()
+ v2a = T.int32()
+ v2c = T.int32()
+ v3a = T.int32()
+ v3c = T.int32()
+ v4a = T.int32()
+ v4c = T.int32()
buffer1 = T.Buffer([8192], "int8")
buffer10 = T.Buffer([2048], "int8")
# body
diff --git a/tests/python/integration/test_lower.py
b/tests/python/integration/test_lower.py
index 1ccdde8b13..965ab80beb 100644
--- a/tests/python/integration/test_lower.py
+++ b/tests/python/integration/test_lower.py
@@ -136,8 +136,8 @@ def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle,
handle_c: T.handle)
axis_vk * 16 : axis_vk * 16 +
16,
]
)
- stride0 = T.var("int32")
- stride1 = T.var("int32")
+ stride0 = T.int32()
+ stride1 = T.int32()
match_buffer_a0 = T.match_buffer(
shared_a[
new_axis_vi * 16 : new_axis_vi
* 16 + 16,
@@ -198,8 +198,8 @@ def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle,
handle_c: T.handle)
axis_vk * 16 : axis_vk * 16 +
16,
]
)
- stride0 = T.var("int32")
- stride1 = T.var("int32")
+ stride0 = T.int32()
+ stride1 = T.int32()
match_buffer_b0 = T.match_buffer(
shared_b[
new_axis_vj * 16 : new_axis_vj
* 16 + 16,
@@ -335,8 +335,8 @@ def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle,
handle_c: T.handle)
new_axis_vj * 16 : new_axis_vj * 16 +
16,
]
)
- stride0 = T.var("int32")
- stride1 = T.var("int32")
+ stride0 = T.int32()
+ stride1 = T.int32()
wmma_c2 = T.match_buffer(
wmma_c[
new_axis_vi * 16 : new_axis_vi * 16 +
16,
diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py
b/tests/python/unittest/test_aot_legalize_packed_call.py
index ad970d52c0..6f66f3a432 100644
--- a/tests/python/unittest/test_aot_legalize_packed_call.py
+++ b/tests/python/unittest/test_aot_legalize_packed_call.py
@@ -35,10 +35,10 @@ class Module:
@T.prim_func
def tir_packed_call() -> None:
- A = T.var("handle")
- B = T.var("handle")
- C = T.var("handle")
- device_context = T.var("handle")
+ A = T.handle()
+ B = T.handle()
+ C = T.handle()
+ device_context = T.handle()
# body
T.evaluate(
T.tvm_call_cpacked(
@@ -65,10 +65,10 @@ class Expected:
@T.prim_func
def tir_packed_call() -> None:
- A = T.var("handle")
- B = T.var("handle")
- C = T.var("handle")
- device_context = T.var("handle")
+ A = T.handle()
+ B = T.handle()
+ C = T.handle()
+ device_context = T.handle()
# body
T.evaluate(
diff --git a/tests/python/unittest/test_arith_domain_touched.py
b/tests/python/unittest/test_arith_domain_touched.py
index 9f7eee0963..e19991b3b8 100644
--- a/tests/python/unittest/test_arith_domain_touched.py
+++ b/tests/python/unittest/test_arith_domain_touched.py
@@ -21,7 +21,7 @@ from tvm.script import tir as T
@T.prim_func
def scalar_func(a: T.handle, b: T.handle):
- m = T.var("int32")
+ m = T.int32()
n = 100
A = T.match_buffer(a, (n, m))
B = T.match_buffer(b, (n, m))
@@ -73,7 +73,7 @@ def test_domain_touched_vector():
@T.prim_func
def func(a: T.handle, b: T.handle):
- n = T.var("int32")
+ n = T.int32()
A = T.match_buffer(a, (n * m,))
B = T.match_buffer(b, (n * m,))
diff --git
a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py
b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py
index 59de0b0c57..0facc9b961 100644
--- a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py
+++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py
@@ -399,12 +399,12 @@ def GMMCUDATensorCore(
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
- s0 = T.var("int32")
- s0_1 = T.var("int32")
- s0_2 = T.var("int32")
- s1 = T.var("int32")
- s1_1 = T.var("int32")
- s1_2 = T.var("int32")
+ s0 = T.int32()
+ s0_1 = T.int32()
+ s0_2 = T.int32()
+ s1 = T.int32()
+ s1_1 = T.int32()
+ s1_2 = T.int32()
# body
# with T.block("root")
Z_wmma_accumulator = T.alloc_buffer([1024, 1024], dtype="float32",
scope="wmma.accumulator")
diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py
b/tests/python/unittest/test_meta_schedule_trace_apply.py
index ae65cc1a81..d09f2a226c 100644
--- a/tests/python/unittest/test_meta_schedule_trace_apply.py
+++ b/tests/python/unittest/test_meta_schedule_trace_apply.py
@@ -637,26 +637,26 @@ class Conv2dInt8_tensorcore_scheduled:
def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1,
64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1,
256), "int32"), p4: T.Buffer((1, 1, 1, 256), "int64"), p5: T.Buffer((1, 1, 1,
256), "int64"), p6: T.Buffer((1, 1, 1, 256), "int64"), p7: T.Buffer((),
"int32"), p8: T.Buffer(1, "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"),
compute: T.Buffer((16, 56, 56, 256), "uint8")) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
- A_s0 = T.var("int32")
- A_s0_1 = T.var("int32")
- A_s0_2 = T.var("int32")
- A_s0_3 = T.var("int32")
- A_s1 = T.var("int32")
- A_s1_1 = T.var("int32")
- A_s1_2 = T.var("int32")
- A_s1_3 = T.var("int32")
- B_s0 = T.var("int32")
- B_s1 = T.var("int32")
- C_s0 = T.var("int32")
- C_s0_1 = T.var("int32")
- C_s0_2 = T.var("int32")
- C_s0_3 = T.var("int32")
- C_s0_4 = T.var("int32")
- C_s1 = T.var("int32")
- C_s1_1 = T.var("int32")
- C_s1_2 = T.var("int32")
- C_s1_3 = T.var("int32")
- C_s1_4 = T.var("int32")
+ A_s0 = T.int32()
+ A_s0_1 = T.int32()
+ A_s0_2 = T.int32()
+ A_s0_3 = T.int32()
+ A_s1 = T.int32()
+ A_s1_1 = T.int32()
+ A_s1_2 = T.int32()
+ A_s1_3 = T.int32()
+ B_s0 = T.int32()
+ B_s1 = T.int32()
+ C_s0 = T.int32()
+ C_s0_1 = T.int32()
+ C_s0_2 = T.int32()
+ C_s0_3 = T.int32()
+ C_s0_4 = T.int32()
+ C_s1 = T.int32()
+ C_s1_1 = T.int32()
+ C_s1_2 = T.int32()
+ C_s1_3 = T.int32()
+ C_s1_4 = T.int32()
# body
# with T.block("root")
conv2d_nhwc_reindex_shared = T.alloc_buffer([50176, 256],
dtype="int32", scope="shared")
diff --git a/tests/python/unittest/test_te_create_primfunc.py
b/tests/python/unittest/test_te_create_primfunc.py
index 0b6f87b833..2598d620ba 100644
--- a/tests/python/unittest/test_te_create_primfunc.py
+++ b/tests/python/unittest/test_te_create_primfunc.py
@@ -199,8 +199,8 @@ def te_multi_output():
@T.prim_func
def tir_multi_output(a0: T.handle, a1: T.handle, b0: T.handle, b1: T.handle)
-> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
- m = T.var("int32")
- n = T.var("int32")
+ m = T.int32()
+ n = T.int32()
A0 = T.match_buffer(a0, (m, n))
A1 = T.match_buffer(a1, (m, n))
B0 = T.match_buffer(b0, (m, n))
@@ -491,8 +491,8 @@ def tir_argmax_idx_val(
var_idx: T.handle, var_val: T.handle, var_argmax_v0: T.handle,
var_argmax_v1: T.handle
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
- m = T.var("int32")
- n = T.var("int32")
+ m = T.int32()
+ n = T.int32()
idx = T.match_buffer(var_idx, [m, n], dtype="int32")
val = T.match_buffer(var_val, [m, n], dtype="float32")
argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="int32")
@@ -538,8 +538,8 @@ def tir_argmax_val_idx(
var_val: T.handle, var_idx: T.handle, var_argmax_v0: T.handle,
var_argmax_v1: T.handle
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
- m = T.var("int32")
- n = T.var("int32")
+ m = T.int32()
+ n = T.int32()
val = T.match_buffer(var_val, [m, n], dtype="float32")
idx = T.match_buffer(var_idx, [m, n], dtype="int32")
argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="float32")
@@ -711,8 +711,8 @@ def tir_resize2d_symbolic(
var_resize: T.handle,
):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
- oh = T.var("int64")
- ow = T.var("int64")
+ oh = T.int64()
+ ow = T.int64()
resize = T.match_buffer(var_resize, [T.int64(2), T.int64(3), oh, ow],
dtype="float32")
for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), oh, ow):
with T.block("resize"):
diff --git a/tests/python/unittest/test_tir_analysis_oob.py
b/tests/python/unittest/test_tir_analysis_oob.py
index 83c0294176..7c8ceed36e 100644
--- a/tests/python/unittest/test_tir_analysis_oob.py
+++ b/tests/python/unittest/test_tir_analysis_oob.py
@@ -44,7 +44,7 @@ def bad_store_loop(A: T.Buffer((2, 3), "float32"), B:
T.Buffer((3, 2), "float32"
@T.prim_func
def unknown_bounds(A: T.Buffer((2, 3), "float32"), B: T.Buffer((3, 2),
"float32")):
- N = T.var("int32")
+ N = T.int32()
for i in range(3):
B[0, N] = A[1, i]
diff --git a/tests/python/unittest/test_tir_intrin.py
b/tests/python/unittest/test_tir_intrin.py
index f887f8877a..1ee709191c 100644
--- a/tests/python/unittest/test_tir_intrin.py
+++ b/tests/python/unittest/test_tir_intrin.py
@@ -193,11 +193,11 @@ class Module:
def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) ->
None:
# function attr dict
T.func_attr({"global_symbol": "test_fma", "tir.noalias": True})
- n = T.var("int32")
- stride = T.var("int32")
- stride_1 = T.var("int32")
- stride_2 = T.var("int32")
- stride_3 = T.var("int32")
+ n = T.int32()
+ stride = T.int32()
+ stride_1 = T.int32()
+ stride_2 = T.int32()
+ stride_3 = T.int32()
A_1 = T.match_buffer(
A,
[n],
diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py
b/tests/python/unittest/test_tir_lower_match_buffer.py
index 535e0bb329..5bea77ffe3 100644
--- a/tests/python/unittest/test_tir_lower_match_buffer.py
+++ b/tests/python/unittest/test_tir_lower_match_buffer.py
@@ -93,8 +93,8 @@ def opaque_access(a: T.handle, b: T.handle) -> None:
)
for i, j, k in T.grid(64, 2, 8):
with T.block():
- Bs_0 = T.var("int32")
- Bs_1 = T.var("int32")
+ Bs_0 = T.int32()
+ Bs_1 = T.int32()
T.reads([])
T.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8])
sub_B = T.match_buffer(
@@ -154,8 +154,8 @@ def high_dim_opaque_access(a: T.handle) -> None:
A = T.match_buffer(a, (16, 32, 64))
for i, j, k in T.grid(16, 2, 4):
with T.block():
- As_0 = T.var("int32")
- As_1 = T.var("int32")
+ As_0 = T.int32()
+ As_1 = T.int32()
T.reads([])
T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
sub_A = T.match_buffer(
@@ -200,8 +200,8 @@ def high_dim_opaque_access_with_source_strides(a: T.handle)
-> None:
A = T.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1])
for i, j, k in T.grid(16, 2, 4):
with T.block():
- As_0 = T.var("int32")
- As_1 = T.var("int32")
+ As_0 = T.int32()
+ As_1 = T.int32()
T.reads([])
T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
sub_A = T.match_buffer(
@@ -254,8 +254,8 @@ def recursive_match(a: T.handle, b: T.handle) -> None:
B[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16],
]
)
- As_0 = T.var("int32")
- As_1 = T.var("int32")
+ As_0 = T.int32()
+ As_1 = T.int32()
sub_A = T.match_buffer(
A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16],
(16, 16),
@@ -276,8 +276,8 @@ def recursive_match(a: T.handle, b: T.handle) -> None:
sub_B[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4],
]
)
- Ass_0 = T.var("int32")
- Ass_1 = T.var("int32")
+ Ass_0 = T.int32()
+ Ass_1 = T.int32()
sub_sub_A = T.match_buffer(
sub_A[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4],
(4, 4),
@@ -355,8 +355,8 @@ def symbolic_match(a: T.handle, b: T.handle, n: T.int32, m:
T.int32) -> None:
with T.block():
T.reads([])
T.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m *
4]])
- Bs_0 = T.var("int32")
- Bs_1 = T.var("int32")
+ Bs_0 = T.int32()
+ Bs_1 = T.int32()
sub_A = T.match_buffer(A[i * m : i * m + m, 0:m], (m, m),
offset_factor=1)
sub_B = T.match_buffer(
B[i * n : i * n + 2, 0 : m * 4], (2, m * 4), strides=[Bs_0,
Bs_1], offset_factor=1
@@ -470,7 +470,7 @@ def fail_buffer_bind(a: T.handle) -> None:
A = T.match_buffer(a, (8, 8))
for i, j in T.grid(8, 2):
with T.block():
- stride = T.var("int32")
+ stride = T.int32()
sub_A = T.match_buffer(
A[i, j * 4 : j * 4 + 4], (1, 4), strides=[stride, stride],
offset_factor=1
)
diff --git a/tests/python/unittest/test_tir_renew_defs.py
b/tests/python/unittest/test_tir_renew_defs.py
index e14cd5a898..e01f5ecb12 100644
--- a/tests/python/unittest/test_tir_renew_defs.py
+++ b/tests/python/unittest/test_tir_renew_defs.py
@@ -88,8 +88,8 @@ def test_match_buffer():
# A and B should be remapped
def func_match_buffer(A: T.Buffer((128, 128), "float32"), B:
T.Buffer((128, 128), "float32")):
with T.block("root"):
- s = T.var("int32")
- e = T.var("int32")
+ s = T.int32()
+ e = T.int32()
# A0 should be remapped
A0 = T.match_buffer(
A[0:128, 0:128],
@@ -157,7 +157,7 @@ def test_undefined_buffer():
def test_symbolic_func():
@T.prim_func
def symbolic_func(a: T.handle, b: T.handle, n: T.int32):
- m = T.var("int32")
+ m = T.int32()
A = T.match_buffer(a, (n, m))
B = T.match_buffer(b, (n, m * 2))
for i, j in T.grid(n, m):
diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py
b/tests/python/unittest/test_tir_schedule_rfactor.py
index 766cc3f867..199e822e84 100644
--- a/tests/python/unittest/test_tir_schedule_rfactor.py
+++ b/tests/python/unittest/test_tir_schedule_rfactor.py
@@ -954,7 +954,7 @@ def argmax_split_body_bufferstore_value_unbound_var(
argmax_v0: T.Buffer((128,), "int32"),
argmax_v1: T.Buffer((128,), "float32"),
) -> None:
- v_unbound = T.var("int32")
+ v_unbound = T.int32()
for i0, i1_0, i1_1 in T.grid(128, 4, 32):
with T.block("argmax"):
i = T.axis.spatial(128, i0)
diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py
b/tests/python/unittest/test_tir_schedule_tensorize.py
index 143cf87d04..fcb4bacbba 100644
--- a/tests/python/unittest/test_tir_schedule_tensorize.py
+++ b/tests/python/unittest/test_tir_schedule_tensorize.py
@@ -195,9 +195,9 @@ def tensorized_matmul(a: T.handle, b: T.handle, c:
T.handle) -> None:
]
)
T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
- A_elem_offset = T.var("int32")
- B_elem_offset = T.var("int32")
- C_elem_offset = T.var("int32")
+ A_elem_offset = T.int32()
+ B_elem_offset = T.int32()
+ C_elem_offset = T.int32()
A_sub = T.match_buffer(
A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
[16, 16],
@@ -267,9 +267,9 @@ def tensorized_batch_matmul_mma(
B[vn : vn + 1, vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 +
16],
)
T.writes(C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj *
16 + 16])
- A_elem_offset = T.var("int32")
- B_elem_offset = T.var("int32")
- C_elem_offset = T.var("int32")
+ A_elem_offset = T.int32()
+ B_elem_offset = T.int32()
+ C_elem_offset = T.int32()
A_sub = T.match_buffer(
A[vn : vn + 1, vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 +
16],
(16, 16),
@@ -429,9 +429,9 @@ def annotated_tensorized_matmul(a: T.handle, b: T.handle,
c: T.handle) -> None:
]
)
T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
- A_elem_offset = T.var("int32")
- B_elem_offset = T.var("int32")
- C_elem_offset = T.var("int32")
+ A_elem_offset = T.int32()
+ B_elem_offset = T.int32()
+ C_elem_offset = T.int32()
A_sub = T.match_buffer(
A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
[16, 16],
@@ -745,9 +745,9 @@ def test_tensorize_matmul_mixed_dtype():
]
)
T.writes(C[vi * T.int64(16) : vi * T.int64(16) +
T.int64(16), vj * T.int64(16) : vj * T.int64(16) + T.int64(16)])
- A_elem_offset = T.var("int64")
- B_elem_offset = T.var("int64")
- C_elem_offset = T.var("int64")
+ A_elem_offset = T.int64()
+ B_elem_offset = T.int64()
+ C_elem_offset = T.int64()
A_sub = T.match_buffer(
A[vi * T.int64(16) : vi * T.int64(16) + T.int64(16),
vk * T.int64(16) : vk * T.int64(16) + T.int64(16)],
[T.int64(16), T.int64(16)],
diff --git a/tests/python/unittest/test_tir_specialize.py
b/tests/python/unittest/test_tir_specialize.py
index 72666a89eb..ebae827ef5 100644
--- a/tests/python/unittest/test_tir_specialize.py
+++ b/tests/python/unittest/test_tir_specialize.py
@@ -22,7 +22,7 @@ from tvm.script import tir as T
@T.prim_func
def matmul(a: T.handle, b: T.handle, c: T.handle, n: T.int32) -> None:
- m = T.var("int32")
+ m = T.int32()
A = T.match_buffer(a, [m, n])
B = T.match_buffer(b, [m, n])
C = T.match_buffer(c, [m, m])
@@ -51,7 +51,7 @@ def matmul_128(a: T.handle, b: T.handle, c: T.handle) -> None:
@T.prim_func
def matmul_m_128(a: T.handle, b: T.handle, c: T.handle) -> None:
- m = T.var("int32")
+ m = T.int32()
A = T.match_buffer(a, [m, 128])
B = T.match_buffer(b, [m, 128])
C = T.match_buffer(c, [m, m])
@@ -66,8 +66,8 @@ def matmul_m_128(a: T.handle, b: T.handle, c: T.handle) ->
None:
@T.prim_func
def matmul_m_8x(a: T.handle, b: T.handle, c: T.handle) -> None:
- x = T.var("int32")
- m = T.var("int32")
+ x = T.int32()
+ m = T.int32()
A = T.match_buffer(a, [m, x * 8])
B = T.match_buffer(b, [m, x * 8])
C = T.match_buffer(c, [m, m])
@@ -82,8 +82,8 @@ def matmul_m_8x(a: T.handle, b: T.handle, c: T.handle) ->
None:
@T.prim_func
def element_wise(a: T.handle, c: T.handle) -> None:
- m = T.var("int32")
- n = T.var("int32")
+ m = T.int32()
+ n = T.int32()
A = T.match_buffer(a, (m, n), "float32")
C = T.match_buffer(c, (m, n), "float32")
@@ -119,7 +119,7 @@ def element_wise_128_64(a: T.handle, c: T.handle) -> None:
@T.prim_func
def element_wise_128_n(a: T.handle, c: T.handle) -> None:
- n = T.var("int32")
+ n = T.int32()
A = T.match_buffer(a, (128, n), "float32")
C = T.match_buffer(c, (128, n), "float32")
B = T.alloc_buffer((128, n), "float32")
@@ -170,7 +170,7 @@ def mem_copy_m_n_p_n(a: T.handle, b: T.handle, m: T.int32,
n: T.int32, p: T.int3
@T.prim_func
def param_in_arith_exprs(a: T.handle, b: T.handle) -> None:
- n = T.var("int32")
+ n = T.int32()
A = T.match_buffer(a, [n // 8, 8], "int32")
B = T.match_buffer(b, [n], "int32")
for i in range(n - 1):
@@ -181,7 +181,7 @@ def param_in_arith_exprs(a: T.handle, b: T.handle) -> None:
@T.prim_func
def param_in_arith_exprs_n_16(a: T.handle, b: T.handle) -> None:
- n = T.var("int32")
+ n = T.int32()
A = T.match_buffer(a, [2, 8], "int32")
B = T.match_buffer(b, [16], "int32")
for i in range(15):
diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py
b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py
index 113d9f0474..1755a66ec9 100644
--- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py
+++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py
@@ -359,7 +359,7 @@ def func_distributivity_expected(
i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32
) -> None:
B = T.Buffer((50,), "int32")
- cse_var_1 = T.var("int32")
+ cse_var_1 = T.int32()
with T.let(cse_var_1, x * y + x * z):
B[i1] = cse_var_1
B[i2] = cse_var_1
@@ -377,7 +377,7 @@ def func_associativity_expected(
i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32
) -> None:
B = T.Buffer((50,), "int32")
- cse_var_1 = T.var("int32")
+ cse_var_1 = T.int32()
with T.let(cse_var_1, (x + y) + z):
B[i1] = cse_var_1
B[i2] = cse_var_1
diff --git a/tests/python/unittest/test_tir_transform_hoist_expression.py
b/tests/python/unittest/test_tir_transform_hoist_expression.py
index 77862f64d6..ca37915597 100644
--- a/tests/python/unittest/test_tir_transform_hoist_expression.py
+++ b/tests/python/unittest/test_tir_transform_hoist_expression.py
@@ -447,7 +447,7 @@ class TestHoistLetExpr(BaseBeforeAfter):
@T.prim_func
def before(A: T.Buffer((4, 4), "float32")):
for i, j in T.grid(4, 4):
- x = T.var("float32")
+ x = T.float32()
A[i, j] = T.Let(x, T.cast(i + 1, "float32"), 5.0 * x + T.cast(j,
"float32"))
@T.prim_func
@@ -466,7 +466,7 @@ class TestSuppressHoistLetExpr(BaseBeforeAfter):
@T.prim_func
def before(A: T.Buffer((4, 4), "float32")):
for i, j in T.grid(4, 4):
- x = T.var("float32")
+ x = T.float32()
A[i, j] = T.Let(x, T.cast(i + 1, "float32"), 5.0 * x + T.cast(j,
"float32"))
expected = before
diff --git a/tests/python/unittest/test_tvmscript_error_report.py
b/tests/python/unittest/test_tvmscript_error_report.py
index d2f275ac3d..2713669bd3 100644
--- a/tests/python/unittest/test_tvmscript_error_report.py
+++ b/tests/python/unittest/test_tvmscript_error_report.py
@@ -511,7 +511,7 @@ def test_report_error_root_block():
def test_load_var():
def load_var_multiple() -> None:
- d = T.var("float32")
+ d = T.float32()
d[2] = d[2, 1] # error cannot provide two indices to load
check_error(load_var_multiple, 3)
@@ -519,7 +519,7 @@ def test_load_var():
def test_store_var():
def store_var_multiple() -> None:
- d = T.var("float32")
+ d = T.float32()
d[2, 1] = d[1] # error cannot provide two indices to store
check_error(store_var_multiple, 3)
diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
index 85d2e808b3..889f0c9eda 100644
--- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
@@ -52,7 +52,7 @@ def test_ir_builder_tir_primfunc_complete():
with IRBuilder() as ib:
with T.prim_func():
T.arg("a", T.handle())
- T.arg("b", T.var("int64"))
+ T.arg("b", T.int64())
T.arg("c", T.Buffer((128, 128), "float32"))
d = T.arg("d", T.handle())
e = T.arg("e", T.Buffer((1024,), "int8"))
@@ -119,12 +119,12 @@ def test_ir_builder_tir_block_base():
def test_ir_builder_tir_block_complete():
with IRBuilder() as ib:
- a = T.var("int64", "a")
+ a = T.int64()
b = T.Buffer((128, 128), "float32")
c = T.Buffer((128, 128), "float32")
- d = T.var("int32", "d")
+ d = T.int32()
e = T.Buffer((128, 128), "float32")
- f = T.var("int32", "f")
+ f = T.int32()
with T.block("block"):
T.where(a > 1)
T.reads(b[0:16, 0:16])
@@ -169,10 +169,10 @@ def test_ir_builder_tir_block_complete():
def test_ir_builder_tir_axis():
with IRBuilder() as ib:
- a = T.var("int32", "a")
- b = T.var("int32", "b")
- c = T.var("int32", "c")
- d = T.var("int32", "d")
+ a = T.int32()
+ b = T.int32()
+ c = T.int32()
+ d = T.int32()
with T.block("block"):
T.axis.spatial(8, a)
T.axis.reduce(16, b)
@@ -269,15 +269,13 @@ def test_ir_builder_tir_for():
def test_ir_builder_tir_assert():
with IRBuilder() as ib:
- with T.Assert(T.var("int32", name="a") == 0, message="a is 0"):
+ with T.Assert(T.int32() == 0, message="a is 0"):
T.evaluate(0)
# the assert generated by IRBuilder
assert_actual = ib.get()
# the expected assert statement
- assert_expected = tir.AssertStmt(
- T.var("int32", name="a") == 0, tir.StringImm("a is 0"), tir.Evaluate(0)
- )
+ assert_expected = tir.AssertStmt(T.int32() == 0, tir.StringImm("a is 0"),
tir.Evaluate(0))
# Check if the generated ir is expected
assert_structural_equal(assert_actual, assert_expected, map_free_vars=True)
@@ -285,13 +283,13 @@ def test_ir_builder_tir_assert():
def test_ir_builder_tir_let():
with IRBuilder() as ib:
- with T.let(T.var("int32", name="a"), tir.IntImm("int32", 2)):
+ with T.let(T.int32(), tir.IntImm("int32", 2)):
T.evaluate(0)
# the let binding generated by IRBuilder
let_actual = ib.get()
# the expected Let statement
- let_expected = tir.LetStmt(T.var("int32", name="a"), tir.IntImm("int32",
2), tir.Evaluate(0))
+ let_expected = tir.LetStmt(T.int32(), tir.IntImm("int32", 2),
tir.Evaluate(0))
# Check if the generated ir is expected
assert_structural_equal(let_actual, let_expected, map_free_vars=True)
@@ -381,7 +379,7 @@ def test_ir_builder_tir_allocate_const():
def test_ir_builder_tir_while():
with IRBuilder() as ib:
- with T.While(T.var("int32", "x") > 0):
+ with T.While(T.int32() > 0):
T.evaluate(0)
# the while generated by IRBuilder
@@ -396,7 +394,7 @@ def test_ir_builder_tir_while():
def test_ir_builder_tir_if_then_else():
with IRBuilder() as ib:
- with T.If(T.var("int32", "c") < 12):
+ with T.If(T.int32() < 12):
with T.Then():
T.evaluate(T.int32(0))
with T.Else():
@@ -418,7 +416,7 @@ def test_ir_builder_tir_if_then_else():
def test_ir_builder_tir_buffer_store():
buffer_a = T.Buffer((10, 10), "float32")
- i = T.var("int32", "x")
+ i = T.int32()
with IRBuilder() as ib:
T.buffer_store(buffer_a, 0.1, [0, i])
diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py
b/tests/python/unittest/test_tvmscript_printer_tir.py
index a04544152e..13aaacb3b7 100644
--- a/tests/python/unittest/test_tvmscript_printer_tir.py
+++ b/tests/python/unittest/test_tvmscript_printer_tir.py
@@ -117,9 +117,9 @@ def test_block_realize():
_assert_print(
obj,
"""
-i = T.var("int32")
-j = T.var("int32")
-k = T.var("int32")
+i = T.int32()
+j = T.int32()
+k = T.int32()
with T.block("block"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(64, j)
@@ -248,13 +248,13 @@ for i, j, k in T.grid(128, 128, 128):
def test_let_stmt():
with IRBuilder() as ib:
- with T.let(T.var("float32"), T.float32(10)):
+ with T.let(T.float32(), T.float32(10)):
T.evaluate(0)
obj = ib.get()
_assert_print(
obj,
"""
-v = T.var("float32")
+v = T.float32()
with T.let(v, T.float32(10)):
T.evaluate(0)
""",
@@ -291,14 +291,14 @@ with T.Assert(1, "assertion"):
def test_while():
with IRBuilder() as ib:
- x = T.var("int32")
+ x = T.int32()
with T.While(x < 10):
T.evaluate(0)
obj = ib.get()
_assert_print(
obj,
"""
-v = T.var("int32")
+v = T.int32()
while v < 10:
T.evaluate(0)
""",
@@ -410,7 +410,7 @@ T.evaluate(1)
def test_if_then_else():
with IRBuilder() as ib:
- with T.If(T.var("int32") == 1):
+ with T.If(T.int32() == 1):
with T.Then():
T.evaluate(0)
@@ -418,7 +418,7 @@ def test_if_then_else():
_assert_print(
obj,
"""
-v = T.var("int32")
+v = T.int32()
if v == 1:
T.evaluate(0)
""",
@@ -458,7 +458,7 @@ def test_var():
_assert_print(
a,
"""
-a = T.var("float32")
+a = T.float32()
a""",
)
@@ -468,7 +468,7 @@ def test_size_var():
_assert_print(
a,
"""
-a = T.var("float32")
+a = T.float32()
a""",
)
@@ -478,7 +478,7 @@ def test_iter_var():
_assert_print(
a,
"""
-a = T.var("int32")
+a = T.int32()
T.iter_var(a, T.Range(0, 8), "DataPar", "")
""",
)
@@ -494,7 +494,7 @@ def test_cast():
_assert_print(
obj,
"""
-a = T.var("float32")
+a = T.float32()
T.Cast("float64", a)
""",
)
@@ -521,15 +521,15 @@ def test_binary_arith():
obj = op(a, b)
if sign.isalpha():
expected = """
-a = T.var("float32")
-b = T.var("float32")
+a = T.float32()
+b = T.float32()
T.{}(a, b)""".format(
sign
)
else:
expected = """
-a = T.var("float32")
-b = T.var("float32")
+a = T.float32()
+b = T.float32()
a {} b""".format(
sign
)
@@ -537,28 +537,28 @@ a {} b""".format(
def test_logical():
- a = T.var("bool", "a")
- b = T.var("bool", "b")
+ a = tir.Var("a", "bool")
+ b = tir.Var("b", "bool")
_assert_print(
tir.And(a, b),
"""
-a = T.var("bool")
-b = T.var("bool")
+a = T.bool()
+b = T.bool()
a and b
""",
)
_assert_print(
tir.Or(a, b),
"""
-a = T.var("bool")
-b = T.var("bool")
+a = T.bool()
+b = T.bool()
a or b
""",
)
_assert_print(
tir.Not(a),
"""
-a = T.var("bool")
+a = T.bool()
not a
""",
)
@@ -579,7 +579,7 @@ def test_ramp():
_assert_print(
obj,
"""
-a = T.var("int32")
+a = T.int32()
T.Ramp(a, 1, 32)
""",
)
@@ -601,7 +601,7 @@ def test_let_expr():
_assert_print(
obj,
"""
-x = T.var("int32")
+x = T.int32()
T.let(x, 1, x + 1)
""",
)
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py
b/tests/python/unittest/test_tvmscript_roundtrip.py
index db21223366..48a5999469 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -2904,10 +2904,10 @@ def constant_folding():
def simplify_bracket():
@T.prim_func
def simplify_bracket() -> None:
- a = T.var("int32")
- b = T.var("int32")
- c = T.var("int32")
- d = T.var("int32")
+ a = T.int32()
+ b = T.int32()
+ c = T.int32()
+ d = T.int32()
T.evaluate(a + b * (c + d))
return simplify_bracket
@@ -3039,8 +3039,8 @@ def multiple_commreducer():
def func_div_mod():
@T.prim_func
def func_div_mod():
- a = T.var("int32")
- b = T.var("int32")
+ a = T.int32()
+ b = T.int32()
T.evaluate(a // b)
T.evaluate(a % b)
T.evaluate(T.truncmod(a, b))
@@ -3316,7 +3316,7 @@ def buffer_ramp_access_as_slice_index():
def let_expression():
@T.prim_func
def func():
- x = T.var("int32")
+ x = T.int32()
T.evaluate(T.let(x, 1, x + 1))
return func
@@ -3542,8 +3542,8 @@ def intrinsic_pow():
def let_stmt_var():
@T.prim_func
def func():
- x = T.var("int32")
- y = T.var("int32")
+ x = T.int32()
+ y = T.int32()
with T.let(x, 0):
with T.let(y, 0):
T.evaluate(0)
@@ -3555,8 +3555,8 @@ def let_stmt_var():
def let_stmt_value():
@T.prim_func
def func():
- x = T.var("int32")
- y = T.var("int32")
+ x = T.int32()
+ y = T.int32()
with T.let(x, y):
with T.let(y, 0):
T.evaluate(0)
diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py
b/tests/python/unittest/test_tvmscript_syntax_sugar.py
index a840722bea..e4ba1f7950 100644
--- a/tests/python/unittest/test_tvmscript_syntax_sugar.py
+++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py
@@ -155,9 +155,9 @@ def test_match_buffer_1d():
# dynamic shape gemm
@T.prim_func
def gemm_dyn_shape(a: T.handle, b: T.handle, c: T.handle):
- N = T.var("int32")
- M = T.var("int32")
- K = T.var("int32")
+ N = T.int32()
+ M = T.int32()
+ K = T.int32()
A = T.match_buffer(a, (N, K), "float32")
B = T.match_buffer(b, (K, M), "float32")
C = T.match_buffer(c, (N, M), "float32")