This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 2260bba731 [Unity] Allocate workspace for all functions (#15118)
2260bba731 is described below

commit 2260bba731e4cfd90b667ce50854e70b816978c5
Author: Lite Ye <[email protected]>
AuthorDate: Mon Jun 19 15:19:02 2023 -0400

    [Unity] Allocate workspace for all functions (#15118)
    
    Allocate workspace for all functions
---
 src/relax/transform/allocate_workspace.cc          | 16 +++++++---
 .../relax/test_transform_allocate_workspace.py     | 36 ++++++++++++++++++++--
 2 files changed, 45 insertions(+), 7 deletions(-)

diff --git a/src/relax/transform/allocate_workspace.cc 
b/src/relax/transform/allocate_workspace.cc
index 95bbfbee7c..4b26b590ef 100644
--- a/src/relax/transform/allocate_workspace.cc
+++ b/src/relax/transform/allocate_workspace.cc
@@ -125,11 +125,17 @@ class WorkspaceProvider : ExprMutator {
       builder_->GetContextIRModule()->Remove(GetRef<GlobalVar>(gvar));
     }
 
-    auto gvar = mod_->GetGlobalVar("main");
-    auto func = Downcast<Function>(mod_->Lookup(gvar));
-    auto new_func = Function(func->params, VisitExpr(func->body), 
func->ret_struct_info,
-                             func->is_pure, func->attrs);
-    builder_->UpdateFunction(gvar, new_func);
+    for (const auto& [gvar, f] : mod_->functions) {
+      workspace_var_main_ = Var();
+      if (!f->IsInstance<relax::FunctionNode>() || 
f->GetAttr<String>(attr::kCodegen) ||
+          f->GetAttr<String>(attr::kComposite)) {
+        continue;
+      }
+      auto func = Downcast<Function>(mod_->Lookup(gvar));
+      auto new_func = Function(func->params, VisitExpr(func->body), 
func->ret_struct_info,
+                               func->is_pure, func->attrs);
+      builder_->UpdateFunction(gvar, new_func);
+    }
     return builder_->GetContextIRModule();
   }
 
diff --git a/tests/python/relax/test_transform_allocate_workspace.py 
b/tests/python/relax/test_transform_allocate_workspace.py
index 7ffbd01b05..aca6ea2fe8 100644
--- a/tests/python/relax/test_transform_allocate_workspace.py
+++ b/tests/python/relax/test_transform_allocate_workspace.py
@@ -55,7 +55,7 @@ class Module:
         return gv1
 
     @R.function
-    def main(
+    def entry_a(
         q: R.Tensor((32, 8, 16, 8), dtype="float16"),
         k: R.Tensor((32, 8, 16, 8), dtype="float16"),
         v: R.Tensor((32, 8, 16, 8), dtype="float16"),
@@ -68,6 +68,20 @@ class Module:
             R.output(gv)
         return gv
 
+    @R.function
+    def entry_b(
+        q: R.Tensor((32, 8, 16, 8), dtype="float16"),
+        k: R.Tensor((32, 8, 16, 8), dtype="float16"),
+        v: R.Tensor((32, 8, 16, 8), dtype="float16"),
+    ) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
+        cls = Module
+        with R.dataflow():
+            gv: R.Tensor((32, 8, 16, 8), dtype="float16") = 
cls.fused_relax_nn_attention_cutlass(
+                q, k, v
+            ) + R.const(1, dtype="float16")
+            R.output(gv)
+        return gv
+
 
 @I.ir_module
 class Expected:
@@ -105,7 +119,7 @@ class Expected:
         return gv1
 
     @R.function
-    def main(
+    def entry_a(
         q: R.Tensor((32, 8, 16, 8), dtype="float16"),
         k: R.Tensor((32, 8, 16, 8), dtype="float16"),
         v: R.Tensor((32, 8, 16, 8), dtype="float16"),
@@ -122,6 +136,24 @@ class Expected:
             R.output(gv)
         return gv
 
+    @R.function
+    def entry_b(
+        q: R.Tensor((32, 8, 16, 8), dtype="float16"),
+        k: R.Tensor((32, 8, 16, 8), dtype="float16"),
+        v: R.Tensor((32, 8, 16, 8), dtype="float16"),
+    ) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
+        cls = Expected
+        with R.dataflow():
+            lv: R.Object = R.vm.alloc_storage(R.shape([65536]), 
R.prim_value(0), R.dtype("uint8"))
+            workspace_main: R.Tensor((65536,), dtype="uint8") = 
R.vm.alloc_tensor(
+                lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8")
+            )
+            gv: R.Tensor((32, 8, 16, 8), dtype="float16") = 
cls.fused_relax_nn_attention_cutlass1(
+                q, k, v, workspace_main
+            ) + R.const(1, dtype="float16")
+            R.output(gv)
+        return gv
+
 
 def test_single_attention():
     rewritten = relax.transform.AllocateWorkspace()(Module)

Reply via email to