This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a60b815fe5 [TIR] Support cross-threaad reduction lowering with 
thread-broadcasting rewrite (#15192)
a60b815fe5 is described below

commit a60b815fe5f97806cbedad79712fb6e618425afe
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Jul 3 11:36:56 2023 -0700

    [TIR] Support cross-threaad reduction lowering with thread-broadcasting 
rewrite (#15192)
    
    This PR enhances the LowerCrossThreadReduction pass with the
    thread-broadcasting block rewrite.
    
    Specifically, previously whenever a TIR block has thread-broadcast
    behavior (i.e., there exists some thread var which is free for the
    block), we never insert a predicate for the block and therefore the
    generated final code has race condition, which sometimes lead to
    wrong computation results.
    
    This PR enhances the pass by collecting thread var information along
    transformation, and rewrite the thread-broadcast TIR block with
    additional predicate clauses which bound the thread vars and
    effectively state that "only execute the block when `thread_var == 0`".
    Therefore, the race condition issue in such blocks is resolved.
---
 src/tir/transforms/lower_cross_thread_reduction.cc | 158 +++++++++++++++++--
 ...t_tir_transform_lower_cross_thread_reduction.py | 174 +++++++++++++++++++++
 2 files changed, 319 insertions(+), 13 deletions(-)

diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc 
b/src/tir/transforms/lower_cross_thread_reduction.cc
index cc402017e6..79dbfb2a02 100644
--- a/src/tir/transforms/lower_cross_thread_reduction.cc
+++ b/src/tir/transforms/lower_cross_thread_reduction.cc
@@ -25,12 +25,31 @@
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
+#include "../../runtime/thread_storage_scope.h"
+#include "../../support/utils.h"
 #include "../schedule/analysis.h"
 #include "./ir_utils.h"
 
 namespace tvm {
 namespace tir {
 
+using runtime::ThreadScope;
+using support::StartsWith;
+
+// Implement a hash and equality function for ThreadScope so that
+// ThreadScope can serve as map key class
+struct ThreadScopeHash {
+  size_t operator()(const ThreadScope& scope) const {
+    return static_cast<size_t>(scope.rank * 30 + scope.dim_index);
+  }
+};
+
+struct ThreadScopeEqual {
+  bool operator()(const ThreadScope& a, const ThreadScope& b) const {
+    return a.rank == b.rank && a.dim_index == b.dim_index;
+  }
+};
+
 /*!
  * \brief Checks if a loop is bound to threadIdx.x/y/z
  * \brief loop The loop to be checked
@@ -478,6 +497,27 @@ class CrossThreadReductionTransformer : public StmtMutator 
{
     return need ? reduction_loops : std::vector<const ForNode*>{};
   }
 
+  // Check if the input block needs thread broadcast rewrite.
+  // One block needs broadcast rewrite when there exists one or more thread
+  // vars which vars free variables to this block.
+  std::vector<std::pair<ThreadScope, Range>> NeedCrossThreadBroadcast(
+      const BlockRealizeNode* realize) {
+    std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual> 
unbound_thread2range =
+        thread2range_;
+    for (const ForNode* loop : loop_stack_) {
+      if (loop->thread_binding.defined()) {
+        ThreadScope scope = 
ThreadScope::Create(loop->thread_binding.value()->thread_tag);
+        unbound_thread2range.erase(scope);
+      }
+    }
+
+    std::vector<std::pair<ThreadScope, Range>> unbound_thread2range_list;
+    for (auto [scope, range] : unbound_thread2range) {
+      unbound_thread2range_list.emplace_back(scope, range);
+    }
+    return unbound_thread2range_list;
+  }
+
   /*!
    * \brief Given that the input block needs cross-thread reduction, check if 
cross-thread reduction
    * can be applied to the block (i.e., the block satisfies all necessary 
conditions of cross-thread
@@ -578,9 +618,39 @@ class CrossThreadReductionTransformer : public StmtMutator 
{
   Stmt VisitStmt_(const ForNode* loop) final {
     loop_stack_.push_back(loop);
     loop_range_map_.Set(loop->loop_var, Range::FromMinExtent(loop->min, 
loop->extent));
+
+    // Collect loop-thread information:
+    // - when encountering a threadIdx loop, we keep note of its domain and
+    // the "loop var -> thread scope" relation, in order to collect all 
existing
+    // threads within a thread block.
+    // - we are careful about thread block boundary for safety.
+    bool is_block_idx = false;
+    bool is_thread_idx = false;
+    if (loop->kind == ForKind::kThreadBinding) {
+      ThreadScope scope = 
ThreadScope::Create(loop->thread_binding.value()->thread_tag);
+      if (scope.rank == 1 && scope.dim_index >= 0) {
+        is_thread_idx = true;
+        ++thread_idx_depth;
+        thread2range_[scope] = Range::FromMinExtent(loop->min, loop->extent);
+        thread_loop_var2scope_[loop->loop_var.get()] = scope;
+      } else if (scope.rank == 0) {
+        is_block_idx = true;
+        ++block_idx_depth;
+      }
+    }
+
     Stmt result = StmtMutator::VisitStmt_(loop);
     loop_stack_.pop_back();
     loop_range_map_.erase(loop->loop_var);
+    if (is_thread_idx) {
+      --thread_idx_depth;
+    }
+    if (is_block_idx) {
+      --block_idx_depth;
+    }
+    if (is_block_idx || (is_thread_idx && thread_idx_depth == 0 && 
block_idx_depth == 0)) {
+      thread2range_.clear();
+    }
 
     // Replace `result` with the pre-stored result if `loop` appears as a key 
in `loop2new_stmt_`.
     auto it = loop2new_stmt_.find(loop);
@@ -613,14 +683,11 @@ class CrossThreadReductionTransformer : public 
StmtMutator {
     return std::move(new_block);
   }
 
-  Stmt VisitStmt_(const BlockRealizeNode* realize) final {
+  void MakeCrossThreadReduction(const BlockRealizeNode* realize,
+                                const std::vector<const ForNode*> 
reduction_loops) {
     const BlockNode* block = realize->block.get();
-    // Step 1. Check whether cross-thread reduction is needed. If no, skip 
this block.
-    std::vector<const ForNode*> reduction_loops = 
NeedCrossThreadReduction(realize);
-    if (reduction_loops.empty()) {
-      return StmtMutator::VisitStmt_(realize);
-    }
-    // Step 2. Check whether cross-thread reduction can be applied. If no, 
throw an exception on
+
+    // Step 1. Check whether cross-thread reduction can be applied. If no, 
throw an exception on
     // which condition the block violates.
     int n_bound_reduction_loops = 0;
     CommReducer reducer{nullptr};
@@ -629,13 +696,13 @@ class CrossThreadReductionTransformer : public 
StmtMutator {
     Array<PrimExpr> wb_indices{nullptr};
     std::tie(n_bound_reduction_loops, reducer, reduction_buffers, 
combiner_rhs, wb_indices) =
         CheckCanApplyCrossThreadReduction(block, reduction_loops);
-    // Step 3. Before doing the cross-thread reduction, in-thread reduction is 
needed when
+    // Step 2. Before doing the cross-thread reduction, in-thread reduction is 
needed when
     //  - not all the reduction-related loops are bound to thread axes, or
     //  - the block-realize has a non-constant-true predicate.
     bool need_in_thread_reduction =
         n_bound_reduction_loops < static_cast<int>(reduction_loops.size()) ||
         !is_one(realize->predicate);
-    // Step 4. Create intermediate buffers, storing them in `ct_buffers` and
+    // Step 3. Create intermediate buffers, storing them in `ct_buffers` and
     // `it_buffers`. Let the scope block allocate these new buffers.
     Array<Buffer>& new_buffers = block2new_buffers_[block_stack_.back()];
     Array<Buffer> ct_buffers = MakeScratchpads(reduction_buffers, 
/*is_cross_thread_buffer=*/true);
@@ -645,16 +712,76 @@ class CrossThreadReductionTransformer : public 
StmtMutator {
       it_buffers = MakeScratchpads(reduction_buffers, 
/*is_cross_thread_buffer=*/false);
       new_buffers.insert(new_buffers.end(), it_buffers.value().begin(), 
it_buffers.value().end());
     }
-    // Step 5. Transform.
+    // Step 4. Transform.
     loop2new_stmt_[reduction_loops[0]] =
         TransformReductionBlock(realize, it_buffers, ct_buffers, 
reduction_buffers, wb_indices,
                                 reducer, combiner_rhs, reduction_loops);
-    // Step 6. Return an empty statement, because the transformation result 
will be inserted when
-    // returning to the first reduction-related loop.
-    return Stmt{nullptr};
+  }
+
+  Stmt MakeCrossThreadBroadcast(
+      const BlockRealizeNode* realize,
+      const std::vector<std::pair<ThreadScope, Range>>& unbound_thread2range) {
+    // Step 1. Generate loop var for each unbound thread.
+    // Update the block predicate with clauses of `thread_var == min`.
+    PrimExpr predicate = realize->predicate;
+    Array<Var> loop_vars;
+    loop_vars.reserve(unbound_thread2range.size());
+    for (auto [scope, range] : unbound_thread2range) {
+      std::string dim_index(1, static_cast<char>(scope.dim_index + 'x'));
+      Var loop_var("t" + dim_index, range->min->dtype);
+      loop_vars.push_back(loop_var);
+      predicate = (loop_var == range->min) && predicate;
+    }
+
+    // Step 2. Update the BlockRealize with the new predicate.
+    ObjectPtr<BlockRealizeNode> p_realize = 
make_object<BlockRealizeNode>(*realize);
+    p_realize->predicate = std::move(predicate);
+
+    // Step 3. Wrap the updated BlockRealize with the new loops.
+    Stmt body(p_realize);
+    for (int i = 0; i < static_cast<int>(unbound_thread2range.size()); ++i) {
+      std::string dim_index(1, 
static_cast<char>(unbound_thread2range[i].first.dim_index + 'x'));
+      body = For(
+          /*loop_var=*/loop_vars[i],                          //
+          /*min=*/unbound_thread2range[i].second->min,        //
+          /*extent=*/unbound_thread2range[i].second->extent,  //
+          /*kind=*/ForKind::kThreadBinding,                   //
+          /*body=*/body,                                      //
+          /*thread_binding=*/
+          IterVar(NullValue<Range>(), Var(""), IterVarType::kThreadIndex,
+                  "threadIdx." + dim_index));
+    }
+    return body;
+  }
+
+  Stmt VisitStmt_(const BlockRealizeNode* realize) final {
+    // Part 1. Check if the block needs cross-thread reduction rewrite.
+    std::vector<const ForNode*> reduction_loops = 
NeedCrossThreadReduction(realize);
+    if (!reduction_loops.empty()) {
+      // Return an empty statement, because the transformation result will
+      // be inserted when returning to the first reduction-related loop.
+      has_cross_thread_reduction_ = true;
+      MakeCrossThreadReduction(realize, reduction_loops);
+      return Stmt{nullptr};
+    }
+
+    if (!has_cross_thread_reduction_) {
+      return StmtMutator::VisitStmt_(realize);
+    }
+
+    // Part 2. Check if the block needs all-thread broadcasting rewrite.
+    // We only check this when cross-thread reduction was detected.
+    std::vector<std::pair<ThreadScope, Range>> unbound_thread2range =
+        NeedCrossThreadBroadcast(realize);
+    if (!unbound_thread2range.empty()) {
+      return MakeCrossThreadBroadcast(realize, unbound_thread2range);
+    }
+
+    return StmtMutator::VisitStmt_(realize);
   }
 
  private:
+  bool has_cross_thread_reduction_ = false;
   std::vector<const StmtNode*> statement_stack_;
   std::vector<const ForNode*> loop_stack_;
   std::vector<const BlockNode*> block_stack_;
@@ -662,6 +789,11 @@ class CrossThreadReductionTransformer : public StmtMutator 
{
   std::unordered_map<const ForNode*, Stmt> loop2new_stmt_;
   Map<Var, Range> loop_range_map_;
   arith::Analyzer analyzer_;
+
+  int block_idx_depth = 0;
+  int thread_idx_depth = 0;
+  std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual> 
thread2range_;
+  std::unordered_map<const VarNode*, ThreadScope> thread_loop_var2scope_;
 };
 
 PrimFunc LowerCrossThreadReduction(PrimFunc f) {
diff --git 
a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py 
b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
index 8b5c212241..2334fe5350 100644
--- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
+++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
@@ -1274,6 +1274,172 @@ def lowered_layer_norm_tuple_sum(
                 ]
 
 
[email protected]_func
+def thread_broadcast_1(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), 
"float32")):
+    temp_local = T.alloc_buffer((256,), scope="local")
+    for i in T.thread_binding(256, thread="blockIdx.x"):
+        for k in T.thread_binding(256, thread="threadIdx.x"):
+            with T.block("sum"):
+                vi, vk = T.axis.remap("SR", [i, k])
+                T.reads(A[vi, vk])
+                T.writes(temp_local[vi])
+                with T.init():
+                    temp_local[vi] = T.float32(0)
+                temp_local[vi] = temp_local[vi] + A[vi, vk]
+        with T.block("add"):
+            vi = T.axis.spatial(256, i)
+            T.reads(temp_local[vi])
+            T.writes(B[vi])
+            B[vi] = temp_local[vi] + T.float32(1)
+
+
[email protected]_func
+def lowered_thread_broadcast_1(A: T.Buffer((256, 256), "float32"), B: 
T.Buffer((256,), "float32")):
+    temp_local = T.alloc_buffer((256,), scope="local")
+    cross_thread_temp_local = T.alloc_buffer((1,), strides=(1,), scope="local")
+    for i in T.thread_binding(256, thread="blockIdx.x"):
+        for k in T.thread_binding(256, thread="threadIdx.x"):
+            with T.block("sum_cross_thread"):
+                vi, vk = T.axis.remap("SR", [i, k])
+                T.reads(A[vi, vk])
+                T.writes(cross_thread_temp_local[0])
+                T.attr(
+                    T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+                    "reduce_scope",
+                    T.reinterpret("handle", T.uint64(0)),
+                )
+                T.tvm_thread_allreduce(
+                    T.uint32(1), A[vi, vk], T.bool(True), 
cross_thread_temp_local[0], k
+                )
+            with T.block("sum_write_back"):
+                vi = T.axis.spatial(256, i)
+                T.where(k == 0)
+                T.reads(cross_thread_temp_local[0])
+                T.writes(temp_local[vi])
+                temp_local[vi] = cross_thread_temp_local[0]
+        for tx in T.thread_binding(256, thread="threadIdx.x"):
+            with T.block("add"):
+                vi = T.axis.spatial(256, i)
+                T.where(tx == 0)
+                T.reads(temp_local[vi])
+                T.writes(B[vi])
+                B[vi] = temp_local[vi] + T.float32(1)
+
+
+# fmt: off
[email protected]_func
+def thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int64(1), 
T.int64(128)), "float16"), p_lv1606: T.handle, p_lv1582: T.handle, p_output0: 
T.handle):
+    n = T.int64()
+    lv1606 = T.match_buffer(p_lv1606, (T.int64(1), T.int64(32), n, 
T.int64(128)), "float16")
+    lv1582 = T.match_buffer(p_lv1582, (T.int64(1), T.int64(1), T.int64(1), n), 
"float16")
+    var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), 
T.int64(32), T.int64(1), n))
+    var_NT_matmul_intermediate_local = T.alloc_buffer((T.int64(1), 
T.int64(32), T.int64(1), n), "float16", scope="local")
+    var_NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(256), 
T.int64(1), T.int64(32), T.int64(1), n), "float16", scope="local")
+    for ax0_ax1_fused in T.thread_binding(n * T.int64(32), 
thread="blockIdx.x"):
+        for ax2_fused_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+            with T.block("NT_matmul_rf_init"):
+                vax2_fused_1 = T.axis.spatial(T.int64(256), ax2_fused_1)
+                v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n)
+                v1 = T.axis.spatial(n, ax0_ax1_fused % n)
+                T.reads()
+                T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 
T.int64(0), v0, T.int64(0), v1])
+                var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), 
v0, T.int64(0), v1] = T.float16(0)
+            for ax2_fused_0 in range(T.int64(1)):
+                with T.block("NT_matmul_rf_update"):
+                    vax2_fused_1 = T.axis.spatial(T.int64(256), ax2_fused_1)
+                    v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n)
+                    v1 = T.axis.spatial(n, ax0_ax1_fused % n)
+                    vax2_fused_0 = T.axis.reduce(T.int64(1), ax2_fused_0)
+                    T.where(ax2_fused_0 * T.int64(256) + ax2_fused_1 < 
T.int64(128))
+                    T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 
T.int64(0), v0, T.int64(0), v1], lv1605[T.int64(0), v0, T.int64(0), 
vax2_fused_0 * T.int64(256) + vax2_fused_1], lv1606[T.int64(0), v0, v1, 
vax2_fused_0 * T.int64(256) + vax2_fused_1])
+                    T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 
T.int64(0), v0, T.int64(0), v1])
+                    var_NT_matmul_intermediate_rf_local[vax2_fused_1, 
T.int64(0), v0, T.int64(0), v1] = 
var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), 
v1] + lv1605[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(256) + 
vax2_fused_1] * lv1606[T.int64(0), v0, v1, vax2_fused_0 * T.int64(256) + 
vax2_fused_1]
+        for ax1_ax2_fused in range(T.int64(1)):
+            for ax0_fused in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                with T.block("NT_matmul"):
+                    vax2_fused_1 = T.axis.reduce(T.int64(256), ax0_fused)
+                    v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n)
+                    v1 = T.axis.spatial(n, ax0_ax1_fused % n)
+                    T.where(T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused 
// n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n < 
n)
+                    T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 
T.int64(0), v0, T.int64(0), v1])
+                    T.writes(var_NT_matmul_intermediate_local[T.int64(0), v0, 
T.int64(0), v1])
+                    with T.init():
+                        var_NT_matmul_intermediate_local[T.int64(0), v0, 
T.int64(0), v1] = T.float16(0)
+                    var_NT_matmul_intermediate_local[T.int64(0), v0, 
T.int64(0), v1] = var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), 
v1] + var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, 
T.int64(0), v1]
+        with T.block("compute"):
+            v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n)
+            v1 = T.axis.spatial(n, ax0_ax1_fused % n)
+            T.where(T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused // n < 
T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n < n)
+            T.reads(var_NT_matmul_intermediate_local[T.int64(0), v0, 
T.int64(0), v1], lv1582[T.int64(0), T.int64(0), T.int64(0), v1])
+            T.writes(var_compute_intermediate[T.int64(0), v0, T.int64(0), v1])
+            var_compute_intermediate[T.int64(0), v0, T.int64(0), v1] = 
T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_local[T.int64(0), v0, 
T.int64(0), v1] * T.float16(0.088397790055248615), T.float16(-65504)), 
lv1582[T.int64(0), T.int64(0), T.int64(0), v1]))
+
+
[email protected]_func
+def lowered_thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), 
T.int64(1), T.int64(128)), "float16"), p_lv1606: T.handle, p_lv1582: T.handle, 
p_output0: T.handle):
+    n = T.int64()
+    lv1606 = T.match_buffer(p_lv1606, (T.int64(1), T.int64(32), n, 
T.int64(128)), "float16")
+    lv1582 = T.match_buffer(p_lv1582, (T.int64(1), T.int64(1), T.int64(1), n), 
"float16")
+    var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), 
T.int64(32), T.int64(1), n))
+    var_NT_matmul_intermediate_local = T.alloc_buffer((T.int64(1), 
T.int64(32), T.int64(1), n), "float16", scope="local")
+    var_NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(256), 
T.int64(1), T.int64(32), T.int64(1), n), "float16", scope="local")
+    cross_thread_var_NT_matmul_intermediate_local = T.alloc_buffer((1,), 
"float16", strides=(1,), scope="local")
+    in_thread_var_NT_matmul_intermediate_local = T.alloc_buffer((1,), 
"float16", strides=(1,), scope="local")
+    for ax0_ax1_fused in T.thread_binding(n * T.int64(32), 
thread="blockIdx.x"):
+        for ax2_fused_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+            with T.block("NT_matmul_rf_init"):
+                vax2_fused_1 = T.axis.spatial(T.int64(256), ax2_fused_1)
+                v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n)
+                v1 = T.axis.spatial(n, ax0_ax1_fused % n)
+                T.reads()
+                T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 
T.int64(0), v0, T.int64(0), v1])
+                var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), 
v0, T.int64(0), v1] = T.float16(0)
+            for ax2_fused_0 in range(T.int64(1)):
+                with T.block("NT_matmul_rf_update"):
+                    vax2_fused_1 = T.axis.spatial(T.int64(256), ax2_fused_1)
+                    v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n)
+                    v1 = T.axis.spatial(n, ax0_ax1_fused % n)
+                    vax2_fused_0 = T.axis.reduce(T.int64(1), ax2_fused_0)
+                    T.where(ax2_fused_0 * T.int64(256) + ax2_fused_1 < 
T.int64(128))
+                    T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 
T.int64(0), v0, T.int64(0), v1], lv1605[T.int64(0), v0, T.int64(0), 
vax2_fused_0 * T.int64(256) + vax2_fused_1], lv1606[T.int64(0), v0, v1, 
vax2_fused_0 * T.int64(256) + vax2_fused_1])
+                    T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 
T.int64(0), v0, T.int64(0), v1])
+                    var_NT_matmul_intermediate_rf_local[vax2_fused_1, 
T.int64(0), v0, T.int64(0), v1] = 
var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), 
v1] + lv1605[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(256) + 
vax2_fused_1] * lv1606[T.int64(0), v0, v1, vax2_fused_0 * T.int64(256) + 
vax2_fused_1]
+        for ax1_ax2_fused in range(T.int64(1)):
+            for ax0_fused in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                with T.block("NT_matmul_in_thread_init"):
+                    T.reads()
+                    T.writes(in_thread_var_NT_matmul_intermediate_local[0])
+                    in_thread_var_NT_matmul_intermediate_local[0] = 
T.float16(0)
+                with T.block("NT_matmul_in_thread"):
+                    vax2_fused_1 = T.axis.reduce(T.int64(256), ax0_fused)
+                    v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n)
+                    v1 = T.axis.spatial(n, ax0_ax1_fused % n)
+                    T.where(T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused 
// n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n < 
n)
+                    T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 
T.int64(0), v0, T.int64(0), v1])
+                    T.writes(in_thread_var_NT_matmul_intermediate_local[0])
+                    in_thread_var_NT_matmul_intermediate_local[0] = 
in_thread_var_NT_matmul_intermediate_local[0] + 
var_NT_matmul_intermediate_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), 
v1]
+                with T.block("NT_matmul_cross_thread"):
+                    T.reads(in_thread_var_NT_matmul_intermediate_local[0])
+                    T.writes(cross_thread_var_NT_matmul_intermediate_local[0])
+                    T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, 
[T.float16(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
+                    T.tvm_thread_allreduce(T.uint32(1), 
in_thread_var_NT_matmul_intermediate_local[0], T.bool(True), 
cross_thread_var_NT_matmul_intermediate_local[0], ax0_fused)
+                with T.block("NT_matmul_write_back"):
+                    v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n)
+                    v1 = T.axis.spatial(n, ax0_ax1_fused % n)
+                    T.where(ax0_fused == T.int64(0))
+                    T.reads(cross_thread_var_NT_matmul_intermediate_local[0])
+                    T.writes(var_NT_matmul_intermediate_local[T.int64(0), v0, 
T.int64(0), v1])
+                    var_NT_matmul_intermediate_local[T.int64(0), v0, 
T.int64(0), v1] = cross_thread_var_NT_matmul_intermediate_local[0]
+        for tx in T.thread_binding(T.int64(256), thread="threadIdx.x"):
+            with T.block("compute"):
+                v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n)
+                v1 = T.axis.spatial(n, ax0_ax1_fused % n)
+                T.where(tx == T.int64(0) and (T.int64(0) <= ax0_ax1_fused // n 
and ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and 
ax0_ax1_fused % n < n))
+                T.reads(var_NT_matmul_intermediate_local[T.int64(0), v0, 
T.int64(0), v1], lv1582[T.int64(0), T.int64(0), T.int64(0), v1])
+                T.writes(var_compute_intermediate[T.int64(0), v0, T.int64(0), 
v1])
+                var_compute_intermediate[T.int64(0), v0, T.int64(0), v1] = 
T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_local[T.int64(0), v0, 
T.int64(0), v1] * T.float16(0.088397790055248615), T.float16(-65504)), 
lv1582[T.int64(0), T.int64(0), T.int64(0), v1]))
+# fmt: on
+
 # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
 
 
@@ -1358,6 +1524,14 @@ def test_argmin_split_init_update_reordered():
     _check(argmin_split_init_update_reordered, 
lowered_argmin_split_init_update_reordered)
 
 
+def test_thread_broadcast_rewrite_1():
+    _check(thread_broadcast_1, lowered_thread_broadcast_1)
+
+
+def test_thread_broadcast_rewrite_2():
+    _check(thread_broadcast_2, lowered_thread_broadcast_2)
+
+
 def test_lower_te():
     a = te.placeholder((32, 2, 2))
     k1 = te.reduce_axis((0, 2), "k1")

Reply via email to