This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new d9d9172cf9 [Unity][TARGET] Updates vulkan codegen for DeclBuffer
(#14641)
d9d9172cf9 is described below
commit d9d9172cf936f8490bb543dee8f013133704baba
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Apr 17 09:33:59 2023 -0400
[Unity][TARGET] Updates vulkan codegen for DeclBuffer (#14641)
[TARGET] Updates vulkan codegen for DeclBuffer
This PR updates vulkan codegen to support for DeclBuffer.
We also disabled subgroup operations.
This is because we find that on Nvidia GPU subgroup
were detected to be 8 and the generated code for softmax
produces wrong results.
Setting it to be 32(the typical size of CUDA warp) also does
not produce the right result.
We directly hard set the value to 1 to be ensure correctness.
The team can revisit this behavior later.
---
src/runtime/vulkan/vulkan_device.cc | 6 +++++-
src/target/spirv/codegen_spirv.cc | 2 ++
src/target/spirv/codegen_spirv.h | 1 +
3 files changed, 8 insertions(+), 1 deletion(-)
diff --git a/src/runtime/vulkan/vulkan_device.cc
b/src/runtime/vulkan/vulkan_device.cc
index b3e017d034..23a7631ad8 100644
--- a/src/runtime/vulkan/vulkan_device.cc
+++ b/src/runtime/vulkan/vulkan_device.cc
@@ -144,7 +144,11 @@ VulkanDeviceProperties::VulkanDeviceProperties(const
VulkanInstance& instance,
max_num_threads =
properties.properties.limits.maxComputeWorkGroupInvocations;
// Even if we can't query it, warp size must be at least 1.
- thread_warp_size = std::max(subgroup.subgroupSize, 1U);
+ // thread_warp_size = std::max(subgroup.subgroupSize, 1U);
+ // vulkan's subgroup may not directly map to warp and atm
+ // can cause issues in softmax allreduce in NVidia GPU
+ // disable warp setting to be safe.
+ thread_warp_size = 1U;
max_block_size_x = properties.properties.limits.maxComputeWorkGroupSize[0];
max_block_size_y = properties.properties.limits.maxComputeWorkGroupSize[1];
diff --git a/src/target/spirv/codegen_spirv.cc
b/src/target/spirv/codegen_spirv.cc
index e3ef5acb83..8840eb1f5d 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -687,6 +687,8 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) {
this->VisitStmt(op->body);
}
+void CodeGenSPIRV::VisitStmt_(const DeclBufferNode* op) {
this->VisitStmt(op->body); }
+
void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h
index 08b9db0ee5..1dae5ac3e8 100644
--- a/src/target/spirv/codegen_spirv.h
+++ b/src/target/spirv/codegen_spirv.h
@@ -107,6 +107,7 @@ class CodeGenSPIRV : public ExprFunctor<spirv::Value(const
PrimExpr&)>,
void VisitStmt_(const WhileNode* op) override;
void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const AllocateNode* op) override;
+ void VisitStmt_(const DeclBufferNode* op) override;
void VisitStmt_(const AttrStmtNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const LetStmtNode* op) override;