This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 93945ce640 [Unity][CODEGEN] Fix metal codegen when with only single 
working dim (#14627)
93945ce640 is described below

commit 93945ce6400d8002d353546ca408a5a015d94e67
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Apr 14 15:56:33 2023 -0400

    [Unity][CODEGEN] Fix metal codegen when with only single working dim 
(#14627)
    
    [CODEGEN] Fix metal codegen when with only single working dim
    
    This PR fixes the metal introduces by a previous commit that removes
    the workdim remapping that caused issues for kernels with only threadIdx.x
    and blockIdx.x
---
 src/target/source/codegen_metal.cc | 10 +++++++++-
 src/target/source/codegen_metal.h  |  1 +
 2 files changed, 10 insertions(+), 1 deletion(-)

diff --git a/src/target/source/codegen_metal.cc 
b/src/target/source/codegen_metal.cc
index 36ef44bc48..767311cb5a 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -147,6 +147,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
     PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
     stream << " threadIdx [[thread_position_in_threadgroup]]\n";
   }
+  thread_work_dim_ = work_dim;
+
   // the function scope.
   stream << ") {\n";
   int func_scope = this->BeginScope();
@@ -158,8 +160,14 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
 
 void CodeGenMetal::BindThreadIndex(const IterVar& iv) {
   ICHECK(!var_idmap_.count(iv->var.get()));
+  // if we only have threadIdx.x
+  // metal will directly print as threadIdx
+  std::string vname = iv->thread_tag;
+  if (thread_work_dim_ <= 1) {
+    vname = vname.substr(0, iv->thread_tag.length() - 2);
+  }
   var_idmap_[iv->var.get()] =
-      CastFromTo(iv->thread_tag, DataType::UInt(thread_index_bits_), 
iv->var.dtype());
+      CastFromTo(vname, DataType::UInt(thread_index_bits_), iv->var.dtype());
 }
 
 void CodeGenMetal::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
diff --git a/src/target/source/codegen_metal.h 
b/src/target/source/codegen_metal.h
index 99332e0046..2564389609 100644
--- a/src/target/source/codegen_metal.h
+++ b/src/target/source/codegen_metal.h
@@ -59,6 +59,7 @@ class CodeGenMetal final : public CodeGenC {
 
  private:
   int thread_index_bits_{32};
+  int thread_work_dim_{0};
   Target target_;
 };
 }  // namespace codegen

Reply via email to