This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6ea9599c2b [ARITH] Use IntImm in canonical scalar hot paths (#19885)
6ea9599c2b is described below

commit 6ea9599c2ba86249a0dcd4ccb5e2b208fe06c61b
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Jun 25 13:27:45 2026 -0400

    [ARITH] Use IntImm in canonical scalar hot paths (#19885)
    
    ## Summary
    
    Canonical simplify operates on scalar index expressions in these paths,
    so direct `IntImm` construction avoids the generic `MakeConst`
    scalar/vector dispatch. This keeps `MakeConst` for generic helper sites
    while streamlining the focused split normal-form constants.
    
    Main changes:
    
    - Use `IntImm` for scalar constants in canonical split/sum normalization
    and related division/mod paths
    - Apply the same cleanup to the scalar-index `Mod` and `FloorMod`
    constant-fold zero cases
---
 src/arith/canonical_simplify.cc | 43 ++++++++++++++++++++---------------------
 src/arith/const_fold.h          |  6 ++----
 2 files changed, 23 insertions(+), 26 deletions(-)

diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc
index ce906ff143..1c7c979ba4 100644
--- a/src/arith/canonical_simplify.cc
+++ b/src/arith/canonical_simplify.cc
@@ -133,15 +133,15 @@ class SplitExprNode : public CanonicalExprNode {
       return IntImm(dtype, 0);
     }
     if (this->upper_factor != SplitExprNode::kPosInf) {
-      res = ModImpl(res, MakeConst(dtype, this->upper_factor), div_mode);
+      res = ModImpl(res, IntImm(dtype, this->upper_factor), div_mode);
     }
     if (this->lower_factor != 1) {
-      res = DivImpl(res, MakeConst(dtype, this->lower_factor), div_mode);
+      res = DivImpl(res, IntImm(dtype, this->lower_factor), div_mode);
     }
     sscale *= this->scale;
     if (sscale != 1) {
       TVM_FFI_ICHECK(dtype.code() != DLDataTypeCode::kDLUInt || sscale > 0);
-      res = res * MakeConst(dtype, sscale);
+      res = res * IntImm(dtype, sscale);
     }
     return res;
   }
@@ -172,20 +172,20 @@ class SplitExprNode : public CanonicalExprNode {
       return false;
     }
     if (this->upper_factor != SplitExprNode::kPosInf) {
-      res = ModImpl(res, MakeConst(this->ty(), this->upper_factor), div_mode);
+      res = ModImpl(res, IntImm(this->ty(), this->upper_factor), div_mode);
       if (!CastIsSafe(dtype, res, analyzer)) {
         return false;
       }
     }
     if (this->lower_factor != 1) {
-      res = DivImpl(res, MakeConst(this->ty(), this->lower_factor), div_mode);
+      res = DivImpl(res, IntImm(this->ty(), this->lower_factor), div_mode);
       if (!CastIsSafe(dtype, res, analyzer)) {
         return false;
       }
     }
     if (this->scale != 1) {
       TVM_FFI_ICHECK(this->ty().code() != DLDataTypeCode::kDLUInt || 
this->scale > 0);
-      res = res * MakeConst(this->ty(), this->scale);
+      res = res * IntImm(this->ty(), this->scale);
       if (!CastIsSafe(dtype, res, analyzer)) {
         return false;
       }
@@ -252,7 +252,7 @@ class SumExprNode : public CanonicalExprNode {
   PrimExpr Normalize() const final {
     // quick path 1.
     if (this->args.size() == 0) {
-      return MakeConst(this->ty(), this->base);
+      return IntImm(this->ty(), this->base);
     }
     return Normalize_(this->ty(), SimplifySplitExprs(args), base);
   }
@@ -354,7 +354,7 @@ class SumExprNode : public CanonicalExprNode {
       }
     }
     if (base > 0 || is_min_value) {
-      res = res + MakeConst(dtype, base);
+      res = res + IntImm(dtype, base);
       if (!CastIsSafe(dtype, res, analyzer)) {
         return false;
       }
@@ -369,7 +369,7 @@ class SumExprNode : public CanonicalExprNode {
       }
     }
     if (base < 0 && !is_min_value) {
-      res = res - MakeConst(dtype, -base);
+      res = res - IntImm(dtype, -base);
       if (!CastIsSafe(dtype, res, analyzer)) {
         return false;
       }
@@ -507,7 +507,7 @@ class SumExprNode : public CanonicalExprNode {
       }
     }
     if (base > 0 || is_min_value) {
-      res = res + MakeConst(dtype, base);
+      res = res + IntImm(dtype, base);
     }
     // negative scales follows using sub.
     for (size_t i = 0; i < args.size(); ++i) {
@@ -516,7 +516,7 @@ class SumExprNode : public CanonicalExprNode {
       }
     }
     if (base < 0 && !is_min_value) {
-      res = res - MakeConst(dtype, -base);
+      res = res - IntImm(dtype, -base);
     }
     return res;
   }
@@ -837,8 +837,7 @@ SplitExpr 
CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval,
       return ToSplitExpr(IntImm(lhs.ty(), 0));
     } else {
       // move the upper_factor modular into index.
-      lhs.CopyOnWrite()->index =
-          ModImpl(lhs->index, MakeConst(lhs.ty(), lhs->upper_factor), 
div_mode);
+      lhs.CopyOnWrite()->index = ModImpl(lhs->index, IntImm(lhs.ty(), 
lhs->upper_factor), div_mode);
       lhs.CopyOnWrite()->upper_factor = SplitExprNode::kPosInf;
       lhs.CopyOnWrite()->scale = 1;
       lhs.CopyOnWrite()->lower_factor *= scaled_cval;
@@ -863,8 +862,8 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* 
plhs, PrimExpr* prhs,
   // collect lhs products and try to eliminate by matching them to prod in rhs
   ffi::Array<ffi::Optional<PrimExpr>> lhs_prods;
   PrimType rhs_ty = prhs->ty();
-  PrimExpr new_rhs = MakeConst(rhs_ty, 1);
-  PrimExpr new_common_scale = MakeConst(rhs_ty, 1);
+  PrimExpr new_rhs = IntImm(rhs_ty, 1);
+  PrimExpr new_common_scale = IntImm(rhs_ty, 1);
   int64_t lhs_cscale = 1, rhs_cscale = 1;
   int num_elimination = 0;
 
@@ -907,13 +906,13 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* 
plhs, PrimExpr* prhs,
 
   // construct prod via canonical form
   PrimType lhs_ty = plhs->ty();
-  PrimExpr new_lhs = MakeConst(lhs_ty, 1);
+  PrimExpr new_lhs = IntImm(lhs_ty, 1);
   for (ffi::Optional<PrimExpr> val : lhs_prods) {
     if (val.defined()) new_lhs = new_lhs * val.value();
   }
-  *plhs = new_lhs * MakeConst(lhs_ty, lhs_cscale);
-  *prhs = new_rhs * MakeConst(rhs_ty, rhs_cscale);
-  *common_scale = new_common_scale * MakeConst(rhs_ty, cscale_gcd);
+  *plhs = new_lhs * IntImm(lhs_ty, lhs_cscale);
+  *prhs = new_rhs * IntImm(rhs_ty, rhs_cscale);
+  *common_scale = new_common_scale * IntImm(rhs_ty, cscale_gcd);
   return true;
 }
 
