This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f83cebb54c [TVMScript] Normalize T.Bind to T.bind for statement
builder convention (#18889)
f83cebb54c is described below
commit f83cebb54c50718fa5f97835172680d5ee25d6a8
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Mar 8 18:29:24 2026 -0400
[TVMScript] Normalize T.Bind to T.bind for statement builder convention
(#18889)
## Summary
- Rename `T.Bind` (capitalized) to `T.bind` (lowercase) to match
TVMScript naming convention: statement builders use lowercase
(`T.evaluate`, `T.buffer_store`, `T.bind`), expression constructors use
capitalized (`T.Cast`, `T.Select`, `T.Let`)
- Keep `Bind = bind` backward-compat alias
- Update parser, printer references, and all test files
## Test plan
- [x] tvmscript tests (771 passed)
- [x] tir-transform tests (346 passed)
- [x] tir-base tests (224 passed)
- [x] pre-commit lint passes
---
python/tvm/script/ir_builder/tir/ir.py | 8 ++++----
python/tvm/script/parser/tir/parser.py | 6 +++---
.../transform/test_s_tir_transform_thread_sync.py | 16 +++++++--------
.../test_tir_analysis_verify_well_formed.py | 24 +++++++++++-----------
.../test_tir_inline_private_functions.py | 4 ++--
.../test_tir_transform_convert_ssa.py | 6 +++---
.../tvmscript/test_tvmscript_ir_builder_tir.py | 4 ++--
.../python/tvmscript/test_tvmscript_printer_tir.py | 2 +-
tests/python/tvmscript/test_tvmscript_roundtrip.py | 4 ++--
.../tvmscript/test_tvmscript_syntax_sugar.py | 4 ++--
10 files changed, 39 insertions(+), 39 deletions(-)
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index ccdfe3fd67..ccc730f805 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -981,7 +981,7 @@ def Assert(condition: PrimExpr, message, error_kind: str =
"RuntimeError") -> fr
return _ffi_api.Assert(condition, error_kind, message) # type:
ignore[attr-defined] # pylint: disable=no-member
-def Bind( # pylint: disable=invalid-name
+def bind(
value: PrimExpr,
type_annotation: Type | None = None, # pylint:
disable=redefined-outer-name
*,
@@ -1024,7 +1024,7 @@ def Let( # pylint: disable=invalid-name
return tir.Let(var, value, expr)
-bind = Bind
+Bind = bind # backward-compat alias
def let(
@@ -1055,9 +1055,9 @@ def let(
def let_expr(v: Var, value: PrimExpr, body: PrimExpr) -> PrimExpr:
return tir.Let(v, value, body)
- @deprecated("T.let", "T.Bind")
+ @deprecated("T.let", "T.bind")
def let_stmt(v: Var, value: PrimExpr) -> Var:
- return Bind(value, var=v)
+ return bind(value, var=v)
if body is None:
return let_stmt(v, value)
diff --git a/python/tvm/script/parser/tir/parser.py
b/python/tvm/script/parser/tir/parser.py
index 660085ba3c..b4d6f88edd 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -145,7 +145,7 @@ def bind_assign_value(self: Parser, node: doc.expr,
var_name: str, value: Any) -
return value
else:
value = tvm.runtime.convert(value)
- var = T.Bind(value)
+ var = T.bind(value)
IRBuilder.name(var_name, var)
return var
@@ -349,7 +349,7 @@ def visit_ann_assign(self: Parser, node: doc.AnnAssign) ->
None:
if not isinstance(ann_var, Var):
self.report_error(node.annotation, "Annotation should be Var")
self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value)
- T.Bind(rhs, var=ann_var)
+ T.bind(rhs, var=ann_var)
@dispatch.register(token="tir", type_name="With")
@@ -467,7 +467,7 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
res.add_callback(partial(res.__exit__, None, None, None))
res.__enter__()
elif isinstance(res, Var):
- # Standalone Var expression (e.g. from T.Bind(value, var=v)) --
+ # Standalone Var expression (e.g. from T.bind(value, var=v)) --
# the Bind statement was already emitted to the parent frame by the
FFI call,
# so just discard the returned Var.
pass
diff --git a/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py
b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py
index ec4b5afe0c..08a51d2655 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py
@@ -113,13 +113,13 @@ def test_sync_bind():
A_shared_1[ax0] = A[blockIdx_x * 512 + ax0]
in_thread_A_temp_1 = T.decl_buffer((1,), data=in_thread_A_temp.data,
scope="local")
in_thread_A_temp_1[0] = T.float32(0)
- A_temp_1 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x])
+ A_temp_1 = T.bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x])
in_thread_A_temp_1[0] = A_temp_1
- A_temp_2 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x +
128])
+ A_temp_2 = T.bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x +
128])
in_thread_A_temp_1[0] = A_temp_2
- A_temp_3 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x +
256])
+ A_temp_3 = T.bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x +
256])
in_thread_A_temp_1[0] = A_temp_3
- A_temp_4 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x +
384])
+ A_temp_4 = T.bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x +
384])
in_thread_A_temp_1[0] = A_temp_4
cross_thread_A_temp_1 = T.decl_buffer((1,),
data=cross_thread_A_temp.data, scope="local")
with T.attr(
@@ -148,13 +148,13 @@ def test_sync_bind():
in_thread_A_temp_1_1 = T.decl_buffer((1,),
data=in_thread_A_temp_1.data, scope="local")
in_thread_A_temp_1_1[0] = T.float32(0)
T.tvm_storage_sync("shared")
- A_temp_1 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x])
+ A_temp_1 = T.bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x])
in_thread_A_temp_1_1[0] = A_temp_1
- A_temp_2 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x +
128])
+ A_temp_2 = T.bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x +
128])
in_thread_A_temp_1_1[0] = A_temp_2
- A_temp_3 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x +
256])
+ A_temp_3 = T.bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x +
256])
in_thread_A_temp_1_1[0] = A_temp_3
- A_temp_4 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x +
384])
+ A_temp_4 = T.bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x +
384])
in_thread_A_temp_1_1[0] = A_temp_4
cross_thread_A_temp_1_1 = T.decl_buffer(
(1,), data=cross_thread_A_temp_1.data, scope="local"
diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
index b89a1cb9c7..d6c1dae3b6 100644
--- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
+++ b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
@@ -92,8 +92,8 @@ def test_error_for_nested_rebind_usage():
@T.prim_func(check_well_formed=False)
def func():
i = T.int32()
- T.Bind(42, var=i)
- T.Bind(42, var=i)
+ T.bind(42, var=i)
+ T.bind(42, var=i)
T.evaluate(i)
with pytest.raises(
@@ -113,9 +113,9 @@ def test_error_for_repeated_binding():
@T.prim_func(check_well_formed=False)
def func():
i = T.int32()
- T.Bind(42, var=i)
+ T.bind(42, var=i)
T.evaluate(i)
- T.Bind(17, var=i)
+ T.bind(17, var=i)
T.evaluate(i)
with pytest.raises(ValueError, match="multiple nested definitions of
variable i"):
@@ -131,12 +131,12 @@ def test_error_for_cross_function_reuse():
class mod:
@T.prim_func
def func1():
- T.Bind(42, var=i)
+ T.bind(42, var=i)
T.evaluate(i)
@T.prim_func
def func2():
- T.Bind(42, var=i)
+ T.bind(42, var=i)
T.evaluate(i)
with pytest.raises(ValueError, match="multiple definitions of variable i"):
@@ -295,10 +295,10 @@ def
test_error_message_without_previous_definition_location():
def func():
x = T.int32()
- T.Bind(42, var=x)
+ T.bind(42, var=x)
T.evaluate(x)
- T.Bind(99, var=x) # This should trigger the error
+ T.bind(99, var=x) # This should trigger the error
T.evaluate(x)
with pytest.raises(ValueError) as exc_info:
@@ -322,8 +322,8 @@ def test_error_message_with_previous_definition_location():
def func():
x = T.int32()
- T.Bind(42, var=x)
- T.Bind(99, var=x) # This should trigger the error
+ T.bind(42, var=x)
+ T.bind(99, var=x) # This should trigger the error
T.evaluate(x)
with pytest.raises(ValueError) as exc_info:
@@ -351,10 +351,10 @@ def test_sequential_redefinition_with_location():
def func():
x = T.int32()
- T.Bind(1, var=x)
+ T.bind(1, var=x)
T.evaluate(x)
- T.Bind(2, var=x) # This should trigger the error
+ T.bind(2, var=x) # This should trigger the error
T.evaluate(x)
with pytest.raises(ValueError) as exc_info:
diff --git a/tests/python/tir-transform/test_tir_inline_private_functions.py
b/tests/python/tir-transform/test_tir_inline_private_functions.py
index e681073fa6..e2f41fda16 100644
--- a/tests/python/tir-transform/test_tir_inline_private_functions.py
+++ b/tests/python/tir-transform/test_tir_inline_private_functions.py
@@ -150,7 +150,7 @@ class TestDeduplicateBlockName(BaseTestCase):
class Expected:
@T.prim_func
def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16],
"float32")):
- A_data_1 = T.Bind(T.address_of(A[0, 0]), T.handle("float32"))
+ A_data_1 = T.bind(T.address_of(A[0, 0]), T.handle("float32"))
A_1 = T.decl_buffer(16, "float32", data=A_data_1)
B_data_1: T.handle("float32") = T.address_of(B[0, 0])
B_1 = T.decl_buffer(16, "float32", data=B_data_1)
@@ -158,7 +158,7 @@ class TestDeduplicateBlockName(BaseTestCase):
with T.sblock("scalar_mul_1"):
B_1[i] = A_1[i] * 2.0
- A_data_2 = T.Bind(T.address_of(A[1, 0]), T.handle("float32"))
+ A_data_2 = T.bind(T.address_of(A[1, 0]), T.handle("float32"))
A_2 = T.decl_buffer(16, "float32", data=A_data_2)
B_data_2: T.handle("float32") = T.address_of(B[1, 0])
B_2 = T.decl_buffer(16, "float32", data=B_data_2)
diff --git a/tests/python/tir-transform/test_tir_transform_convert_ssa.py
b/tests/python/tir-transform/test_tir_transform_convert_ssa.py
index df69bc384d..625001bf9f 100644
--- a/tests/python/tir-transform/test_tir_transform_convert_ssa.py
+++ b/tests/python/tir-transform/test_tir_transform_convert_ssa.py
@@ -42,9 +42,9 @@ def test_reuse_in_sequential_bind():
@T.prim_func(private=True)
def expected():
- var1 = T.Bind(T.int32(16))
+ var1 = T.bind(T.int32(16))
T.evaluate(var1)
- var2 = T.Bind(T.int32(32))
+ var2 = T.bind(T.int32(32))
T.evaluate(var2)
mod = tvm.IRModule.from_expr(before)
@@ -108,7 +108,7 @@ def test_reused_var_across_module():
@T.prim_func(private=True)
def func():
- var = T.Bind(10)
+ var = T.bind(10)
T.evaluate(var)
before = tvm.IRModule(
diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
index ee45aebeda..460457601a 100644
--- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
@@ -315,9 +315,9 @@ def test_ir_builder_tir_assert():
def test_ir_builder_tir_bind():
- # Test that T.Bind emits a flat Bind statement and returns the Var.
+ # Test that T.bind emits a flat Bind statement and returns the Var.
with IRBuilder() as ib:
- v = T.Bind(tir.IntImm("int32", 2))
+ v = T.bind(tir.IntImm("int32", 2))
# the let binding generated by IRBuilder
let_actual = ib.get()
diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py
b/tests/python/tvmscript/test_tvmscript_printer_tir.py
index 7bf0c9f1d0..406a8c6a79 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py
@@ -255,7 +255,7 @@ for i, j, k in T.grid(128, 128, 128):
def test_bind():
with IRBuilder() as ib:
with T.prim_func():
- v = T.Bind(T.float32(10))
+ v = T.bind(T.float32(10))
ib.name("v", v)
T.evaluate(1)
obj = ib.get()
diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py
b/tests/python/tvmscript/test_tvmscript_roundtrip.py
index 040932720a..ab64737ce1 100644
--- a/tests/python/tvmscript/test_tvmscript_roundtrip.py
+++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py
@@ -2729,8 +2729,8 @@ def intrinsic_pow():
def bind_var():
@T.prim_func
def func():
- x = T.Bind(0)
- y = T.Bind(0)
+ x = T.bind(0)
+ y = T.bind(0)
T.evaluate(0)
T.evaluate(0)
diff --git a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py
b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py
index f3d19f8eba..cc707a8ccf 100644
--- a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py
+++ b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py
@@ -410,7 +410,7 @@ def test_preserve_trivial_let_binding():
@T.prim_func
def explicit(i: T.int32):
j = T.int32()
- T.Bind(i, var=j)
+ T.bind(i, var=j)
T.evaluate(j)
@T.prim_func
@@ -425,7 +425,7 @@ def test_preserve_trivial_let_binding_of_value():
@T.prim_func
def explicit(i: T.int32):
j = T.int32()
- T.Bind(42, var=j)
+ T.bind(42, var=j)
T.evaluate(j)
@T.prim_func