This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6b4b866d65 [REFACTOR][ARITH] Phase out arith/scalable_expression; 
arith no longer proves over scalable vectors (#19638)
6b4b866d65 is described below

commit 6b4b866d65f2afcc215bcf94895b8be2dba2b77b
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu May 28 22:02:42 2026 -0400

    [REFACTOR][ARITH] Phase out arith/scalable_expression; arith no longer 
proves over scalable vectors (#19638)
    
    ## Summary
    
    Phase out `src/arith/scalable_expression.{h,cc}`. The arith layer no
    longer attempts to prove anything about scalable vectors — proofs that
    depended on `Target::Current()` are removed. Scalable vectors remain a
    first-class concept; arith just doesn't reason about their lengths.
    
    ## Use-site summary
    
    Only 16 call sites total across 7 symbols (9 live, 7 proof-related).
    
    | Symbol | Live callers (kept) | Proof callers (deleted) | New home |
    |---|---|---|---|
    | `ExtractVscaleFactor` | 4 × `arith/rewrite_simplify.cc` + 2 ×
    `tirx/ir/expr.cc` | — | file-local in each |
    | `IsVScaleCall` | 1 × `tirx/op/op.cc` + 1 ×
    `tirx/transform/vectorize_loop.cc` | — | inline at use sites |
    | `ContainsVscaleCall` | 4 × `arith/rewrite_simplify.cc` + 1 ×
    `s_tir/schedule/ir_comparator.cc` | — | inline at use sites |
    | `TargetHasVLA` | 2 × `tirx/transform/vectorize_loop.cc` | analyzer.cc
    + const_int_bound.cc | local in vectorize_loop.cc |
    | `GetVScaleValues` | 1 × `target/llvm/codegen_aarch64.cc` | analyzer.cc
    + const_int_bound.cc | inlined at codegen_aarch64 |
    | `CanProveVscaleExpressionFromKnownValues` | — | analyzer.cc | DELETE |
    | `SubstituteVScaleWithKnownValue` | — | internal only | DELETE |
    
    ## Changes (6 commits)
    
    1. Move `ExtractVscaleFactor` to file-local anonymous-namespace helpers
    in `rewrite_simplify.cc` and `tirx/ir/expr.cc`. Function is small;
    per-file duplication is cleaner than a shared header.
    2. Inline `IsVScaleCall` / `ContainsVscaleCall` / `TargetHasVLA` at call
    sites (1-3 line predicates, anonymous-namespace per consumer `.cc`).
    3. Drop the scalable-vector proof scaffolding from `arith/analyzer.cc`
    (substitution-proof loop) and `arith/const_int_bound.cc` (vscale
    branch). `vscale()` calls fall back to `Everything()` — no special bound
    narrowing.
    4. Delete `scalable_expression.{h,cc}`. Inline the `GetVScaleValues`
    body at `codegen_aarch64.cc` (computes `max_val = vector_width / 8`
    floor-rounded to a power of two for the LLVM `vscale_range` attribute).
    5. Mark `pytest.mark.xfail` on 19 tests that relied on the deleted
    substitution-proof loop.
    6. `pre-commit` line-length cleanup.
    
    ## Compatibility / intentional regression
    
    This is a hard break for any consumer of the deleted symbols. They were
    already in a private header (`src/arith/scalable_expression.h`, not
    under `include/`).
    
    19 tests that proved vscale-bearing inequalities on SVE / RVV are
    xfailed. The proofs were target-dependent and the new policy is that
    arith does not attempt them.
---
 src/arith/analyzer.cc                              |  18 ---
 src/arith/const_int_bound.cc                       |   6 -
 src/arith/rewrite_simplify.cc                      |  36 +++++-
 src/arith/scalable_expression.cc                   | 127 ---------------------
 src/arith/scalable_expression.h                    |  96 ----------------
 src/s_tir/schedule/ir_comparator.cc                |  20 +++-
 src/target/llvm/codegen_aarch64.cc                 |  20 +++-
 src/tirx/ir/expr.cc                                |  27 ++++-
 src/tirx/op/op.cc                                  |  14 ++-
 src/tirx/transform/vectorize_loop.cc               |  28 ++++-
 tests/python/arith/test_arith_rewrite_simplify.py  |   7 ++
 tests/python/arith/test_arith_simplify.py          |   8 ++
 tests/python/s_tir/dlight/test_cpu_reduction.py    |   5 +
 .../s_tir/schedule/test_tir_schedule_split_fuse.py | 117 +------------------
 14 files changed, 146 insertions(+), 383 deletions(-)

diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index 8bce80f4ef..38c699692e 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -28,7 +28,6 @@
 #include <tvm/tirx/expr.h>
 #include <tvm/tirx/op.h>
 
-#include "./scalable_expression.h"
 #include "const_fold.h"
 #include "product_normal_form.h"
 
@@ -231,23 +230,6 @@ bool Analyzer::CanProve(const PrimExpr& expr, 
ProofStrength strength) {
     }
   }
 
-  // Current analysis may not be powerful enough to prove expressions 
containing
-  // the same symbolic value multiple times. However, when the symbolic values 
are
-  // "T.vscale" and the compile target uses a scalable architecture extension 
like
-  // VLA, we can make some assumptions about the value of vscale and iterate 
over a
-  // space of pre-defined values to attempt to prove the expression.
-  Target curr_target = Target::Current();
-  if (ContainsVscaleCall(simplified)) {
-    if (TargetHasVLA(curr_target)) {
-      auto kVScaleValues = GetVScaleValues(curr_target);
-      return CanProveVscaleExpressionFromKnownValues(this, simplified, 
kVScaleValues);
-    }
-    LOG(WARNING)
-        << "The expression contains scalable values. An attempt to prove by 
substituting "
-           "with known values of vscale was not performed. This proof 
currently only supports "
-           "VLA targets, but the target was "
-        << curr_target;
-  }
   return false;
 }
 
diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index 4de7ab4d40..c1dc4826f7 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -33,7 +33,6 @@
 #include "constraint_extract.h"
 #include "int_operator.h"
 #include "pattern_match.h"
-#include "scalable_expression.h"
 
 namespace tvm {
 namespace arith {
@@ -417,17 +416,12 @@ class ConstIntBoundAnalyzer::Impl
     // only special handle >> and & which can be
     // used for index calculation.
 
-    auto curr_target = Target::Current();
     if (op->op.same_as(tirx::builtin::shift_right())) {
       return VisitRightShift(op);
     } else if (op->op.same_as(tirx::builtin::shift_left())) {
       return VisitLeftShift(op);
     } else if (op->op.same_as(tirx::builtin::bitwise_and())) {
       return VisitBitwiseAnd(op);
-    } else if (op->op.same_as(tirx::builtin::vscale()) && 
TargetHasVLA(curr_target)) {
-      auto kVScaleValues = GetVScaleValues(curr_target);
-      unsigned int max_val = *std::max_element(kVScaleValues.begin(), 
kVScaleValues.end());
-      return MakeBound(1, max_val);
     } else {
       return Everything(op->dtype);
     }
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 804cb3cd97..06ac29f991 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -34,15 +34,41 @@
 #include <utility>
 
 #include "../target/datatype/registry.h"
+#include "../tirx/analysis/check_contains.h"
 #include "conjunctive_normal_form.h"
 #include "const_fold.h"
 #include "constraint_extract.h"
 #include "pattern_match.h"
-#include "scalable_expression.h"
 
 namespace tvm {
 namespace arith {
 
+namespace {
+// File-local helper: true if `expr` is a call to tirx::builtin::vscale().
+bool IsVScaleCall(const PrimExpr& expr) {
+  if (const auto* call = expr.as<tirx::CallNode>()) {
+    return call->op.same_as(tirx::builtin::vscale());
+  }
+  return false;
+}
+
+// File-local helper: true if `expr` contains a call to 
tirx::builtin::vscale().
+bool ContainsVscaleCall(const PrimExpr& expr) {
+  return tirx::CheckContains::ExprContains(expr, IsVScaleCall);
+}
+
+// File-local helper: returns the vscale multiplier if `lanes` is of the form
+// `multiplier * vscale()` or `vscale() * multiplier`, nullopt otherwise.
+std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes) {
+  PVar<IntImm> multiplier;
+  PCallExpr<PVscaleOp> vscale;
+  if (PMatchesOneOf(multiplier * vscale, vscale * multiplier).Match(lanes)) {
+    return multiplier.Eval()->value;
+  }
+  return std::nullopt;
+}
+}  // namespace
+
 using namespace tirx;
 
 TVM_FFI_STATIC_INIT_BLOCK() { 
RewriteSimplifierStatsNode::RegisterReflection(); }
@@ -789,7 +815,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* 
op) {
         return ramp(div(b1, c2), div(c1, c2), lanes).Eval();
       }
       // If all possible indices in ramp are the same.
-      if (CanProveGreaterEqual(b1.Eval(), 0) && 
!arith::ExtractVscaleFactor(lanes.Eval())) {
+      if (CanProveGreaterEqual(b1.Eval(), 0) && 
!ExtractVscaleFactor(lanes.Eval())) {
         ModularSet bmod = analyzer_->modular_set(b1.Eval());
         int64_t ramp_min = bmod->base / c2val;
         auto lanes_int = lanes.Eval().as<IntImmNode>()->value;
@@ -946,7 +972,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* 
op) {
       // If all possible indices in ramp are the same.
       if (CanProveGreaterEqual(b1.Eval(), 0)) {
         ModularSet bmod = analyzer_->modular_set(b1.Eval());
-        if (!arith::ExtractVscaleFactor(lanes.Eval())) {
+        if (!ExtractVscaleFactor(lanes.Eval())) {
           auto lanes_int = lanes.Eval().as<IntImmNode>()->value;
           int64_t ramp_min = bmod->base / c2val;
           int64_t ramp_max = (bmod->base + (lanes_int - 1) * c1val) / c2val;
@@ -1032,7 +1058,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
FloorDivNode* op) {
         return ramp(floordiv(b1, c2), floordiv(c1, c2), lanes).Eval();
       }
       // If all possible indices in ramp are the same.
-      if (!arith::ExtractVscaleFactor(lanes.Eval())) {
+      if (!ExtractVscaleFactor(lanes.Eval())) {
         ModularSet bmod = analyzer_->modular_set(b1.Eval());
         int64_t ramp_min = floordiv(bmod->base, c2val);
         auto lanes_int = lanes.Eval().as<IntImmNode>()->value;
@@ -1186,7 +1212,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
FloorModNode* op) {
       }
       // If all possible indices in ramp are the same.
       ModularSet bmod = analyzer_->modular_set(b1.Eval());
-      if (!arith::ExtractVscaleFactor(lanes.Eval())) {
+      if (!ExtractVscaleFactor(lanes.Eval())) {
         int64_t ramp_min = floordiv(bmod->base, c2val);
         auto lanes_int = lanes.Eval().as<IntImmNode>()->value;
         int64_t ramp_max = floordiv(bmod->base + (lanes_int - 1) * c1val, 
c2val);
diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc
deleted file mode 100644
index 005eea0e9c..0000000000
--- a/src/arith/scalable_expression.cc
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/arith/scalable_expression.cc
- * \brief Analyze scalable expressions.
- */
-
-#include "scalable_expression.h"
-
-#include <tvm/tirx/expr.h>
-#include <tvm/tirx/op.h>
-
-#include <vector>
-
-#include "../tirx/analysis/check_contains.h"
-#include "../tirx/transform/replace_selected_expr.h"
-#include "./pattern_match.h"
-
-namespace tvm {
-namespace arith {
-
-bool IsVScaleCall(const PrimExpr& expr) {
-  if (auto call = expr.as<tirx::CallNode>()) {
-    return call->op.same_as(tirx::builtin::vscale());
-  }
-  return false;
-}
-
-bool ContainsVscaleCall(const PrimExpr& expr) {
-  return tirx::CheckContains::ExprContains(expr, IsVScaleCall);
-}
-
-PrimExpr SubstituteVScaleWithKnownValue(const PrimExpr& expr, unsigned int 
vscale_value) {
-  std::function<bool(const PrimExpr&)> predicate_selector = [](const PrimExpr& 
current_expr) {
-    return IsVScaleCall(current_expr);
-  };
-  std::function<bool(const PrimExpr&)> can_replace_inside = [](const PrimExpr& 
current_expr) {
-    return true;
-  };
-
-  return tirx::ReplaceSelectedExpr::ReplaceSelectedExprInExpr(
-      expr, predicate_selector, tirx::MakeConstScalar(DataType::Int(32), 
vscale_value),
-      can_replace_inside);
-}
-
-std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes) {
-  PVar<IntImm> multiplier;
-  PCallExpr<PVscaleOp> vscale;
-
-  if (PMatchesOneOf(multiplier * vscale, vscale * multiplier).Match(lanes)) {
-    return multiplier.Eval()->value;
-  } else {
-    return std::nullopt;
-  }
-}
-
-bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const 
PrimExpr& expr,
-                                             const std::vector<unsigned int>& 
vscale_values) {
-  bool can_prove_expr = true;
-  for (const unsigned int vscale_value : vscale_values) {
-    PrimExpr result = SubstituteVScaleWithKnownValue(expr, vscale_value);
-    result = analyzer->Simplify(result);
-    const int64_t* as_int = tirx::as_const_int(result);
-    if (!as_int || *as_int == 0) {
-      can_prove_expr = false;
-      break;
-    }
-  }
-  return can_prove_expr;
-}
-
-bool TargetHasVLA(ffi::Optional<Target> target) {
-  if (!target.defined()) {
-    target = Target::Current();
-  }
-  bool has_vla{false};
-  if (target.defined()) {
-    // aarch64
-    has_vla = 
Downcast<Target>(target)->GetAttr<bool>("feature.has_sve").value_or(false);
-    // riscv{32,64}
-    static auto target_has_feature_fn =
-        tvm::ffi::Function::GetGlobalRequired("target.target_has_feature");
-    has_vla |= target_has_feature_fn("v", target).cast<bool>();
-  }
-  return has_vla;
-}
-
-const std::vector<unsigned int> GetVScaleValues(ffi::Optional<Target> target) {
-  unsigned int vector_width = 0;
-  std::vector<unsigned int> kVScaleValues;
-  if (!target.defined()) {
-    target = Target::Current();
-  }
-  if (target.defined()) {
-    static auto llvm_get_vector_width_fn =
-        tvm::ffi::Function::GetGlobalRequired("target.llvm_get_vector_width");
-    vector_width = llvm_get_vector_width_fn(target).cast<int>();
-  }
-  // scale list with powers of two
-  for (unsigned int i = 0;; ++i) {
-    auto power = static_cast<unsigned int>(std::pow(2, i));
-    if (power > (vector_width / 8)) break;
-    kVScaleValues.push_back(power);
-  }
-
-  return kVScaleValues;
-}
-
-}  // namespace arith
-}  // namespace tvm
diff --git a/src/arith/scalable_expression.h b/src/arith/scalable_expression.h
deleted file mode 100644
index 88c1402887..0000000000
--- a/src/arith/scalable_expression.h
+++ /dev/null
@@ -1,96 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/arith/scalable_expression.h
- * \brief Analyze scalable expressions.
- */
-
-#ifndef TVM_ARITH_SCALABLE_EXPRESSION_H_
-#define TVM_ARITH_SCALABLE_EXPRESSION_H_
-
-#include <tvm/arith/analyzer.h>
-#include <tvm/ir/expr.h>
-#include <tvm/target/target.h>
-
-#include <optional>
-#include <vector>
-
-namespace tvm {
-namespace arith {
-
-/*!
- * \brief Check if an expr is a call to the vscale intrinsic.
- * \param expr The expr to check
- * \return True if the expr is a call to the vscale intrinsic, false if not.
- */
-bool IsVScaleCall(const PrimExpr& expr);
-
-/*!
- * \brief Check if an expr contains a call to the vscale intrinsic.
- * \param expr The expr to check
- * \return True if the expr contains a call to the vscale intrinsic, false if 
not.
- */
-bool ContainsVscaleCall(const PrimExpr& expr);
-
-/*!
- * \brief Substitute a vscale intrinsic call with a known scalar value.
- * \param expr The expr to apply substitutions to.
- * \param vscale_value The scalar value to replace vscale with.
- * \return A rewritten expression with vscale values replaced with a scalar 
value.
- */
-PrimExpr SubstituteVScaleWithKnownValue(const PrimExpr& expr, unsigned int 
vscale_value);
-
-/*!
- * \brief Returns the vscale multiplier as a nullable type
- * \param lanes The scalable lanes as a PrimExpr
- * \return vscale multiplier as std::optional<int>
- */
-std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes);
-
-/*!
- * \brief Check if the expression can be proven when evaluating it on all 
possible values
-           of vscale.
- * \param analyzer An analyzer instance.
- * \param expr The expression to try to prove.
- * \param vscale_values A list of values to substitute vscale with.
- * \return Whether or not the expression can be proven with this technique.
- */
-bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const 
PrimExpr& expr,
-                                             const std::vector<unsigned int>& 
vscale_values);
-
-/*!
- * \brief Check whether the compilation target supports SVE
- * \brief Check whether the compilation target supports VLA
- * \param target The target to check.
- * \return Whether VLA is supported
- */
-bool TargetHasVLA(ffi::Optional<Target> target = std::nullopt);
-
-/*!
- * \brief Get a list of known vscale values to try for an VLA target.
- * \param target The target to check.
- * \return A list of vscale values as std::vector<usigned int>
- */
-const std::vector<unsigned int> GetVScaleValues(ffi::Optional<Target> target = 
std::nullopt);
-
-}  // namespace arith
-}  // namespace tvm
-
-#endif  // TVM_ARITH_SCALABLE_EXPRESSION_H_
diff --git a/src/s_tir/schedule/ir_comparator.cc 
b/src/s_tir/schedule/ir_comparator.cc
index 06b8ea6d4a..1bb66a2381 100644
--- a/src/s_tir/schedule/ir_comparator.cc
+++ b/src/s_tir/schedule/ir_comparator.cc
@@ -19,11 +19,27 @@
 #include "./ir_comparator.h"
 
 #include <tvm/ffi/cast.h>
+#include <tvm/tirx/builtin.h>
 
-#include "../../arith/scalable_expression.h"
+#include "../../tirx/analysis/check_contains.h"
 
 namespace tvm {
 
+namespace {
+// File-local helper: true if `expr` is a call to tirx::builtin::vscale().
+bool IsVScaleCall(const PrimExpr& expr) {
+  if (const auto* call = expr.as<tirx::CallNode>()) {
+    return call->op.same_as(tirx::builtin::vscale());
+  }
+  return false;
+}
+
+// File-local helper: true if `expr` contains a call to 
tirx::builtin::vscale().
+bool ContainsVscaleCall(const PrimExpr& expr) {
+  return tirx::CheckContains::ExprContains(expr, IsVScaleCall);
+}
+}  // namespace
+
 namespace s_tir {
 using namespace tvm::tirx;
 
@@ -80,7 +96,7 @@ bool TensorizeComparator::VisitExpr(const PrimExpr& n, const 
PrimExpr& other) {
   bool equal = n.same_as(other) ||
                ((n->type_index() == other->type_index()) &&
                 n.dtype().code() == other.dtype().code() && 
ExprComparator::VisitExpr(n, other)) ||
-               (tvm::arith::ContainsVscaleCall(n) && 
analyzer_.CanProveEqual(n, other));
+               (ContainsVscaleCall(n) && analyzer_.CanProveEqual(n, other));
 
   if (!equal && assert_mode_) {
     std::ostringstream os;
diff --git a/src/target/llvm/codegen_aarch64.cc 
b/src/target/llvm/codegen_aarch64.cc
index 18da2e66d7..3a0a365899 100644
--- a/src/target/llvm/codegen_aarch64.cc
+++ b/src/target/llvm/codegen_aarch64.cc
@@ -29,7 +29,6 @@
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/logging.h>
 
-#include "../../arith/scalable_expression.h"
 #include "codegen_cpu.h"
 #include "llvm_instance.h"
 
@@ -58,9 +57,22 @@ void CodeGenAArch64::AddFunction(const GlobalVar& gvar, 
const PrimFunc& f) {
 void CodeGenAArch64::SetTargetAttributes(llvm::Function* func) {
   // Add vscale_range() function attribute when appropriate.
   if (llvm_target_->TargetHasCPUFeature("sve") || 
llvm_target_->TargetHasCPUFeature("sme")) {
-    auto kVScaleValues = arith::GetVScaleValues(Target::Current());
-    if (!kVScaleValues.empty()) {
-      unsigned int max_val = *std::max_element(kVScaleValues.begin(), 
kVScaleValues.end());
+    // Compute max_val = largest power-of-two <= vector_width/8.
+    // Guard against calling llvm_get_vector_width_fn when no target is active 
—
+    // Target::Current() returns an undefined Target outside a compilation 
context.
+    static auto llvm_get_vector_width_fn =
+        tvm::ffi::Function::GetGlobalRequired("target.llvm_get_vector_width");
+    unsigned int max_val = 0;
+    if (auto target = Target::Current(); target.defined()) {
+      unsigned int vector_width =
+          static_cast<unsigned 
int>(llvm_get_vector_width_fn(target).cast<int>());
+      for (unsigned int i = 0;; ++i) {
+        unsigned int power = 1u << i;
+        if (power > (vector_width / 8)) break;
+        max_val = power;
+      }
+    }
+    if (max_val > 0) {
       func->addFnAttr(
           llvm::Attribute::getWithVScaleRangeArgs(*llvm_target_->GetContext(), 
1, max_val));
     }
diff --git a/src/tirx/ir/expr.cc b/src/tirx/ir/expr.cc
index d458c09180..2fdb083085 100644
--- a/src/tirx/ir/expr.cc
+++ b/src/tirx/ir/expr.cc
@@ -29,13 +29,34 @@
 
 #include <optional>
 
-#include "../../arith/scalable_expression.h"
 #include "../../support/str_escape.h"
 #include "buffer_common.h"
 
 namespace tvm {
 namespace tirx {
 
+namespace {
+// File-local helper: returns the vscale multiplier if `lanes` is of the form
+// `multiplier * vscale()` or `vscale() * multiplier`, nullopt otherwise.
+std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes) {
+  auto is_vscale = [](const PrimExpr& e) -> bool {
+    if (const auto* call = e.as<CallNode>()) {
+      return call->op.same_as(tirx::builtin::vscale());
+    }
+    return false;
+  };
+  if (const auto* mul = lanes.as<MulNode>()) {
+    if (const auto* imm = mul->a.as<IntImmNode>(); imm && is_vscale(mul->b)) {
+      return static_cast<int>(imm->value);
+    }
+    if (const auto* imm = mul->b.as<IntImmNode>(); imm && is_vscale(mul->a)) {
+      return static_cast<int>(imm->value);
+    }
+  }
+  return std::nullopt;
+}
+}  // namespace
+
 TVM_FFI_STATIC_INIT_BLOCK() {
   VarNode::RegisterReflection();
   SizeVarNode::RegisterReflection();
@@ -514,7 +535,7 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, 
Span span) {
     // Stick to int32 lanes for fixed length vectors
     node->lanes = lanes;
   } else { /* scalable vector */
-    std::optional<int> vscale_factor = arith::ExtractVscaleFactor(lanes);
+    std::optional<int> vscale_factor = ExtractVscaleFactor(lanes);
     TVM_FFI_ICHECK(vscale_factor) << "Invalid expression for scalable lanes " 
<< lanes;
 
     node->dtype = 
base.dtype().with_scalable_vscale_factor(vscale_factor.value());
@@ -548,7 +569,7 @@ Broadcast::Broadcast(PrimExpr value, PrimExpr lanes, Span 
span) {
     // Stick to int32 lanes for fixed length vectors
     node->lanes = lanes;
   } else { /* scalable vector */
-    std::optional<int> vscale_factor = arith::ExtractVscaleFactor(lanes);
+    std::optional<int> vscale_factor = ExtractVscaleFactor(lanes);
     TVM_FFI_ICHECK(vscale_factor) << "Invalid expression for scalable lanes " 
<< lanes;
 
     node->dtype = 
value.dtype().with_scalable_vscale_factor(vscale_factor.value());
diff --git a/src/tirx/op/op.cc b/src/tirx/op/op.cc
index b7cb5fe820..539c56eecb 100644
--- a/src/tirx/op/op.cc
+++ b/src/tirx/op/op.cc
@@ -34,13 +34,23 @@
 #include <cmath>
 // Centralized header for constant folders.
 #include "../../arith/const_fold.h"
-#include "../../arith/scalable_expression.h"
 #include "../../target/datatype/registry.h"
+#include "../analysis/check_contains.h"
 
 namespace tvm {
 
 using namespace tirx;
 
+namespace {
+// File-local helper: true if `expr` is a call to tirx::builtin::vscale().
+bool IsVScaleCall(const PrimExpr& expr) {
+  if (const auto* call = expr.as<CallNode>()) {
+    return call->op.same_as(builtin::vscale());
+  }
+  return false;
+}
+}  // namespace
+
 // macro to register an unary op
 #define TVM_TIR_REGISTER_PURE_UNARY_OP(OpName)                             \
   TVM_TIR_REGISTER_OP(OpName).set_num_inputs(1).set_attr<TCallEffectKind>( \
@@ -696,7 +706,7 @@ PrimExpr operator==(PrimExpr a, PrimExpr b) { return 
equal(a, b); }
 PrimExpr equal(PrimExpr a, PrimExpr b, Span span) {
   BinaryOpMatchTypes(a, b, span);
   if (auto ret = arith::TryConstFold<tirx::EQ>(a, b)) return ret.value();
-  if (arith::IsVScaleCall(a) && arith::IsVScaleCall(b)) return true;
+  if (IsVScaleCall(a) && IsVScaleCall(b)) return true;
   return tirx::EQ(a, b, span);
 }
 
diff --git a/src/tirx/transform/vectorize_loop.cc 
b/src/tirx/transform/vectorize_loop.cc
index 0ac9680d0a..540e641bdf 100644
--- a/src/tirx/transform/vectorize_loop.cc
+++ b/src/tirx/transform/vectorize_loop.cc
@@ -38,7 +38,6 @@
 #include <unordered_map>
 #include <vector>
 
-#include "../../src/arith/scalable_expression.h"
 #include "../../tirx/analysis/check_contains.h"
 #include "tvm/runtime/data_type.h"
 #include "tvm/tirx/buffer.h"
@@ -46,6 +45,27 @@
 namespace tvm {
 namespace tirx {
 
+namespace {
+// File-local helper: true if `expr` is a call to tirx::builtin::vscale().
+bool IsVScaleCall(const PrimExpr& expr) {
+  if (const auto* call = expr.as<CallNode>()) {
+    return call->op.same_as(builtin::vscale());
+  }
+  return false;
+}
+
+// File-local helper: true if the target supports Variable-Length Array 
extensions
+// (AArch64 SVE or RISC-V V).
+bool TargetHasVLA(Target target) {
+  if (!target.defined()) return false;
+  bool has_vla = target->GetAttr<bool>("feature.has_sve").value_or(false);
+  static auto target_has_feature_fn =
+      tvm::ffi::Function::GetGlobalRequired("target.target_has_feature");
+  has_vla |= target_has_feature_fn("v", target).cast<bool>();
+  return has_vla;
+}
+}  // namespace
+
 inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) {
   if (is_scalable) {
     return Mul(Call(DataType::Int(32), builtin::vscale(), {}), 
lanes_or_vscale_factor);
@@ -86,7 +106,7 @@ bool EnableBufferLevelPredication(Target target) {
   }
 
   // Use buffer-level predication by default for VLA targets
-  return arith::TargetHasVLA(target);
+  return TargetHasVLA(target);
 }
 
 /*!
@@ -956,8 +976,8 @@ class LoopVectorizer : public StmtMutator {
       auto* extent_as_int = op->extent.as<IntImmNode>();
 
       if (!extent_as_int || extent_as_int->value < 1) {
-        bool is_scalable_expr = CheckContains::ExprContains(op->extent, 
arith::IsVScaleCall);
-        TVM_FFI_ICHECK(is_scalable_expr && arith::TargetHasVLA(target_))
+        bool is_scalable_expr = CheckContains::ExprContains(op->extent, 
IsVScaleCall);
+        TVM_FFI_ICHECK(is_scalable_expr && TargetHasVLA(target_))
             << "Failed to vectorize loop with extent " << op->extent << " for 
target " << target_;
       }
       TVM_FFI_ICHECK(is_zero(op->min));
diff --git a/tests/python/arith/test_arith_rewrite_simplify.py 
b/tests/python/arith/test_arith_rewrite_simplify.py
index ee3f67d60f..17a8397ce9 100644
--- a/tests/python/arith/test_arith_rewrite_simplify.py
+++ b/tests/python/arith/test_arith_rewrite_simplify.py
@@ -917,6 +917,13 @@ class TestMaxIndex(BaseCompare):
     )
 
 
+# These simplifications relied on arith::CanProve being able to prove
+# vscale-bearing inequalities (e.g. vscale() > 0) by substituting known
+# vscale values for the current VLA target. That proof loop has been removed
+# from the arith layer -- arith no longer attempts to reason about scalable
+# vector lengths at the target level. The simplifications are correct in
+# principle but can no longer be proven without the substitution loop.
[email protected](reason="arith no longer proves vscale-bearing inequalities 
via substitution")
 class TestScalableIndex(BaseCompare):
     x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32")
     test_case = tvm.testing.parameter(
diff --git a/tests/python/arith/test_arith_simplify.py 
b/tests/python/arith/test_arith_simplify.py
index d30109fc44..5202dcba2c 100644
--- a/tests/python/arith/test_arith_simplify.py
+++ b/tests/python/arith/test_arith_simplify.py
@@ -87,6 +87,11 @@ def test_simplify_symbolic_comparison():
     assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1, 
PS.SYMBOLIC_BOUND)
 
 
+# These tests exercised arith::CanProve's substitution-based proof loop for
+# vscale-bearing expressions (iterating over known vscale values for a VLA 
target).
+# That loop has been removed -- arith no longer attempts target-dependent 
proofs
+# about scalable-vector lengths. The LOG(WARNING) for non-VLA targets is also 
gone.
[email protected](reason="arith no longer proves vscale-bearing inequalities 
via substitution")
 @pytest.mark.parametrize(
     "expression",
     [
@@ -103,6 +108,9 @@ def 
test_simplify_vscale_comparison_with_sve_target(expression):
         assert ana.can_prove(expression)
 
 
[email protected](
+    reason="arith no longer emits a LOG(WARNING) for vscale proofs on non-VLA 
targets"
+)
 def test_simplify_vscale_comparison_without_sve_target(capfd):
     ana = tvm.arith.Analyzer()
     vs = tvm.tirx.vscale()
diff --git a/tests/python/s_tir/dlight/test_cpu_reduction.py 
b/tests/python/s_tir/dlight/test_cpu_reduction.py
index 9059efeb9f..28e60a1d44 100644
--- a/tests/python/s_tir/dlight/test_cpu_reduction.py
+++ b/tests/python/s_tir/dlight/test_cpu_reduction.py
@@ -191,6 +191,11 @@ def test_rvv_code_size_reduction(fast):
     )
 
 
+# The arith analyzer no longer proves vscale-bearing inequalities via
+# substitution (CanProveVscaleExpressionFromKnownValues was deleted). This
+# weakens simplification of scalable-vector index expressions, which can
+# prevent the RVV vectorization schedule from producing scalable vector ops.
[email protected](reason="arith no longer proves vscale-bearing inequalities 
via substitution")
 def test_rvv_fast_softmax_vectorizes_exp():
     """fast_softmax + schedule should produce RVV vector instructions
     for the polynomial exp approximation (no scalar exp calls)."""
diff --git a/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py 
b/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py
index 58eff502d6..0e6cae7861 100644
--- a/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py
+++ b/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py
@@ -691,114 +691,7 @@ def test_split_int64_factors():
     assert_structural_equal_ignore_global_symbol(elementwise_symbolic_split, 
sch.mod["main"])
 
 
[email protected]("num_elements", [128, 115])
-def test_sve_scalable_split_predicated(num_elements):
-    """
-    By default, splitting with by vscale factors over a fixed-length loop will
-    result in loop-level predication being inserted. This is because, at
-    compile-time, we don't know if vscale is a multiple of the extent of the
-    loop to be split.
-    """
-    with tvm.target.Target({"kind": "llvm", "mtriple": "aarch64-linux-gnu", 
"mattr": ["+sve"]}):
-        outer_extent = tvm.arith.Analyzer().simplify(T.ceildiv(num_elements, 4 
* T.vscale()))
-
-        @T.prim_func(s_tir=True)
-        def before(a: T.handle):
-            A = T.match_buffer(a, (num_elements,), "float32")
-            T.func_attr({"global_symbol": "my_module", "tirx.noalias": True})
-            for i in T.serial(num_elements):
-                with T.sblock("A"):
-                    v_i = T.axis.remap("S", [i])
-                    A[v_i] = 1.0
-
-        @T.prim_func(s_tir=True)
-        def after(a: T.handle):
-            A = T.match_buffer(a, (num_elements,), "float32")
-            T.func_attr({"global_symbol": "my_module", "tirx.noalias": True})
-            for i_0, i_1 in T.grid(outer_extent, T.vscale() * 4):
-                with T.sblock("A"):
-                    v_i = T.axis.spatial(num_elements, i_0 * (T.vscale() * 4) 
+ i_1)
-                    T.where(i_0 * (T.vscale() * 4) + i_1 < num_elements)
-                    A[v_i] = 1.0
-
-        sch = tvm.s_tir.Schedule(before)
-        (a,) = sch.get_loops("A")
-        sch.split(a, factors=[outer_extent, 4 * T.vscale()])
-
-    tvm.ir.assert_structural_equal(sch.mod["main"], after)
-
-
-def test_sve_scalable_split_assume_exact_multiple():
-    """
-    If the schedule writer knows the extent of the loop to be split will always
-    be a multiple of vscale, they may use `disable_predication=True` to ensure
-    a predicate is not created. This can be used to ensure predication is not
-    inserted.
-    """
-    with tvm.target.Target({"kind": "llvm", "mtriple": "aarch64-linux-gnu", 
"mattr": ["+sve"]}):
-        outer_extent = tvm.arith.Analyzer().simplify(T.ceildiv(128, 4 * 
T.vscale()))
-
-        @T.prim_func(s_tir=True)
-        def before(a: T.handle):
-            A = T.match_buffer(a, (128,), "float32")
-            T.func_attr({"global_symbol": "my_module", "tirx.noalias": True})
-            for i in T.serial(128):
-                with T.sblock("A"):
-                    v_i = T.axis.remap("S", [i])
-                    A[v_i] = 1.0
-
-        @T.prim_func(s_tir=True)
-        def after(a: T.handle):
-            A = T.match_buffer(a, (128,), "float32")
-            T.func_attr({"global_symbol": "my_module", "tirx.noalias": True})
-            for i_0, i_1 in T.grid(outer_extent, T.vscale() * 4):
-                with T.sblock("A"):
-                    v_i = T.axis.spatial(128, i_0 * (T.vscale() * 4) + i_1)
-                    A[v_i] = 1.0
-
-        sch = tvm.s_tir.Schedule(before)
-        (a,) = sch.get_loops("A")
-        sch.split(
-            a,
-            factors=[outer_extent, 4 * T.vscale()],
-            disable_predication=True,
-        )
-
-    tvm.ir.assert_structural_equal(sch.mod["main"], after)
-
-
-def test_sve_split_over_scalable_loop():
-    @T.prim_func(s_tir=True)
-    def before(a: T.handle):
-        A = T.match_buffer(a, (128,), "float32")
-        T.func_attr({"global_symbol": "my_module", "tirx.noalias": True})
-        for i in T.serial(4 * T.vscale()):
-            with T.sblock("A"):
-                v_i = T.axis.remap("S", [i])
-                A[v_i] = 1.0
-
-    @T.prim_func(s_tir=True)
-    def after(a: T.handle):
-        A = T.match_buffer(a, (128,), "float32")
-        T.func_attr({"global_symbol": "my_module", "tirx.noalias": True})
-        for i_0, i_1 in T.grid(T.vscale() * 2, T.vscale() * 2):
-            with T.sblock("A"):
-                v_i = T.axis.spatial(T.vscale() * 4, i_0 * (T.vscale() * 2) + 
i_1)
-                T.where(i_0 * (T.vscale() * 2) + i_1 < T.vscale() * 4)
-                A[v_i] = 1.0
-
-    with tvm.target.Target({"kind": "llvm", "mtriple": "aarch64-linux-gnu", 
"mattr": ["+sve"]}):
-        sch = tvm.s_tir.Schedule(before)
-        (a,) = sch.get_loops("A")
-        sch.split(
-            a,
-            factors=[2 * T.vscale(), 2 * T.vscale()],
-        )
-
-    tvm.ir.assert_structural_equal(sch.mod["main"], after)
-
-
-def test_unsupported_target_scalable_split(capfd):
+def test_unsupported_target_scalable_split():
     @T.prim_func(s_tir=True)
     def before(a: T.handle):
         A = T.match_buffer(a, (128,), "float32")
@@ -815,14 +708,6 @@ def test_unsupported_target_scalable_split(capfd):
     with pytest.raises(tvm.s_tir.schedule.ScheduleError, match=err_msg):
         sch.split(a, factors=[T.ceildiv(128, 4 * T.vscale()), 4 * T.vscale()])
 
-    warning_msg = (
-        "Warning: The expression contains scalable values. An attempt to prove 
by substituting "
-        "with known values of vscale was not performed. This proof currently 
only supports "
-        "VLA targets, but the target was "
-    )
-    captured = capfd.readouterr().err
-    assert warning_msg in captured
-
 
 def test_fused_symbolic_2D_tiling():
     @T.prim_func(s_tir=True)


Reply via email to