This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 672ce33657 [TIR] Propagate storage scope of undefined vars in
SplitHostDevice. (#11255)
672ce33657 is described below
commit 672ce336571d1a4cd914c3eff27597e9c7a52527
Author: Chris Sullivan <[email protected]>
AuthorDate: Mon May 16 10:22:16 2022 -0700
[TIR] Propagate storage scope of undefined vars in SplitHostDevice. (#11255)
* [TIR] Propogate storage scope of undefined vars in SplitHostDevice.
* Test global.texture for input, output, and intermediate buffers.
---
src/tir/transforms/split_host_device.cc | 7 ++-
tests/python/unittest/test_tir_texture_scope.py | 62 +++++++++++++++++++++++++
2 files changed, 68 insertions(+), 1 deletion(-)
diff --git a/src/tir/transforms/split_host_device.cc
b/src/tir/transforms/split_host_device.cc
index 1b8c150079..85845616f1 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -281,7 +281,12 @@ class HostDeviceSplitter : public StmtMutator {
// Create a new version of v.
auto it = handle_data_type_.find(var.get());
if (it != handle_data_type_.end()) {
- tir::Var new_var(var->name_hint,
PointerType(PrimType((*it).second->dtype)));
+ String storage_scope;
+ if (auto* ptr_type = var->type_annotation.as<PointerTypeNode>()) {
+ storage_scope = ptr_type->storage_scope;
+ }
+ tir::Var new_var(var->name_hint,
+ PointerType(PrimType((*it).second->dtype),
storage_scope));
params.push_back(new_var);
remap_vars.Set(var, new_var);
} else {
diff --git a/tests/python/unittest/test_tir_texture_scope.py
b/tests/python/unittest/test_tir_texture_scope.py
new file mode 100644
index 0000000000..701a1fe77a
--- /dev/null
+++ b/tests/python/unittest/test_tir_texture_scope.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+import tvm
+from tvm.ir.module import IRModule
+from tvm import tir
+from tvm.script import tir as T
+
+
+def test_texture_scope():
+ @tvm.script.ir_module
+ class PlusOneMultTwo:
+ @T.prim_func
+ def main(a: T.handle, b: T.handle) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ A = T.match_buffer(a, (128, 128, 4), dtype="float32",
scope="global.texture")
+ B = T.alloc_buffer((128, 128, 4), dtype="float32",
scope="global.texture")
+ C = T.match_buffer(b, (128, 128, 4), dtype="float32",
scope="global.texture")
+ for block_idx in T.thread_binding(0, 128, thread="blockIdx.x"):
+ for thread_idx in T.thread_binding(0, 128,
thread="threadIdx.x"):
+ for k in T.serial(4):
+ with T.block("B"):
+ vb, vt, vk = T.axis.remap("SSS", [block_idx,
thread_idx, k])
+ B[vb, vt, vk] = A[vb, vt, vk] + T.float32(1)
+ for block_idx in T.thread_binding(0, 128, thread="blockIdx.x"):
+ for thread_idx in T.thread_binding(0, 128,
thread="threadIdx.x"):
+ for k in T.serial(4):
+ with T.block("C"):
+ vb, vt, vk = T.axis.remap("SSS", [block_idx,
thread_idx, k])
+ C[vb, vt, vk] = B[vb, vt, vk] * T.float32(2)
+
+ sch = tir.Schedule(PlusOneMultTwo, debug_mask="all")
+
+ def schedule_block(block):
+ _, _, inner = sch.get_loops(block)
+ sch.vectorize(inner)
+
+ schedule_block(sch.get_block("B"))
+ schedule_block(sch.get_block("C"))
+
+ target = tvm.target.Target("opencl")
+ mod = tvm.build(sch.mod["main"], target=target)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(sys.argv))