This is an automated email from the ASF dual-hosted git repository.

xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8613c79aa0 [TIR] Enable Host Func Attribute for PrimFunc (#14020)
8613c79aa0 is described below

commit 8613c79aa027afcf022766e7ed6288a2de779416
Author: Xiyou Zhou <[email protected]>
AuthorDate: Fri Feb 17 23:20:49 2023 -0800

    [TIR] Enable Host Func Attribute for PrimFunc (#14020)
---
 include/tvm/tir/function.h                  |  7 +++
 src/tir/transforms/primfunc_utils.cc        |  4 ++
 tests/python/unittest/test_tir_host_func.py | 79 +++++++++++++++++++++++++++++
 3 files changed, 90 insertions(+)

diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
index e135c26199..48328263fb 100644
--- a/include/tvm/tir/function.h
+++ b/include/tvm/tir/function.h
@@ -325,6 +325,13 @@ constexpr const char* kIsEntryFunc = "tir.is_entry_func";
  */
 constexpr const char* kIsGlobalFunc = "tir.is_global_func";
 
+/*!
+ * \brief Mark the function as run on the host, mutually exclusive with 
kTarget.
+ *
+ * Type: Integer
+ */
+constexpr const char* kIsHostFunc = "tir.is_host_func";
+
 }  // namespace attr
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/transforms/primfunc_utils.cc 
b/src/tir/transforms/primfunc_utils.cc
index d2bb259f99..208077b492 100644
--- a/src/tir/transforms/primfunc_utils.cc
+++ b/src/tir/transforms/primfunc_utils.cc
@@ -30,6 +30,10 @@ namespace tir {
 namespace transform {
 transform::Pass BindTarget(Target target) {
   auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext 
ctx) {
+    if (f->GetAttr<Integer>(tvm::tir::attr::kIsHostFunc) == 1) {
+      return WithAttr(std::move(WithoutAttr(std::move(f), 
tvm::tir::attr::kIsHostFunc)),
+                      tvm::attr::kTarget, 
target->host.value_or(Target("llvm")));
+    }
     return WithAttr(std::move(f), tvm::attr::kTarget, target);
   };
   return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.BindTarget", {});
diff --git a/tests/python/unittest/test_tir_host_func.py 
b/tests/python/unittest/test_tir_host_func.py
new file mode 100644
index 0000000000..ea0ad7ba4a
--- /dev/null
+++ b/tests/python/unittest/test_tir_host_func.py
@@ -0,0 +1,79 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm.script import ir as I
+from tvm.script import tir as T
+from tvm.meta_schedule.testing import te_workload
+
+# pylint: 
disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,missing-class-docstring,missing-function-docstring
+# fmt: off
+
[email protected]_module
+class Module:
+    @T.prim_func
+    def main(
+        A: T.Buffer((729, 729), "float32"),
+        B: T.Buffer((729, 729), "float32"),
+        C: T.Buffer((729, 729), "float32"),
+    ):
+        T.func_attr(
+            {
+                "global_symbol": "test",
+                "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": 
""}),
+                "tir.noalias": True,
+            }
+        )
+        # with T.block("root"):
+        for i, j, k in T.grid(729, 729, 729):
+            with T.block("C"):
+                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
+                T.reads(A[v_i, v_k], B[v_k, v_j])
+                T.writes(C[v_i, v_j])
+                with T.init():
+                    C[v_i, v_j] = T.float32(0)
+                C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
+
+# fmt: on
+# pylint: 
enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,missing-class-docstring,missing-function-docstring
+
+
+def test_host_func():
+    """Test that host functions are not split."""
+    # te schedule copied from test_tir_transform_split_host_device.py
+
+    func = tvm.te.create_prim_func(
+        te_workload.matmul(729, 729, 729, in_dtype="float32", 
out_dtype="float32")
+    )
+    mod = tvm.ir.IRModule({"main": func})
+    target = tvm.target.Target("cuda")
+    mod = tvm.tir.transform.Apply(
+        lambda f: f.with_attr(
+            {
+                "global_symbol": "test",
+                "tir.is_host_func": 1,
+            }
+        )
+    )(mod)
+    mod = tvm.tir.transform.BindTarget(target)(mod)
+    tvm.ir.assert_structural_equal(mod, Module)
+    assert (
+        "tir.is_host_func" not in mod["main"].attrs
+    ), """Target and is_host_func attributes should be mutually exclusive"""
+
+
+if __name__ == "__main__":
+    test_host_func()

Reply via email to