This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 672ce33657 [TIR] Propagate storage scope of undefined vars in 
SplitHostDevice. (#11255)
672ce33657 is described below

commit 672ce336571d1a4cd914c3eff27597e9c7a52527
Author: Chris Sullivan <[email protected]>
AuthorDate: Mon May 16 10:22:16 2022 -0700

    [TIR] Propagate storage scope of undefined vars in SplitHostDevice. (#11255)
    
    * [TIR] Propogate storage scope of undefined vars in SplitHostDevice.
    
    * Test global.texture for input, output, and intermediate buffers.
---
 src/tir/transforms/split_host_device.cc         |  7 ++-
 tests/python/unittest/test_tir_texture_scope.py | 62 +++++++++++++++++++++++++
 2 files changed, 68 insertions(+), 1 deletion(-)

diff --git a/src/tir/transforms/split_host_device.cc 
b/src/tir/transforms/split_host_device.cc
index 1b8c150079..85845616f1 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -281,7 +281,12 @@ class HostDeviceSplitter : public StmtMutator {
         // Create a new version of v.
         auto it = handle_data_type_.find(var.get());
         if (it != handle_data_type_.end()) {
-          tir::Var new_var(var->name_hint, 
PointerType(PrimType((*it).second->dtype)));
+          String storage_scope;
+          if (auto* ptr_type = var->type_annotation.as<PointerTypeNode>()) {
+            storage_scope = ptr_type->storage_scope;
+          }
+          tir::Var new_var(var->name_hint,
+                           PointerType(PrimType((*it).second->dtype), 
storage_scope));
           params.push_back(new_var);
           remap_vars.Set(var, new_var);
         } else {
diff --git a/tests/python/unittest/test_tir_texture_scope.py 
b/tests/python/unittest/test_tir_texture_scope.py
new file mode 100644
index 0000000000..701a1fe77a
--- /dev/null
+++ b/tests/python/unittest/test_tir_texture_scope.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+import tvm
+from tvm.ir.module import IRModule
+from tvm import tir
+from tvm.script import tir as T
+
+
+def test_texture_scope():
+    @tvm.script.ir_module
+    class PlusOneMultTwo:
+        @T.prim_func
+        def main(a: T.handle, b: T.handle) -> None:
+            T.func_attr({"global_symbol": "main", "tir.noalias": True})
+            A = T.match_buffer(a, (128, 128, 4), dtype="float32", 
scope="global.texture")
+            B = T.alloc_buffer((128, 128, 4), dtype="float32", 
scope="global.texture")
+            C = T.match_buffer(b, (128, 128, 4), dtype="float32", 
scope="global.texture")
+            for block_idx in T.thread_binding(0, 128, thread="blockIdx.x"):
+                for thread_idx in T.thread_binding(0, 128, 
thread="threadIdx.x"):
+                    for k in T.serial(4):
+                        with T.block("B"):
+                            vb, vt, vk = T.axis.remap("SSS", [block_idx, 
thread_idx, k])
+                            B[vb, vt, vk] = A[vb, vt, vk] + T.float32(1)
+            for block_idx in T.thread_binding(0, 128, thread="blockIdx.x"):
+                for thread_idx in T.thread_binding(0, 128, 
thread="threadIdx.x"):
+                    for k in T.serial(4):
+                        with T.block("C"):
+                            vb, vt, vk = T.axis.remap("SSS", [block_idx, 
thread_idx, k])
+                            C[vb, vt, vk] = B[vb, vt, vk] * T.float32(2)
+
+    sch = tir.Schedule(PlusOneMultTwo, debug_mask="all")
+
+    def schedule_block(block):
+        _, _, inner = sch.get_loops(block)
+        sch.vectorize(inner)
+
+    schedule_block(sch.get_block("B"))
+    schedule_block(sch.get_block("C"))
+
+    target = tvm.target.Target("opencl")
+    mod = tvm.build(sch.mod["main"], target=target)
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main(sys.argv))

Reply via email to