This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 01268ac089 [LLVM][Codegen] Enable SVE/VLA for RISCV targets
01268ac089 is described below
commit 01268ac089560ad0e59666db6c6e2cc9571b593a
Author: Balint Cristian <[email protected]>
AuthorDate: Tue May 13 07:34:05 2025 +0300
[LLVM][Codegen] Enable SVE/VLA for RISCV targets
---
src/arith/analyzer.cc | 9 +-
src/arith/const_int_bound.cc | 7 +-
src/arith/scalable_expression.cc | 33 ++++-
src/arith/scalable_expression.h | 15 +-
src/target/llvm/codegen_aarch64.cc | 4 +-
src/tir/transforms/vectorize_loop.cc | 6 +-
tests/python/arith/test_arith_simplify.py | 2 +-
.../python/codegen/test_target_codegen_aarch64.py | 163 ++++++---------------
.../python/codegen/test_target_codegen_llvm_vla.py | 149 +++++++++++++++++++
.../tir-schedule/test_tir_schedule_split_fuse.py | 2 +-
.../tir-transform/test_tir_transform_vectorize.py | 2 +-
11 files changed, 251 insertions(+), 141 deletions(-)
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 602a198a2b..10c7676282 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -231,17 +231,18 @@ bool Analyzer::CanProve(const PrimExpr& expr,
ProofStrength strength) {
// Current analysis may not be powerful enough to prove expressions
containing
// the same symbolic value multiple times. However, when the symbolic values
are
// "T.vscale" and the compile target uses a scalable architecture extension
like
- // SVE, we can make some assumptions about the value of vscale and iterate
over a
+ // VLA, we can make some assumptions about the value of vscale and iterate
over a
// space of pre-defined values to attempt to prove the expression.
Target curr_target = Target::Current();
if (ContainsVscaleCall(simplified)) {
- if (TargetHasSVE(curr_target)) {
- return CanProveVscaleExpressionFromKnownValues(this, simplified,
kAArch64VScaleValues);
+ if (TargetHasVLA(curr_target)) {
+ auto kVScaleValues = GetVScaleValues(curr_target);
+ return CanProveVscaleExpressionFromKnownValues(this, simplified,
kVScaleValues);
}
LOG(WARNING)
<< "The expression contains scalable values. An attempt to prove by
substituting "
"with known values of vscale was not performed. This proof
currently only supports "
- "AArch64 SVE targets, but the target was "
+ "VLA targets, but the target was "
<< curr_target;
}
return false;
diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index ac8ac91711..7409ecc6f3 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -364,15 +364,16 @@ class ConstIntBoundAnalyzer::Impl
// only special handle >> and & which can be
// used for index calculation.
+ auto curr_target = Target::Current();
if (op->op.same_as(tir::builtin::shift_right())) {
return VisitRightShift(op);
} else if (op->op.same_as(tir::builtin::shift_left())) {
return VisitLeftShift(op);
} else if (op->op.same_as(tir::builtin::bitwise_and())) {
return VisitBitwiseAnd(op);
- } else if (op->op.same_as(tir::builtin::vscale()) &&
TargetHasSVE(Target::Current())) {
- unsigned int max_val =
- *std::max_element(kAArch64VScaleValues.begin(),
kAArch64VScaleValues.end());
+ } else if (op->op.same_as(tir::builtin::vscale()) &&
TargetHasVLA(curr_target)) {
+ auto kVScaleValues = GetVScaleValues(curr_target);
+ unsigned int max_val = *std::max_element(kVScaleValues.begin(),
kVScaleValues.end());
return MakeBound(1, max_val);
} else {
return Everything(op->dtype);
diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc
index beb75c1f3e..1937b9c34e 100644
--- a/src/arith/scalable_expression.cc
+++ b/src/arith/scalable_expression.cc
@@ -86,14 +86,41 @@ bool
CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr
return can_prove_expr;
}
-bool TargetHasSVE(Optional<Target> target) {
+bool TargetHasVLA(Optional<Target> target) {
if (!target.defined()) {
target = Target::Current();
}
+ bool has_vla{false};
if (target.defined()) {
- return
Downcast<Target>(target)->GetFeature<Bool>("has_sve").value_or(Bool(false));
+ // aarch64
+ has_vla =
Downcast<Target>(target)->GetFeature<Bool>("has_sve").value_or(Bool(false));
+ // riscv{32,64}
+ static auto target_has_feature_fn =
+ tvm::ffi::Function::GetGlobalRequired("target.target_has_feature");
+ has_vla |= target_has_feature_fn("v", target).cast<bool>();
}
- return false;
+ return has_vla;
+}
+
+const std::vector<unsigned int> GetVScaleValues(Optional<Target> target) {
+ unsigned int vector_width = 0;
+ std::vector<unsigned int> kVScaleValues;
+ if (!target.defined()) {
+ target = Target::Current();
+ }
+ if (target.defined()) {
+ static auto llvm_get_vector_width_fn =
+ tvm::ffi::Function::GetGlobalRequired("target.llvm_get_vector_width");
+ vector_width = llvm_get_vector_width_fn(target).cast<int>();
+ }
+ // scale list with powers of two
+ for (unsigned int i = 0;; ++i) {
+ auto power = static_cast<unsigned int>(std::pow(2, i));
+ if (power > (vector_width / 8)) break;
+ kVScaleValues.push_back(power);
+ }
+
+ return kVScaleValues;
}
} // namespace arith
diff --git a/src/arith/scalable_expression.h b/src/arith/scalable_expression.h
index 70d753a299..2470d5dcd8 100644
--- a/src/arith/scalable_expression.h
+++ b/src/arith/scalable_expression.h
@@ -35,9 +35,6 @@
namespace tvm {
namespace arith {
-/*! \brief A list of known vscale values to try for an AArch64 SVE target. */
-static const std::vector<unsigned int> kAArch64VScaleValues = {1, 2, 4, 8, 16};
-
/*!
* \brief Check if an expr is a call to the vscale intrinsic.
* \param expr The expr to check
@@ -80,10 +77,18 @@ bool
CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr
/*!
* \brief Check whether the compilation target supports SVE
+ * \brief Check whether the compilation target supports VLA
+ * \param target The target to check.
+ * \return Whether VLA is supported
+ */
+bool TargetHasVLA(Optional<Target> target = std::nullopt);
+
+/*!
+ * \brief Get a list of known vscale values to try for an VLA target.
* \param target The target to check.
- * \return Whether SVE is supported
+ * \return A list of vscale values as std::vector<usigned int>
*/
-bool TargetHasSVE(Optional<Target> target = std::nullopt);
+const std::vector<unsigned int> GetVScaleValues(Optional<Target> target =
std::nullopt);
} // namespace arith
} // namespace tvm
diff --git a/src/target/llvm/codegen_aarch64.cc
b/src/target/llvm/codegen_aarch64.cc
index 1399fc083a..b690c0fc28 100644
--- a/src/target/llvm/codegen_aarch64.cc
+++ b/src/target/llvm/codegen_aarch64.cc
@@ -57,8 +57,8 @@ void CodeGenAArch64::SetTargetAttributes(llvm::Function*
func) {
#if TVM_LLVM_VERSION >= 130
// Add vscale_range() function attribute when appropriate.
if (llvm_target_->TargetHasCPUFeature("sve") ||
llvm_target_->TargetHasCPUFeature("sme")) {
- unsigned int max_val =
- *std::max_element(arith::kAArch64VScaleValues.begin(),
arith::kAArch64VScaleValues.end());
+ auto kVScaleValues = arith::GetVScaleValues(Target::Current());
+ unsigned int max_val = *std::max_element(kVScaleValues.begin(),
kVScaleValues.end());
func->addFnAttr(
llvm::Attribute::getWithVScaleRangeArgs(*llvm_target_->GetContext(),
1, max_val));
}
diff --git a/src/tir/transforms/vectorize_loop.cc
b/src/tir/transforms/vectorize_loop.cc
index 3df73f0edb..54b2daf836 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -80,8 +80,8 @@ bool EnableBufferLevelPredication(Target target) {
return enable_buffer_predication.value();
}
- // Use buffer-level predication by default for AArch64 SVE targets
- return arith::TargetHasSVE(target);
+ // Use buffer-level predication by default for VLA targets
+ return arith::TargetHasVLA(target);
}
/*!
@@ -972,7 +972,7 @@ class LoopVectorizer : public StmtMutator {
if (!extent_as_int || extent_as_int->value < 1) {
bool is_scalable_expr = CheckContains::ExprContains(op->extent,
arith::IsVScaleCall);
- ICHECK(is_scalable_expr && arith::TargetHasSVE(target_))
+ ICHECK(is_scalable_expr && arith::TargetHasVLA(target_))
<< "Failed to vectorize loop with extent " << op->extent << " for
target " << target_;
}
ICHECK(is_zero(op->min));
diff --git a/tests/python/arith/test_arith_simplify.py
b/tests/python/arith/test_arith_simplify.py
index 3b02377400..4971acbd45 100644
--- a/tests/python/arith/test_arith_simplify.py
+++ b/tests/python/arith/test_arith_simplify.py
@@ -113,7 +113,7 @@ def
test_simplify_vscale_comparison_without_sve_target(capfd):
warning_msg = (
"Warning: The expression contains scalable values. An attempt to prove
by substituting "
"with known values of vscale was not performed. This proof currently
only supports "
- "AArch64 SVE targets, but the target was llvm -keys=arm_cpu,cpu
-mtriple=aarch64-linux-gnu"
+ "VLA targets, but the target was llvm -keys=arm_cpu,cpu
-mtriple=aarch64-linux-gnu"
)
capture = capfd.readouterr().err
assert warning_msg in capture
diff --git a/tests/python/codegen/test_target_codegen_aarch64.py
b/tests/python/codegen/test_target_codegen_aarch64.py
index 43870044d5..2c8f185d8e 100644
--- a/tests/python/codegen/test_target_codegen_aarch64.py
+++ b/tests/python/codegen/test_target_codegen_aarch64.py
@@ -43,7 +43,9 @@ def test_mul(dtype):
A = te.placeholder(m, dtype=type, name="A")
B = te.placeholder(m, dtype=type, name="B")
C = te.compute((m), lambda i: A[i] * B[i], name="C")
- f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+ with tvm.target.Target(target):
+ f = tvm.tir.build(te.create_prim_func([A, B, C]))
# Verify we see SVE load instructions and mul instructions using z
registers
assembly = f.get_source("asm")
@@ -73,7 +75,9 @@ def test_add(dtype):
A = te.placeholder(m, dtype=type, name="A")
B = te.placeholder(m, dtype=type, name="B")
C = te.compute((m), lambda i: A[i] + B[i], name="C")
- f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+ with tvm.target.Target(target):
+ f = tvm.tir.build(te.create_prim_func([A, B, C]))
# Verify we see SVE load instructions and add instructions using z
registers
assembly = f.get_source("asm")
@@ -103,7 +107,9 @@ def test_sub(dtype):
A = te.placeholder(m, dtype=type, name="A")
B = te.placeholder(m, dtype=type, name="B")
C = te.compute((m), lambda i: A[i] - B[i], name="C")
- f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+ with tvm.target.Target(target):
+ f = tvm.tir.build(te.create_prim_func([A, B, C]))
# Verify we see SVE load instructions and sub instructions using z
registers
assembly = f.get_source("asm")
@@ -134,7 +140,9 @@ def test_muladd(dtype):
B = te.placeholder(m, dtype=type, name="B")
C = te.placeholder(m, dtype=type, name="C")
D = te.compute((m), lambda i: A[i] * B[i] + C[i], name="D")
- f = tvm.tir.build(te.create_prim_func([A, B, C, D]), target=target)
+
+ with tvm.target.Target(target):
+ f = tvm.tir.build(te.create_prim_func([A, B, C, D]))
# Verify we see SVE load instructions and either mad or mla
instructions using z registers
assembly = f.get_source("asm")
@@ -164,7 +172,9 @@ def test_max(dtype):
A = te.placeholder(m, dtype=type, name="A")
B = te.placeholder(m, dtype=type, name="B")
C = te.compute((m), lambda i: tvm.te.max(A[i], B[i]))
- f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+ with tvm.target.Target(target):
+ f = tvm.tir.build(te.create_prim_func([A, B, C]))
# Verify we see SVE load instructions and cmgt + sel instructions or a
max instruction, all using z registers
assembly = f.get_source("asm")
@@ -198,7 +208,9 @@ def test_min(dtype):
A = te.placeholder(m, dtype=type, name="A")
B = te.placeholder(m, dtype=type, name="B")
C = te.compute((m), lambda i: tvm.te.min(A[i], B[i]))
- f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+ with tvm.target.Target(target):
+ f = tvm.tir.build(te.create_prim_func([A, B, C]))
# Verify we see SVE load instructions and cmgt + sel instructions or a
min instruction, all using z registers
assembly = f.get_source("asm")
@@ -232,7 +244,9 @@ def test_div(dtype):
A = te.placeholder(m, dtype=type, name="A")
B = te.placeholder(m, dtype=type, name="B")
C = te.compute((m), lambda i: tvm.te.div(A[i], B[i]))
- f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+ with tvm.target.Target(target):
+ f = tvm.tir.build(te.create_prim_func([A, B, C]))
# Verify we see SVE load instructions and div instructions using z
registers
assembly = f.get_source("asm")
@@ -261,7 +275,9 @@ def test_mod(dtype):
A = te.placeholder(m, dtype=type, name="A")
B = te.placeholder(m, dtype=type, name="B")
C = te.compute((m), lambda i: tvm.te.floormod(A[i], B[i]), name="C")
- f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+ with tvm.target.Target(target):
+ f = tvm.tir.build(te.create_prim_func([A, B, C]))
# Verify we see SVE load instructions and mls instructions using z
registers
assembly = f.get_source("asm")
@@ -291,7 +307,9 @@ def test_eq(dtype):
A = te.placeholder(m, dtype=type, name="A")
B = te.placeholder(m, dtype=type, name="B")
C = te.compute((m), lambda i: A[i] == B[i], name="C")
- f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+ with tvm.target.Target(target):
+ f = tvm.tir.build(te.create_prim_func([A, B, C]))
# Verify we see SVE load instructions and cmpeq or cmeq instructions
using z registers
assembly = f.get_source("asm")
@@ -321,7 +339,9 @@ def test_neq(dtype):
A = te.placeholder(m, dtype=type, name="A")
B = te.placeholder(m, dtype=type, name="B")
C = te.compute((m), lambda i: A[i] != B[i], name="C")
- f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+ with tvm.target.Target(target):
+ f = tvm.tir.build(te.create_prim_func([A, B, C]))
# Verify we see SVE load instructions and cmpgt, cmgt, cmpne or cmne
instructions, all using z registers
assembly = f.get_source("asm")
@@ -350,7 +370,9 @@ def test_or(dtype):
A = te.placeholder(m, dtype=type, name="A")
B = te.placeholder(m, dtype=type, name="B")
C = te.compute((m), lambda i: A[i] | B[i], name="C")
- f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+ with tvm.target.Target(target):
+ f = tvm.tir.build(te.create_prim_func([A, B, C]))
# Verify we see SVE load instructions and orr instructions using z
registers
assembly = f.get_source("asm")
@@ -379,7 +401,9 @@ def test_and(dtype):
A = te.placeholder(m, dtype=type, name="A")
B = te.placeholder(m, dtype=type, name="B")
C = te.compute((m), lambda i: A[i] & B[i], name="C")
- f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+ with tvm.target.Target(target):
+ f = tvm.tir.build(te.create_prim_func([A, B, C]))
# Verify we see SVE load instructions and and instructions using z
registers
assembly = f.get_source("asm")
@@ -407,7 +431,9 @@ def test_not(dtype):
m = te.var("m")
A = te.placeholder(m, dtype=type, name="A")
C = te.compute((m), lambda i: ~A[i], name="C")
- f = tvm.tir.build(te.create_prim_func([A, C]), target=target)
+
+ with tvm.target.Target(target):
+ f = tvm.tir.build(te.create_prim_func([A, C]))
# Verify we see SVE load instructions and eor instructions using z
registers
assembly = f.get_source("asm")
@@ -440,7 +466,9 @@ def test_memcpy(dtype):
A = te.placeholder(m, dtype=type, name="A")
B = te.placeholder(m, dtype="int32", name="B")
C = te.compute((m), lambda i: A[B[i]], name="C")
- f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+ with tvm.target.Target(target):
+ f = tvm.tir.build(te.create_prim_func([A, B, C]))
# Verify we see gather instructions in the assembly
assembly = f.get_source("asm")
@@ -451,65 +479,6 @@ def test_memcpy(dtype):
check_correct_assembly(type=dtype)
[email protected](
- llvm_version_major() < 11, reason="Vscale is not supported in earlier
versions of LLVM"
-)
-def test_codegen_vscale():
- target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
- vscale = tvm.tir.vscale()
-
- @T.prim_func
- def main(A: T.Buffer((5,), "int32")):
- for i in range(5):
- A[i] = 2 * vscale
-
- build_mod = tvm.tir.build(main, target=target)
- llvm = build_mod.get_source()
-
- assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM."
-
-
[email protected](
- llvm_version_major() < 11, reason="Vscale is not supported in earlier
versions of LLVM"
-)
-def test_scalable_buffer_load_store():
- target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
-
- @T.prim_func
- def my_func(a: T.handle, b: T.handle):
- A = T.match_buffer(a, (128,), "float32")
- B = T.match_buffer(b, (128,), "float32")
- T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
- B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())]
-
- mod = tvm.tir.build(my_func, target=target)
- llvm = mod.get_source("ll")
-
- assert re.findall(r"load <vscale x 4 x float>", llvm), "No scalable load
in generated LLVM."
- assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable
store in generated LLVM."
-
-
[email protected](
- llvm_version_major() < 11, reason="Vscale is not supported in earlier
versions of LLVM"
-)
-def test_scalable_broadcast():
- target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
-
- @T.prim_func
- def my_func(a: T.handle):
- A = T.match_buffer(a, (128,), "float32")
- T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
- A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale())
-
- mod = tvm.tir.build(my_func, target=target)
- llvm = mod.get_source("ll")
-
- assert re.findall(
- r"shufflevector \(<vscale x 4 x float> insertelement \(<vscale x 4 x
float>", llvm
- ), "No scalable broadcast in generated LLVM."
- assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable
store in generated LLVM."
-
-
@pytest.mark.skipif(
llvm_version_major() < 13,
reason="Function attribute vscale_range() is not supported in earlier
versions of LLVM",
@@ -529,7 +498,9 @@ def test_vscale_range_function_attribute(mattr,
expect_attr):
m = te.var("m")
A = te.placeholder(m, dtype="float32", name="A")
C = te.compute((m), lambda i: A[i] + 1, name="C")
- f = tvm.tir.build(te.create_prim_func([A, C]), target=target)
+
+ with tvm.target.Target(target):
+ f = tvm.tir.build(te.create_prim_func([A, C]))
# Check if the vscale_range() attribute exists
ll = f.get_source("ll")
@@ -545,49 +516,5 @@ def test_vscale_range_function_attribute(mattr,
expect_attr):
), f"Unexpected function attribute vscale_range() was found in
generated LLVM IR"
[email protected](
- reason="Vscale and get.active.lane.mask are not supported in earlier
versions of LLVM",
-)
-def test_get_active_lane_mask():
- target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
-
- @T.prim_func
- def before(a: T.handle):
- A = T.match_buffer(a, (30,), "int1")
- for i in range(T.ceildiv(30, T.vscale() * 4)):
- A[i : i + T.vscale() * 4] =
T.get_active_lane_mask("uint1xvscalex4", i, 30)
-
- with tvm.target.Target(target):
- out = tvm.tir.build(before)
-
- ll = out.get_source("ll")
- assert "get.active.lane.mask" in ll
-
-
[email protected](
- reason="Vscale and get.active.lane.mask are not supported in earlier
versions of LLVM",
-)
-def test_predicated_scalable_buffer():
- target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
-
- @T.prim_func
- def before(a: T.handle, b: T.handle):
- A = T.match_buffer(a, (16,), "float32")
- B = T.match_buffer(b, (16,), "float32")
- T.func_attr({"global_symbol": "main", "tir.noalias": True})
- for i_0 in T.serial(T.ceildiv(16, 4 * T.vscale())):
- for i_1 in T.vectorized(4 * T.vscale()):
- if i_0 * 4 * T.vscale() + i_1 < 14:
- B[i_0 * 4 * T.vscale() + i_1] = A[i_0 * 4 * T.vscale() +
i_1] + 1.0
-
- with tvm.target.Target(target):
- out = tvm.tir.build(before)
-
- ll = out.get_source("ll")
- assert "get.active.lane.mask" in ll
- assert "llvm.masked.load" in ll
- assert "llvm.masked.store" in ll
-
-
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/codegen/test_target_codegen_llvm_vla.py
b/tests/python/codegen/test_target_codegen_llvm_vla.py
new file mode 100644
index 0000000000..7ca3083dd5
--- /dev/null
+++ b/tests/python/codegen/test_target_codegen_llvm_vla.py
@@ -0,0 +1,149 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Codegen tests for VLA extensions
+"""
+
+import re
+import pytest
+
+import tvm
+from tvm import te
+from tvm.script import tir as T
+from tvm.target.codegen import llvm_version_major
+
+
[email protected](
+ llvm_version_major() < 11, reason="Vscale is not supported in earlier
versions of LLVM"
+)
[email protected]_targets(
+ "llvm -mtriple=aarch64-linux-gnu -mattr=+sve",
+ "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64
-mattr=+64bit,+a,+c,+d,+f,+m,+v",
+)
+def test_codegen_vscale(target):
+ vscale = tvm.tir.vscale()
+
+ @T.prim_func
+ def main(A: T.Buffer((5,), "int32")):
+ for i in range(5):
+ A[i] = 2 * vscale
+
+ with tvm.target.Target(target):
+ build_mod = tvm.tir.build(main)
+
+ llvm = build_mod.get_source()
+ assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM."
+
+
[email protected](
+ llvm_version_major() < 11, reason="Vscale is not supported in earlier
versions of LLVM"
+)
[email protected]_targets(
+ "llvm -mtriple=aarch64-linux-gnu -mattr=+sve",
+ "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64
-mattr=+64bit,+a,+c,+d,+f,+m,+v",
+)
+def test_scalable_buffer_load_store(target):
+ @T.prim_func
+ def my_func(a: T.handle, b: T.handle):
+ A = T.match_buffer(a, (128,), "float32")
+ B = T.match_buffer(b, (128,), "float32")
+ T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+ B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())]
+
+ with tvm.target.Target(target):
+ mod = tvm.tir.build(my_func)
+
+ llvm = mod.get_source("ll")
+ assert re.findall(r"load <vscale x 4 x float>", llvm), "No scalable load
in generated LLVM."
+ assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable
store in generated LLVM."
+
+
[email protected](
+ llvm_version_major() < 11, reason="Vscale is not supported in earlier
versions of LLVM"
+)
[email protected]_targets(
+ "llvm -mtriple=aarch64-linux-gnu -mattr=+sve",
+ "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64
-mattr=+64bit,+a,+c,+d,+f,+m,+v",
+)
+def test_scalable_broadcast(target):
+ @T.prim_func
+ def my_func(a: T.handle):
+ A = T.match_buffer(a, (128,), "float32")
+ T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+ A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale())
+
+ with tvm.target.Target(target):
+ mod = tvm.tir.build(my_func)
+
+ llvm = mod.get_source("ll")
+ assert re.findall(
+ r"shufflevector \(<vscale x 4 x float> insertelement \(<vscale x 4 x
float>", llvm
+ ), "No scalable broadcast in generated LLVM."
+ assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable
store in generated LLVM."
+
+
[email protected](
+ reason="Vscale and get.active.lane.mask are not supported in earlier
versions of LLVM",
+)
[email protected]_targets(
+ "llvm -mtriple=aarch64-linux-gnu -mattr=+sve",
+ "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64
-mattr=+64bit,+a,+c,+d,+f,+m,+v",
+)
+def test_get_active_lane_mask(target):
+ @T.prim_func
+ def before(a: T.handle):
+ A = T.match_buffer(a, (30,), "int1")
+ for i in range(T.ceildiv(30, T.vscale() * 4)):
+ A[i : i + T.vscale() * 4] =
T.get_active_lane_mask("uint1xvscalex4", i, 30)
+
+ with tvm.target.Target(target):
+ out = tvm.tir.build(before)
+
+ ll = out.get_source("ll")
+ assert "get.active.lane.mask" in ll
+
+
[email protected](
+ reason="Vscale and get.active.lane.mask are not supported in earlier
versions of LLVM",
+)
[email protected]_targets(
+ "llvm -mtriple=aarch64-linux-gnu -mattr=+sve",
+ "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64
-mattr=+64bit,+a,+c,+d,+f,+m,+v",
+)
+def test_predicated_scalable_buffer(target):
+ @T.prim_func
+ def before(a: T.handle, b: T.handle):
+ A = T.match_buffer(a, (16,), "float32")
+ B = T.match_buffer(b, (16,), "float32")
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ for i_0 in T.serial(T.ceildiv(16, 4 * T.vscale())):
+ for i_1 in T.vectorized(4 * T.vscale()):
+ if i_0 * 4 * T.vscale() + i_1 < 14:
+ B[i_0 * 4 * T.vscale() + i_1] = A[i_0 * 4 * T.vscale() +
i_1] + 1.0
+
+ with tvm.target.Target(target):
+ out = tvm.tir.build(before)
+
+ ll = out.get_source("ll")
+ assert "get.active.lane.mask" in ll
+ assert "llvm.masked.load" in ll
+ assert "llvm.masked.store" in ll
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py
b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py
index 22344acfe1..f09f7417ba 100644
--- a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py
+++ b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py
@@ -816,7 +816,7 @@ def test_unsupported_target_scalable_split(capfd):
warning_msg = (
"Warning: The expression contains scalable values. An attempt to prove
by substituting "
"with known values of vscale was not performed. This proof currently
only supports "
- "AArch64 SVE targets, but the target was "
+ "VLA targets, but the target was "
)
captured = capfd.readouterr().err
assert warning_msg in captured
diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py
b/tests/python/tir-transform/test_tir_transform_vectorize.py
index 9b61255285..13bb1c60cb 100644
--- a/tests/python/tir-transform/test_tir_transform_vectorize.py
+++ b/tests/python/tir-transform/test_tir_transform_vectorize.py
@@ -670,7 +670,7 @@ def test_vectorize_and_predicate_invalid_conditions():
def test_vectorize_with_explicitly_disabled_buffer_level_predication():
- # Since the target has the SVE feature, buffer level predication is enabled
+ # Since the target has the VLA feature, buffer level predication is enabled
# by default. However, it has been explicitly disabled by the pass context
# option, so no buffer-level predicates should be added.
@T.prim_func