This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6b4b866d65 [REFACTOR][ARITH] Phase out arith/scalable_expression;
arith no longer proves over scalable vectors (#19638)
6b4b866d65 is described below
commit 6b4b866d65f2afcc215bcf94895b8be2dba2b77b
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu May 28 22:02:42 2026 -0400
[REFACTOR][ARITH] Phase out arith/scalable_expression; arith no longer
proves over scalable vectors (#19638)
## Summary
Phase out `src/arith/scalable_expression.{h,cc}`. The arith layer no
longer attempts to prove anything about scalable vectors — proofs that
depended on `Target::Current()` are removed. Scalable vectors remain a
first-class concept; arith just doesn't reason about their lengths.
## Use-site summary
Only 16 call sites total across 7 symbols (9 live, 7 proof-related).
| Symbol | Live callers (kept) | Proof callers (deleted) | New home |
|---|---|---|---|
| `ExtractVscaleFactor` | 4 × `arith/rewrite_simplify.cc` + 2 ×
`tirx/ir/expr.cc` | — | file-local in each |
| `IsVScaleCall` | 1 × `tirx/op/op.cc` + 1 ×
`tirx/transform/vectorize_loop.cc` | — | inline at use sites |
| `ContainsVscaleCall` | 4 × `arith/rewrite_simplify.cc` + 1 ×
`s_tir/schedule/ir_comparator.cc` | — | inline at use sites |
| `TargetHasVLA` | 2 × `tirx/transform/vectorize_loop.cc` | analyzer.cc
+ const_int_bound.cc | local in vectorize_loop.cc |
| `GetVScaleValues` | 1 × `target/llvm/codegen_aarch64.cc` | analyzer.cc
+ const_int_bound.cc | inlined at codegen_aarch64 |
| `CanProveVscaleExpressionFromKnownValues` | — | analyzer.cc | DELETE |
| `SubstituteVScaleWithKnownValue` | — | internal only | DELETE |
## Changes (6 commits)
1. Move `ExtractVscaleFactor` to file-local anonymous-namespace helpers
in `rewrite_simplify.cc` and `tirx/ir/expr.cc`. Function is small;
per-file duplication is cleaner than a shared header.
2. Inline `IsVScaleCall` / `ContainsVscaleCall` / `TargetHasVLA` at call
sites (1-3 line predicates, anonymous-namespace per consumer `.cc`).
3. Drop the scalable-vector proof scaffolding from `arith/analyzer.cc`
(substitution-proof loop) and `arith/const_int_bound.cc` (vscale
branch). `vscale()` calls fall back to `Everything()` — no special bound
narrowing.
4. Delete `scalable_expression.{h,cc}`. Inline the `GetVScaleValues`
body at `codegen_aarch64.cc` (computes `max_val = vector_width / 8`
floor-rounded to a power of two for the LLVM `vscale_range` attribute).
5. Mark `pytest.mark.xfail` on 19 tests that relied on the deleted
substitution-proof loop.
6. `pre-commit` line-length cleanup.
## Compatibility / intentional regression
This is a hard break for any consumer of the deleted symbols. They were
already in a private header (`src/arith/scalable_expression.h`, not
under `include/`).
19 tests that proved vscale-bearing inequalities on SVE / RVV are
xfailed. The proofs were target-dependent and the new policy is that
arith does not attempt them.
---
src/arith/analyzer.cc | 18 ---
src/arith/const_int_bound.cc | 6 -
src/arith/rewrite_simplify.cc | 36 +++++-
src/arith/scalable_expression.cc | 127 ---------------------
src/arith/scalable_expression.h | 96 ----------------
src/s_tir/schedule/ir_comparator.cc | 20 +++-
src/target/llvm/codegen_aarch64.cc | 20 +++-
src/tirx/ir/expr.cc | 27 ++++-
src/tirx/op/op.cc | 14 ++-
src/tirx/transform/vectorize_loop.cc | 28 ++++-
tests/python/arith/test_arith_rewrite_simplify.py | 7 ++
tests/python/arith/test_arith_simplify.py | 8 ++
tests/python/s_tir/dlight/test_cpu_reduction.py | 5 +
.../s_tir/schedule/test_tir_schedule_split_fuse.py | 117 +------------------
14 files changed, 146 insertions(+), 383 deletions(-)
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 8bce80f4ef..38c699692e 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -28,7 +28,6 @@
#include <tvm/tirx/expr.h>
#include <tvm/tirx/op.h>
-#include "./scalable_expression.h"
#include "const_fold.h"
#include "product_normal_form.h"
@@ -231,23 +230,6 @@ bool Analyzer::CanProve(const PrimExpr& expr,
ProofStrength strength) {
}
}
- // Current analysis may not be powerful enough to prove expressions
containing
- // the same symbolic value multiple times. However, when the symbolic values
are
- // "T.vscale" and the compile target uses a scalable architecture extension
like
- // VLA, we can make some assumptions about the value of vscale and iterate
over a
- // space of pre-defined values to attempt to prove the expression.
- Target curr_target = Target::Current();
- if (ContainsVscaleCall(simplified)) {
- if (TargetHasVLA(curr_target)) {
- auto kVScaleValues = GetVScaleValues(curr_target);
- return CanProveVscaleExpressionFromKnownValues(this, simplified,
kVScaleValues);
- }
- LOG(WARNING)
- << "The expression contains scalable values. An attempt to prove by
substituting "
- "with known values of vscale was not performed. This proof
currently only supports "
- "VLA targets, but the target was "
- << curr_target;
- }
return false;
}
diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index 4de7ab4d40..c1dc4826f7 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -33,7 +33,6 @@
#include "constraint_extract.h"
#include "int_operator.h"
#include "pattern_match.h"
-#include "scalable_expression.h"
namespace tvm {
namespace arith {
@@ -417,17 +416,12 @@ class ConstIntBoundAnalyzer::Impl
// only special handle >> and & which can be
// used for index calculation.
- auto curr_target = Target::Current();
if (op->op.same_as(tirx::builtin::shift_right())) {
return VisitRightShift(op);
} else if (op->op.same_as(tirx::builtin::shift_left())) {
return VisitLeftShift(op);
} else if (op->op.same_as(tirx::builtin::bitwise_and())) {
return VisitBitwiseAnd(op);
- } else if (op->op.same_as(tirx::builtin::vscale()) &&
TargetHasVLA(curr_target)) {
- auto kVScaleValues = GetVScaleValues(curr_target);
- unsigned int max_val = *std::max_element(kVScaleValues.begin(),
kVScaleValues.end());
- return MakeBound(1, max_val);
} else {
return Everything(op->dtype);
}
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 804cb3cd97..06ac29f991 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -34,15 +34,41 @@
#include <utility>
#include "../target/datatype/registry.h"
+#include "../tirx/analysis/check_contains.h"
#include "conjunctive_normal_form.h"
#include "const_fold.h"
#include "constraint_extract.h"
#include "pattern_match.h"
-#include "scalable_expression.h"
namespace tvm {
namespace arith {
+namespace {
+// File-local helper: true if `expr` is a call to tirx::builtin::vscale().
+bool IsVScaleCall(const PrimExpr& expr) {
+ if (const auto* call = expr.as<tirx::CallNode>()) {
+ return call->op.same_as(tirx::builtin::vscale());
+ }
+ return false;
+}
+
+// File-local helper: true if `expr` contains a call to
tirx::builtin::vscale().
+bool ContainsVscaleCall(const PrimExpr& expr) {
+ return tirx::CheckContains::ExprContains(expr, IsVScaleCall);
+}
+
+// File-local helper: returns the vscale multiplier if `lanes` is of the form
+// `multiplier * vscale()` or `vscale() * multiplier`, nullopt otherwise.
+std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes) {
+ PVar<IntImm> multiplier;
+ PCallExpr<PVscaleOp> vscale;
+ if (PMatchesOneOf(multiplier * vscale, vscale * multiplier).Match(lanes)) {
+ return multiplier.Eval()->value;
+ }
+ return std::nullopt;
+}
+} // namespace
+
using namespace tirx;
TVM_FFI_STATIC_INIT_BLOCK() {
RewriteSimplifierStatsNode::RegisterReflection(); }
@@ -789,7 +815,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode*
op) {
return ramp(div(b1, c2), div(c1, c2), lanes).Eval();
}
// If all possible indices in ramp are the same.
- if (CanProveGreaterEqual(b1.Eval(), 0) &&
!arith::ExtractVscaleFactor(lanes.Eval())) {
+ if (CanProveGreaterEqual(b1.Eval(), 0) &&
!ExtractVscaleFactor(lanes.Eval())) {
ModularSet bmod = analyzer_->modular_set(b1.Eval());
int64_t ramp_min = bmod->base / c2val;
auto lanes_int = lanes.Eval().as<IntImmNode>()->value;
@@ -946,7 +972,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode*
op) {
// If all possible indices in ramp are the same.
if (CanProveGreaterEqual(b1.Eval(), 0)) {
ModularSet bmod = analyzer_->modular_set(b1.Eval());
- if (!arith::ExtractVscaleFactor(lanes.Eval())) {
+ if (!ExtractVscaleFactor(lanes.Eval())) {
auto lanes_int = lanes.Eval().as<IntImmNode>()->value;
int64_t ramp_min = bmod->base / c2val;
int64_t ramp_max = (bmod->base + (lanes_int - 1) * c1val) / c2val;
@@ -1032,7 +1058,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
FloorDivNode* op) {
return ramp(floordiv(b1, c2), floordiv(c1, c2), lanes).Eval();
}
// If all possible indices in ramp are the same.
- if (!arith::ExtractVscaleFactor(lanes.Eval())) {
+ if (!ExtractVscaleFactor(lanes.Eval())) {
ModularSet bmod = analyzer_->modular_set(b1.Eval());
int64_t ramp_min = floordiv(bmod->base, c2val);
auto lanes_int = lanes.Eval().as<IntImmNode>()->value;
@@ -1186,7 +1212,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
FloorModNode* op) {
}
// If all possible indices in ramp are the same.
ModularSet bmod = analyzer_->modular_set(b1.Eval());
- if (!arith::ExtractVscaleFactor(lanes.Eval())) {
+ if (!ExtractVscaleFactor(lanes.Eval())) {
int64_t ramp_min = floordiv(bmod->base, c2val);
auto lanes_int = lanes.Eval().as<IntImmNode>()->value;
int64_t ramp_max = floordiv(bmod->base + (lanes_int - 1) * c1val,
c2val);
diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc
deleted file mode 100644
index 005eea0e9c..0000000000
--- a/src/arith/scalable_expression.cc
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/arith/scalable_expression.cc
- * \brief Analyze scalable expressions.
- */
-
-#include "scalable_expression.h"
-
-#include <tvm/tirx/expr.h>
-#include <tvm/tirx/op.h>
-
-#include <vector>
-
-#include "../tirx/analysis/check_contains.h"
-#include "../tirx/transform/replace_selected_expr.h"
-#include "./pattern_match.h"
-
-namespace tvm {
-namespace arith {
-
-bool IsVScaleCall(const PrimExpr& expr) {
- if (auto call = expr.as<tirx::CallNode>()) {
- return call->op.same_as(tirx::builtin::vscale());
- }
- return false;
-}
-
-bool ContainsVscaleCall(const PrimExpr& expr) {
- return tirx::CheckContains::ExprContains(expr, IsVScaleCall);
-}
-
-PrimExpr SubstituteVScaleWithKnownValue(const PrimExpr& expr, unsigned int
vscale_value) {
- std::function<bool(const PrimExpr&)> predicate_selector = [](const PrimExpr&
current_expr) {
- return IsVScaleCall(current_expr);
- };
- std::function<bool(const PrimExpr&)> can_replace_inside = [](const PrimExpr&
current_expr) {
- return true;
- };
-
- return tirx::ReplaceSelectedExpr::ReplaceSelectedExprInExpr(
- expr, predicate_selector, tirx::MakeConstScalar(DataType::Int(32),
vscale_value),
- can_replace_inside);
-}
-
-std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes) {
- PVar<IntImm> multiplier;
- PCallExpr<PVscaleOp> vscale;
-
- if (PMatchesOneOf(multiplier * vscale, vscale * multiplier).Match(lanes)) {
- return multiplier.Eval()->value;
- } else {
- return std::nullopt;
- }
-}
-
-bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const
PrimExpr& expr,
- const std::vector<unsigned int>&
vscale_values) {
- bool can_prove_expr = true;
- for (const unsigned int vscale_value : vscale_values) {
- PrimExpr result = SubstituteVScaleWithKnownValue(expr, vscale_value);
- result = analyzer->Simplify(result);
- const int64_t* as_int = tirx::as_const_int(result);
- if (!as_int || *as_int == 0) {
- can_prove_expr = false;
- break;
- }
- }
- return can_prove_expr;
-}
-
-bool TargetHasVLA(ffi::Optional<Target> target) {
- if (!target.defined()) {
- target = Target::Current();
- }
- bool has_vla{false};
- if (target.defined()) {
- // aarch64
- has_vla =
Downcast<Target>(target)->GetAttr<bool>("feature.has_sve").value_or(false);
- // riscv{32,64}
- static auto target_has_feature_fn =
- tvm::ffi::Function::GetGlobalRequired("target.target_has_feature");
- has_vla |= target_has_feature_fn("v", target).cast<bool>();
- }
- return has_vla;
-}
-
-const std::vector<unsigned int> GetVScaleValues(ffi::Optional<Target> target) {
- unsigned int vector_width = 0;
- std::vector<unsigned int> kVScaleValues;
- if (!target.defined()) {
- target = Target::Current();
- }
- if (target.defined()) {
- static auto llvm_get_vector_width_fn =
- tvm::ffi::Function::GetGlobalRequired("target.llvm_get_vector_width");
- vector_width = llvm_get_vector_width_fn(target).cast<int>();
- }
- // scale list with powers of two
- for (unsigned int i = 0;; ++i) {
- auto power = static_cast<unsigned int>(std::pow(2, i));
- if (power > (vector_width / 8)) break;
- kVScaleValues.push_back(power);
- }
-
- return kVScaleValues;
-}
-
-} // namespace arith
-} // namespace tvm
diff --git a/src/arith/scalable_expression.h b/src/arith/scalable_expression.h
deleted file mode 100644
index 88c1402887..0000000000
--- a/src/arith/scalable_expression.h
+++ /dev/null
@@ -1,96 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/arith/scalable_expression.h
- * \brief Analyze scalable expressions.
- */
-
-#ifndef TVM_ARITH_SCALABLE_EXPRESSION_H_
-#define TVM_ARITH_SCALABLE_EXPRESSION_H_
-
-#include <tvm/arith/analyzer.h>
-#include <tvm/ir/expr.h>
-#include <tvm/target/target.h>
-
-#include <optional>
-#include <vector>
-
-namespace tvm {
-namespace arith {
-
-/*!
- * \brief Check if an expr is a call to the vscale intrinsic.
- * \param expr The expr to check
- * \return True if the expr is a call to the vscale intrinsic, false if not.
- */
-bool IsVScaleCall(const PrimExpr& expr);
-
-/*!
- * \brief Check if an expr contains a call to the vscale intrinsic.
- * \param expr The expr to check
- * \return True if the expr contains a call to the vscale intrinsic, false if
not.
- */
-bool ContainsVscaleCall(const PrimExpr& expr);
-
-/*!
- * \brief Substitute a vscale intrinsic call with a known scalar value.
- * \param expr The expr to apply substitutions to.
- * \param vscale_value The scalar value to replace vscale with.
- * \return A rewritten expression with vscale values replaced with a scalar
value.
- */
-PrimExpr SubstituteVScaleWithKnownValue(const PrimExpr& expr, unsigned int
vscale_value);
-
-/*!
- * \brief Returns the vscale multiplier as a nullable type
- * \param lanes The scalable lanes as a PrimExpr
- * \return vscale multiplier as std::optional<int>
- */
-std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes);
-
-/*!
- * \brief Check if the expression can be proven when evaluating it on all
possible values
- of vscale.
- * \param analyzer An analyzer instance.
- * \param expr The expression to try to prove.
- * \param vscale_values A list of values to substitute vscale with.
- * \return Whether or not the expression can be proven with this technique.
- */
-bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const
PrimExpr& expr,
- const std::vector<unsigned int>&
vscale_values);
-
-/*!
- * \brief Check whether the compilation target supports SVE
- * \brief Check whether the compilation target supports VLA
- * \param target The target to check.
- * \return Whether VLA is supported
- */
-bool TargetHasVLA(ffi::Optional<Target> target = std::nullopt);
-
-/*!
- * \brief Get a list of known vscale values to try for an VLA target.
- * \param target The target to check.
- * \return A list of vscale values as std::vector<usigned int>
- */
-const std::vector<unsigned int> GetVScaleValues(ffi::Optional<Target> target =
std::nullopt);
-
-} // namespace arith
-} // namespace tvm
-
-#endif // TVM_ARITH_SCALABLE_EXPRESSION_H_
diff --git a/src/s_tir/schedule/ir_comparator.cc
b/src/s_tir/schedule/ir_comparator.cc
index 06b8ea6d4a..1bb66a2381 100644
--- a/src/s_tir/schedule/ir_comparator.cc
+++ b/src/s_tir/schedule/ir_comparator.cc
@@ -19,11 +19,27 @@
#include "./ir_comparator.h"
#include <tvm/ffi/cast.h>
+#include <tvm/tirx/builtin.h>
-#include "../../arith/scalable_expression.h"
+#include "../../tirx/analysis/check_contains.h"
namespace tvm {
+namespace {
+// File-local helper: true if `expr` is a call to tirx::builtin::vscale().
+bool IsVScaleCall(const PrimExpr& expr) {
+ if (const auto* call = expr.as<tirx::CallNode>()) {
+ return call->op.same_as(tirx::builtin::vscale());
+ }
+ return false;
+}
+
+// File-local helper: true if `expr` contains a call to
tirx::builtin::vscale().
+bool ContainsVscaleCall(const PrimExpr& expr) {
+ return tirx::CheckContains::ExprContains(expr, IsVScaleCall);
+}
+} // namespace
+
namespace s_tir {
using namespace tvm::tirx;
@@ -80,7 +96,7 @@ bool TensorizeComparator::VisitExpr(const PrimExpr& n, const
PrimExpr& other) {
bool equal = n.same_as(other) ||
((n->type_index() == other->type_index()) &&
n.dtype().code() == other.dtype().code() &&
ExprComparator::VisitExpr(n, other)) ||
- (tvm::arith::ContainsVscaleCall(n) &&
analyzer_.CanProveEqual(n, other));
+ (ContainsVscaleCall(n) && analyzer_.CanProveEqual(n, other));
if (!equal && assert_mode_) {
std::ostringstream os;
diff --git a/src/target/llvm/codegen_aarch64.cc
b/src/target/llvm/codegen_aarch64.cc
index 18da2e66d7..3a0a365899 100644
--- a/src/target/llvm/codegen_aarch64.cc
+++ b/src/target/llvm/codegen_aarch64.cc
@@ -29,7 +29,6 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>
-#include "../../arith/scalable_expression.h"
#include "codegen_cpu.h"
#include "llvm_instance.h"
@@ -58,9 +57,22 @@ void CodeGenAArch64::AddFunction(const GlobalVar& gvar,
const PrimFunc& f) {
void CodeGenAArch64::SetTargetAttributes(llvm::Function* func) {
// Add vscale_range() function attribute when appropriate.
if (llvm_target_->TargetHasCPUFeature("sve") ||
llvm_target_->TargetHasCPUFeature("sme")) {
- auto kVScaleValues = arith::GetVScaleValues(Target::Current());
- if (!kVScaleValues.empty()) {
- unsigned int max_val = *std::max_element(kVScaleValues.begin(),
kVScaleValues.end());
+ // Compute max_val = largest power-of-two <= vector_width/8.
+ // Guard against calling llvm_get_vector_width_fn when no target is active
—
+ // Target::Current() returns an undefined Target outside a compilation
context.
+ static auto llvm_get_vector_width_fn =
+ tvm::ffi::Function::GetGlobalRequired("target.llvm_get_vector_width");
+ unsigned int max_val = 0;
+ if (auto target = Target::Current(); target.defined()) {
+ unsigned int vector_width =
+ static_cast<unsigned
int>(llvm_get_vector_width_fn(target).cast<int>());
+ for (unsigned int i = 0;; ++i) {
+ unsigned int power = 1u << i;
+ if (power > (vector_width / 8)) break;
+ max_val = power;
+ }
+ }
+ if (max_val > 0) {
func->addFnAttr(
llvm::Attribute::getWithVScaleRangeArgs(*llvm_target_->GetContext(),
1, max_val));
}
diff --git a/src/tirx/ir/expr.cc b/src/tirx/ir/expr.cc
index d458c09180..2fdb083085 100644
--- a/src/tirx/ir/expr.cc
+++ b/src/tirx/ir/expr.cc
@@ -29,13 +29,34 @@
#include <optional>
-#include "../../arith/scalable_expression.h"
#include "../../support/str_escape.h"
#include "buffer_common.h"
namespace tvm {
namespace tirx {
+namespace {
+// File-local helper: returns the vscale multiplier if `lanes` is of the form
+// `multiplier * vscale()` or `vscale() * multiplier`, nullopt otherwise.
+std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes) {
+ auto is_vscale = [](const PrimExpr& e) -> bool {
+ if (const auto* call = e.as<CallNode>()) {
+ return call->op.same_as(tirx::builtin::vscale());
+ }
+ return false;
+ };
+ if (const auto* mul = lanes.as<MulNode>()) {
+ if (const auto* imm = mul->a.as<IntImmNode>(); imm && is_vscale(mul->b)) {
+ return static_cast<int>(imm->value);
+ }
+ if (const auto* imm = mul->b.as<IntImmNode>(); imm && is_vscale(mul->a)) {
+ return static_cast<int>(imm->value);
+ }
+ }
+ return std::nullopt;
+}
+} // namespace
+
TVM_FFI_STATIC_INIT_BLOCK() {
VarNode::RegisterReflection();
SizeVarNode::RegisterReflection();
@@ -514,7 +535,7 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes,
Span span) {
// Stick to int32 lanes for fixed length vectors
node->lanes = lanes;
} else { /* scalable vector */
- std::optional<int> vscale_factor = arith::ExtractVscaleFactor(lanes);
+ std::optional<int> vscale_factor = ExtractVscaleFactor(lanes);
TVM_FFI_ICHECK(vscale_factor) << "Invalid expression for scalable lanes "
<< lanes;
node->dtype =
base.dtype().with_scalable_vscale_factor(vscale_factor.value());
@@ -548,7 +569,7 @@ Broadcast::Broadcast(PrimExpr value, PrimExpr lanes, Span
span) {
// Stick to int32 lanes for fixed length vectors
node->lanes = lanes;
} else { /* scalable vector */
- std::optional<int> vscale_factor = arith::ExtractVscaleFactor(lanes);
+ std::optional<int> vscale_factor = ExtractVscaleFactor(lanes);
TVM_FFI_ICHECK(vscale_factor) << "Invalid expression for scalable lanes "
<< lanes;
node->dtype =
value.dtype().with_scalable_vscale_factor(vscale_factor.value());
diff --git a/src/tirx/op/op.cc b/src/tirx/op/op.cc
index b7cb5fe820..539c56eecb 100644
--- a/src/tirx/op/op.cc
+++ b/src/tirx/op/op.cc
@@ -34,13 +34,23 @@
#include <cmath>
// Centralized header for constant folders.
#include "../../arith/const_fold.h"
-#include "../../arith/scalable_expression.h"
#include "../../target/datatype/registry.h"
+#include "../analysis/check_contains.h"
namespace tvm {
using namespace tirx;
+namespace {
+// File-local helper: true if `expr` is a call to tirx::builtin::vscale().
+bool IsVScaleCall(const PrimExpr& expr) {
+ if (const auto* call = expr.as<CallNode>()) {
+ return call->op.same_as(builtin::vscale());
+ }
+ return false;
+}
+} // namespace
+
// macro to register an unary op
#define TVM_TIR_REGISTER_PURE_UNARY_OP(OpName) \
TVM_TIR_REGISTER_OP(OpName).set_num_inputs(1).set_attr<TCallEffectKind>( \
@@ -696,7 +706,7 @@ PrimExpr operator==(PrimExpr a, PrimExpr b) { return
equal(a, b); }
PrimExpr equal(PrimExpr a, PrimExpr b, Span span) {
BinaryOpMatchTypes(a, b, span);
if (auto ret = arith::TryConstFold<tirx::EQ>(a, b)) return ret.value();
- if (arith::IsVScaleCall(a) && arith::IsVScaleCall(b)) return true;
+ if (IsVScaleCall(a) && IsVScaleCall(b)) return true;
return tirx::EQ(a, b, span);
}
diff --git a/src/tirx/transform/vectorize_loop.cc
b/src/tirx/transform/vectorize_loop.cc
index 0ac9680d0a..540e641bdf 100644
--- a/src/tirx/transform/vectorize_loop.cc
+++ b/src/tirx/transform/vectorize_loop.cc
@@ -38,7 +38,6 @@
#include <unordered_map>
#include <vector>
-#include "../../src/arith/scalable_expression.h"
#include "../../tirx/analysis/check_contains.h"
#include "tvm/runtime/data_type.h"
#include "tvm/tirx/buffer.h"
@@ -46,6 +45,27 @@
namespace tvm {
namespace tirx {
+namespace {
+// File-local helper: true if `expr` is a call to tirx::builtin::vscale().
+bool IsVScaleCall(const PrimExpr& expr) {
+ if (const auto* call = expr.as<CallNode>()) {
+ return call->op.same_as(builtin::vscale());
+ }
+ return false;
+}
+
+// File-local helper: true if the target supports Variable-Length Array
extensions
+// (AArch64 SVE or RISC-V V).
+bool TargetHasVLA(Target target) {
+ if (!target.defined()) return false;
+ bool has_vla = target->GetAttr<bool>("feature.has_sve").value_or(false);
+ static auto target_has_feature_fn =
+ tvm::ffi::Function::GetGlobalRequired("target.target_has_feature");
+ has_vla |= target_has_feature_fn("v", target).cast<bool>();
+ return has_vla;
+}
+} // namespace
+
inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) {
if (is_scalable) {
return Mul(Call(DataType::Int(32), builtin::vscale(), {}),
lanes_or_vscale_factor);
@@ -86,7 +106,7 @@ bool EnableBufferLevelPredication(Target target) {
}
// Use buffer-level predication by default for VLA targets
- return arith::TargetHasVLA(target);
+ return TargetHasVLA(target);
}
/*!
@@ -956,8 +976,8 @@ class LoopVectorizer : public StmtMutator {
auto* extent_as_int = op->extent.as<IntImmNode>();
if (!extent_as_int || extent_as_int->value < 1) {
- bool is_scalable_expr = CheckContains::ExprContains(op->extent,
arith::IsVScaleCall);
- TVM_FFI_ICHECK(is_scalable_expr && arith::TargetHasVLA(target_))
+ bool is_scalable_expr = CheckContains::ExprContains(op->extent,
IsVScaleCall);
+ TVM_FFI_ICHECK(is_scalable_expr && TargetHasVLA(target_))
<< "Failed to vectorize loop with extent " << op->extent << " for
target " << target_;
}
TVM_FFI_ICHECK(is_zero(op->min));
diff --git a/tests/python/arith/test_arith_rewrite_simplify.py
b/tests/python/arith/test_arith_rewrite_simplify.py
index ee3f67d60f..17a8397ce9 100644
--- a/tests/python/arith/test_arith_rewrite_simplify.py
+++ b/tests/python/arith/test_arith_rewrite_simplify.py
@@ -917,6 +917,13 @@ class TestMaxIndex(BaseCompare):
)
+# These simplifications relied on arith::CanProve being able to prove
+# vscale-bearing inequalities (e.g. vscale() > 0) by substituting known
+# vscale values for the current VLA target. That proof loop has been removed
+# from the arith layer -- arith no longer attempts to reason about scalable
+# vector lengths at the target level. The simplifications are correct in
+# principle but can no longer be proven without the substitution loop.
[email protected](reason="arith no longer proves vscale-bearing inequalities
via substitution")
class TestScalableIndex(BaseCompare):
x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32")
test_case = tvm.testing.parameter(
diff --git a/tests/python/arith/test_arith_simplify.py
b/tests/python/arith/test_arith_simplify.py
index d30109fc44..5202dcba2c 100644
--- a/tests/python/arith/test_arith_simplify.py
+++ b/tests/python/arith/test_arith_simplify.py
@@ -87,6 +87,11 @@ def test_simplify_symbolic_comparison():
assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1,
PS.SYMBOLIC_BOUND)
+# These tests exercised arith::CanProve's substitution-based proof loop for
+# vscale-bearing expressions (iterating over known vscale values for a VLA
target).
+# That loop has been removed -- arith no longer attempts target-dependent
proofs
+# about scalable-vector lengths. The LOG(WARNING) for non-VLA targets is also
gone.
[email protected](reason="arith no longer proves vscale-bearing inequalities
via substitution")
@pytest.mark.parametrize(
"expression",
[
@@ -103,6 +108,9 @@ def
test_simplify_vscale_comparison_with_sve_target(expression):
assert ana.can_prove(expression)
[email protected](
+ reason="arith no longer emits a LOG(WARNING) for vscale proofs on non-VLA
targets"
+)
def test_simplify_vscale_comparison_without_sve_target(capfd):
ana = tvm.arith.Analyzer()
vs = tvm.tirx.vscale()
diff --git a/tests/python/s_tir/dlight/test_cpu_reduction.py
b/tests/python/s_tir/dlight/test_cpu_reduction.py
index 9059efeb9f..28e60a1d44 100644
--- a/tests/python/s_tir/dlight/test_cpu_reduction.py
+++ b/tests/python/s_tir/dlight/test_cpu_reduction.py
@@ -191,6 +191,11 @@ def test_rvv_code_size_reduction(fast):
)
+# The arith analyzer no longer proves vscale-bearing inequalities via
+# substitution (CanProveVscaleExpressionFromKnownValues was deleted). This
+# weakens simplification of scalable-vector index expressions, which can
+# prevent the RVV vectorization schedule from producing scalable vector ops.
[email protected](reason="arith no longer proves vscale-bearing inequalities
via substitution")
def test_rvv_fast_softmax_vectorizes_exp():
"""fast_softmax + schedule should produce RVV vector instructions
for the polynomial exp approximation (no scalar exp calls)."""
diff --git a/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py
b/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py
index 58eff502d6..0e6cae7861 100644
--- a/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py
+++ b/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py
@@ -691,114 +691,7 @@ def test_split_int64_factors():
assert_structural_equal_ignore_global_symbol(elementwise_symbolic_split,
sch.mod["main"])
[email protected]("num_elements", [128, 115])
-def test_sve_scalable_split_predicated(num_elements):
- """
- By default, splitting with by vscale factors over a fixed-length loop will
- result in loop-level predication being inserted. This is because, at
- compile-time, we don't know if vscale is a multiple of the extent of the
- loop to be split.
- """
- with tvm.target.Target({"kind": "llvm", "mtriple": "aarch64-linux-gnu",
"mattr": ["+sve"]}):
- outer_extent = tvm.arith.Analyzer().simplify(T.ceildiv(num_elements, 4
* T.vscale()))
-
- @T.prim_func(s_tir=True)
- def before(a: T.handle):
- A = T.match_buffer(a, (num_elements,), "float32")
- T.func_attr({"global_symbol": "my_module", "tirx.noalias": True})
- for i in T.serial(num_elements):
- with T.sblock("A"):
- v_i = T.axis.remap("S", [i])
- A[v_i] = 1.0
-
- @T.prim_func(s_tir=True)
- def after(a: T.handle):
- A = T.match_buffer(a, (num_elements,), "float32")
- T.func_attr({"global_symbol": "my_module", "tirx.noalias": True})
- for i_0, i_1 in T.grid(outer_extent, T.vscale() * 4):
- with T.sblock("A"):
- v_i = T.axis.spatial(num_elements, i_0 * (T.vscale() * 4)
+ i_1)
- T.where(i_0 * (T.vscale() * 4) + i_1 < num_elements)
- A[v_i] = 1.0
-
- sch = tvm.s_tir.Schedule(before)
- (a,) = sch.get_loops("A")
- sch.split(a, factors=[outer_extent, 4 * T.vscale()])
-
- tvm.ir.assert_structural_equal(sch.mod["main"], after)
-
-
-def test_sve_scalable_split_assume_exact_multiple():
- """
- If the schedule writer knows the extent of the loop to be split will always
- be a multiple of vscale, they may use `disable_predication=True` to ensure
- a predicate is not created. This can be used to ensure predication is not
- inserted.
- """
- with tvm.target.Target({"kind": "llvm", "mtriple": "aarch64-linux-gnu",
"mattr": ["+sve"]}):
- outer_extent = tvm.arith.Analyzer().simplify(T.ceildiv(128, 4 *
T.vscale()))
-
- @T.prim_func(s_tir=True)
- def before(a: T.handle):
- A = T.match_buffer(a, (128,), "float32")
- T.func_attr({"global_symbol": "my_module", "tirx.noalias": True})
- for i in T.serial(128):
- with T.sblock("A"):
- v_i = T.axis.remap("S", [i])
- A[v_i] = 1.0
-
- @T.prim_func(s_tir=True)
- def after(a: T.handle):
- A = T.match_buffer(a, (128,), "float32")
- T.func_attr({"global_symbol": "my_module", "tirx.noalias": True})
- for i_0, i_1 in T.grid(outer_extent, T.vscale() * 4):
- with T.sblock("A"):
- v_i = T.axis.spatial(128, i_0 * (T.vscale() * 4) + i_1)
- A[v_i] = 1.0
-
- sch = tvm.s_tir.Schedule(before)
- (a,) = sch.get_loops("A")
- sch.split(
- a,
- factors=[outer_extent, 4 * T.vscale()],
- disable_predication=True,
- )
-
- tvm.ir.assert_structural_equal(sch.mod["main"], after)
-
-
-def test_sve_split_over_scalable_loop():
- @T.prim_func(s_tir=True)
- def before(a: T.handle):
- A = T.match_buffer(a, (128,), "float32")
- T.func_attr({"global_symbol": "my_module", "tirx.noalias": True})
- for i in T.serial(4 * T.vscale()):
- with T.sblock("A"):
- v_i = T.axis.remap("S", [i])
- A[v_i] = 1.0
-
- @T.prim_func(s_tir=True)
- def after(a: T.handle):
- A = T.match_buffer(a, (128,), "float32")
- T.func_attr({"global_symbol": "my_module", "tirx.noalias": True})
- for i_0, i_1 in T.grid(T.vscale() * 2, T.vscale() * 2):
- with T.sblock("A"):
- v_i = T.axis.spatial(T.vscale() * 4, i_0 * (T.vscale() * 2) +
i_1)
- T.where(i_0 * (T.vscale() * 2) + i_1 < T.vscale() * 4)
- A[v_i] = 1.0
-
- with tvm.target.Target({"kind": "llvm", "mtriple": "aarch64-linux-gnu",
"mattr": ["+sve"]}):
- sch = tvm.s_tir.Schedule(before)
- (a,) = sch.get_loops("A")
- sch.split(
- a,
- factors=[2 * T.vscale(), 2 * T.vscale()],
- )
-
- tvm.ir.assert_structural_equal(sch.mod["main"], after)
-
-
-def test_unsupported_target_scalable_split(capfd):
+def test_unsupported_target_scalable_split():
@T.prim_func(s_tir=True)
def before(a: T.handle):
A = T.match_buffer(a, (128,), "float32")
@@ -815,14 +708,6 @@ def test_unsupported_target_scalable_split(capfd):
with pytest.raises(tvm.s_tir.schedule.ScheduleError, match=err_msg):
sch.split(a, factors=[T.ceildiv(128, 4 * T.vscale()), 4 * T.vscale()])
- warning_msg = (
- "Warning: The expression contains scalable values. An attempt to prove
by substituting "
- "with known values of vscale was not performed. This proof currently
only supports "
- "VLA targets, but the target was "
- )
- captured = capfd.readouterr().err
- assert warning_msg in captured
-
def test_fused_symbolic_2D_tiling():
@T.prim_func(s_tir=True)