@@ -1051,7 +1050,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const 
FloorDivNode* op) {
             }
             // Apply floormod(floordiv_result, m) to complete the identity
             PrimExpr div_result = Normalize(lhs);
-            return this->VisitExpr(floormod(div_result, MakeConst(a.ty(), 
new_mod)));
+            return this->VisitExpr(floormod(div_result, IntImm(a.ty(), 
new_mod)));
           }
         }
       }
@@ -1098,7 +1097,7 @@ SplitExpr 
CanonicalSimplifier::Impl::SplitModConst(SplitExpr lhs, int64_t cval,
       // Do a recursive call to simplify the mod with the new factor.
       if (new_upper_factor < lhs->upper_factor && lhs->upper_factor != 
SplitExprNode::kPosInf) {
         auto updated = ToSplitExpr(
-            this->VisitExpr(ModImpl(lhs->index, MakeConst(lhs.ty(), 
new_upper_factor), div_mode)));
+            this->VisitExpr(ModImpl(lhs->index, IntImm(lhs.ty(), 
new_upper_factor), div_mode)));
         // re-apply the lower_factor
         if (lhs->lower_factor != 1) {
           auto ret = SplitDivConst(updated, lhs->lower_factor, div_mode);
@@ -1416,7 +1415,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const 
LTNode* op) {
     PrimType dtype = divisible->ty();
     TVM_FFI_ICHECK(extra->ty() == dtype);
     PrimExpr normal_extra = extra->Normalize();
-    if (this->analyzer_->CanProve(normal_extra < MakeConst(dtype, gcd)) &&
+    if (this->analyzer_->CanProve(normal_extra < IntImm(dtype, gcd)) &&
         this->analyzer_->CanProve(normal_extra >= IntImm(dtype, 0))) {
       // Case 1. 0 <= xn < d
       divisible.CopyOnWrite()->DivideBy(gcd);
diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h
index 4793538316..0f1bd4d8a6 100644
--- a/src/arith/const_fold.h
+++ b/src/arith/const_fold.h
@@ -272,8 +272,7 @@ inline ffi::Optional<PrimExpr> 
TryConstFold<tirx::Mod>(PrimExpr a, PrimExpr b) {
       if (pa->value == 0) return a;
     }
     if (pb) {
-      // MakeConst can handle both vector and scalar types.
-      if (pb->value == 1) return tirx::MakeConst(result_ty, 0);
+      if (pb->value == 1) return IntImm(result_ty, 0);
       TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero";
     }
   });
@@ -329,8 +328,7 @@ inline ffi::Optional<PrimExpr> 
TryConstFold<tirx::FloorMod>(PrimExpr a, PrimExpr
       if (pa->value == 0) return a;
     }
     if (pb) {
-      // MakeConst can handle both vector and scalar types.
-      if (pb->value == 1) return tirx::MakeConst(result_ty, 0);
+      if (pb->value == 1) return IntImm(result_ty, 0);
       TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero";
     }
   });

Reply via email to