This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 93945ce640 [Unity][CODEGEN] Fix metal codegen when with only single
working dim (#14627)
93945ce640 is described below
commit 93945ce6400d8002d353546ca408a5a015d94e67
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Apr 14 15:56:33 2023 -0400
[Unity][CODEGEN] Fix metal codegen when with only single working dim
(#14627)
[CODEGEN] Fix metal codegen when with only single working dim
This PR fixes the metal introduces by a previous commit that removes
the workdim remapping that caused issues for kernels with only threadIdx.x
and blockIdx.x
---
src/target/source/codegen_metal.cc | 10 +++++++++-
src/target/source/codegen_metal.h | 1 +
2 files changed, 10 insertions(+), 1 deletion(-)
diff --git a/src/target/source/codegen_metal.cc
b/src/target/source/codegen_metal.cc
index 36ef44bc48..767311cb5a 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -147,6 +147,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
stream << " threadIdx [[thread_position_in_threadgroup]]\n";
}
+ thread_work_dim_ = work_dim;
+
// the function scope.
stream << ") {\n";
int func_scope = this->BeginScope();
@@ -158,8 +160,14 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
void CodeGenMetal::BindThreadIndex(const IterVar& iv) {
ICHECK(!var_idmap_.count(iv->var.get()));
+ // if we only have threadIdx.x
+ // metal will directly print as threadIdx
+ std::string vname = iv->thread_tag;
+ if (thread_work_dim_ <= 1) {
+ vname = vname.substr(0, iv->thread_tag.length() - 2);
+ }
var_idmap_[iv->var.get()] =
- CastFromTo(iv->thread_tag, DataType::UInt(thread_index_bits_),
iv->var.dtype());
+ CastFromTo(vname, DataType::UInt(thread_index_bits_), iv->var.dtype());
}
void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
diff --git a/src/target/source/codegen_metal.h
b/src/target/source/codegen_metal.h
index 99332e0046..2564389609 100644
--- a/src/target/source/codegen_metal.h
+++ b/src/target/source/codegen_metal.h
@@ -59,6 +59,7 @@ class CodeGenMetal final : public CodeGenC {
private:
int thread_index_bits_{32};
+ int thread_work_dim_{0};
Target target_;
};
} // namespace codegen