This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 35f4896bb1 [Unity][Relax] Set Shape Function to Be Host Function 
(#14090)
35f4896bb1 is described below

commit 35f4896bb1fa6563a7384908d5b41c63a41b8507
Author: Xiyou Zhou <[email protected]>
AuthorDate: Wed Feb 22 12:50:15 2023 -0800

    [Unity][Relax] Set Shape Function to Be Host Function (#14090)
    
    Set shape function to be host func.
---
 src/relax/backend/vm/vm_shape_lower.cc                   | 5 +++++
 tests/python/relax/test_backend_transform_shape_lower.py | 1 +
 2 files changed, 6 insertions(+)

diff --git a/src/relax/backend/vm/vm_shape_lower.cc 
b/src/relax/backend/vm/vm_shape_lower.cc
index 090bcf01b5..f4b272979b 100644
--- a/src/relax/backend/vm/vm_shape_lower.cc
+++ b/src/relax/backend/vm/vm_shape_lower.cc
@@ -531,6 +531,11 @@ class VMShapeLowerMutator
     // the shape_func to indicate that this is a host function
     // This could require us to attach target to the relax function here.
     tir::PrimFunc shape_func(params, body, ret_type, buffer_map);
+    if (shape_func->attrs.GetAttr<tvm::Target>(tvm::attr::kTarget) == nullptr) 
{
+      // kTarget and kIsHostFunc are mutually exclusive
+      shape_func =
+          WithAttr<tir::PrimFunc>(std::move(shape_func), 
tvm::tir::attr::kIsHostFunc, Integer(1));
+    }
     GlobalVar shape_func_var = builder_->AddFunction(shape_func, "shape_func");
     builder_->Emit(Call(shape_func_var, {shape_heap_}), "_");
     return to_compute.size();
diff --git a/tests/python/relax/test_backend_transform_shape_lower.py 
b/tests/python/relax/test_backend_transform_shape_lower.py
index 5cd104dd01..9c11b352c8 100644
--- a/tests/python/relax/test_backend_transform_shape_lower.py
+++ b/tests/python/relax/test_backend_transform_shape_lower.py
@@ -178,6 +178,7 @@ def test_symbolic_compute():
         @T.prim_func
         def shape_func(H: T.Buffer(T.int64(4), "int64")):
             # generated compute function
+            T.func_attr({"tir.is_host_func": 1})
             H[T.int64(sindex["k+1"])] = H[T.int64(sindex["k"])] + T.int64(1)
 
         @R.function

Reply via email to