This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d7f5753 [SPIRV] Support Bool buffer argument (#7591)
d7f5753 is described below
commit d7f57532746680732e58ab028d8c3129b9140d3d
Author: masahi <[email protected]>
AuthorDate: Fri Mar 5 09:55:09 2021 +0900
[SPIRV] Support Bool buffer argument (#7591)
---
src/target/spirv/codegen_spirv.cc | 24 +++++--
tests/python/unittest/test_target_codegen_spirv.py | 75 ++++++++++++++++++++++
2 files changed, 93 insertions(+), 6 deletions(-)
diff --git a/src/target/spirv/codegen_spirv.cc
b/src/target/spirv/codegen_spirv.cc
index 6311b43..24608eb 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -45,10 +45,15 @@ std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const
PrimFunc& f, const std::
if (auto* ptr = arg->type_annotation.as<PointerTypeNode>()) {
auto* prim = ptr->element_type.as<PrimTypeNode>();
ICHECK(prim);
- DataType value_type = prim->dtype;
+ DataType value_storage_type = prim->dtype;
+ if (value_storage_type == DataType::UInt(1)) {
+ // We need a physically addressable buffer type to support boolean
tensors.
+ // The loaded byte is cast to bool inside the LoadNode visitor below.
+ value_storage_type = DataType::UInt(8);
+ }
spirv::Value arg_value =
- builder_->BufferArgument(builder_->GetSType(value_type), 0,
num_buffer);
- storage_info_[arg.get()].UpdateContentType(value_type);
+ builder_->BufferArgument(builder_->GetSType(value_storage_type),
0, num_buffer);
+ storage_info_[arg.get()].UpdateContentType(value_storage_type);
var_map_[arg.get()] = arg_value;
} else {
LOG(FATAL) << "require all handles to be typed";
@@ -369,11 +374,18 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op)
{
mask |= spv::MemoryAccessVolatileMask;
}
if (op->dtype.lanes() == 1) {
- ICHECK_EQ(info.content_type, op->dtype)
- << "Vulkan only allow one type access to the same buffer";
spirv::Value index = MakeValue(op->index);
spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
- return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
+ spirv::Value loaded = builder_->MakeValue(spv::OpLoad, content_type, ptr,
mask);
+ if (op->dtype == DataType::UInt(1)) {
+ // A bool tensor is backed by a byte buffer, we cast to bool here.
+ auto bool_ty = builder_->GetSType(DataType::UInt(1));
+ return builder_->Cast(bool_ty, loaded);
+ } else {
+ ICHECK_EQ(info.content_type, op->dtype)
+ << "Vulkan only allow one type access to the same buffer";
+ return loaded;
+ }
} else {
if (op->dtype.element_of() == info.content_type) {
// because content type is element type, we can only do scalarize load.
diff --git a/tests/python/unittest/test_target_codegen_spirv.py
b/tests/python/unittest/test_target_codegen_spirv.py
new file mode 100644
index 0000000..2cbf0be
--- /dev/null
+++ b/tests/python/unittest/test_target_codegen_spirv.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm import te
+from tvm.topi.math import cast
+import numpy as np
+
+
+def test_bool_load():
+ def do_copy(A, B, n):
+ ib = tvm.tir.ir_builder.create()
+ A = ib.buffer_ptr(A)
+ B = ib.buffer_ptr(B)
+
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+
+ max_threads = 32
+ ib.scope_attr(bx, "thread_extent", tvm.tir.indexdiv(n + max_threads -
1, max_threads))
+ ib.scope_attr(tx, "thread_extent", max_threads)
+ tid = bx * max_threads + tx
+
+ with ib.if_scope(tid < n):
+ B[tid] = cast(A[tid], "int32")
+
+ return ib.get()
+
+ n = 1024
+ A = te.placeholder((n,), name="A", dtype="bool")
+ B = te.placeholder((n,), name="B", dtype="int32")
+
+ target = "vulkan"
+
+ if not tvm.testing.device_enabled(target):
+ return
+
+ B = te.extern(
+ A.shape,
+ [A],
+ lambda ins, outs: do_copy(ins[0], outs[0], n),
+ name="bool_copy_ir",
+ dtype="int32",
+ )
+ s = te.create_schedule(B.op)
+
+ with tvm.transform.PassContext(opt_level=3):
+ func = tvm.build(s, [A, B], target)
+
+ ctx = tvm.context(target, 0)
+ a_np = np.random.uniform(size=n) > 0.5
+ b_np = np.zeros((n,), dtype="int32")
+ a = tvm.nd.array(a_np, ctx)
+ b = tvm.nd.array(b_np, ctx)
+ func(a, b)
+ ref = a_np.astype(np.int32)
+ tvm.testing.assert_allclose(b.asnumpy(), ref)
+
+
+if __name__ == "__main__":
+ test_bool_load()