This is an automated email from the ASF dual-hosted git repository.
cbalint13 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new eb25db1260 [TIRx][RISC-V] Use scalable RVV loops for fixed vectorize
(#19776)
eb25db1260 is described below
commit eb25db1260144c101642285a13a8b5eb73a40538
Author: Zephyr <[email protected]>
AuthorDate: Tue Jun 16 19:34:44 2026 +0800
[TIRx][RISC-V] Use scalable RVV loops for fixed vectorize (#19776)
This PR improves TIRx vectorization for RISC-V RVV targets.
Fixed-width `T.vectorized` loops can be lowered to fixed LLVM vectors
such as `<16 x float>`, which LLVM/RVV may scalarize into repeated
scalar `flw/fsub.s/fsw` instructions. This PR rewrites fixed-width
vectorized loops on RVV targets into scalable `T.vscale() * 4` chunks
with lane masks, allowing LLVM to generate RVV load/store instructions
instead.
The change is limited to RISC-V RVV and does not enable the same
automatic rewrite for Arm SVE.
Tested on a RISC-V K3 board:
Before: flw/fsub.s/fsw = 16/16/16, vle32/vse32 = 0/0
After: flw/fsub.s/fsw = 0/0/0, vle32/vse32 = 1/1
Also added a RISC-V LLVM codegen regression test.
---
src/tirx/transform/vectorize_loop.cc | 97 ++++++++++++++++++++---
tests/python/codegen/test_target_codegen_riscv.py | 43 ++++++++++
2 files changed, 130 insertions(+), 10 deletions(-)
diff --git a/src/tirx/transform/vectorize_loop.cc
b/src/tirx/transform/vectorize_loop.cc
index a1e954f951..fe6734863b 100644
--- a/src/tirx/transform/vectorize_loop.cc
+++ b/src/tirx/transform/vectorize_loop.cc
@@ -54,16 +54,27 @@ bool IsVScaleCall(const PrimExpr& expr) {
return false;
}
+bool TargetHasRVV(Target target) {
+ if (!target.defined()) return false;
+ static auto target_has_feature_fn =
+ tvm::ffi::Function::GetGlobalRequired("target.target_has_feature");
+ return target_has_feature_fn("v", target).cast<bool>();
+}
+
// File-local helper: true if the target supports Variable-Length Array
extensions
// (AArch64 SVE or RISC-V V).
bool TargetHasVLA(Target target) {
if (!target.defined()) return false;
bool has_vla = target->GetAttr<bool>("feature.has_sve").value_or(false);
- static auto target_has_feature_fn =
- tvm::ffi::Function::GetGlobalRequired("target.target_has_feature");
- has_vla |= target_has_feature_fn("v", target).cast<bool>();
+ has_vla |= TargetHasRVV(target);
return has_vla;
}
+
+bool ContainsCallNode(const Stmt& stmt) {
+ return CheckContains::StmtContains(stmt, [](const PrimExpr& expr) {
+ return expr.as<CallNode>() != nullptr;
+ });
+}
} // namespace
inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) {
@@ -132,7 +143,8 @@ bool EnableBufferLevelPredication(Target target) {
*/
class TryPredicateBufferAccesses : public StmtExprMutator {
public:
- TryPredicateBufferAccesses() {}
+ explicit TryPredicateBufferAccesses(bool allow_offset_predication)
+ : allow_offset_predication_(allow_offset_predication) {}
/*!
* \brief Run the pass to try to exact predicates.
@@ -157,7 +169,10 @@ class TryPredicateBufferAccesses : public StmtExprMutator {
return {false, stmt};
}
- base_ = Downcast<Ramp>(lt->a)->base;
+ Ramp pred_ramp = Downcast<Ramp>(lt->a);
+ base_ = pred_ramp->base;
+ stride_ = pred_ramp->stride;
+ lanes_ = pred_ramp->lanes;
limit_ = Downcast<Broadcast>(lt->b)->value;
// Now we can try to predicate
@@ -190,11 +205,21 @@ class TryPredicateBufferAccesses : public StmtExprMutator
{
}
Ramp ramp = Downcast<Ramp>(node->indices[0]);
- // The vectorized access pattern must match the base of the predicate
- if (!ffi::StructuralEqual()(ramp->base, base_)) {
+ if (!ffi::StructuralEqual()(ramp->stride, stride_) ||
+ !ffi::StructuralEqual()(ramp->lanes, lanes_)) {
return node;
}
+ bool same_base = ffi::StructuralEqual()(ramp->base, base_);
+ if (!same_base) {
+ // The lane mask describes which lanes are active, independent of the
+ // memory base. This covers accesses such as A[offset + i] guarded by
+ // a predicate over i.
+ if (!allow_offset_predication_) {
+ return node;
+ }
+ }
+
DataType buf_predicate_dtype =
DataType(DataType::kUInt, 1, ramp->dtype.get_lanes_or_vscale_factor(),
ramp->dtype.is_scalable_vector());
@@ -202,15 +227,27 @@ class TryPredicateBufferAccesses : public StmtExprMutator
{
num_accesses_rewritten_ += 1;
auto writer = node.CopyOnWrite();
- writer->predicate = lane_mask;
+ if (node->predicate.defined() && allow_offset_predication_) {
+ // Buffer predicates are uint1 lane masks, so mask merging uses bitwise
+ // and rather than logical &&.
+ writer->predicate = node->predicate.value() & lane_mask;
+ } else {
+ writer->predicate = lane_mask;
+ }
return node;
}
/*! \brief The variable base expr of the predicate. */
PrimExpr base_;
+ /*! \brief The lane stride of the predicate. */
+ PrimExpr stride_;
+ /*! \brief The lane count of the predicate. */
+ PrimExpr lanes_;
/*! \brief The limit of the predicate. The expr specifies the upper bound of
the base's
* evaluated value. */
PrimExpr limit_;
+ /*! \brief Whether to predicate offset buffer accesses that use the same
lane layout. */
+ bool allow_offset_predication_;
/*! \brief The number of buffer accesses in the stmt we will analyze. */
size_t num_accesses_analyzed_ = 0;
/*! \brief The number of buffer accesses rewritten with predicates. */
@@ -819,7 +856,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
if (EnableBufferLevelPredication(target_) &&
condition.dtype().is_scalable_or_fixed_length_vector() &&
!else_case.defined()) {
std::pair<bool, Stmt> success_stmt_pair =
- TryPredicateBufferAccesses().Run(then_case, condition);
+ TryPredicateBufferAccesses(TargetHasRVV(target_)).Run(then_case,
condition);
bool can_remove_if_then_else = success_stmt_pair.first;
if (can_remove_if_then_else) {
return success_stmt_pair.second;
@@ -975,12 +1012,19 @@ class LoopVectorizer : public StmtMutator {
if (op->kind == ForKind::kVectorized) {
auto* extent_as_int = op->extent.as<IntImmNode>();
+ TVM_FFI_ICHECK(is_zero(op->min));
+ // General calls still have vectorization paths that query a compile-time
+ // lane count, so keep them on the existing fixed-width path for now.
+ if (extent_as_int && extent_as_int->value > 1 && TargetHasRVV(target_) &&
+ !ContainsCallNode(op->body)) {
+ return VectorizeFixedLoopForRVV(op, extent_as_int->value);
+ }
+
if (!extent_as_int || extent_as_int->value < 1) {
bool is_scalable_expr = CheckContains::ExprContains(op->extent,
IsVScaleCall);
TVM_FFI_ICHECK(is_scalable_expr && TargetHasVLA(target_))
<< "Failed to vectorize loop with extent " << op->extent << " for
target " << target_;
}
- TVM_FFI_ICHECK(is_zero(op->min));
return Vectorizer(op->loop_var, op->extent, target_)(op->body);
} else {
return StmtMutator::VisitStmt_(op);
@@ -999,6 +1043,39 @@ class LoopVectorizer : public StmtMutator {
}
private:
+ Stmt VectorizeFixedLoopForRVV(const ForNode* op, int64_t extent) {
+ // Match the existing TIRx scalable-vector convention. LLVM/RVV still
+ // selects the runtime vector length with vsetvli.
+ static constexpr int kDefaultVScaleFactor = 4;
+ DataType index_dtype = op->loop_var->dtype;
+ PrimExpr zero = make_const(index_dtype, 0);
+ PrimExpr fixed_extent = make_const(index_dtype, extent);
+ PrimExpr scalable_lanes = CreateNewLanes(/*is_scalable=*/true,
kDefaultVScaleFactor);
+ DataType lane_dtype = scalable_lanes.dtype();
+ PrimExpr scalable_lanes_index = scalable_lanes;
+ if (scalable_lanes_index.dtype() != index_dtype) {
+ scalable_lanes_index = Cast(index_dtype, scalable_lanes_index);
+ }
+ PrimExpr num_chunks = ceildiv(fixed_extent, scalable_lanes_index);
+
+ Var outer(op->loop_var->name_hint + ".vla.o", index_dtype);
+ Var inner(op->loop_var->name_hint + ".vla.i", lane_dtype);
+ PrimExpr inner_index = inner;
+ if (inner_index.dtype() != index_dtype) {
+ inner_index = Cast(index_dtype, inner_index);
+ }
+ PrimExpr index = outer * scalable_lanes_index + inner_index;
+ Stmt body = Substitute(op->body, {{op->loop_var, index}});
+ Stmt guarded_body = IfThenElse(index < fixed_extent, body, std::nullopt,
op->span);
+ Stmt vector_loop =
+ For(inner, make_const(lane_dtype, 0), scalable_lanes,
ForKind::kVectorized, guarded_body,
+ std::nullopt, op->annotations, std::nullopt, op->span);
+ Stmt loop = For(outer, zero, num_chunks, ForKind::kSerial, vector_loop,
std::nullopt, {},
+ std::nullopt, op->span);
+
+ return this->VisitStmt(loop);
+ }
+
Target target_ = Target::Current();
};
diff --git a/tests/python/codegen/test_target_codegen_riscv.py
b/tests/python/codegen/test_target_codegen_riscv.py
index 5b9b1ecd77..3ac75dc337 100644
--- a/tests/python/codegen/test_target_codegen_riscv.py
+++ b/tests/python/codegen/test_target_codegen_riscv.py
@@ -16,6 +16,7 @@
# under the License.
# ruff: noqa: E501, F841
+import re
import pytest
import tvm
@@ -113,5 +114,47 @@ def test_rvv_vscale_llvm_dbginfo(target):
f = tvm.tirx.build(rvv_with_vscale, target)
[email protected](not env.has_llvm_min_version(14), reason="need llvm >= 14")
+def test_rvv_fixed_width_vectorized_loop_uses_scalable_chunks():
+ @T.prim_func(s_tir=True)
+ def fixed16_negative(
+ A: T.Buffer((14, 23, 67, 99), "float32"),
+ B: T.Buffer((14, 23, 67, 99), "float32"),
+ ):
+ for n, c, h, wo in T.grid(14, 23, 67, 7):
+ for wi in T.vectorized(0, 16):
+ if wo * 16 + wi < 99:
+ B[n, c, h, wo * 16 + wi] = T.float32(0) - A[n, c, h, wo *
16 + wi]
+
+ @T.prim_func(s_tir=True)
+ def fixed16_negative_int64(A: T.Buffer((16,), "float32"), B:
T.Buffer((16,), "float32")):
+ for wi in T.vectorized(T.int64(0), T.int64(16)):
+ B[wi] = T.float32(0) - A[wi]
+
+ target = tvm.target.Target(
+ {
+ "kind": "llvm",
+ "device": "riscv_cpu",
+ "mtriple": "riscv64-linux-gnu",
+ "mcpu": "generic-rv64",
+ "mattr": ["+64bit", "+a", "+c", "+d", "+f", "+m", "+v"],
+ }
+ )
+
+ def check_codegen(func):
+ with target:
+ f = tvm.tirx.build(func, target)
+
+ assembly = f.inspect_source("asm")
+ assert "vle32.v" in assembly
+ assert "vse32.v" in assembly
+ assert not re.search(r"\bflw\b", assembly)
+ assert not re.search(r"\bfsub\.s\b", assembly)
+ assert not re.search(r"\bfsw\b", assembly)
+
+ check_codegen(fixed16_negative)
+ check_codegen(fixed16_negative_int64)
+
+
if __name__ == "__main__":
tvm.testing.main()