This is an automated email from the ASF dual-hosted git repository.
cbalint13 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 81e62ecbb2 [TIRx][LLVM] Support scalable Ramp lowering (#19866)
81e62ecbb2 is described below
commit 81e62ecbb27724794940ce13e551e43313cae0c7
Author: Zephyr <[email protected]>
AuthorDate: Fri Jun 26 11:47:31 2026 +0800
[TIRx][LLVM] Support scalable Ramp lowering (#19866)
The existing Ramp lowering path in CodeGenLLVM constructs fixed-width
vectors by inserting each lane explicitly. This does not work for
scalable vectors, whose runtime lane count is not known at compile time.
Previously, CodeGenLLVM rejected scalable-vector Ramp expressions. This
prevents vectorized TIR/TIRx programs from lowering induction
expressions to RVV/SVE-style scalable vectors.
This patch adds a separate lowering path for scalable integer Ramp
expressions using LLVM stepvector. A scalable Ramp expression:
Ramp(base, stride, lanes)
is lowered as:
splat(base) + stepvector() * splat(stride)
For LLVM >= 20, this uses llvm.stepvector. For older LLVM versions, this
uses llvm.experimental.stepvector.
A RISC-V RVV codegen test is added to verify that a vectorized induction
expression lowers to RVV lane-id and arithmetic instructions, such as
vid.v, vmul.v*, and vadd.v*.
---
src/target/llvm/codegen_llvm.cc | 29 ++++++++++++++++++++---
tests/python/codegen/test_target_codegen_riscv.py | 26 ++++++++++++++++++++
2 files changed, 52 insertions(+), 3 deletions(-)
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 0a5acd348a..f091015f22 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -1898,9 +1898,32 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op)
{
llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) {
PrimType dtype(op->ty()->dtype);
- llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(dtype));
- // TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455
- TVM_FFI_ICHECK(!dtype.IsScalableVector());
+ llvm::Type* vec_type = DTypeToLLVMType(dtype);
+ if (dtype.IsScalableVector()) {
+ TVM_FFI_ICHECK(dtype.MatchesCode(DLDataTypeCode::kDLInt,
DLDataTypeCode::kDLUInt))
+ << "Scalable ramps require an integer dtype, but got " << dtype;
+ TVM_FFI_ICHECK_GE(dtype.bits(), 8)
+ << "Scalable ramps require at least 8-bit elements, but got " << dtype;
+
+#if TVM_LLVM_VERSION >= 200
+ constexpr llvm::Intrinsic::ID stepvector_id = llvm::Intrinsic::stepvector;
+#else
+ constexpr llvm::Intrinsic::ID stepvector_id =
llvm::Intrinsic::experimental_stepvector;
+#endif
+ llvm::Function* stepvector = GetIntrinsicDecl(stepvector_id, vec_type, {});
+ llvm::Value* step = builder_->CreateCall(stepvector);
+ llvm::ElementCount lanes =
llvm::ElementCount::getScalable(dtype.VScaleFactor());
+ PrimType elem_dtype = dtype.WithLanes(1);
+ llvm::Value* base_scalar =
+ CreateCast(PrimType(op->base.ty()->dtype), elem_dtype,
MakeValue(op->base));
+ llvm::Value* stride_scalar =
+ CreateCast(PrimType(op->stride.ty()->dtype), elem_dtype,
MakeValue(op->stride));
+ llvm::Value* base = builder_->CreateVectorSplat(lanes, base_scalar);
+ llvm::Value* stride = builder_->CreateVectorSplat(lanes, stride_scalar);
+ return builder_->CreateAdd(base, builder_->CreateMul(step, stride));
+ }
+
+ llvm::Value* vec = llvm::UndefValue::get(vec_type);
int lanes = dtype.lanes();
for (int i = 0; i < lanes; ++i) {
vec = builder_->CreateInsertElement(
diff --git a/tests/python/codegen/test_target_codegen_riscv.py
b/tests/python/codegen/test_target_codegen_riscv.py
index dba1b5e7ad..45010b9798 100644
--- a/tests/python/codegen/test_target_codegen_riscv.py
+++ b/tests/python/codegen/test_target_codegen_riscv.py
@@ -169,5 +169,31 @@ def
test_rvv_fixed_width_vectorized_loop_uses_scalable_chunks():
check_codegen(fixed16_negative_int64)
[email protected](not env.has_llvm_min_version(14), reason="need llvm >= 14")
+def test_rvv_scalable_ramp_expression():
+ @T.prim_func(s_tir=True)
+ def ramp_compare(B: T.Buffer((16,), "int32")):
+ for i in T.vectorized(16):
+ B[i] = T.Select(i * 3 + 5 < 29, i * 3 + 5, -1)
+
+ target = tvm.target.Target(
+ {
+ "kind": "llvm",
+ "device": "riscv_cpu",
+ "mtriple": "riscv64-linux-gnu",
+ "mcpu": "generic-rv64",
+ "mattr": ["+64bit", "+a", "+c", "+d", "+f", "+m", "+v"],
+ }
+ )
+
+ with target:
+ f = tvm.tirx.build(ramp_compare, target)
+
+ assembly = f.inspect_source("asm")
+ assert "vid.v" in assembly
+ assert re.search(r"\bvmul\.v", assembly)
+ assert re.search(r"\bvadd\.v", assembly)
+
+
if __name__ == "__main__":
tvm.testing.main()