This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a64d1f1cc3 [TIR] Make T.reinterpret nop when dtype is the same (#16879)
a64d1f1cc3 is described below
commit a64d1f1cc37da7f202d943c2bea7eb747e624599
Author: Wuwei Lin <[email protected]>
AuthorDate: Sun Apr 14 08:21:30 2024 -0700
[TIR] Make T.reinterpret nop when dtype is the same (#16879)
* [TIR] Make T.reinterpret nop when dtype is the same
* fix scalable vec handling
---
python/tvm/tir/op.py | 4 ++--
src/tir/op/op.cc | 8 ++++++--
tests/python/codegen/test_target_codegen_cuda.py | 2 +-
.../python/tvmscript/test_tvmscript_parser_tir.py | 22 ++++++++++++++++++++++
4 files changed, 31 insertions(+), 5 deletions(-)
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 8816880e7b..6b72e63f29 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -1789,7 +1789,7 @@ def infinity(dtype: str, span: Optional[Span] = None) ->
Any:
return _ffi_api.infinity(dtype, span) # type: ignore
-def reinterpret(dtype, value) -> Any:
+def reinterpret(dtype, value, span: Optional[Span] = None) -> Any:
"""infinity value of dtype
Parameters
@@ -1808,7 +1808,7 @@ def reinterpret(dtype, value) -> Any:
value : tvm.Expr
The reinterpret cast value of dtype.
"""
- return call_intrin(dtype, "tir.reinterpret", value)
+ return _ffi_api.reinterpret(dtype, value, span) # type: ignore
def exp(x):
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 7f47e66062..b613639786 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -409,8 +409,10 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span
span) {
// reinterpret
PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) {
if (value.dtype() == t) return value;
- ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes())
- << "Bitcast requires size match " << t << " vs " << value.dtype();
+ if (!t.is_scalable_vector() && !value.dtype().is_scalable_vector()) {
+ ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() *
t.lanes())
+ << "Bitcast requires size match " << t << " vs " << value.dtype();
+ }
return tir::Call(t, tir::builtin::reinterpret(), {value}, span);
}
@@ -1083,6 +1085,8 @@
TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc);
TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast);
+TVM_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret);
+
// operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func)
\
TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b,
Span span) { \
diff --git a/tests/python/codegen/test_target_codegen_cuda.py
b/tests/python/codegen/test_target_codegen_cuda.py
index 23ba0fc3ce..112c521d06 100644
--- a/tests/python/codegen/test_target_codegen_cuda.py
+++ b/tests/python/codegen/test_target_codegen_cuda.py
@@ -1120,7 +1120,7 @@ def test_invalid_reinterpret():
@T.prim_func
def func(A: T.Buffer((4,), "uint32"), B: T.Buffer((4,), "uint8")) -> None:
for tx in T.thread_binding(4, "threadIdx.x"):
- B[tx] = T.reinterpret("uint8", A[tx])
+ B[tx] = T.call_intrin("uint8", "tir.reinterpret", A[tx])
with pytest.raises(tvm.error.TVMError):
tvm.build(func, target="cuda")
diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py
b/tests/python/tvmscript/test_tvmscript_parser_tir.py
index 465ffa5cb6..530746a6fc 100644
--- a/tests/python/tvmscript/test_tvmscript_parser_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py
@@ -449,5 +449,27 @@ def test_inferred_sinfo_with_dynamic_buffer():
tvm.ir.assert_structural_equal(func.struct_info, expected)
+def test_reinterpret_nop():
+ """Test builtin reinterpret op"""
+
+ @T.prim_func
+ def func(A: T.Buffer((32,), "float32"), B: T.Buffer((32,), "float32")) ->
None:
+ T.func_attr({"global_symbol": "main"})
+ for i in T.serial(0, 32):
+ with T.block():
+ vi = T.axis.remap("S", [i])
+ B[vi] = T.reinterpret("float32", A[vi])
+
+ @T.prim_func
+ def expected(A: T.Buffer((32,), "float32"), B: T.Buffer((32,), "float32"))
-> None:
+ T.func_attr({"global_symbol": "main"})
+ for i in T.serial(0, 32):
+ with T.block():
+ vi = T.axis.remap("S", [i])
+ B[vi] = A[vi]
+
+ tvm.ir.assert_structural_equal(func, expected)
+
+
if __name__ == "__main__":
tvm.testing.main()