This is an automated email from the ASF dual-hosted git repository.

cbalint13 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new eb25db1260 [TIRx][RISC-V] Use scalable RVV loops for fixed vectorize 
(#19776)
eb25db1260 is described below

commit eb25db1260144c101642285a13a8b5eb73a40538
Author: Zephyr <[email protected]>
AuthorDate: Tue Jun 16 19:34:44 2026 +0800

    [TIRx][RISC-V] Use scalable RVV loops for fixed vectorize (#19776)
    
    This PR improves TIRx vectorization for RISC-V RVV targets.
    
    Fixed-width `T.vectorized` loops can be lowered to fixed LLVM vectors
    such as `<16 x float>`, which LLVM/RVV may scalarize into repeated
    scalar `flw/fsub.s/fsw` instructions. This PR rewrites fixed-width
    vectorized loops on RVV targets into scalable `T.vscale() * 4` chunks
    with lane masks, allowing LLVM to generate RVV load/store instructions
    instead.
    
    The change is limited to RISC-V RVV and does not enable the same
    automatic rewrite for Arm SVE.
    
    Tested on a RISC-V K3 board:
    
    Before: flw/fsub.s/fsw = 16/16/16, vle32/vse32 = 0/0
    After:  flw/fsub.s/fsw = 0/0/0,    vle32/vse32 = 1/1
    
    Also added a RISC-V LLVM codegen regression test.
---
 src/tirx/transform/vectorize_loop.cc              | 97 ++++++++++++++++++++---
 tests/python/codegen/test_target_codegen_riscv.py | 43 ++++++++++
 2 files changed, 130 insertions(+), 10 deletions(-)

diff --git a/src/tirx/transform/vectorize_loop.cc 
b/src/tirx/transform/vectorize_loop.cc
index a1e954f951..fe6734863b 100644
--- a/src/tirx/transform/vectorize_loop.cc
+++ b/src/tirx/transform/vectorize_loop.cc
@@ -54,16 +54,27 @@ bool IsVScaleCall(const PrimExpr& expr) {
   return false;
 }
 
+bool TargetHasRVV(Target target) {
+  if (!target.defined()) return false;
+  static auto target_has_feature_fn =
+      tvm::ffi::Function::GetGlobalRequired("target.target_has_feature");
+  return target_has_feature_fn("v", target).cast<bool>();
+}
+
 // File-local helper: true if the target supports Variable-Length Array 
extensions
 // (AArch64 SVE or RISC-V V).
 bool TargetHasVLA(Target target) {
   if (!target.defined()) return false;
   bool has_vla = target->GetAttr<bool>("feature.has_sve").value_or(false);
-  static auto target_has_feature_fn =
-      tvm::ffi::Function::GetGlobalRequired("target.target_has_feature");
-  has_vla |= target_has_feature_fn("v", target).cast<bool>();
+  has_vla |= TargetHasRVV(target);
   return has_vla;
 }
+
+bool ContainsCallNode(const Stmt& stmt) {
+  return CheckContains::StmtContains(stmt, [](const PrimExpr& expr) {
+    return expr.as<CallNode>() != nullptr;
+  });
+}
 }  // namespace
 
 inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) {
@@ -132,7 +143,8 @@ bool EnableBufferLevelPredication(Target target) {
  */
 class TryPredicateBufferAccesses : public StmtExprMutator {
  public:
-  TryPredicateBufferAccesses() {}
+  explicit TryPredicateBufferAccesses(bool allow_offset_predication)
+      : allow_offset_predication_(allow_offset_predication) {}
 
   /*!
    * \brief Run the pass to try to exact predicates.
@@ -157,7 +169,10 @@ class TryPredicateBufferAccesses : public StmtExprMutator {
       return {false, stmt};
     }
 
-    base_ = Downcast<Ramp>(lt->a)->base;
+    Ramp pred_ramp = Downcast<Ramp>(lt->a);
+    base_ = pred_ramp->base;
+    stride_ = pred_ramp->stride;
+    lanes_ = pred_ramp->lanes;
     limit_ = Downcast<Broadcast>(lt->b)->value;
 
     // Now we can try to predicate
@@ -190,11 +205,21 @@ class TryPredicateBufferAccesses : public StmtExprMutator 
{
     }
     Ramp ramp = Downcast<Ramp>(node->indices[0]);
 
-    // The vectorized access pattern must match the base of the predicate
-    if (!ffi::StructuralEqual()(ramp->base, base_)) {
+    if (!ffi::StructuralEqual()(ramp->stride, stride_) ||
+        !ffi::StructuralEqual()(ramp->lanes, lanes_)) {
       return node;
     }
 
+    bool same_base = ffi::StructuralEqual()(ramp->base, base_);
+    if (!same_base) {
+      // The lane mask describes which lanes are active, independent of the
+      // memory base.  This covers accesses such as A[offset + i] guarded by
+      // a predicate over i.
+      if (!allow_offset_predication_) {
+        return node;
+      }
+    }
+
     DataType buf_predicate_dtype =
         DataType(DataType::kUInt, 1, ramp->dtype.get_lanes_or_vscale_factor(),
                  ramp->dtype.is_scalable_vector());
@@ -202,15 +227,27 @@ class TryPredicateBufferAccesses : public StmtExprMutator 
{
 
     num_accesses_rewritten_ += 1;
     auto writer = node.CopyOnWrite();
-    writer->predicate = lane_mask;
+    if (node->predicate.defined() && allow_offset_predication_) {
+      // Buffer predicates are uint1 lane masks, so mask merging uses bitwise
+      // and rather than logical &&.
+      writer->predicate = node->predicate.value() & lane_mask;
+    } else {
+      writer->predicate = lane_mask;
+    }
     return node;
   }
 
   /*! \brief The variable base expr of the predicate. */
   PrimExpr base_;
+  /*! \brief The lane stride of the predicate. */
+  PrimExpr stride_;
+  /*! \brief The lane count of the predicate. */
+  PrimExpr lanes_;
   /*! \brief The limit of the predicate. The expr specifies the upper bound of 
the base's
    * evaluated value. */
   PrimExpr limit_;
+  /*! \brief Whether to predicate offset buffer accesses that use the same 
lane layout. */
+  bool allow_offset_predication_;
   /*! \brief The number of buffer accesses in the stmt we will analyze. */
   size_t num_accesses_analyzed_ = 0;
   /*! \brief The number of buffer accesses rewritten with predicates. */
@@ -819,7 +856,7 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
     if (EnableBufferLevelPredication(target_) &&
         condition.dtype().is_scalable_or_fixed_length_vector() && 
!else_case.defined()) {
       std::pair<bool, Stmt> success_stmt_pair =
-          TryPredicateBufferAccesses().Run(then_case, condition);
+          TryPredicateBufferAccesses(TargetHasRVV(target_)).Run(then_case, 
condition);
       bool can_remove_if_then_else = success_stmt_pair.first;
       if (can_remove_if_then_else) {
         return success_stmt_pair.second;
@@ -975,12 +1012,19 @@ class LoopVectorizer : public StmtMutator {
     if (op->kind == ForKind::kVectorized) {
       auto* extent_as_int = op->extent.as<IntImmNode>();
 
+      TVM_FFI_ICHECK(is_zero(op->min));
+      // General calls still have vectorization paths that query a compile-time
+      // lane count, so keep them on the existing fixed-width path for now.
+      if (extent_as_int && extent_as_int->value > 1 && TargetHasRVV(target_) &&
+          !ContainsCallNode(op->body)) {
+        return VectorizeFixedLoopForRVV(op, extent_as_int->value);
+      }
+
       if (!extent_as_int || extent_as_int->value < 1) {
         bool is_scalable_expr = CheckContains::ExprContains(op->extent, 
IsVScaleCall);
         TVM_FFI_ICHECK(is_scalable_expr && TargetHasVLA(target_))
             << "Failed to vectorize loop with extent " << op->extent << " for 
target " << target_;
       }
-      TVM_FFI_ICHECK(is_zero(op->min));
       return Vectorizer(op->loop_var, op->extent, target_)(op->body);
     } else {
       return StmtMutator::VisitStmt_(op);
@@ -999,6 +1043,39 @@ class LoopVectorizer : public StmtMutator {
   }
 
  private:
+  Stmt VectorizeFixedLoopForRVV(const ForNode* op, int64_t extent) {
+    // Match the existing TIRx scalable-vector convention.  LLVM/RVV still
+    // selects the runtime vector length with vsetvli.
+    static constexpr int kDefaultVScaleFactor = 4;
+    DataType index_dtype = op->loop_var->dtype;
+    PrimExpr zero = make_const(index_dtype, 0);
+    PrimExpr fixed_extent = make_const(index_dtype, extent);
+    PrimExpr scalable_lanes = CreateNewLanes(/*is_scalable=*/true, 
kDefaultVScaleFactor);
+    DataType lane_dtype = scalable_lanes.dtype();
+    PrimExpr scalable_lanes_index = scalable_lanes;
+    if (scalable_lanes_index.dtype() != index_dtype) {
+      scalable_lanes_index = Cast(index_dtype, scalable_lanes_index);
+    }
+    PrimExpr num_chunks = ceildiv(fixed_extent, scalable_lanes_index);
+
+    Var outer(op->loop_var->name_hint + ".vla.o", index_dtype);
+    Var inner(op->loop_var->name_hint + ".vla.i", lane_dtype);
+    PrimExpr inner_index = inner;
+    if (inner_index.dtype() != index_dtype) {
+      inner_index = Cast(index_dtype, inner_index);
+    }
+    PrimExpr index = outer * scalable_lanes_index + inner_index;
+    Stmt body = Substitute(op->body, {{op->loop_var, index}});
+    Stmt guarded_body = IfThenElse(index < fixed_extent, body, std::nullopt, 
op->span);
+    Stmt vector_loop =
+        For(inner, make_const(lane_dtype, 0), scalable_lanes, 
ForKind::kVectorized, guarded_body,
+            std::nullopt, op->annotations, std::nullopt, op->span);
+    Stmt loop = For(outer, zero, num_chunks, ForKind::kSerial, vector_loop, 
std::nullopt, {},
+                    std::nullopt, op->span);
+
+    return this->VisitStmt(loop);
+  }
+
   Target target_ = Target::Current();
 };
 
diff --git a/tests/python/codegen/test_target_codegen_riscv.py 
b/tests/python/codegen/test_target_codegen_riscv.py
index 5b9b1ecd77..3ac75dc337 100644
--- a/tests/python/codegen/test_target_codegen_riscv.py
+++ b/tests/python/codegen/test_target_codegen_riscv.py
@@ -16,6 +16,7 @@
 # under the License.
 # ruff: noqa: E501, F841
 
+import re
 import pytest
 
 import tvm
@@ -113,5 +114,47 @@ def test_rvv_vscale_llvm_dbginfo(target):
         f = tvm.tirx.build(rvv_with_vscale, target)
 
 
[email protected](not env.has_llvm_min_version(14), reason="need llvm >= 14")
+def test_rvv_fixed_width_vectorized_loop_uses_scalable_chunks():
+    @T.prim_func(s_tir=True)
+    def fixed16_negative(
+        A: T.Buffer((14, 23, 67, 99), "float32"),
+        B: T.Buffer((14, 23, 67, 99), "float32"),
+    ):
+        for n, c, h, wo in T.grid(14, 23, 67, 7):
+            for wi in T.vectorized(0, 16):
+                if wo * 16 + wi < 99:
+                    B[n, c, h, wo * 16 + wi] = T.float32(0) - A[n, c, h, wo * 
16 + wi]
+
+    @T.prim_func(s_tir=True)
+    def fixed16_negative_int64(A: T.Buffer((16,), "float32"), B: 
T.Buffer((16,), "float32")):
+        for wi in T.vectorized(T.int64(0), T.int64(16)):
+            B[wi] = T.float32(0) - A[wi]
+
+    target = tvm.target.Target(
+        {
+            "kind": "llvm",
+            "device": "riscv_cpu",
+            "mtriple": "riscv64-linux-gnu",
+            "mcpu": "generic-rv64",
+            "mattr": ["+64bit", "+a", "+c", "+d", "+f", "+m", "+v"],
+        }
+    )
+
+    def check_codegen(func):
+        with target:
+            f = tvm.tirx.build(func, target)
+
+        assembly = f.inspect_source("asm")
+        assert "vle32.v" in assembly
+        assert "vse32.v" in assembly
+        assert not re.search(r"\bflw\b", assembly)
+        assert not re.search(r"\bfsub\.s\b", assembly)
+        assert not re.search(r"\bfsw\b", assembly)
+
+    check_codegen(fixed16_negative)
+    check_codegen(fixed16_negative_int64)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to