This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d9c2b9c966 [TIR][BugFix]Ensure the Var's scope is correct (#15406)
d9c2b9c966 is described below
commit d9c2b9c966d1068b3b5d2113e442902dd0f7d7bf
Author: sisleyli <[email protected]>
AuthorDate: Fri Jul 28 10:01:39 2023 +0800
[TIR][BugFix]Ensure the Var's scope is correct (#15406)
* [TIR][BugFix]Ensure Var's scope is correct
* add new testcases
* fix lint
---------
Co-authored-by: Bin Li <[email protected]>
---
src/tir/transforms/unsupported_dtype_legalize.cc | 15 ++++-
.../unittest/test_tir_transform_bf16_legalize.py | 69 ++++++++++++++++++++++
2 files changed, 81 insertions(+), 3 deletions(-)
diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc
b/src/tir/transforms/unsupported_dtype_legalize.cc
index 0040951961..030dbd01ba 100644
--- a/src/tir/transforms/unsupported_dtype_legalize.cc
+++ b/src/tir/transforms/unsupported_dtype_legalize.cc
@@ -79,7 +79,11 @@ class ComputeLegalizePlanner : public StmtExprVisitor {
// remap all intermediate constant buffer to promote data types (fp16/fp32)
if (MatchDType(op->dtype) && op->ConstantAllocationSize() != 0) {
DataType dtype = promote_dtype_.with_lanes(op->dtype.lanes());
- Var buffer_var = Var(op->buffer_var->name_hint,
PointerType(PrimType(dtype)));
+ String storage_scope = "global";
+ if (auto* ptr_type =
op->buffer_var->type_annotation.as<PointerTypeNode>()) {
+ storage_scope = ptr_type->storage_scope;
+ }
+ Var buffer_var = Var(op->buffer_var->name_hint,
PointerType(PrimType(dtype), storage_scope));
(*var_remap_)[op->buffer_var] = buffer_var;
}
return StmtExprVisitor::VisitStmt_(op);
@@ -496,7 +500,11 @@ class StorageLegalizer : public StmtExprMutator {
Stmt VisitStmt_(const AllocateNode* op) final {
if (MatchDType(op->dtype)) {
DataType dtype = GetStorageUIntDType(op->dtype);
- Var buffer_var = Var(op->buffer_var->name_hint,
PointerType(PrimType(dtype)));
+ String storage_scope = "global";
+ if (auto* ptr_type =
op->buffer_var->type_annotation.as<PointerTypeNode>()) {
+ storage_scope = ptr_type->storage_scope;
+ }
+ Var buffer_var = Var(op->buffer_var->name_hint,
PointerType(PrimType(dtype), storage_scope));
var_remap_[op->buffer_var] = buffer_var;
return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition,
op->body));
} else {
@@ -637,7 +645,8 @@ class StorageLegalizer : public StmtExprMutator {
if (auto* elem_type = ptr_type->element_type.as<PrimTypeNode>()) {
if (MatchDType(elem_type->dtype)) {
Var new_var =
- Var(var->name_hint,
PointerType(PrimType(GetStorageUIntDType(elem_type->dtype))));
+ Var(var->name_hint,
PointerType(PrimType(GetStorageUIntDType(elem_type->dtype)),
+ ptr_type->storage_scope));
var_remap_[var] = new_var;
return new_var;
}
diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py
b/tests/python/unittest/test_tir_transform_bf16_legalize.py
index 20de9dc594..e2752e8bbb 100644
--- a/tests/python/unittest/test_tir_transform_bf16_legalize.py
+++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py
@@ -114,5 +114,74 @@ def test_bf16_storage_legalize():
tvm.ir.assert_structural_equal(after, expected)
+def test_bf16_storage_scope():
+ def get_before():
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def main(
+ Aptr: T.handle("bfloat16", storage_scope="shared"),
+ Bptr: T.handle("bfloat16", storage_scope="local"),
+ Dptr: T.handle("bfloat16"),
+ ):
+ T.func_attr({"global_symbol": "main"})
+ A = T.decl_buffer((100,), "bfloat16", data=Aptr)
+ B = T.decl_buffer((100,), "bfloat16", data=Bptr)
+ D = T.decl_buffer((100,), "bfloat16", data=Dptr)
+ C = T.decl_buffer((100,), "bfloat16")
+ for i in T.grid(100):
+ C[i] = A[i] + B[i]
+ D[i] = T.exp(C[i])
+
+ return Before
+
+ def after_compute_legalize():
+ @tvm.script.ir_module
+ class After:
+ @T.prim_func
+ def main(
+ Aptr: T.handle("bfloat16", storage_scope="shared"),
+ Bptr: T.handle("bfloat16", storage_scope="local"),
+ Dptr: T.handle("bfloat16"),
+ ):
+ T.func_attr({"global_symbol": "main"})
+ A = T.decl_buffer((100,), "bfloat16", data=Aptr)
+ B = T.decl_buffer((100,), "bfloat16", data=Bptr)
+ D = T.decl_buffer((100,), "bfloat16", data=Dptr)
+ C = T.decl_buffer((100,), "float32")
+ for i in T.grid(100):
+ C[i] = bf16tof32(A[i]) + bf16tof32(B[i])
+ D[i] = f32tobf16(T.exp(C[i]))
+
+ return After
+
+ def after_storage_legalize():
+ @tvm.script.ir_module
+ class After:
+ @T.prim_func
+ def main(
+ Aptr: T.handle("uint16", storage_scope="shared"),
+ Bptr: T.handle("uint16", storage_scope="local"),
+ Dptr: T.handle("uint16"),
+ ):
+ T.func_attr({"global_symbol": "main"})
+ A = T.decl_buffer((100,), "uint16", data=Aptr)
+ B = T.decl_buffer((100,), "uint16", data=Bptr)
+ D = T.decl_buffer((100,), "uint16", data=Dptr)
+ C = T.decl_buffer((100,), "float32")
+ for i in T.grid(100):
+ C[i] = u16tof32(A[i]) + u16tof32(B[i])
+ D[i] = f32tou16(T.exp(C[i]))
+
+ return After
+
+ before = get_before()
+ after_compute = tvm.tir.transform.BF16ComputeLegalize()(before)
+ after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute)
+ tvm.ir.assert_structural_equal(after_compute, after_compute_legalize())
+ tvm.ir.assert_structural_equal(after_storage, after_storage_legalize())
+
+
if __name__ == "__main__":
test_bf16_storage_legalize()
+ test_bf16_storage_scope()