This is an automated email from the ASF dual-hosted git repository.
xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8613c79aa0 [TIR] Enable Host Func Attribute for PrimFunc (#14020)
8613c79aa0 is described below
commit 8613c79aa027afcf022766e7ed6288a2de779416
Author: Xiyou Zhou <[email protected]>
AuthorDate: Fri Feb 17 23:20:49 2023 -0800
[TIR] Enable Host Func Attribute for PrimFunc (#14020)
---
include/tvm/tir/function.h | 7 +++
src/tir/transforms/primfunc_utils.cc | 4 ++
tests/python/unittest/test_tir_host_func.py | 79 +++++++++++++++++++++++++++++
3 files changed, 90 insertions(+)
diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
index e135c26199..48328263fb 100644
--- a/include/tvm/tir/function.h
+++ b/include/tvm/tir/function.h
@@ -325,6 +325,13 @@ constexpr const char* kIsEntryFunc = "tir.is_entry_func";
*/
constexpr const char* kIsGlobalFunc = "tir.is_global_func";
+/*!
+ * \brief Mark the function as run on the host, mutually exclusive with
kTarget.
+ *
+ * Type: Integer
+ */
+constexpr const char* kIsHostFunc = "tir.is_host_func";
+
} // namespace attr
} // namespace tir
} // namespace tvm
diff --git a/src/tir/transforms/primfunc_utils.cc
b/src/tir/transforms/primfunc_utils.cc
index d2bb259f99..208077b492 100644
--- a/src/tir/transforms/primfunc_utils.cc
+++ b/src/tir/transforms/primfunc_utils.cc
@@ -30,6 +30,10 @@ namespace tir {
namespace transform {
transform::Pass BindTarget(Target target) {
auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext
ctx) {
+ if (f->GetAttr<Integer>(tvm::tir::attr::kIsHostFunc) == 1) {
+ return WithAttr(std::move(WithoutAttr(std::move(f),
tvm::tir::attr::kIsHostFunc)),
+ tvm::attr::kTarget,
target->host.value_or(Target("llvm")));
+ }
return WithAttr(std::move(f), tvm::attr::kTarget, target);
};
return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.BindTarget", {});
diff --git a/tests/python/unittest/test_tir_host_func.py
b/tests/python/unittest/test_tir_host_func.py
new file mode 100644
index 0000000000..ea0ad7ba4a
--- /dev/null
+++ b/tests/python/unittest/test_tir_host_func.py
@@ -0,0 +1,79 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm.script import ir as I
+from tvm.script import tir as T
+from tvm.meta_schedule.testing import te_workload
+
+# pylint:
disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,missing-class-docstring,missing-function-docstring
+# fmt: off
+
[email protected]_module
+class Module:
+ @T.prim_func
+ def main(
+ A: T.Buffer((729, 729), "float32"),
+ B: T.Buffer((729, 729), "float32"),
+ C: T.Buffer((729, 729), "float32"),
+ ):
+ T.func_attr(
+ {
+ "global_symbol": "test",
+ "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag":
""}),
+ "tir.noalias": True,
+ }
+ )
+ # with T.block("root"):
+ for i, j, k in T.grid(729, 729, 729):
+ with T.block("C"):
+ v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
+ T.reads(A[v_i, v_k], B[v_k, v_j])
+ T.writes(C[v_i, v_j])
+ with T.init():
+ C[v_i, v_j] = T.float32(0)
+ C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
+
+# fmt: on
+# pylint:
enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,missing-class-docstring,missing-function-docstring
+
+
+def test_host_func():
+ """Test that host functions are not split."""
+ # te schedule copied from test_tir_transform_split_host_device.py
+
+ func = tvm.te.create_prim_func(
+ te_workload.matmul(729, 729, 729, in_dtype="float32",
out_dtype="float32")
+ )
+ mod = tvm.ir.IRModule({"main": func})
+ target = tvm.target.Target("cuda")
+ mod = tvm.tir.transform.Apply(
+ lambda f: f.with_attr(
+ {
+ "global_symbol": "test",
+ "tir.is_host_func": 1,
+ }
+ )
+ )(mod)
+ mod = tvm.tir.transform.BindTarget(target)(mod)
+ tvm.ir.assert_structural_equal(mod, Module)
+ assert (
+ "tir.is_host_func" not in mod["main"].attrs
+ ), """Target and is_host_func attributes should be mutually exclusive"""
+
+
+if __name__ == "__main__":
+ test_host_func()