This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 35f4896bb1 [Unity][Relax] Set Shape Function to Be Host Function
(#14090)
35f4896bb1 is described below
commit 35f4896bb1fa6563a7384908d5b41c63a41b8507
Author: Xiyou Zhou <[email protected]>
AuthorDate: Wed Feb 22 12:50:15 2023 -0800
[Unity][Relax] Set Shape Function to Be Host Function (#14090)
Set shape function to be host func.
---
src/relax/backend/vm/vm_shape_lower.cc | 5 +++++
tests/python/relax/test_backend_transform_shape_lower.py | 1 +
2 files changed, 6 insertions(+)
diff --git a/src/relax/backend/vm/vm_shape_lower.cc
b/src/relax/backend/vm/vm_shape_lower.cc
index 090bcf01b5..f4b272979b 100644
--- a/src/relax/backend/vm/vm_shape_lower.cc
+++ b/src/relax/backend/vm/vm_shape_lower.cc
@@ -531,6 +531,11 @@ class VMShapeLowerMutator
// the shape_func to indicate that this is a host function
// This could require us to attach target to the relax function here.
tir::PrimFunc shape_func(params, body, ret_type, buffer_map);
+ if (shape_func->attrs.GetAttr<tvm::Target>(tvm::attr::kTarget) == nullptr)
{
+ // kTarget and kIsHostFunc are mutually exclusive
+ shape_func =
+ WithAttr<tir::PrimFunc>(std::move(shape_func),
tvm::tir::attr::kIsHostFunc, Integer(1));
+ }
GlobalVar shape_func_var = builder_->AddFunction(shape_func, "shape_func");
builder_->Emit(Call(shape_func_var, {shape_heap_}), "_");
return to_compute.size();
diff --git a/tests/python/relax/test_backend_transform_shape_lower.py
b/tests/python/relax/test_backend_transform_shape_lower.py
index 5cd104dd01..9c11b352c8 100644
--- a/tests/python/relax/test_backend_transform_shape_lower.py
+++ b/tests/python/relax/test_backend_transform_shape_lower.py
@@ -178,6 +178,7 @@ def test_symbolic_compute():
@T.prim_func
def shape_func(H: T.Buffer(T.int64(4), "int64")):
# generated compute function
+ T.func_attr({"tir.is_host_func": 1})
H[T.int64(sindex["k+1"])] = H[T.int64(sindex["k"])] + T.int64(1)
@R.function