This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d9c2b9c966 [TIR][BugFix]Ensure the Var's scope is correct (#15406)
d9c2b9c966 is described below

commit d9c2b9c966d1068b3b5d2113e442902dd0f7d7bf
Author: sisleyli <[email protected]>
AuthorDate: Fri Jul 28 10:01:39 2023 +0800

    [TIR][BugFix]Ensure the Var's scope is correct (#15406)
    
    * [TIR][BugFix]Ensure Var's scope is correct
    
    * add new testcases
    
    * fix lint
    
    ---------
    
    Co-authored-by: Bin Li <[email protected]>
---
 src/tir/transforms/unsupported_dtype_legalize.cc   | 15 ++++-
 .../unittest/test_tir_transform_bf16_legalize.py   | 69 ++++++++++++++++++++++
 2 files changed, 81 insertions(+), 3 deletions(-)

diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc 
b/src/tir/transforms/unsupported_dtype_legalize.cc
index 0040951961..030dbd01ba 100644
--- a/src/tir/transforms/unsupported_dtype_legalize.cc
+++ b/src/tir/transforms/unsupported_dtype_legalize.cc
@@ -79,7 +79,11 @@ class ComputeLegalizePlanner : public StmtExprVisitor {
     // remap all intermediate constant buffer to promote data types (fp16/fp32)
     if (MatchDType(op->dtype) && op->ConstantAllocationSize() != 0) {
       DataType dtype = promote_dtype_.with_lanes(op->dtype.lanes());
-      Var buffer_var = Var(op->buffer_var->name_hint, 
PointerType(PrimType(dtype)));
+      String storage_scope = "global";
+      if (auto* ptr_type = 
op->buffer_var->type_annotation.as<PointerTypeNode>()) {
+        storage_scope = ptr_type->storage_scope;
+      }
+      Var buffer_var = Var(op->buffer_var->name_hint, 
PointerType(PrimType(dtype), storage_scope));
       (*var_remap_)[op->buffer_var] = buffer_var;
     }
     return StmtExprVisitor::VisitStmt_(op);
@@ -496,7 +500,11 @@ class StorageLegalizer : public StmtExprMutator {
   Stmt VisitStmt_(const AllocateNode* op) final {
     if (MatchDType(op->dtype)) {
       DataType dtype = GetStorageUIntDType(op->dtype);
-      Var buffer_var = Var(op->buffer_var->name_hint, 
PointerType(PrimType(dtype)));
+      String storage_scope = "global";
+      if (auto* ptr_type = 
op->buffer_var->type_annotation.as<PointerTypeNode>()) {
+        storage_scope = ptr_type->storage_scope;
+      }
+      Var buffer_var = Var(op->buffer_var->name_hint, 
PointerType(PrimType(dtype), storage_scope));
       var_remap_[op->buffer_var] = buffer_var;
       return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, 
op->body));
     } else {
@@ -637,7 +645,8 @@ class StorageLegalizer : public StmtExprMutator {
         if (auto* elem_type = ptr_type->element_type.as<PrimTypeNode>()) {
           if (MatchDType(elem_type->dtype)) {
             Var new_var =
-                Var(var->name_hint, 
PointerType(PrimType(GetStorageUIntDType(elem_type->dtype))));
+                Var(var->name_hint, 
PointerType(PrimType(GetStorageUIntDType(elem_type->dtype)),
+                                                ptr_type->storage_scope));
             var_remap_[var] = new_var;
             return new_var;
           }
diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py 
b/tests/python/unittest/test_tir_transform_bf16_legalize.py
index 20de9dc594..e2752e8bbb 100644
--- a/tests/python/unittest/test_tir_transform_bf16_legalize.py
+++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py
@@ -114,5 +114,74 @@ def test_bf16_storage_legalize():
     tvm.ir.assert_structural_equal(after, expected)
 
 
+def test_bf16_storage_scope():
+    def get_before():
+        @tvm.script.ir_module
+        class Before:
+            @T.prim_func
+            def main(
+                Aptr: T.handle("bfloat16", storage_scope="shared"),
+                Bptr: T.handle("bfloat16", storage_scope="local"),
+                Dptr: T.handle("bfloat16"),
+            ):
+                T.func_attr({"global_symbol": "main"})
+                A = T.decl_buffer((100,), "bfloat16", data=Aptr)
+                B = T.decl_buffer((100,), "bfloat16", data=Bptr)
+                D = T.decl_buffer((100,), "bfloat16", data=Dptr)
+                C = T.decl_buffer((100,), "bfloat16")
+                for i in T.grid(100):
+                    C[i] = A[i] + B[i]
+                    D[i] = T.exp(C[i])
+
+        return Before
+
+    def after_compute_legalize():
+        @tvm.script.ir_module
+        class After:
+            @T.prim_func
+            def main(
+                Aptr: T.handle("bfloat16", storage_scope="shared"),
+                Bptr: T.handle("bfloat16", storage_scope="local"),
+                Dptr: T.handle("bfloat16"),
+            ):
+                T.func_attr({"global_symbol": "main"})
+                A = T.decl_buffer((100,), "bfloat16", data=Aptr)
+                B = T.decl_buffer((100,), "bfloat16", data=Bptr)
+                D = T.decl_buffer((100,), "bfloat16", data=Dptr)
+                C = T.decl_buffer((100,), "float32")
+                for i in T.grid(100):
+                    C[i] = bf16tof32(A[i]) + bf16tof32(B[i])
+                    D[i] = f32tobf16(T.exp(C[i]))
+
+        return After
+
+    def after_storage_legalize():
+        @tvm.script.ir_module
+        class After:
+            @T.prim_func
+            def main(
+                Aptr: T.handle("uint16", storage_scope="shared"),
+                Bptr: T.handle("uint16", storage_scope="local"),
+                Dptr: T.handle("uint16"),
+            ):
+                T.func_attr({"global_symbol": "main"})
+                A = T.decl_buffer((100,), "uint16", data=Aptr)
+                B = T.decl_buffer((100,), "uint16", data=Bptr)
+                D = T.decl_buffer((100,), "uint16", data=Dptr)
+                C = T.decl_buffer((100,), "float32")
+                for i in T.grid(100):
+                    C[i] = u16tof32(A[i]) + u16tof32(B[i])
+                    D[i] = f32tou16(T.exp(C[i]))
+
+        return After
+
+    before = get_before()
+    after_compute = tvm.tir.transform.BF16ComputeLegalize()(before)
+    after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute)
+    tvm.ir.assert_structural_equal(after_compute, after_compute_legalize())
+    tvm.ir.assert_structural_equal(after_storage, after_storage_legalize())
+
+
 if __name__ == "__main__":
     test_bf16_storage_legalize()
+    test_bf16_storage_scope()

Reply via email to