This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d7f5753  [SPIRV] Support Bool buffer argument (#7591)
d7f5753 is described below

commit d7f57532746680732e58ab028d8c3129b9140d3d
Author: masahi <[email protected]>
AuthorDate: Fri Mar 5 09:55:09 2021 +0900

    [SPIRV] Support Bool buffer argument (#7591)
---
 src/target/spirv/codegen_spirv.cc                  | 24 +++++--
 tests/python/unittest/test_target_codegen_spirv.py | 75 ++++++++++++++++++++++
 2 files changed, 93 insertions(+), 6 deletions(-)

diff --git a/src/target/spirv/codegen_spirv.cc 
b/src/target/spirv/codegen_spirv.cc
index 6311b43..24608eb 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -45,10 +45,15 @@ std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const 
PrimFunc& f, const std::
       if (auto* ptr = arg->type_annotation.as<PointerTypeNode>()) {
         auto* prim = ptr->element_type.as<PrimTypeNode>();
         ICHECK(prim);
-        DataType value_type = prim->dtype;
+        DataType value_storage_type = prim->dtype;
+        if (value_storage_type == DataType::UInt(1)) {
+          // We need a physically addressable buffer type to support boolean 
tensors.
+          // The loaded byte is cast to bool inside the LoadNode visitor below.
+          value_storage_type = DataType::UInt(8);
+        }
         spirv::Value arg_value =
-            builder_->BufferArgument(builder_->GetSType(value_type), 0, 
num_buffer);
-        storage_info_[arg.get()].UpdateContentType(value_type);
+            builder_->BufferArgument(builder_->GetSType(value_storage_type), 
0, num_buffer);
+        storage_info_[arg.get()].UpdateContentType(value_storage_type);
         var_map_[arg.get()] = arg_value;
       } else {
         LOG(FATAL) << "require all handles to be typed";
@@ -369,11 +374,18 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) 
{
     mask |= spv::MemoryAccessVolatileMask;
   }
   if (op->dtype.lanes() == 1) {
-    ICHECK_EQ(info.content_type, op->dtype)
-        << "Vulkan only allow one type access to the same buffer";
     spirv::Value index = MakeValue(op->index);
     spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
-    return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
+    spirv::Value loaded = builder_->MakeValue(spv::OpLoad, content_type, ptr, 
mask);
+    if (op->dtype == DataType::UInt(1)) {
+      // A bool tensor is backed by a byte buffer, we cast to bool here.
+      auto bool_ty = builder_->GetSType(DataType::UInt(1));
+      return builder_->Cast(bool_ty, loaded);
+    } else {
+      ICHECK_EQ(info.content_type, op->dtype)
+          << "Vulkan only allow one type access to the same buffer";
+      return loaded;
+    }
   } else {
     if (op->dtype.element_of() == info.content_type) {
       // because content type is element type, we can only do scalarize load.
diff --git a/tests/python/unittest/test_target_codegen_spirv.py 
b/tests/python/unittest/test_target_codegen_spirv.py
new file mode 100644
index 0000000..2cbf0be
--- /dev/null
+++ b/tests/python/unittest/test_target_codegen_spirv.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm import te
+from tvm.topi.math import cast
+import numpy as np
+
+
+def test_bool_load():
+    def do_copy(A, B, n):
+        ib = tvm.tir.ir_builder.create()
+        A = ib.buffer_ptr(A)
+        B = ib.buffer_ptr(B)
+
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+
+        max_threads = 32
+        ib.scope_attr(bx, "thread_extent", tvm.tir.indexdiv(n + max_threads - 
1, max_threads))
+        ib.scope_attr(tx, "thread_extent", max_threads)
+        tid = bx * max_threads + tx
+
+        with ib.if_scope(tid < n):
+            B[tid] = cast(A[tid], "int32")
+
+        return ib.get()
+
+    n = 1024
+    A = te.placeholder((n,), name="A", dtype="bool")
+    B = te.placeholder((n,), name="B", dtype="int32")
+
+    target = "vulkan"
+
+    if not tvm.testing.device_enabled(target):
+        return
+
+    B = te.extern(
+        A.shape,
+        [A],
+        lambda ins, outs: do_copy(ins[0], outs[0], n),
+        name="bool_copy_ir",
+        dtype="int32",
+    )
+    s = te.create_schedule(B.op)
+
+    with tvm.transform.PassContext(opt_level=3):
+        func = tvm.build(s, [A, B], target)
+
+    ctx = tvm.context(target, 0)
+    a_np = np.random.uniform(size=n) > 0.5
+    b_np = np.zeros((n,), dtype="int32")
+    a = tvm.nd.array(a_np, ctx)
+    b = tvm.nd.array(b_np, ctx)
+    func(a, b)
+    ref = a_np.astype(np.int32)
+    tvm.testing.assert_allclose(b.asnumpy(), ref)
+
+
+if __name__ == "__main__":
+    test_bool_load()

Reply via email to