masahi commented on code in PR #11355:
URL: https://github.com/apache/tvm/pull/11355#discussion_r875734757
##########
src/target/source/codegen_cuda.cc:
##########
@@ -818,9 +819,79 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op,
std::ostream& os) {
std::string local_ptr = this->PrintExpr(op->args[3]);
std::string local_elem_offset = this->PrintExpr(op->args[4]);
std::string smem_ptr = this->PrintExpr(op->args[5]);
- std::string smem_elem_offset = this->PrintExpr(op->args[6]);
- this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr,
local_elem_offset,
- smem_ptr, smem_elem_offset);
+
+ if (trans && op->dtype.bits() == 8) {
+ // Since ldmatrix assumes that a matrix element is 16 bit, it cannot
properly transpose an
+ // int8 matrix.
+ std::string smem_stride = this->PrintExpr(op->args[6]);
+ ICHECK(num == 4);
+ os << "for (int i = 0; i < 16; ++i) {\n";
+ os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr
+ << "[(i % 8) / 4 * " + smem_stride + " * 16 + (threadIdx.x % 4) * 4 *
" + smem_stride +
+ "+ (i % 4) * " + smem_stride + " + threadIdx.x / 4 + (i / 8)
* 8];\n";
+ os << "}\n";
+ } else {
+ std::string smem_elem_offset = this->PrintExpr(op->args[6]);
+ this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr,
local_elem_offset,
+ smem_ptr, smem_elem_offset);
+ }
+ } else if (op->op.same_as(builtin::mma_store())) {
+ int m = Downcast<Integer>(op->args[0])->value;
+ int n = Downcast<Integer>(op->args[1])->value;
+ std::string dst = this->PrintExpr(op->args[2]);
+ std::string src = this->PrintExpr(op->args[3]);
+ std::string src_offset = this->PrintExpr(op->args[4]);
+ PrimExpr stride = op->args[5];
+
+ ICHECK(m == 16 && n == 16) << "Only m == 16 && n == 16 case supported for
now";
+
+ // Each thread in a warp holds a certain number of elements of an MMA
output.
+ // For example, if we compute a 16x16 tile using MMA, each thread holds 8
elements
+ // in its registers. So conceptually, a warp memory is organized as a 32x8
block.
+ // A map from a 16x16 tile to a 32x8 block of memory is specified by the
index map below.
+
+ // To store the 32x8 output back to a 16x16 tile in shared or global
memory, we invert this map
+ // to determine the output location for each 8 element.
+
+ const auto* index_map_func =
+
runtime::Registry::Get("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout");
+ ICHECK(index_map_func);
+
+ auto inverse_index_map =
+ IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0,
n)});
Review Comment:
I'm very excited about the use of `IndexMap:::Inverse(...)` here. Initially,
I derived the index by hand, and later figured out how to use `DetectIterMap`
and `InverseAffineMap` to compute the same index automatically, only to realize
that I basically reimplemented `IndexMap:::Inverse(...)` . cc @Lunderberg
@vinx13
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]