This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 01268ac089 [LLVM][Codegen] Enable SVE/VLA for RISCV targets
01268ac089 is described below

commit 01268ac089560ad0e59666db6c6e2cc9571b593a
Author: Balint Cristian <[email protected]>
AuthorDate: Tue May 13 07:34:05 2025 +0300

    [LLVM][Codegen] Enable SVE/VLA for RISCV targets
---
 src/arith/analyzer.cc                              |   9 +-
 src/arith/const_int_bound.cc                       |   7 +-
 src/arith/scalable_expression.cc                   |  33 ++++-
 src/arith/scalable_expression.h                    |  15 +-
 src/target/llvm/codegen_aarch64.cc                 |   4 +-
 src/tir/transforms/vectorize_loop.cc               |   6 +-
 tests/python/arith/test_arith_simplify.py          |   2 +-
 .../python/codegen/test_target_codegen_aarch64.py  | 163 ++++++---------------
 .../python/codegen/test_target_codegen_llvm_vla.py | 149 +++++++++++++++++++
 .../tir-schedule/test_tir_schedule_split_fuse.py   |   2 +-
 .../tir-transform/test_tir_transform_vectorize.py  |   2 +-
 11 files changed, 251 insertions(+), 141 deletions(-)

diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 602a198a2b..10c7676282 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -231,17 +231,18 @@ bool Analyzer::CanProve(const PrimExpr& expr, 
ProofStrength strength) {
   // Current analysis may not be powerful enough to prove expressions 
containing
   // the same symbolic value multiple times. However, when the symbolic values 
are
   // "T.vscale" and the compile target uses a scalable architecture extension 
like
-  // SVE, we can make some assumptions about the value of vscale and iterate 
over a
+  // VLA, we can make some assumptions about the value of vscale and iterate 
over a
   // space of pre-defined values to attempt to prove the expression.
   Target curr_target = Target::Current();
   if (ContainsVscaleCall(simplified)) {
-    if (TargetHasSVE(curr_target)) {
-      return CanProveVscaleExpressionFromKnownValues(this, simplified, 
kAArch64VScaleValues);
+    if (TargetHasVLA(curr_target)) {
+      auto kVScaleValues = GetVScaleValues(curr_target);
+      return CanProveVscaleExpressionFromKnownValues(this, simplified, 
kVScaleValues);
     }
     LOG(WARNING)
         << "The expression contains scalable values. An attempt to prove by 
substituting "
            "with known values of vscale was not performed. This proof 
currently only supports "
-           "AArch64 SVE targets, but the target was "
+           "VLA targets, but the target was "
         << curr_target;
   }
   return false;
diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index ac8ac91711..7409ecc6f3 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -364,15 +364,16 @@ class ConstIntBoundAnalyzer::Impl
     // only special handle >> and & which can be
     // used for index calculation.
 
+    auto curr_target = Target::Current();
     if (op->op.same_as(tir::builtin::shift_right())) {
       return VisitRightShift(op);
     } else if (op->op.same_as(tir::builtin::shift_left())) {
       return VisitLeftShift(op);
     } else if (op->op.same_as(tir::builtin::bitwise_and())) {
       return VisitBitwiseAnd(op);
-    } else if (op->op.same_as(tir::builtin::vscale()) && 
TargetHasSVE(Target::Current())) {
-      unsigned int max_val =
-          *std::max_element(kAArch64VScaleValues.begin(), 
kAArch64VScaleValues.end());
+    } else if (op->op.same_as(tir::builtin::vscale()) && 
TargetHasVLA(curr_target)) {
+      auto kVScaleValues = GetVScaleValues(curr_target);
+      unsigned int max_val = *std::max_element(kVScaleValues.begin(), 
kVScaleValues.end());
       return MakeBound(1, max_val);
     } else {
       return Everything(op->dtype);
diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc
index beb75c1f3e..1937b9c34e 100644
--- a/src/arith/scalable_expression.cc
+++ b/src/arith/scalable_expression.cc
@@ -86,14 +86,41 @@ bool 
CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr
   return can_prove_expr;
 }
 
-bool TargetHasSVE(Optional<Target> target) {
+bool TargetHasVLA(Optional<Target> target) {
   if (!target.defined()) {
     target = Target::Current();
   }
+  bool has_vla{false};
   if (target.defined()) {
-    return 
Downcast<Target>(target)->GetFeature<Bool>("has_sve").value_or(Bool(false));
+    // aarch64
+    has_vla = 
Downcast<Target>(target)->GetFeature<Bool>("has_sve").value_or(Bool(false));
+    // riscv{32,64}
+    static auto target_has_feature_fn =
+        tvm::ffi::Function::GetGlobalRequired("target.target_has_feature");
+    has_vla |= target_has_feature_fn("v", target).cast<bool>();
   }
-  return false;
+  return has_vla;
+}
+
+const std::vector<unsigned int> GetVScaleValues(Optional<Target> target) {
+  unsigned int vector_width = 0;
+  std::vector<unsigned int> kVScaleValues;
+  if (!target.defined()) {
+    target = Target::Current();
+  }
+  if (target.defined()) {
+    static auto llvm_get_vector_width_fn =
+        tvm::ffi::Function::GetGlobalRequired("target.llvm_get_vector_width");
+    vector_width = llvm_get_vector_width_fn(target).cast<int>();
+  }
+  // scale list with powers of two
+  for (unsigned int i = 0;; ++i) {
+    auto power = static_cast<unsigned int>(std::pow(2, i));
+    if (power > (vector_width / 8)) break;
+    kVScaleValues.push_back(power);
+  }
+
+  return kVScaleValues;
 }
 
 }  // namespace arith
diff --git a/src/arith/scalable_expression.h b/src/arith/scalable_expression.h
index 70d753a299..2470d5dcd8 100644
--- a/src/arith/scalable_expression.h
+++ b/src/arith/scalable_expression.h
@@ -35,9 +35,6 @@
 namespace tvm {
 namespace arith {
 
-/*! \brief A list of known vscale values to try for an AArch64 SVE target. */
-static const std::vector<unsigned int> kAArch64VScaleValues = {1, 2, 4, 8, 16};
-
 /*!
  * \brief Check if an expr is a call to the vscale intrinsic.
  * \param expr The expr to check
@@ -80,10 +77,18 @@ bool 
CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr
 
 /*!
  * \brief Check whether the compilation target supports SVE
+ * \brief Check whether the compilation target supports VLA
+ * \param target The target to check.
+ * \return Whether VLA is supported
+ */
+bool TargetHasVLA(Optional<Target> target = std::nullopt);
+
+/*!
+ * \brief Get a list of known vscale values to try for an VLA target.
  * \param target The target to check.
- * \return Whether SVE is supported
+ * \return A list of vscale values as std::vector<usigned int>
  */
-bool TargetHasSVE(Optional<Target> target = std::nullopt);
+const std::vector<unsigned int> GetVScaleValues(Optional<Target> target = 
std::nullopt);
 
 }  // namespace arith
 }  // namespace tvm
diff --git a/src/target/llvm/codegen_aarch64.cc 
b/src/target/llvm/codegen_aarch64.cc
index 1399fc083a..b690c0fc28 100644
--- a/src/target/llvm/codegen_aarch64.cc
+++ b/src/target/llvm/codegen_aarch64.cc
@@ -57,8 +57,8 @@ void CodeGenAArch64::SetTargetAttributes(llvm::Function* 
func) {
 #if TVM_LLVM_VERSION >= 130
   // Add vscale_range() function attribute when appropriate.
   if (llvm_target_->TargetHasCPUFeature("sve") || 
llvm_target_->TargetHasCPUFeature("sme")) {
-    unsigned int max_val =
-        *std::max_element(arith::kAArch64VScaleValues.begin(), 
arith::kAArch64VScaleValues.end());
+    auto kVScaleValues = arith::GetVScaleValues(Target::Current());
+    unsigned int max_val = *std::max_element(kVScaleValues.begin(), 
kVScaleValues.end());
     func->addFnAttr(
         llvm::Attribute::getWithVScaleRangeArgs(*llvm_target_->GetContext(), 
1, max_val));
   }
diff --git a/src/tir/transforms/vectorize_loop.cc 
b/src/tir/transforms/vectorize_loop.cc
index 3df73f0edb..54b2daf836 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -80,8 +80,8 @@ bool EnableBufferLevelPredication(Target target) {
     return enable_buffer_predication.value();
   }
 
-  // Use buffer-level predication by default for AArch64 SVE targets
-  return arith::TargetHasSVE(target);
+  // Use buffer-level predication by default for VLA targets
+  return arith::TargetHasVLA(target);
 }
 
 /*!
@@ -972,7 +972,7 @@ class LoopVectorizer : public StmtMutator {
 
       if (!extent_as_int || extent_as_int->value < 1) {
         bool is_scalable_expr = CheckContains::ExprContains(op->extent, 
arith::IsVScaleCall);
-        ICHECK(is_scalable_expr && arith::TargetHasSVE(target_))
+        ICHECK(is_scalable_expr && arith::TargetHasVLA(target_))
             << "Failed to vectorize loop with extent " << op->extent << " for 
target " << target_;
       }
       ICHECK(is_zero(op->min));
diff --git a/tests/python/arith/test_arith_simplify.py 
b/tests/python/arith/test_arith_simplify.py
index 3b02377400..4971acbd45 100644
--- a/tests/python/arith/test_arith_simplify.py
+++ b/tests/python/arith/test_arith_simplify.py
@@ -113,7 +113,7 @@ def 
test_simplify_vscale_comparison_without_sve_target(capfd):
     warning_msg = (
         "Warning: The expression contains scalable values. An attempt to prove 
by substituting "
         "with known values of vscale was not performed. This proof currently 
only supports "
-        "AArch64 SVE targets, but the target was llvm -keys=arm_cpu,cpu 
-mtriple=aarch64-linux-gnu"
+        "VLA targets, but the target was llvm -keys=arm_cpu,cpu 
-mtriple=aarch64-linux-gnu"
     )
     capture = capfd.readouterr().err
     assert warning_msg in capture
diff --git a/tests/python/codegen/test_target_codegen_aarch64.py 
b/tests/python/codegen/test_target_codegen_aarch64.py
index 43870044d5..2c8f185d8e 100644
--- a/tests/python/codegen/test_target_codegen_aarch64.py
+++ b/tests/python/codegen/test_target_codegen_aarch64.py
@@ -43,7 +43,9 @@ def test_mul(dtype):
         A = te.placeholder(m, dtype=type, name="A")
         B = te.placeholder(m, dtype=type, name="B")
         C = te.compute((m), lambda i: A[i] * B[i], name="C")
-        f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+        with tvm.target.Target(target):
+            f = tvm.tir.build(te.create_prim_func([A, B, C]))
 
         # Verify we see SVE load instructions and mul instructions using z 
registers
         assembly = f.get_source("asm")
@@ -73,7 +75,9 @@ def test_add(dtype):
         A = te.placeholder(m, dtype=type, name="A")
         B = te.placeholder(m, dtype=type, name="B")
         C = te.compute((m), lambda i: A[i] + B[i], name="C")
-        f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+        with tvm.target.Target(target):
+            f = tvm.tir.build(te.create_prim_func([A, B, C]))
 
         # Verify we see SVE load instructions and add instructions using z 
registers
         assembly = f.get_source("asm")
@@ -103,7 +107,9 @@ def test_sub(dtype):
         A = te.placeholder(m, dtype=type, name="A")
         B = te.placeholder(m, dtype=type, name="B")
         C = te.compute((m), lambda i: A[i] - B[i], name="C")
-        f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+        with tvm.target.Target(target):
+            f = tvm.tir.build(te.create_prim_func([A, B, C]))
 
         # Verify we see SVE load instructions and sub instructions using z 
registers
         assembly = f.get_source("asm")
@@ -134,7 +140,9 @@ def test_muladd(dtype):
         B = te.placeholder(m, dtype=type, name="B")
         C = te.placeholder(m, dtype=type, name="C")
         D = te.compute((m), lambda i: A[i] * B[i] + C[i], name="D")
-        f = tvm.tir.build(te.create_prim_func([A, B, C, D]), target=target)
+
+        with tvm.target.Target(target):
+            f = tvm.tir.build(te.create_prim_func([A, B, C, D]))
 
         # Verify we see SVE load instructions and either mad or mla 
instructions using z registers
         assembly = f.get_source("asm")
@@ -164,7 +172,9 @@ def test_max(dtype):
         A = te.placeholder(m, dtype=type, name="A")
         B = te.placeholder(m, dtype=type, name="B")
         C = te.compute((m), lambda i: tvm.te.max(A[i], B[i]))
-        f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+        with tvm.target.Target(target):
+            f = tvm.tir.build(te.create_prim_func([A, B, C]))
 
         # Verify we see SVE load instructions and cmgt + sel instructions or a 
max instruction, all using z registers
         assembly = f.get_source("asm")
@@ -198,7 +208,9 @@ def test_min(dtype):
         A = te.placeholder(m, dtype=type, name="A")
         B = te.placeholder(m, dtype=type, name="B")
         C = te.compute((m), lambda i: tvm.te.min(A[i], B[i]))
-        f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+        with tvm.target.Target(target):
+            f = tvm.tir.build(te.create_prim_func([A, B, C]))
 
         # Verify we see SVE load instructions and cmgt + sel instructions or a 
min instruction, all using z registers
         assembly = f.get_source("asm")
@@ -232,7 +244,9 @@ def test_div(dtype):
         A = te.placeholder(m, dtype=type, name="A")
         B = te.placeholder(m, dtype=type, name="B")
         C = te.compute((m), lambda i: tvm.te.div(A[i], B[i]))
-        f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+        with tvm.target.Target(target):
+            f = tvm.tir.build(te.create_prim_func([A, B, C]))
 
         # Verify we see SVE load instructions and div instructions using z 
registers
         assembly = f.get_source("asm")
@@ -261,7 +275,9 @@ def test_mod(dtype):
         A = te.placeholder(m, dtype=type, name="A")
         B = te.placeholder(m, dtype=type, name="B")
         C = te.compute((m), lambda i: tvm.te.floormod(A[i], B[i]), name="C")
-        f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+        with tvm.target.Target(target):
+            f = tvm.tir.build(te.create_prim_func([A, B, C]))
 
         # Verify we see SVE load instructions and mls instructions using z 
registers
         assembly = f.get_source("asm")
@@ -291,7 +307,9 @@ def test_eq(dtype):
         A = te.placeholder(m, dtype=type, name="A")
         B = te.placeholder(m, dtype=type, name="B")
         C = te.compute((m), lambda i: A[i] == B[i], name="C")
-        f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+        with tvm.target.Target(target):
+            f = tvm.tir.build(te.create_prim_func([A, B, C]))
 
         # Verify we see SVE load instructions and cmpeq or cmeq instructions 
using z registers
         assembly = f.get_source("asm")
@@ -321,7 +339,9 @@ def test_neq(dtype):
         A = te.placeholder(m, dtype=type, name="A")
         B = te.placeholder(m, dtype=type, name="B")
         C = te.compute((m), lambda i: A[i] != B[i], name="C")
-        f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+        with tvm.target.Target(target):
+            f = tvm.tir.build(te.create_prim_func([A, B, C]))
 
         # Verify we see SVE load instructions and cmpgt, cmgt, cmpne or cmne 
instructions, all using z registers
         assembly = f.get_source("asm")
@@ -350,7 +370,9 @@ def test_or(dtype):
         A = te.placeholder(m, dtype=type, name="A")
         B = te.placeholder(m, dtype=type, name="B")
         C = te.compute((m), lambda i: A[i] | B[i], name="C")
-        f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+        with tvm.target.Target(target):
+            f = tvm.tir.build(te.create_prim_func([A, B, C]))
 
         # Verify we see SVE load instructions and orr instructions using z 
registers
         assembly = f.get_source("asm")
@@ -379,7 +401,9 @@ def test_and(dtype):
         A = te.placeholder(m, dtype=type, name="A")
         B = te.placeholder(m, dtype=type, name="B")
         C = te.compute((m), lambda i: A[i] & B[i], name="C")
-        f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+        with tvm.target.Target(target):
+            f = tvm.tir.build(te.create_prim_func([A, B, C]))
 
         # Verify we see SVE load instructions and and instructions using z 
registers
         assembly = f.get_source("asm")
@@ -407,7 +431,9 @@ def test_not(dtype):
         m = te.var("m")
         A = te.placeholder(m, dtype=type, name="A")
         C = te.compute((m), lambda i: ~A[i], name="C")
-        f = tvm.tir.build(te.create_prim_func([A, C]), target=target)
+
+        with tvm.target.Target(target):
+            f = tvm.tir.build(te.create_prim_func([A, C]))
 
         # Verify we see SVE load instructions and eor instructions using z 
registers
         assembly = f.get_source("asm")
@@ -440,7 +466,9 @@ def test_memcpy(dtype):
         A = te.placeholder(m, dtype=type, name="A")
         B = te.placeholder(m, dtype="int32", name="B")
         C = te.compute((m), lambda i: A[B[i]], name="C")
-        f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target)
+
+        with tvm.target.Target(target):
+            f = tvm.tir.build(te.create_prim_func([A, B, C]))
 
         # Verify we see gather instructions in the assembly
         assembly = f.get_source("asm")
@@ -451,65 +479,6 @@ def test_memcpy(dtype):
     check_correct_assembly(type=dtype)
 
 
[email protected](
-    llvm_version_major() < 11, reason="Vscale is not supported in earlier 
versions of LLVM"
-)
-def test_codegen_vscale():
-    target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
-    vscale = tvm.tir.vscale()
-
-    @T.prim_func
-    def main(A: T.Buffer((5,), "int32")):
-        for i in range(5):
-            A[i] = 2 * vscale
-
-    build_mod = tvm.tir.build(main, target=target)
-    llvm = build_mod.get_source()
-
-    assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM."
-
-
[email protected](
-    llvm_version_major() < 11, reason="Vscale is not supported in earlier 
versions of LLVM"
-)
-def test_scalable_buffer_load_store():
-    target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
-
-    @T.prim_func
-    def my_func(a: T.handle, b: T.handle):
-        A = T.match_buffer(a, (128,), "float32")
-        B = T.match_buffer(b, (128,), "float32")
-        T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
-        B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())]
-
-    mod = tvm.tir.build(my_func, target=target)
-    llvm = mod.get_source("ll")
-
-    assert re.findall(r"load <vscale x 4 x float>", llvm), "No scalable load 
in generated LLVM."
-    assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable 
store in generated LLVM."
-
-
[email protected](
-    llvm_version_major() < 11, reason="Vscale is not supported in earlier 
versions of LLVM"
-)
-def test_scalable_broadcast():
-    target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
-
-    @T.prim_func
-    def my_func(a: T.handle):
-        A = T.match_buffer(a, (128,), "float32")
-        T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
-        A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale())
-
-    mod = tvm.tir.build(my_func, target=target)
-    llvm = mod.get_source("ll")
-
-    assert re.findall(
-        r"shufflevector \(<vscale x 4 x float> insertelement \(<vscale x 4 x 
float>", llvm
-    ), "No scalable broadcast in generated LLVM."
-    assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable 
store in generated LLVM."
-
-
 @pytest.mark.skipif(
     llvm_version_major() < 13,
     reason="Function attribute vscale_range() is not supported in earlier 
versions of LLVM",
@@ -529,7 +498,9 @@ def test_vscale_range_function_attribute(mattr, 
expect_attr):
     m = te.var("m")
     A = te.placeholder(m, dtype="float32", name="A")
     C = te.compute((m), lambda i: A[i] + 1, name="C")
-    f = tvm.tir.build(te.create_prim_func([A, C]), target=target)
+
+    with tvm.target.Target(target):
+        f = tvm.tir.build(te.create_prim_func([A, C]))
 
     # Check if the vscale_range() attribute exists
     ll = f.get_source("ll")
@@ -545,49 +516,5 @@ def test_vscale_range_function_attribute(mattr, 
expect_attr):
         ), f"Unexpected function attribute vscale_range() was found in 
generated LLVM IR"
 
 
[email protected](
-    reason="Vscale and get.active.lane.mask are not supported in earlier 
versions of LLVM",
-)
-def test_get_active_lane_mask():
-    target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
-
-    @T.prim_func
-    def before(a: T.handle):
-        A = T.match_buffer(a, (30,), "int1")
-        for i in range(T.ceildiv(30, T.vscale() * 4)):
-            A[i : i + T.vscale() * 4] = 
T.get_active_lane_mask("uint1xvscalex4", i, 30)
-
-    with tvm.target.Target(target):
-        out = tvm.tir.build(before)
-
-    ll = out.get_source("ll")
-    assert "get.active.lane.mask" in ll
-
-
[email protected](
-    reason="Vscale and get.active.lane.mask are not supported in earlier 
versions of LLVM",
-)
-def test_predicated_scalable_buffer():
-    target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
-
-    @T.prim_func
-    def before(a: T.handle, b: T.handle):
-        A = T.match_buffer(a, (16,), "float32")
-        B = T.match_buffer(b, (16,), "float32")
-        T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        for i_0 in T.serial(T.ceildiv(16, 4 * T.vscale())):
-            for i_1 in T.vectorized(4 * T.vscale()):
-                if i_0 * 4 * T.vscale() + i_1 < 14:
-                    B[i_0 * 4 * T.vscale() + i_1] = A[i_0 * 4 * T.vscale() + 
i_1] + 1.0
-
-    with tvm.target.Target(target):
-        out = tvm.tir.build(before)
-
-    ll = out.get_source("ll")
-    assert "get.active.lane.mask" in ll
-    assert "llvm.masked.load" in ll
-    assert "llvm.masked.store" in ll
-
-
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/codegen/test_target_codegen_llvm_vla.py 
b/tests/python/codegen/test_target_codegen_llvm_vla.py
new file mode 100644
index 0000000000..7ca3083dd5
--- /dev/null
+++ b/tests/python/codegen/test_target_codegen_llvm_vla.py
@@ -0,0 +1,149 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Codegen tests for VLA extensions
+"""
+
+import re
+import pytest
+
+import tvm
+from tvm import te
+from tvm.script import tir as T
+from tvm.target.codegen import llvm_version_major
+
+
[email protected](
+    llvm_version_major() < 11, reason="Vscale is not supported in earlier 
versions of LLVM"
+)
[email protected]_targets(
+    "llvm -mtriple=aarch64-linux-gnu -mattr=+sve",
+    "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 
-mattr=+64bit,+a,+c,+d,+f,+m,+v",
+)
+def test_codegen_vscale(target):
+    vscale = tvm.tir.vscale()
+
+    @T.prim_func
+    def main(A: T.Buffer((5,), "int32")):
+        for i in range(5):
+            A[i] = 2 * vscale
+
+    with tvm.target.Target(target):
+        build_mod = tvm.tir.build(main)
+
+    llvm = build_mod.get_source()
+    assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM."
+
+
[email protected](
+    llvm_version_major() < 11, reason="Vscale is not supported in earlier 
versions of LLVM"
+)
[email protected]_targets(
+    "llvm -mtriple=aarch64-linux-gnu -mattr=+sve",
+    "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 
-mattr=+64bit,+a,+c,+d,+f,+m,+v",
+)
+def test_scalable_buffer_load_store(target):
+    @T.prim_func
+    def my_func(a: T.handle, b: T.handle):
+        A = T.match_buffer(a, (128,), "float32")
+        B = T.match_buffer(b, (128,), "float32")
+        T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+        B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())]
+
+    with tvm.target.Target(target):
+        mod = tvm.tir.build(my_func)
+
+    llvm = mod.get_source("ll")
+    assert re.findall(r"load <vscale x 4 x float>", llvm), "No scalable load 
in generated LLVM."
+    assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable 
store in generated LLVM."
+
+
[email protected](
+    llvm_version_major() < 11, reason="Vscale is not supported in earlier 
versions of LLVM"
+)
[email protected]_targets(
+    "llvm -mtriple=aarch64-linux-gnu -mattr=+sve",
+    "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 
-mattr=+64bit,+a,+c,+d,+f,+m,+v",
+)
+def test_scalable_broadcast(target):
+    @T.prim_func
+    def my_func(a: T.handle):
+        A = T.match_buffer(a, (128,), "float32")
+        T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+        A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale())
+
+    with tvm.target.Target(target):
+        mod = tvm.tir.build(my_func)
+
+    llvm = mod.get_source("ll")
+    assert re.findall(
+        r"shufflevector \(<vscale x 4 x float> insertelement \(<vscale x 4 x 
float>", llvm
+    ), "No scalable broadcast in generated LLVM."
+    assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable 
store in generated LLVM."
+
+
[email protected](
+    reason="Vscale and get.active.lane.mask are not supported in earlier 
versions of LLVM",
+)
[email protected]_targets(
+    "llvm -mtriple=aarch64-linux-gnu -mattr=+sve",
+    "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 
-mattr=+64bit,+a,+c,+d,+f,+m,+v",
+)
+def test_get_active_lane_mask(target):
+    @T.prim_func
+    def before(a: T.handle):
+        A = T.match_buffer(a, (30,), "int1")
+        for i in range(T.ceildiv(30, T.vscale() * 4)):
+            A[i : i + T.vscale() * 4] = 
T.get_active_lane_mask("uint1xvscalex4", i, 30)
+
+    with tvm.target.Target(target):
+        out = tvm.tir.build(before)
+
+    ll = out.get_source("ll")
+    assert "get.active.lane.mask" in ll
+
+
[email protected](
+    reason="Vscale and get.active.lane.mask are not supported in earlier 
versions of LLVM",
+)
[email protected]_targets(
+    "llvm -mtriple=aarch64-linux-gnu -mattr=+sve",
+    "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 
-mattr=+64bit,+a,+c,+d,+f,+m,+v",
+)
+def test_predicated_scalable_buffer(target):
+    @T.prim_func
+    def before(a: T.handle, b: T.handle):
+        A = T.match_buffer(a, (16,), "float32")
+        B = T.match_buffer(b, (16,), "float32")
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        for i_0 in T.serial(T.ceildiv(16, 4 * T.vscale())):
+            for i_1 in T.vectorized(4 * T.vscale()):
+                if i_0 * 4 * T.vscale() + i_1 < 14:
+                    B[i_0 * 4 * T.vscale() + i_1] = A[i_0 * 4 * T.vscale() + 
i_1] + 1.0
+
+    with tvm.target.Target(target):
+        out = tvm.tir.build(before)
+
+    ll = out.get_source("ll")
+    assert "get.active.lane.mask" in ll
+    assert "llvm.masked.load" in ll
+    assert "llvm.masked.store" in ll
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py 
b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py
index 22344acfe1..f09f7417ba 100644
--- a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py
+++ b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py
@@ -816,7 +816,7 @@ def test_unsupported_target_scalable_split(capfd):
     warning_msg = (
         "Warning: The expression contains scalable values. An attempt to prove 
by substituting "
         "with known values of vscale was not performed. This proof currently 
only supports "
-        "AArch64 SVE targets, but the target was "
+        "VLA targets, but the target was "
     )
     captured = capfd.readouterr().err
     assert warning_msg in captured
diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py 
b/tests/python/tir-transform/test_tir_transform_vectorize.py
index 9b61255285..13bb1c60cb 100644
--- a/tests/python/tir-transform/test_tir_transform_vectorize.py
+++ b/tests/python/tir-transform/test_tir_transform_vectorize.py
@@ -670,7 +670,7 @@ def test_vectorize_and_predicate_invalid_conditions():
 
 
 def test_vectorize_with_explicitly_disabled_buffer_level_predication():
-    # Since the target has the SVE feature, buffer level predication is enabled
+    # Since the target has the VLA feature, buffer level predication is enabled
     # by default. However, it has been explicitly disabled by the pass context
     # option, so no buffer-level predicates should be added.
     @T.prim_func

Reply via email to