This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 435f641862 [Codegen] Support codegen for vectorized tir.ShuffleNode
(#17748)
435f641862 is described below
commit 435f641862e7ee307534e0dce04f26528c453b60
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Mar 15 09:34:57 2025 -0400
[Codegen] Support codegen for vectorized tir.ShuffleNode (#17748)
This PR introduces the support for vectorized tir.ShuffleNode,
which is useful for extracting bits and converting to float4,
since float4 is sub-byte.
Prior to this PR, ShuffleNode is not supported in vectorization.
This PR allows vectorizing ShuffleNode subject to special patterns,
and still throws error for ShuffleNodes that don't meet the pattern
requirements.
---
src/target/source/codegen_c.cc | 36 ++++++++--
src/target/source/literal/cuda_half_t.h | 44 ++++++------
src/tir/ir/expr_functor.cc | 5 +-
src/tir/transforms/vectorize_loop.cc | 68 ++++++++++++++++++-
.../python/codegen/test_target_codegen_cuda_fp4.py | 79 ++++++++++++++++++++++
5 files changed, 204 insertions(+), 28 deletions(-)
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 575f52e225..a67cb80b91 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -943,22 +943,43 @@ void CodeGenC::VisitExpr_(const ShuffleNode* op,
std::ostream& os) { // NOLINT(
// NOTE: important to print expr first
// in case each expr have their own nested expressions
// print each elements
- for (const PrimExpr& vec : op->vectors) {
- std::string vec_value = this->PrintExpr(vec);
- if (vec.dtype().lanes() == 1) {
+ if (op->vectors.size() > 1) {
+ for (const PrimExpr& vec : op->vectors) {
+ std::string vec_value = this->PrintExpr(vec);
+ if (vec.dtype().lanes() == 1) {
+ concat_vec.push_back(vec_value);
+ } else {
+ // print out each element
+ for (int i = 0; i < vec.dtype().lanes(); ++i) {
+ // access i-th element of each vector
+ std::ostringstream vec_elem_strm;
+ vec_elem_strm << vec_value << "[" << i << "]";
+ concat_vec.push_back(vec_elem_strm.str());
+ }
+ }
+ }
+ } else {
+ // Extract elements from a single vector-type value.
+ std::string vec_value = "(" + this->PrintExpr(op->vectors[0]) + ")";
+ if (op->vectors[0].dtype().lanes() == 1) {
concat_vec.push_back(vec_value);
} else {
// print out each element
- for (int i = 0; i < vec.dtype().lanes(); ++i) {
+ for (int i = 0; i < op->vectors[0].dtype().lanes(); ++i) {
// access i-th element of each vector
std::ostringstream vec_elem_strm;
- vec_elem_strm << vec_value << "[" << i << "]";
+ PrintVecElemLoad(vec_value, op->vectors[0].dtype(), i, vec_elem_strm);
concat_vec.push_back(vec_elem_strm.str());
}
}
}
if (op->indices.size() == 1) {
// This is an extract element
+ CHECK(op->indices[0]->IsInstance<IntImmNode>())
+ << "The ShuffleNode indices are expected to be constants at codegen
time. However, "
+ << "a non-constant index is " << op->indices[0]
+ << ". Please avoid using ShuffleNode or eliminate the ShuffleNode with
loop unroll or "
+ << "vectorize.";
int64_t idx = Downcast<IntImm>(op->indices[0])->value;
ICHECK_LT(idx, concat_vec.size());
os << concat_vec[idx];
@@ -969,6 +990,11 @@ void CodeGenC::VisitExpr_(const ShuffleNode* op,
std::ostream& os) { // NOLINT(
os << '(';
for (size_t i = 0; i < op->indices.size(); ++i) {
if (i != 0) os << ", ";
+ CHECK(op->indices[i]->IsInstance<IntImmNode>())
+ << "The ShuffleNode indices are expected to be constants at codegen
time. However, "
+ << "a non-constant index is " << op->indices[i]
+ << ". Please avoid using ShuffleNode or eliminate the ShuffleNode
with loop unroll or "
+ << "vectorize.";
os << concat_vec[Downcast<IntImm>(op->indices[i])->value];
}
os << ')';
diff --git a/src/target/source/literal/cuda_half_t.h
b/src/target/source/literal/cuda_half_t.h
index b095f5b8cf..039d89b93f 100644
--- a/src/target/source/literal/cuda_half_t.h
+++ b/src/target/source/literal/cuda_half_t.h
@@ -454,26 +454,6 @@ struct __align__(8) half4_bfloat164 {
(static_cast<__uint32_t>(lo_part.__x) |
(static_cast<__uint32_t>(hi_part.__x) << 16));
return result;
}
- __device__ __nv_fp8x2_e5m2 make_fp8x2_e5m2(__nv_fp8_storage_t x,
__nv_fp8_storage_t y) {
- __nv_fp8x2_e5m2 result;
- result.__x = (x) | (y << 8);
- return result;
- }
- __device__ __nv_fp8x4_e5m2 make_fp8x4_e5m2(__nv_fp8_storage_t a,
__nv_fp8_storage_t b, __nv_fp8_storage_t c, __nv_fp8_storage_t d) {
- __nv_fp8x4_e5m2 result;
- result.__x = (a) | (b << 8) | (c << 16) | (d << 24);
- return result;
- }
- __device__ __nv_fp8x2_e4m3 make_fp8x2_e4m3(__nv_fp8_storage_t x,
__nv_fp8_storage_t y) {
- __nv_fp8x2_e4m3 result;
- result.__x = (x) | (y << 8);
- return result;
- }
- __device__ __nv_fp8x4_e4m3 make_fp8x4_e4m3(__nv_fp8_storage_t a,
__nv_fp8_storage_t b, __nv_fp8_storage_t c, __nv_fp8_storage_t d) {
- __nv_fp8x4_e4m3 result;
- result.__x = (a) | (b << 8) | (c << 16) | (d << 24);
- return result;
- }
)";
}
if (enable_fp4) {
@@ -542,6 +522,30 @@ __host__ __device__ nv_bfloat162
cast_to_nv_bfloat162(const __nv_fp8x2_e4m3& fp8
)";
}
}
+ if (enable_fp8) {
+ stream << R"(
+__device__ __nv_fp8x2_e5m2 make___nv_fp8x2_e5m2(__nv_fp8_e5m2 x, __nv_fp8_e5m2
y) {
+ __nv_fp8x2_e5m2 result;
+ result.__x = (x.__x) | (y.__x << 8);
+ return result;
+}
+__device__ __nv_fp8x4_e5m2 make___nv_fp8x4_e5m2(__nv_fp8_e5m2 a, __nv_fp8_e5m2
b, __nv_fp8_e5m2 c, __nv_fp8_e5m2 d) {
+ __nv_fp8x4_e5m2 result;
+ result.__x = (a.__x) | (b.__x << 8) | (c.__x << 16) | (d.__x << 24);
+ return result;
+}
+__device__ __nv_fp8x2_e4m3 make___nv_fp8x2_e4m3(__nv_fp8_e4m3 x, __nv_fp8_e4m3
y) {
+ __nv_fp8x2_e4m3 result;
+ result.__x = (x.__x) | (y.__x << 8);
+ return result;
+}
+__device__ __nv_fp8x4_e4m3 make___nv_fp8x4_e4m3(__nv_fp8_e4m3 a, __nv_fp8_e4m3
b, __nv_fp8_e4m3 c, __nv_fp8_e4m3 d) {
+ __nv_fp8x4_e4m3 result;
+ result.__x = (a.__x) | (b.__x << 8) | (c.__x << 16) | (d.__x << 24);
+ return result;
+}
+)";
+ }
if (enable_fp4) {
stream << R"(
__device__ __nv_fp4x2_e2m1 make___nv_fp4x2_e2m1(__nv_fp4_e2m1 x, __nv_fp4_e2m1
y) {
diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc
index 34b46583d5..3c117b58a7 100644
--- a/src/tir/ir/expr_functor.cc
+++ b/src/tir/ir/expr_functor.cc
@@ -279,10 +279,11 @@ PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op)
{
PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) {
auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); };
auto vectors = op->vectors.Map(fexpr);
- if (vectors.same_as(op->vectors)) {
+ auto indices = op->indices.Map(fexpr);
+ if (vectors.same_as(op->vectors) && indices.same_as(op->indices)) {
return GetRef<PrimExpr>(op);
} else {
- return Shuffle(vectors, op->indices);
+ return Shuffle(vectors, indices);
}
}
diff --git a/src/tir/transforms/vectorize_loop.cc
b/src/tir/transforms/vectorize_loop.cc
index ec290e48d4..58ce6d6174 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -503,7 +503,11 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
if (value.dtype().is_scalable_vector()) {
return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
{value});
} else {
- return Call(op->dtype.with_lanes(lanes), op->op, {value});
+ int new_lanes = (op->dtype != DataType::NVFloat4E2M1FN() &&
+ op->args[0].dtype() != DataType::NVFloat4E2M1FN())
+ ? (value.dtype().bits() * value.dtype().lanes()) /
op->dtype.bits()
+ : value.dtype().lanes();
+ return Call(op->dtype.with_lanes(new_lanes), op->op, {value});
}
}
}
@@ -624,6 +628,68 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
}
}
}
+ PrimExpr VisitExpr_(const ShuffleNode* op) final {
+ CHECK(op->vectors.size() == 1 && op->indices.size() == 1)
+ << "Cannot vectorize ShuffleNode with multiple vectors or indices: the
vector size is "
+ << op->vectors.size() << " and the index size is " <<
op->indices.size();
+ int lane_vectors = 0;
+ int lane_indices = 0;
+ Array<PrimExpr> vectors = MutateArray(op->vectors, &lane_vectors);
+ Array<PrimExpr> indices = MutateArray(op->indices, &lane_indices);
+ if (vectors.same_as(op->vectors) && indices.same_as(op->indices)) {
+ return GetRef<PrimExpr>(op);
+ }
+
+ int new_vec_length = Downcast<IntImm>(var_lanes_)->value /
op->vectors[0].dtype().lanes();
+ PrimExpr updated_index = indices[0];
+ // Check that the indices satisfy the specific patterns.
+ auto f_check_index = [this, op](const PrimExpr& index) {
+ // Allowing Ramp(0, 1, var_lanes_)
+ if (const auto* ramp = index.as<RampNode>()) {
+ if (ramp->base->IsInstance<IntImmNode>() &&
Downcast<IntImm>(ramp->base)->value == 0 &&
+ ramp->stride->IsInstance<IntImmNode>() &&
Downcast<IntImm>(ramp->stride)->value == 1 &&
+ ramp->lanes->IsInstance<IntImmNode>() &&
+ Downcast<IntImm>(ramp->lanes)->value ==
Downcast<IntImm>(var_lanes_)->value) {
+ return true;
+ }
+ }
+ // Allowing FloorMod(Ramp(0, 1, var_lanes_),
Broadcast(op->vectors[0]->lanes, var_lanes_))
+ if (const auto* floordiv = index.as<FloorModNode>()) {
+ if (const auto* ramp = floordiv->a.as<RampNode>()) {
+ if (const auto* broadcast = floordiv->b.as<BroadcastNode>()) {
+ if (ramp->base->IsInstance<IntImmNode>() &&
Downcast<IntImm>(ramp->base)->value == 0 &&
+ ramp->stride->IsInstance<IntImmNode>() &&
+ Downcast<IntImm>(ramp->stride)->value == 1 &&
+ ramp->lanes->IsInstance<IntImmNode>() &&
+ Downcast<IntImm>(ramp->lanes)->value ==
Downcast<IntImm>(var_lanes_)->value &&
+ broadcast->value->IsInstance<IntImmNode>() &&
+ Downcast<IntImm>(broadcast->value)->value ==
op->vectors[0]->dtype.lanes() &&
+ broadcast->lanes->IsInstance<IntImmNode>() &&
+ Downcast<IntImm>(broadcast->lanes)->value ==
Downcast<IntImm>(var_lanes_)->value) {
+ return true;
+ }
+ }
+ }
+ }
+
+ return false;
+ };
+ CHECK(f_check_index(updated_index));
+
+ if (new_vec_length == 1) {
+ return tir::Substitute(op->vectors[0], {{var_, tvm::IntImm(var_->dtype,
0)}});
+ } else {
+ PrimExpr prev_ramp = ramp_;
+ PrimExpr prev_var_lanes = var_lanes_;
+ ramp_ = Ramp(IntImm(var_->dtype, 0), IntImm(var_->dtype, 2),
new_vec_length);
+ var_lanes_ = tvm::IntImm(var_lanes_.dtype(), new_vec_length);
+ lane_vectors = 0;
+ vectors = MutateArray(op->vectors, &lane_vectors);
+ ramp_ = prev_ramp;
+ var_lanes_ = prev_var_lanes;
+ return vectors[0];
+ }
+ }
// BufferStore
Stmt VisitStmt_(const BufferStoreNode* op) final {
auto store = GetRef<BufferStore>(op);
diff --git a/tests/python/codegen/test_target_codegen_cuda_fp4.py
b/tests/python/codegen/test_target_codegen_cuda_fp4.py
index 0a170026c9..14820ec34f 100644
--- a/tests/python/codegen/test_target_codegen_cuda_fp4.py
+++ b/tests/python/codegen/test_target_codegen_cuda_fp4.py
@@ -211,5 +211,84 @@ def test_e2m1_reinterpret():
)
[email protected]_cuda_compute_version(10)
+def test_e2m1_dequantize():
+ n = 128
+
+ dev = tvm.device("cuda", 0)
+ target = tvm.target.Target.from_device(dev)
+ num_elem_per_storage = 32 // 4
+
+ def get_reinterpret_mod(func_type, vector_length):
+ @T.prim_func
+ def shuffle_reinterpret(
+ A: T.Buffer((n // num_elem_per_storage,), "uint32"),
+ B: T.Buffer((n,), "float16"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for i in range(n):
+ with T.block("C"):
+ v_i = T.axis.spatial(n, i)
+ T.reads(A[v_i])
+ T.writes(B[v_i])
+ B[v_i] = T.Shuffle(
+ [
+ T.reinterpret(
+ "float4_e2m1fnx2",
+ T.bitwise_and(
+ T.shift_right(
+ A[v_i // num_elem_per_storage],
+ ((v_i % num_elem_per_storage) // 2 * 4
* 2).astype(
+ "uint32"
+ ),
+ ),
+ T.uint32((1 << (4 * 2)) - 1),
+ ).astype("uint8"),
+ ).astype("float16x2")
+ ],
+ indices=[v_i % 2],
+ )
+
+ @T.prim_func
+ def scalar_reinterpret(
+ A: T.Buffer((n // num_elem_per_storage,), "uint32"),
+ B: T.Buffer((n,), "float16"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for i in range(n):
+ with T.block("C"):
+ v_i = T.axis.spatial(n, i)
+ T.reads(A[v_i])
+ T.writes(B[v_i])
+ B[v_i] = T.reinterpret(
+ "float4_e2m1fn",
+ T.bitwise_and(
+ T.shift_right(
+ A[v_i // num_elem_per_storage],
+ (v_i % num_elem_per_storage *
4).astype("uint32"),
+ ),
+ T.uint32((1 << 4) - 1),
+ ).astype("uint8"),
+ ).astype("float16")
+
+ func = shuffle_reinterpret if func_type == "shuffle" else
scalar_reinterpret
+ sch = tvm.tir.Schedule(func)
+ block = sch.get_block("C")
+ b = sch.get_loops(block)
+ bx, tx, vec = sch.split(b[0], factors=[None, 32, vector_length])
+ sch.bind(bx, "blockIdx.x")
+ sch.bind(tx, "threadIdx.x")
+ sch.vectorize(vec)
+ return sch.mod
+
+ # We only test the whether the code can be compiled.
+ for func_type, vector_length in product(["shuffle", "scalar"], [1, 2, 4]):
+ if func_type == "shuffle" and vector_length == 1:
+ # Vectorize is necessary for shuffle.
+ continue
+ mod = get_reinterpret_mod(func_type, vector_length)
+ tvm.compile(mod, target=target)
+
+
if __name__ == "__main__":
tvm.testing.main()