This is an automated email from the ASF dual-hosted git repository.
expye pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 468bf2da79 [TIR][Transform] Introduce new `InjectPermutedLayout` pass
(#16070)
468bf2da79 is described below
commit 468bf2da7902974047977a7f31ffbfcfaa422eac
Author: Yixin Dong <[email protected]>
AuthorDate: Sun Nov 12 01:25:40 2023 -0800
[TIR][Transform] Introduce new `InjectPermutedLayout` pass (#16070)
* 1104
1104
1106
* 1106
* try fix ci
---
src/tir/transforms/inject_permuted_layout.cc | 401 +++++++++++----------
.../test_tir_transform_inject_permuted_layout.py | 351 ++++++++++++++++++
2 files changed, 569 insertions(+), 183 deletions(-)
diff --git a/src/tir/transforms/inject_permuted_layout.cc
b/src/tir/transforms/inject_permuted_layout.cc
index a1afbeae6f..cccf2c505a 100644
--- a/src/tir/transforms/inject_permuted_layout.cc
+++ b/src/tir/transforms/inject_permuted_layout.cc
@@ -19,44 +19,55 @@
/*!
* \file inject_permuted_layout.cc
- * \brief The pass for inject permuted layout.
+ * \brief The pass injects permuted layout for shared memory buffers to avoid
bank conflicts.
*/
-
#include <tvm/arith/analyzer.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
+#include "../../arith/ir_mutator_with_analyzer.h"
+#include "../../runtime/thread_storage_scope.h"
#include "../../support/utils.h"
-#include "../ir/functor_common.h"
#include "ir_utils.h"
namespace tvm {
namespace tir {
-using tir::Block;
-using tir::BlockRealize;
-using tir::Call;
-using tir::For;
+using namespace arith;
+using namespace runtime;
-class PermutedLayoutInjector : public StmtExprMutator {
+class PermutedLayoutInjector : private IRMutatorWithAnalyzer {
public:
- PermutedLayoutInjector() {}
+ static PrimFunc Transform(PrimFunc func) {
+ Analyzer analyzer;
+
+ auto new_body = PermutedLayoutInjector(func, &analyzer)(func->body);
+ auto func_node = func.CopyOnWrite();
+ func_node->body = new_body;
+ return func;
+ }
private:
- Array<PrimExpr> GetNewIndices(PrimExpr s0, PrimExpr s1, int smem_width) {
- // index after vectorize(8)
- PrimExpr i = s0, j = floordiv(s1, 8), v = floormod(s1, 8);
- PrimExpr permuted_j;
- // In the following comments, each number represent a 8 * fp16 load
- // which is correspond to a index (i, j) in line 50's PrimExpr
- // Each 8 number correspond to 32 memory bank (every bank has 32 bit):
- // 8 * 8 * 16bit = 32 * 32bit
- // And we have 32 banks in total, so all loads in one column share
- // same memory bank
- if (smem_width % 64 == 0) {
- // use 8 * 8 permuted
+ explicit PermutedLayoutInjector(PrimFunc func, Analyzer* analyzer)
+ : IRMutatorWithAnalyzer(analyzer) {
+ buffer_map_.insert(func->buffer_map.begin(), func->buffer_map.end());
+ }
+
+ using IRMutatorWithAnalyzer::VisitExpr_;
+ using IRMutatorWithAnalyzer::VisitStmt_;
+
+ Array<PrimExpr> PermuteIndices(PrimExpr row_idx, PrimExpr col_idx, int
row_size) {
+ ICHECK(permute_);
+ // Index after vectorizing by 8
+ PrimExpr col_idx_outer = floordiv(col_idx, VECTORIZE_FACTOR),
+ col_idx_inner = floormod(col_idx, VECTORIZE_FACTOR);
+ PrimExpr new_col_idx_outer;
+ if (row_size % 64 == 0) {
+ // Use 8 * 8 permuted layout
+ // Every number below corresponds to 8 consecutive fp16 number in shared
mem, i.e. one read
+ // Every row below corresponds to 32 banks
// 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
// 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
// 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5
@@ -65,10 +76,13 @@ class PermutedLayoutInjector : public StmtExprMutator {
// 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2
// 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1
// 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0
- PrimExpr permuted_j_mod_8 = (floormod(j, 8) ^ floormod(i, 8));
- permuted_j = floordiv(j, 8) * 8 + permuted_j_mod_8;
+ auto row_idx_sub = floormod(row_idx, 8);
+ new_col_idx_outer = col_idx_outer ^ row_idx_sub;
} else {
- // use 8 * 4 permuted
+ ICHECK(row_size % 32 == 0);
+ // Use 8 * 4 permuted layout
+ // Every number below corresponds to 8 consecutive fp16 number in shared
mem, i.e. one read
+ // Every row below corresponds to 16 banks
// 0 1 2 3 ==> 0 1 2 3
// 0 1 2 3 ==> 0 1 2 3
// 0 1 2 3 ==> 1 0 3 2
@@ -77,183 +91,204 @@ class PermutedLayoutInjector : public StmtExprMutator {
// 0 1 2 3 ==> 2 3 0 1
// 0 1 2 3 ==> 3 2 1 0
// 0 1 2 3 ==> 3 2 1 0
- // in 8 number each line view:
+ // View with 8 elements per row:
// 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3
// 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
// 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
// 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
- permuted_j = floormod(j, 4) ^ floordiv(floormod(i, 8), 2);
+ auto row_idx_sub = floormod(row_idx, 8);
+ new_col_idx_outer = col_idx_outer ^ floordiv(row_idx_sub, 2);
}
- return {s0, permuted_j * 8 + v};
+ return {row_idx, analyzer_->Simplify(new_col_idx_outer * 8 +
col_idx_inner)};
}
- Stmt VisitStmt_(const BlockRealizeNode* _op) final {
- BlockRealize br = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(_op));
- BlockRealizeNode* op = br.CopyOnWrite();
- if (op->block->annotations.count("permuted_layout") == 0) {
- return br;
+ static bool CheckAnnotation(ObjectRef annotation) {
+ if (auto* node = annotation.as<StringObj>()) {
+ // Support string annotation for backward compatibility
+ return GetRef<String>(node) != "";
+ } else if (auto* node = annotation.as<IntImmNode>()) {
+ return node->value != 0;
+ } else {
+ LOG(FATAL) << "Invalid permuted layout annotation: " << annotation;
}
- String val =
Downcast<String>(op->block->annotations.at("permuted_layout"));
- if (val.empty()) return br;
- Block blk = op->block;
- Stmt body = blk->body;
- if (support::StartsWith(val, "g2s")) {
- // Case 1. Rewrite global to share.dyn
-
- // Step 1.1. Handle case when have local stage
- // Block with local stage is like
- // body {
- // SeqStmt {
- // seq[0]: local <- global
- // seq[1]: shared.dyn <- local
- // }
- // }
- // We only need to rewrite seq[1]
- bool have_local_stage = (body.as<SeqStmtNode>() != nullptr);
- Stmt upper_loop;
- if (have_local_stage) {
- SeqStmt seq = Downcast<SeqStmt>(body);
- ICHECK(seq->size() == 2);
- upper_loop = seq->seq[0];
- body = seq->seq[1];
- }
-
- // Step 1.2. get inner loop body
- std::vector<const ForNode*> loops;
- while (const ForNode* loop = body.as<ForNode>()) {
- loops.push_back(loop);
- body = loop->body;
- }
- Optional<PrimExpr> if_then_else_condition = NullOpt;
- const BufferStoreNode* store = body.as<BufferStoreNode>();
- if (!store) {
- // Case 1.2.1. IfThenElse generated by reverse_compute_inline
- // It is always like
- // if condition:
- // loop_body
- // We just extract the inner loop body inside IfThenElseNode
- const IfThenElseNode* if_then_else = body.as<IfThenElseNode>();
- store = if_then_else->then_case.as<BufferStoreNode>();
- ICHECK(!if_then_else->else_case);
- if_then_else_condition = if_then_else->condition;
- }
- ICHECK(store) << body;
-
- // Step 1.3. Get smem width and refuse to make any difference if invalid
- auto smem_width = store->buffer->shape[1].as<IntImmNode>()->value;
- if (smem_width % 32 != 0) {
- LOG(WARNING) << "Permuted Layout for " << op->block->name_hint
- << " is not supported since its second dimension is not
divisible by 32";
- return br;
- }
- if (smem_width % 64 == 32) {
- if (store->buffer->shape[0].as<IntImmNode>()->value % 2 != 0) {
- LOG(WARNING) << "Permuted Layout for " << op->block->name_hint
- << " is not supported since its first dimension is not
divisible by 2"
- << " and second dimension is not divisible by 64";
- return br;
- }
- }
-
- // Step 1.4. Set corresponding member variable
- if (val.at(4) == 'A') {
- smem_width_A_ = smem_width;
- } else {
- smem_width_B_ = smem_width;
- }
-
- // Step 1.5. Rewrite index
- PrimExpr s0 = store->indices[0];
- PrimExpr s1 = store->indices[1];
- Array<PrimExpr> new_indices = GetNewIndices(s0, s1, smem_width);
- // Step 1.6. Create new BlockRealize
- Stmt new_body = BufferStore(store->buffer, store->value, new_indices);
- if (if_then_else_condition) {
- // Case 1.6.1. Add back IfThenElse
- new_body = IfThenElse(if_then_else_condition.value(), new_body);
- }
- for (int i = loops.size() - 1; i >= 0; i--) {
- const ForNode* loop = loops[i];
- new_body = For(loop->loop_var, loop->min, loop->extent, loop->kind,
new_body,
- loop->thread_binding, loop->annotations);
- }
- if (have_local_stage) {
- // Case 1.6.1. Add back local stage
- new_body = SeqStmt({upper_loop, new_body});
- }
- Block new_blk = Block(blk->iter_vars, blk->reads, blk->writes,
blk->name_hint, new_body,
- blk->init, blk->alloc_buffers, blk->match_buffers,
blk->annotations);
- BlockRealize new_br = BlockRealize(op->iter_values, op->predicate,
new_blk);
- return new_br;
- } else if (support::StartsWith(val, "s2l")) {
- // Case 2. rewrite share.dyn to local
- // Step 2.1. Retrieve previous set member variable
- int smem_width = val.at(4) == 'A' ? smem_width_A_ : smem_width_B_;
- if (smem_width == -1) {
- return br;
- }
-
- // Step 2.2. Rewrite index
- // Body of shared.dyn to local is always
T.evaluate(T.ptx_ldmatrix(args...))
- // Please refer to the load tensor intrinsic
- Evaluate eval = Downcast<Evaluate>(body);
- Call ldmat_call = Downcast<Call>(eval->value);
- ICHECK(ldmat_call->args.size() == 7);
- Array<PrimExpr> new_ldmat_args;
- // Step 2.2.1. Add unchanged args
- for (int i = 0; i < 5; i++) {
- new_ldmat_args.push_back(ldmat_call->args[i]);
- }
- // 5th argument is always a T.tvm_access_ptr call
- // Please refer to the load tensor intrinsic
- Call accptr_call = Downcast<Call>(ldmat_call->args[5]);
- PrimExpr smem_offset = ldmat_call->args[6];
-
- // Step 2.2.2. Create new access ptr call
- Array<PrimExpr> new_accptr_args;
- for (int i = 0; i < 5; i++) {
- // 2th args of T.tvm_access_ptr call is offset, we set it to 0 and
calculate
- // total offset in ldmatrix call
- new_accptr_args.push_back(i == 2 ? 0 : accptr_call->args[i]);
- }
- Call new_accptr_call = Call(accptr_call->dtype, accptr_call->op,
new_accptr_args);
- new_ldmat_args.push_back(new_accptr_call);
-
- // Step 2.2.3. Calculate new offset
- // We convert offset to 2-dimension, reindex it and convert it back
- PrimExpr accptr_offset = accptr_call->args[2];
- PrimExpr offset = smem_offset + accptr_offset;
- PrimExpr s0 = floordiv(offset, smem_width), s1 = floormod(offset,
smem_width);
- Array<PrimExpr> new_indices = GetNewIndices(s0, s1, smem_width);
- PrimExpr new_offset = new_indices[0] * smem_width + new_indices[1];
- new_ldmat_args.push_back(new_offset);
- // Step 2.2.4. Rewrite the rest part
- Call new_ldmat_call = Call(ldmat_call->dtype, ldmat_call->op,
new_ldmat_args);
- Stmt new_body = Evaluate(new_ldmat_call);
- Block new_blk = Block(blk->iter_vars, blk->reads, blk->writes,
blk->name_hint, new_body,
- blk->init, blk->alloc_buffers, blk->match_buffers,
blk->annotations);
- BlockRealize new_br = BlockRealize(op->iter_values, op->predicate,
new_blk);
- return new_br;
+ }
+
+ Stmt VisitStmt_(const BlockNode* op) final {
+ // Record the mapping from buffer data var to buffer for later lookup
+ for (auto buffer : op->alloc_buffers) {
+ buffer_map_.insert({buffer->data, buffer});
+ }
+ for (auto match_buffer : op->match_buffers) {
+ buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer});
+ }
+
+ if (op->annotations.count("permuted_layout") == 0 ||
+ !CheckAnnotation(op->annotations.at("permuted_layout"))) {
+ return IRMutatorWithAnalyzer::VisitStmt_(op);
}
- return StmtExprMutator::VisitStmt_(op);
+ auto prev_permute = permute_;
+ permute_ = true;
+
+ Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));
+
+ permute_ = prev_permute;
+
+ // Erase the permuted_layout annotation after the pass
+ auto block_node = block.CopyOnWrite();
+ block_node->annotations.erase("permuted_layout");
+ return block;
}
- int smem_width_A_ = -1;
- int smem_width_B_ = -1;
-};
+ int CheckAndGetBufferRowSize(Buffer buffer) {
+ CHECK(buffer->shape.size() >= 2)
+ << "The dimension of Buffer \"" << buffer->name << "\" with shape " <<
buffer->shape
+ << " should be at least 2";
-PrimFunc InjectPermutedLayout(PrimFunc func) {
- auto fptr = func.CopyOnWrite();
- fptr->body = PermutedLayoutInjector()(std::move(fptr->body));
- return func;
-}
+ auto dim = buffer->shape.size();
+ auto buffer_row_size = buffer->shape[dim - 1].as<IntImmNode>()->value;
+ auto buffer_col_size = buffer->shape[dim - 2].as<IntImmNode>()->value;
+
+ if (buffer_row_size % 64 != 0) {
+ CHECK(buffer_row_size % 32 == 0)
+ << "Permuted Layout for Buffer \"" << buffer->name << "\" with shape
" << buffer->shape
+ << " is not supported since its second dimension is not divisible by
32";
+ CHECK(buffer_col_size % 2 == 0)
+ << "Permuted Layout for Buffer \"" << buffer->name << "\" with shape
" << buffer->shape
+ << " is not supported since its first dimension is not divisible by
2 and second "
+ "dimension is not divisible by 64";
+ }
+
+ return buffer_row_size;
+ }
+
+ Array<PrimExpr> HandleBufferIndices(Buffer buffer, Array<PrimExpr> indices) {
+ auto buffer_row_size = CheckAndGetBufferRowSize(buffer);
+
+ // Mutate the last two indices
+ auto indices_size = indices.size();
+ PrimExpr row_idx = indices[indices_size - 2];
+ PrimExpr col_idx = indices[indices_size - 1];
+ auto new_indices = PermuteIndices(row_idx, col_idx, buffer_row_size);
+ indices.Set(indices_size - 2, new_indices[0]);
+ indices.Set(indices_size - 1, new_indices[1]);
+ return indices;
+ }
+
+ Stmt VisitStmt_(const BufferStoreNode* op) final {
+ // Rewrite write from global to shared.dyn or shared
+ // We assume the shape of the shared memory is [..., row_size, col_size],
+ // where row_size is divisible by 64, or divisible by 32 and col_size is
divisible by 2.
+ auto store = Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));
+
+ if (!permute_ || store->buffer->shape.size() < 2) {
+ return store;
+ }
+
+ auto scope = StorageScope::Create(GetPtrStorageScope(store->buffer->data));
+ if (scope.rank != StorageRank::kShared) {
+ return store;
+ }
+
+ auto store_node = store.CopyOnWrite();
+ store_node->indices = HandleBufferIndices(store_node->buffer,
store_node->indices);
+ return store;
+ }
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ // Rewrite load from shared or shared.dyn to global
+ auto load = Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
+
+ if (!permute_ || load->buffer->shape.size() < 2) {
+ return load;
+ }
+
+ auto scope = StorageScope::Create(GetPtrStorageScope(load->buffer->data));
+ if (scope.rank != StorageRank::kShared) {
+ return load;
+ }
+
+ auto load_node = load.CopyOnWrite();
+ load_node->indices = HandleBufferIndices(load_node->buffer,
load_node->indices);
+ return load;
+ }
+
+ PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr, Optional<PrimExpr>
offset = NullOpt) {
+ // The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and
accumulate it to
+ // smem_offset
+ CHECK(access_ptr->IsInstance<CallNode>())
+ << "Invalid access ptr for permuted layout: " << access_ptr;
+ auto access_ptr_call = Downcast<Call>(access_ptr);
+ CHECK(access_ptr_call->op.same_as(builtin::tvm_access_ptr()))
+ << "Invalid access ptr for permuted layout: " << access_ptr;
+
+ auto buffer_map_iter =
buffer_map_.find(Downcast<Var>(access_ptr_call->args[1]));
+ CHECK(buffer_map_iter != buffer_map_.end())
+ << "The buffer corresponding to data Var " << access_ptr_call->args[1]
<< " is not found";
+ int buffer_row_size = CheckAndGetBufferRowSize(buffer_map_iter->second);
+
+ PrimExpr smem_offset = access_ptr_call->args[2] + (offset.defined() ?
offset.value() : 0);
+
+ // Convert offset to 2-dimension, reindex it and convert it back
+ PrimExpr row_idx = floordiv(smem_offset, buffer_row_size);
+ PrimExpr col_idx = floormod(smem_offset, buffer_row_size);
+
+ auto new_indices = PermuteIndices(row_idx, col_idx, buffer_row_size);
+ auto new_offset = analyzer_->Simplify(new_indices[0] * buffer_row_size +
new_indices[1]);
+
+ auto new_access_ptr = access_ptr_call.CopyOnWrite();
+ new_access_ptr->args.Set(2, new_offset);
+ return access_ptr_call;
+ }
+
+ PrimExpr VisitExpr_(const CallNode* op) final {
+ // Rewrite from/to shared or shared.dyn to/from local
+ auto call = Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op));
+
+ if (!permute_) {
+ return call;
+ }
+
+ if (!call->op.same_as(builtin::ptx_ldmatrix()) &&
!call->op.same_as(builtin::mma_store())) {
+ return call;
+ }
+
+ if (call->op.same_as(builtin::ptx_ldmatrix())) {
+ // form: T.ptx_ldmatrix(..., smem_ptr, smem_offset)
+ // smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask)
+ auto access_ptr = call->args[5];
+ PrimExpr smem_offset = call->args[6];
+ auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr, smem_offset);
+ auto new_call = call.CopyOnWrite();
+ new_call->args.Set(5, new_access_ptr);
+ new_call->args.Set(6, IntImm(smem_offset->dtype, 0));
+ return call;
+ } else if (call->op.same_as(builtin::mma_store())) {
+ // TODO(yixin): mma_store is not fully tested yet
+ // because we will directly store result to Buffer instead of calling
mma_store now
+ auto access_ptr = call->args[2];
+ auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr);
+ auto new_call = call.CopyOnWrite();
+ new_call->args.Set(2, new_access_ptr);
+ return call;
+ } else {
+ LOG(FATAL) << "Invalid call node: " << call;
+ }
+ }
+
+ static constexpr size_t VECTORIZE_FACTOR = 8;
+ static constexpr size_t BANK_SIZE_BYTES = 128;
+
+ // Mapping from data Var of a Buffer to Buffer, for lookup
+ std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
+ bool permute_ = false;
+};
namespace transform {
Pass InjectPermutedLayout() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
- return InjectPermutedLayout(std::move(f));
+ return PermutedLayoutInjector::Transform(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tir.InjectPermutedLayout", {});
}
diff --git a/tests/python/unittest/test_tir_transform_inject_permuted_layout.py
b/tests/python/unittest/test_tir_transform_inject_permuted_layout.py
new file mode 100644
index 0000000000..6495cdb2bd
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_inject_permuted_layout.py
@@ -0,0 +1,351 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm import IRModule
+from tvm.script import tir as T
+from tvm.tir import PrimFunc
+
+
+def _check_primfunc_transform(before: PrimFunc, expected: PrimFunc):
+ before_module = IRModule.from_expr(before)
+ after_module = tvm.tir.transform.InjectPermutedLayout()(before_module)
+
+ after = after_module["before"].without_attr("global_symbol")
+ expected = expected.without_attr("global_symbol")
+
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+# This pass is adapted from another previous pass, so we need to ensure
backward compatibility here
+def test_backward_compatibility_shared_a():
+ # fmt: off
+ @T.prim_func
+ def before(X: T.Buffer((4096, 4096), "float16")):
+ # with T.block("root"):
+ for blockIdx_y in T.thread_binding(256, thread="blockIdx.y"):
+ for threadIdx_y in T.thread_binding(4, thread="threadIdx.y"):
+ for threadIdx_x in T.thread_binding(32, thread="threadIdx.x"):
+ with T.block(""):
+ T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y * 8 +
threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4 +
97, threadIdx_x % 4 * 8:threadIdx_x % 4 * 8 + 4072])
+ T.writes()
+ for ax2_0_0 in range(128):
+ with T.block(""):
+ T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y
* 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x //
4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 + threadIdx_x % 4 * 8 +
8])
+ T.writes()
+ X_reindex_shared_dyn = T.alloc_buffer((128,
32), "float16", strides=(32, 1), scope="shared.dyn")
+ with T.block("X_reindex_shared.dyn"):
+ T.reads(X[blockIdx_y // 8 * 128 +
threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 +
threadIdx_x // 4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 +
threadIdx_x % 4 * 8 + 8])
+ T.writes(X_reindex_shared_dyn[threadIdx_y
* 8 + threadIdx_x // 4:threadIdx_y * 8 + threadIdx_x // 4 + 97, threadIdx_x % 4
* 8:threadIdx_x % 4 * 8 + 8])
+ T.block_attr({"permuted_layout": "g2s_A"})
+ for ax0_ax1_fused_0 in range(4):
+ for ax0_ax1_fused_3 in T.vectorized(8):
+
X_reindex_shared_dyn[ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4,
threadIdx_x % 4 * 8 + ax0_ax1_fused_3] = X[blockIdx_y // 8 * 128 +
ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, ax2_0_0 * 32 +
threadIdx_x % 4 * 8 + ax0_ax1_fused_3]
+ for ax2_0_1 in range(4):
+ with T.block(""):
+
T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64:threadIdx_y // 2 * 64 + 64,
ax2_0_1 * 8:ax2_0_1 * 8 + 8])
+ T.writes()
+ X_reindex_shared_dyn_m16n8k8_matrixA =
T.alloc_buffer((64, 8), "float16", scope="m16n8k8.matrixA")
+ for ax0_0, ax1_0 in T.grid(2, 1):
+ with
T.block("X_reindex_shared.dyn_m16n8k8.matrixA_o"):
+
T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64 + ax0_0 * 32:threadIdx_y //
2 * 64 + ax0_0 * 32 + 32, ax2_0_1 * 8:ax2_0_1 * 8 + 8])
+
T.writes(X_reindex_shared_dyn_m16n8k8_matrixA[ax0_0 * 32:ax0_0 * 32 + 32, 0:8])
+
T.block_attr({"permuted_layout": "s2l_A"})
+ T.ptx_ldmatrix("float16",
T.bool(False), 4, ".b16", X_reindex_shared_dyn_m16n8k8_matrixA.data, ax0_0 * 8,
T.tvm_access_ptr(T.type_annotation("float16"), X_reindex_shared_dyn.data,
threadIdx_y // 2 * 2048 + ax0_0 * 1024 + ax2_0_1 * 8, 1024, 1), threadIdx_x *
32)
+
+ @T.prim_func
+ def expected(X: T.Buffer((4096, 4096), "float16")):
+ for blockIdx_y in T.thread_binding(256, thread="blockIdx.y"):
+ for threadIdx_y in T.thread_binding(4, thread="threadIdx.y"):
+ for threadIdx_x in T.thread_binding(32, thread="threadIdx.x"):
+ with T.block(""):
+ for ax2_0_0 in T.serial(128):
+ with T.block(""):
+ X_reindex_shared_dyn = T.alloc_buffer((128,
32), "float16", strides=(32, 1), scope="shared.dyn")
+ with T.block("X_reindex_shared.dyn"):
+ # annotate the reads and writes because
they cannot be inferred from tir.bitwise_xor
+ T.reads(X[blockIdx_y // 8 * 128 +
threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 +
threadIdx_x // 4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 +
threadIdx_x % 4 * 8 + 8])
+ T.writes(X_reindex_shared_dyn[threadIdx_y
* 8 + threadIdx_x // 4:threadIdx_y * 8 + threadIdx_x // 4 + 97, threadIdx_x % 4
* 8:threadIdx_x % 4 * 8 + 8])
+ for ax0_ax1_fused_0 in range(4):
+ for ax0_ax1_fused_3 in T.vectorized(8):
+
X_reindex_shared_dyn[ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4,
T.bitwise_xor(threadIdx_x % 4, threadIdx_x // 8) * 8 + ax0_ax1_fused_3] =
X[blockIdx_y // 8 * 128 + ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x
// 4, ax2_0_0 * 32 + threadIdx_x % 4 * 8 + ax0_ax1_fused_3]
+ for ax2_0_1 in T.serial(4):
+ with T.block(""):
+ X_reindex_shared_dyn_m16n8k8_matrixA =
T.alloc_buffer((64, 8), "float16", scope="m16n8k8.matrixA")
+ for ax0_0, ax1_0 in T.grid(2, 1):
+ with
T.block("X_reindex_shared.dyn_m16n8k8.matrixA_o"):
+
T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64 + ax0_0 * 32:threadIdx_y //
2 * 64 + ax0_0 * 32 + 32, ax2_0_1 * 8:ax2_0_1 * 8 + 8])
+
T.writes(X_reindex_shared_dyn_m16n8k8_matrixA[ax0_0 * 32:ax0_0 * 32 + 32, 0:8])
+ T.ptx_ldmatrix("float16",
T.bool(False), 4, ".b16", X_reindex_shared_dyn_m16n8k8_matrixA.data, ax0_0 * 8,
T.tvm_access_ptr(T.type_annotation("float16"), X_reindex_shared_dyn.data,
threadIdx_y // 2 * 2048 + ax0_0 * 1024 + threadIdx_x * 32 +
T.bitwise_xor(ax2_0_1, threadIdx_x % 8 // 2) * 8, 1024, 1), 0)
+ # fmt: on
+ _check_primfunc_transform(before, expected)
+
+
+def test_backward_compatibility_shared_a_and_b():
+ # fmt: off
+ @T.prim_func
+ def before(X: T.Buffer((4096, 4096), "float16"), Y: T.Buffer((4096, 4096),
"float16")):
+ for blockIdx_x in T.thread_binding(4, thread="blockIdx.x"):
+ for blockIdx_y in T.thread_binding(256, thread="blockIdx.y"):
+ for threadIdx_y in T.thread_binding(4, thread="threadIdx.y"):
+ for threadIdx_x in T.thread_binding(32,
thread="threadIdx.x"):
+ with T.block(""):
+ for ax2_0_0 in T.serial(128):
+ with T.block(""):
+ X_reindex_shared_dyn =
T.alloc_buffer((128, 32), "float16", strides=(32, 1), scope="shared.dyn")
+ Y_reindex_shared_dyn = T.alloc_buffer((32,
128), "float16", strides=(128, 1), scope="shared.dyn")
+ with T.block("X_reindex_shared.dyn"):
+ T.block_attr({"permuted_layout":
"g2s_A"})
+ for ax0_ax1_fused_0 in range(4):
+ for ax0_ax1_fused_3 in
T.vectorized(8):
+
X_reindex_shared_dyn[ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4,
threadIdx_x % 4 * 8 + ax0_ax1_fused_3] = X[blockIdx_y // 8 * 128 +
ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4, ax2_0_0 * 32 +
threadIdx_x % 4 * 8 + ax0_ax1_fused_3]
+ with T.block("Y_reindex_shared.dyn"):
+ T.block_attr({"permuted_layout":
"g2s_B"})
+ for ax0_ax1_fused_0 in range(4):
+ for ax0_ax1_fused_3 in
T.vectorized(8):
+
Y_reindex_shared_dyn[ax0_ax1_fused_0 * 8 + threadIdx_y * 2 + threadIdx_x // 16,
threadIdx_x % 16 * 8 + ax0_ax1_fused_3] = Y[ax2_0_0 * 32 + ax0_ax1_fused_0 * 8
+ threadIdx_y * 2 + threadIdx_x // 16, blockIdx_x * 1024 + blockIdx_y % 8 * 128
+ threadIdx_x % 16 * 8 + ax0_ax1_fused_3]
+ for ax2_0_1 in T.serial(4):
+ with T.block(""):
+
X_reindex_shared_dyn_m16n8k8_matrixA = T.alloc_buffer((64, 8), "float16",
scope="m16n8k8.matrixA")
+
Y_reindex_shared_dyn_m16n8k8_matrixB = T.alloc_buffer((8, 64), "float16",
scope="m16n8k8.matrixB")
+ for ax0_0, ax1_0 in T.grid(2, 1):
+ with
T.block("X_reindex_shared.dyn_m16n8k8.matrixA_o"):
+
T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64 + ax0_0 * 32:threadIdx_y //
2 * 64 + ax0_0 * 32 + 32, ax2_0_1 * 8:ax2_0_1 * 8 + 8])
+
T.writes(X_reindex_shared_dyn_m16n8k8_matrixA[ax0_0 * 32:ax0_0 * 32 + 32, 0:8])
+
T.block_attr({"permuted_layout": "s2l_A"})
+ T.ptx_ldmatrix("float16",
T.bool(False), 4, ".b16", X_reindex_shared_dyn_m16n8k8_matrixA.data, ax0_0 * 8,
T.tvm_access_ptr(T.type_annotation("float16"), X_reindex_shared_dyn.data,
threadIdx_y // 2 * 2048 + ax0_0 * 1024 + ax2_0_1 * 8, 1024, 1), threadIdx_x *
32)
+ for ax0_0, ax1_0 in T.grid(1, 2):
+ with
T.block("Y_reindex_shared.dyn_m16n8k8.matrixB_o"):
+
T.reads(Y_reindex_shared_dyn[ax2_0_1 * 8:ax2_0_1 * 8 + 8, threadIdx_y % 2 * 64
+ ax1_0 * 32:threadIdx_y % 2 * 64 + ax1_0 * 32 + 32])
+
T.writes(Y_reindex_shared_dyn_m16n8k8_matrixB[0:8, ax1_0 * 32:ax1_0 * 32 + 32])
+
T.block_attr({"permuted_layout": "s2l_B"})
+ T.ptx_ldmatrix("float16",
T.bool(True), 4, ".b16", Y_reindex_shared_dyn_m16n8k8_matrixB.data, ax1_0 * 8,
T.tvm_access_ptr(T.type_annotation("float16"), Y_reindex_shared_dyn.data,
ax2_0_1 * 1024 + threadIdx_y % 2 * 64 + ax1_0 * 32, 1024, 1), threadIdx_x % 8 *
128 + threadIdx_x // 8 * 8)
+
+ @T.prim_func
+ def expected(X: T.Buffer((4096, 4096), "float16"), Y: T.Buffer((4096,
4096), "float16")):
+ for blockIdx_x in T.thread_binding(4, thread="blockIdx.x"):
+ for blockIdx_y in T.thread_binding(256, thread="blockIdx.y"):
+ for threadIdx_y in T.thread_binding(4, thread="threadIdx.y"):
+ for threadIdx_x in T.thread_binding(32,
thread="threadIdx.x"):
+ with T.block(""):
+ T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y * 8
+ threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4 +
97, threadIdx_x % 4 * 8:threadIdx_x % 4 * 8 + 4072], Y[threadIdx_y * 2 +
threadIdx_x // 16:threadIdx_y * 2 + threadIdx_x // 16 + 4089, blockIdx_x * 1024
+ blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8:blockIdx_x * 1024 + blockIdx_y %
8 * 128 + threadIdx_x % 16 * 8 + 8])
+ T.writes()
+ for ax2_0_0 in T.serial(128):
+ with T.block(""):
+ T.reads(X[blockIdx_y // 8 * 128 +
threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 +
threadIdx_x // 4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 +
threadIdx_x % 4 * 8 + 8], Y[ax2_0_0 * 32 + threadIdx_y * 2 + threadIdx_x //
16:ax2_0_0 * 32 + threadIdx_y * 2 + threadIdx_x // 16 + 25, blockIdx_x * 1024 +
blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8:blockIdx_x * 1024 + blockIdx_y % 8
* 128 + threadIdx_x % 16 * 8 + 8])
+ T.writes()
+ X_reindex_shared_dyn =
T.alloc_buffer((128, 32), "float16", strides=(32, 1), scope="shared.dyn")
+ Y_reindex_shared_dyn = T.alloc_buffer((32,
128), "float16", strides=(128, 1), scope="shared.dyn")
+ with T.block("X_reindex_shared.dyn"):
+ T.reads(X[blockIdx_y // 8 * 128 +
threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 +
threadIdx_x // 4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 +
threadIdx_x % 4 * 8 + 8])
+
T.writes(X_reindex_shared_dyn[threadIdx_y * 8 + threadIdx_x // 4:threadIdx_y *
8 + threadIdx_x // 4 + 97, threadIdx_x % 4 * 8:threadIdx_x % 4 * 8 + 8])
+ for ax0_ax1_fused_0 in range(4):
+ for ax0_ax1_fused_3 in
T.vectorized(8):
+
X_reindex_shared_dyn[ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x // 4,
T.bitwise_xor(threadIdx_x % 4, threadIdx_x // 8) * 8 + ax0_ax1_fused_3] =
X[blockIdx_y // 8 * 128 + ax0_ax1_fused_0 * 32 + threadIdx_y * 8 + threadIdx_x
// 4, ax2_0_0 * 32 + threadIdx_x % 4 * 8 + ax0_ax1_fused_3]
+ with T.block("Y_reindex_shared.dyn"):
+ T.reads(Y[ax2_0_0 * 32 + threadIdx_y *
2 + threadIdx_x // 16:ax2_0_0 * 32 + threadIdx_y * 2 + threadIdx_x // 16 + 25,
blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8:blockIdx_x *
1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16 * 8 + 8])
+
T.writes(Y_reindex_shared_dyn[threadIdx_y * 2 + threadIdx_x // 16:threadIdx_y *
2 + threadIdx_x // 16 + 25, threadIdx_x % 16 * 8:threadIdx_x % 16 * 8 + 8])
+ for ax0_ax1_fused_0 in range(4):
+ for ax0_ax1_fused_3 in
T.vectorized(8):
+
Y_reindex_shared_dyn[ax0_ax1_fused_0 * 8 + threadIdx_y * 2 + threadIdx_x // 16,
T.bitwise_xor(threadIdx_x % 16, threadIdx_y * 2 + threadIdx_x // 16) * 8 +
ax0_ax1_fused_3] = Y[ax2_0_0 * 32 + ax0_ax1_fused_0 * 8 + threadIdx_y * 2 +
threadIdx_x // 16, blockIdx_x * 1024 + blockIdx_y % 8 * 128 + threadIdx_x % 16
* 8 + ax0_ax1_fused_3]
+ for ax2_0_1 in T.serial(4):
+ with T.block(""):
+
X_reindex_shared_dyn_m16n8k8_matrixA = T.alloc_buffer((64, 8), "float16",
scope="m16n8k8.matrixA")
+
Y_reindex_shared_dyn_m16n8k8_matrixB = T.alloc_buffer((8, 64), "float16",
scope="m16n8k8.matrixB")
+ for ax0_0, ax1_0 in T.grid(2, 1):
+ with
T.block("X_reindex_shared.dyn_m16n8k8.matrixA_o"):
+
T.reads(X_reindex_shared_dyn[threadIdx_y // 2 * 64 + ax0_0 * 32:threadIdx_y //
2 * 64 + ax0_0 * 32 + 32, ax2_0_1 * 8:ax2_0_1 * 8 + 8])
+
T.writes(X_reindex_shared_dyn_m16n8k8_matrixA[ax0_0 * 32:ax0_0 * 32 + 32, 0:8])
+ T.ptx_ldmatrix("float16",
T.bool(False), 4, ".b16", X_reindex_shared_dyn_m16n8k8_matrixA.data, ax0_0 * 8,
T.tvm_access_ptr(T.type_annotation("float16"), X_reindex_shared_dyn.data,
threadIdx_y // 2 * 2048 + ax0_0 * 1024 + threadIdx_x * 32 +
T.bitwise_xor(ax2_0_1, threadIdx_x % 8 // 2) * 8, 1024, 1), 0)
+ for ax0_0, ax1_0 in T.grid(1, 2):
+ with
T.block("Y_reindex_shared.dyn_m16n8k8.matrixB_o"):
+
T.reads(Y_reindex_shared_dyn[ax2_0_1 * 8:ax2_0_1 * 8 + 8, threadIdx_y % 2 * 64
+ ax1_0 * 32:threadIdx_y % 2 * 64 + ax1_0 * 32 + 32])
+
T.writes(Y_reindex_shared_dyn_m16n8k8_matrixB[0:8, ax1_0 * 32:ax1_0 * 32 + 32])
+ T.ptx_ldmatrix("float16",
T.bool(True), 4, ".b16", Y_reindex_shared_dyn_m16n8k8_matrixB.data, ax1_0 * 8,
T.tvm_access_ptr(T.type_annotation("float16"), Y_reindex_shared_dyn.data,
ax2_0_1 * 1024 + threadIdx_x % 8 * 128 + T.bitwise_xor(threadIdx_y % 2 * 8 +
ax1_0 * 4 + threadIdx_x // 8, threadIdx_x % 8) * 8, 1024, 1), 0)
+ # fmt: on
+ _check_primfunc_transform(before, expected)
+
+
+def test_buffer_a():
+ # fmt: off
+ @T.prim_func
+ def before(p_A: T.handle):
+ A = T.match_buffer(p_A, (T.int64(128), T.int64(32)), "float16")
+ A_shared_dyn = T.alloc_buffer((T.int64(128), T.int64(32)), "float16",
scope="shared.dyn")
+ A_warp = T.alloc_buffer((T.int64(4), T.int64(1), T.int64(32),
T.int64(8)), "float16", scope="warp")
+ for threadIdx_z in T.thread_binding(T.int64(2), thread="threadIdx.z"):
+ for threadIdx_y in T.thread_binding(T.int64(2),
thread="threadIdx.y"):
+ for threadIdx_x in T.thread_binding(T.int64(32),
thread="threadIdx.x"):
+ for v0 in range(T.int64(4)):
+ for v1 in T.vectorized(T.int64(8)):
+ with T.block("A_reindex_shared.dyn"):
+ T.block_attr({"permuted_layout": 1})
+ A_shared_dyn[
+ v0 * T.int64(32) + threadIdx_z *
T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4),
+ threadIdx_x % T.int64(4) * T.int64(8) + v1
+ ] = A[
+ (v0 * T.int64(32) + threadIdx_z *
T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4)) %
T.int64(32),
+ threadIdx_x % T.int64(4) * T.int64(8) + v1
+ ]
+ for v0, v1 in T.grid(T.int64(2), T.int64(4)):
+ with T.block("A_reindex_shared.dyn_warp_o"):
+ T.block_attr({"permuted_layout": 1})
+ with T.block("A_reindex_shared.dyn_warp_o"):
+ T.reads(A_shared_dyn[threadIdx_z * T.int64(64)
+ v1 * T.int64(16):threadIdx_z * T.int64(64) + v1 * T.int64(16) + T.int64(16),
v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)])
+ T.writes(A_warp[v1, T.int64(0),
T.int64(0):T.int64(32), T.int64(0):T.int64(8)])
+ T.ptx_ldmatrix("float16", T.bool(False), 4,
".b16",
+ A_warp.data,
+ v1 * T.int64(256) + threadIdx_x *
T.int64(8),
+
T.tvm_access_ptr(T.type_annotation("float16"),
+ A_shared_dyn.data,
+ threadIdx_z * T.int64(2048) + v1 *
T.int64(512) + v0 * T.int64(16), T.int64(512),
+ 1
+ ),
+ threadIdx_x % T.int64(16) * T.int64(32) +
threadIdx_x // T.int64(16) * T.int64(8)
+ )
+
+ @T.prim_func
+ def expected(A: T.Buffer((T.int64(128), T.int64(32)), "float16")):
+ A_shared_dyn = T.alloc_buffer((T.int64(128), T.int64(32)), "float16",
scope="shared.dyn")
+ A_warp = T.alloc_buffer((T.int64(4), T.int64(1), T.int64(32),
T.int64(8)), "float16", scope="warp")
+ for threadIdx_z in T.thread_binding(T.int64(2), thread="threadIdx.z"):
+ for threadIdx_y in T.thread_binding(T.int64(2),
thread="threadIdx.y"):
+ for threadIdx_x in T.thread_binding(T.int64(32),
thread="threadIdx.x"):
+ for v0 in range(T.int64(4)):
+ for v1 in T.vectorized(T.int64(8)):
+ with T.block("A_reindex_shared.dyn"):
+ T.reads(A[(v0 * T.int64(32) + threadIdx_z *
T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4)) %
T.int64(32), threadIdx_x % T.int64(4) * T.int64(8) + v1])
+ T.writes(A_shared_dyn[v0 * T.int64(32) +
threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x //
T.int64(4), threadIdx_x % T.int64(4) * T.int64(8) + v1])
+ A_shared_dyn[v0 * T.int64(32) + threadIdx_z *
T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4),
T.bitwise_xor(threadIdx_x % T.int64(4), threadIdx_x // T.int64(8)) * T.int64(8)
+ v1] = A[(v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y *
T.int64(8) + threadIdx_x // T.int64(4)) % T.int64(32), threadIdx_x % T.int64(4)
* T.int64(8) + v1]
+ for v0, v1 in T.grid(T.int64(2), T.int64(4)):
+ with T.block("A_reindex_shared.dyn_warp_o"):
+ T.reads(A_shared_dyn[threadIdx_z * T.int64(64) +
v1 * T.int64(16):threadIdx_z * T.int64(64) + v1 * T.int64(16) + T.int64(16), v0
* T.int64(16):v0 * T.int64(16) + T.int64(16)])
+ T.writes(A_warp[v1, T.int64(0),
T.int64(0):T.int64(32), T.int64(0):T.int64(8)])
+ with T.block("A_reindex_shared.dyn_warp_o"):
+ T.reads(A_shared_dyn[threadIdx_z * T.int64(64)
+ v1 * T.int64(16):threadIdx_z * T.int64(64) + v1 * T.int64(16) + T.int64(16),
v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)])
+ T.writes(A_warp[v1, T.int64(0),
T.int64(0):T.int64(32), T.int64(0):T.int64(8)])
+ T.ptx_ldmatrix("float16", T.bool(False), 4,
".b16", A_warp.data, v1 * T.int64(256) + threadIdx_x * T.int64(8),
T.tvm_access_ptr(T.type_annotation("float16"), A_shared_dyn.data, threadIdx_z *
T.int64(2048) + v1 * T.int64(512) + threadIdx_x % T.int64(16) * T.int64(32) +
T.bitwise_xor(v0 * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x %
T.int64(8) // T.int64(2)) * T.int64(8), T.int64(512), 1), T.int64(0))
+
+ # fmt: on
+ _check_primfunc_transform(before, expected)
+
+
+def test_buffer_b():
+ # fmt: off
+ @T.prim_func
+ def before(B: T.Buffer((T.int64(128), T.int64(32)), "float16")):
+ B_shared_dyn = T.alloc_buffer((T.int64(128), T.int64(32)), "float16",
scope="shared.dyn")
+ for threadIdx_z in T.thread_binding(T.int64(2), thread="threadIdx.z"):
+ for threadIdx_y in T.thread_binding(T.int64(2),
thread="threadIdx.y"):
+ for threadIdx_x in T.thread_binding(T.int64(32),
thread="threadIdx.x"):
+ for v0 in range(T.int64(4)):
+ for v1 in T.vectorized(T.int64(8)):
+ with T.block("B_reindex_shared.dyn"):
+ T.block_attr({"permuted_layout": 1})
+ B_shared_dyn[v0 * T.int64(32) + threadIdx_z *
T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x
% T.int64(4) * T.int64(8) + v1] = B[v0 * T.int64(32) + threadIdx_z *
T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x
% T.int64(4) * T.int64(8) + v1]
+ for v0 in range(T.int64(2)):
+ with T.block(""):
+ B_warp = T.alloc_buffer((T.int64(4), T.int64(1),
T.int64(32), T.int64(8)), "float16", scope="warp")
+ for v1 in range(T.int64(4)):
+ with T.block("B_reindex_shared.dyn_warp_o"):
+ T.block_attr({"permuted_layout": 1})
+ with
T.block("B_reindex_shared.dyn_warp_o"):
+ T.reads(B_shared_dyn[threadIdx_y *
T.int64(64) + v1 * T.int64(16):threadIdx_y * T.int64(64) + v1 * T.int64(16) +
T.int64(16), v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)])
+ T.writes(B_warp[v1, T.int64(0),
T.int64(0):T.int64(32), T.int64(0):T.int64(8)])
+ T.ptx_ldmatrix("float16",
T.bool(False), 4, ".b16", B_warp.data, v1 * T.int64(256) + threadIdx_x *
T.int64(8), T.tvm_access_ptr(T.type_annotation("float16"), B_shared_dyn.data,
threadIdx_y * T.int64(2048) + v1 * T.int64(512) + v0 * T.int64(16),
T.int64(512), 1), threadIdx_x // T.int64(16) * T.int64(256) + threadIdx_x %
T.int64(8) * T.int64(32) + threadIdx_x % T.int64(16) // T.int64(8) * T.int64(8))
+
+ @T.prim_func
+ def expected(B: T.Buffer((T.int64(128), T.int64(32)), "float16")):
+ B_shared_dyn = T.alloc_buffer((T.int64(128), T.int64(32)), "float16",
scope="shared.dyn")
+ for threadIdx_z in T.thread_binding(T.int64(2), thread="threadIdx.z"):
+ for threadIdx_y in T.thread_binding(T.int64(2),
thread="threadIdx.y"):
+ for threadIdx_x in T.thread_binding(T.int64(32),
thread="threadIdx.x"):
+ for v0 in range(T.int64(4)):
+ for v1 in T.vectorized(T.int64(8)):
+ with T.block("B_reindex_shared.dyn"):
+ T.reads(B[v0 * T.int64(32) + threadIdx_z *
T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x
% T.int64(4) * T.int64(8) + v1])
+ T.writes(B_shared_dyn[v0 * T.int64(32) +
threadIdx_z * T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x //
T.int64(4), threadIdx_x % T.int64(4) * T.int64(8) + v1])
+ B_shared_dyn[v0 * T.int64(32) + threadIdx_z *
T.int64(16) + threadIdx_y * T.int64(8) + threadIdx_x // T.int64(4),
T.bitwise_xor(threadIdx_x % T.int64(4), threadIdx_x // T.int64(8)) * T.int64(8)
+ v1] = B[v0 * T.int64(32) + threadIdx_z * T.int64(16) + threadIdx_y *
T.int64(8) + threadIdx_x // T.int64(4), threadIdx_x % T.int64(4) * T.int64(8) +
v1]
+ for v0 in range(T.int64(2)):
+ with T.block(""):
+ B_warp = T.alloc_buffer((T.int64(4), T.int64(1),
T.int64(32), T.int64(8)), "float16", scope="warp")
+ for v1 in range(T.int64(4)):
+ with T.block("B_reindex_shared.dyn_warp_o"):
+ T.reads(B_shared_dyn[threadIdx_y *
T.int64(64) + v1 * T.int64(16):threadIdx_y * T.int64(64) + v1 * T.int64(16) +
T.int64(16), v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)])
+ T.writes(B_warp[v1, T.int64(0),
T.int64(0):T.int64(32), T.int64(0):T.int64(8)])
+ with
T.block("B_reindex_shared.dyn_warp_o"):
+ T.reads(B_shared_dyn[threadIdx_y *
T.int64(64) + v1 * T.int64(16):threadIdx_y * T.int64(64) + v1 * T.int64(16) +
T.int64(16), v0 * T.int64(16):v0 * T.int64(16) + T.int64(16)])
+ T.writes(B_warp[v1, T.int64(0),
T.int64(0):T.int64(32), T.int64(0):T.int64(8)])
+ T.ptx_ldmatrix("float16",
T.bool(False), 4, ".b16", B_warp.data, v1 * T.int64(256) + threadIdx_x *
T.int64(8), T.tvm_access_ptr(T.type_annotation("float16"), B_shared_dyn.data,
threadIdx_y * T.int64(2048) + v1 * T.int64(512) + threadIdx_x // T.int64(16) *
T.int64(256) + threadIdx_x % T.int64(8) * T.int64(32) + T.bitwise_xor(v0 *
T.int64(2) + threadIdx_x % T.int64(16) // T.int64(8), threadIdx_x % T.int64(8)
// T.int64(2)) * T.int64(8), T.int64(512), [...]
+
+ # fmt: on
+ _check_primfunc_transform(before, expected)
+
+
+def test_buffer_c_fp32():
+ # fmt: off
+ @T.prim_func
+ def before(p_O: T.handle):
+ O = T.match_buffer(p_O, (T.int64(128), T.int64(128)), "float16")
+ O_shared_dyn = T.alloc_buffer((T.int64(128), T.int64(128)),
scope="shared.dyn")
+ O_warp = T.alloc_buffer((T.int64(4), T.int64(4), T.int64(32),
T.int64(8)), scope="warp")
+ for threadIdx_z in T.thread_binding(T.int64(2), thread="threadIdx.z"):
+ for threadIdx_y in T.thread_binding(T.int64(2),
thread="threadIdx.y"):
+ for threadIdx_x in T.thread_binding(T.int64(32),
thread="threadIdx.x"):
+ for v0, v1 in T.grid(T.int64(4), T.int64(4)):
+ with T.block("O.dyn_warp_o"):
+ T.block_attr({"permuted_layout": 1})
+ with T.block("O.dyn_warp_o"):
+ for local_id in range(T.int64(8)):
+ O_shared_dyn[threadIdx_z * T.int64(64) +
v0 * T.int64(16) + local_id % T.int64(4) // T.int64(2) * T.int64(8) +
threadIdx_x // T.int64(4), threadIdx_y * T.int64(64) + v1 * T.int64(16) +
local_id // T.int64(4) * T.int64(8) + threadIdx_x % T.int64(4) * T.int64(2) +
local_id % T.int64(2)] = O_warp[v0, v1, threadIdx_x, local_id]
+ for v0 in range(T.int64(16)):
+ for v1 in T.vectorized(T.int64(8)):
+ with T.block("O.dyn"):
+ T.block_attr({"permuted_layout": 1})
+ O[v0 * T.int64(8) + threadIdx_z * T.int64(4) +
threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x %
T.int64(16) * T.int64(8) + v1] = T.Cast("float16", O_shared_dyn[v0 * T.int64(8)
+ threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x //
T.int64(16), threadIdx_x % T.int64(16) * T.int64(8) + v1])
+
+
+ @T.prim_func
+ def expected(O: T.Buffer((T.int64(128), T.int64(128)), "float16")):
+ # with T.block("root"):
+ O_shared_dyn = T.alloc_buffer((T.int64(128), T.int64(128)),
scope="shared.dyn")
+ O_warp = T.alloc_buffer((T.int64(4), T.int64(4), T.int64(32),
T.int64(8)), scope="warp")
+ for threadIdx_z in T.thread_binding(T.int64(2), thread="threadIdx.z"):
+ for threadIdx_y in T.thread_binding(T.int64(2),
thread="threadIdx.y"):
+ for threadIdx_x in T.thread_binding(T.int64(32),
thread="threadIdx.x"):
+ for v0, v1 in T.grid(T.int64(4), T.int64(4)):
+ with T.block("O.dyn_warp_o"):
+ T.reads(O_warp[v0, v1, threadIdx_x,
T.int64(0):T.int64(8)])
+ T.writes(O_shared_dyn[threadIdx_z * T.int64(64) +
v0 * T.int64(16) + threadIdx_x // T.int64(4):threadIdx_z * T.int64(64) + v0 *
T.int64(16) + threadIdx_x // T.int64(4) + T.int64(9), threadIdx_y * T.int64(64)
+ v1 * T.int64(16) + threadIdx_x % T.int64(4) * T.int64(2):threadIdx_y *
T.int64(64) + v1 * T.int64(16) + threadIdx_x % T.int64(4) * T.int64(2) +
T.int64(10)])
+ with T.block("O.dyn_warp_o"):
+ T.reads(O_warp[v0, v1, threadIdx_x,
T.int64(0):T.int64(8)])
+ T.writes(O_shared_dyn[threadIdx_z *
T.int64(64) + v0 * T.int64(16) + threadIdx_x // T.int64(4):threadIdx_z *
T.int64(64) + v0 * T.int64(16) + threadIdx_x // T.int64(4) + T.int64(9),
threadIdx_y * T.int64(64) + v1 * T.int64(16) + threadIdx_x % T.int64(4) *
T.int64(2):threadIdx_y * T.int64(64) + v1 * T.int64(16) + threadIdx_x %
T.int64(4) * T.int64(2) + T.int64(10)])
+ for local_id in range(T.int64(8)):
+ O_shared_dyn[threadIdx_z * T.int64(64) +
v0 * T.int64(16) + local_id % T.int64(4) // T.int64(2) * T.int64(8) +
threadIdx_x // T.int64(4), T.bitwise_xor(threadIdx_y * T.int64(8) + v1 *
T.int64(2) + local_id // T.int64(4), threadIdx_x // T.int64(4)) * T.int64(8) +
threadIdx_x % T.int64(4) * T.int64(2) + local_id % T.int64(2)] = O_warp[v0, v1,
threadIdx_x, local_id]
+ for v0 in range(T.int64(16)):
+ for v1 in T.vectorized(T.int64(8)):
+ with T.block("O.dyn"):
+ T.reads(O_shared_dyn[v0 * T.int64(8) +
threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x //
T.int64(16), threadIdx_x % T.int64(16) * T.int64(8) + v1])
+ T.writes(O[v0 * T.int64(8) + threadIdx_z *
T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x
% T.int64(16) * T.int64(8) + v1])
+ O[v0 * T.int64(8) + threadIdx_z * T.int64(4) +
threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16), threadIdx_x %
T.int64(16) * T.int64(8) + v1] = T.Cast("float16", O_shared_dyn[v0 * T.int64(8)
+ threadIdx_z * T.int64(4) + threadIdx_y * T.int64(2) + threadIdx_x //
T.int64(16), T.bitwise_xor(threadIdx_x % T.int64(16), threadIdx_z * T.int64(4)
+ threadIdx_y * T.int64(2) + threadIdx_x // T.int64(16)) * T.int64(8) + v1])
+
+ # fmt: on
+ _check_primfunc_transform(before, expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()