This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6ea9599c2b [ARITH] Use IntImm in canonical scalar hot paths (#19885)
6ea9599c2b is described below
commit 6ea9599c2ba86249a0dcd4ccb5e2b208fe06c61b
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Jun 25 13:27:45 2026 -0400
[ARITH] Use IntImm in canonical scalar hot paths (#19885)
## Summary
Canonical simplify operates on scalar index expressions in these paths,
so direct `IntImm` construction avoids the generic `MakeConst`
scalar/vector dispatch. This keeps `MakeConst` for generic helper sites
while streamlining the focused split normal-form constants.
Main changes:
- Use `IntImm` for scalar constants in canonical split/sum normalization
and related division/mod paths
- Apply the same cleanup to the scalar-index `Mod` and `FloorMod`
constant-fold zero cases
---
src/arith/canonical_simplify.cc | 43 ++++++++++++++++++++---------------------
src/arith/const_fold.h | 6 ++----
2 files changed, 23 insertions(+), 26 deletions(-)
diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc
index ce906ff143..1c7c979ba4 100644
--- a/src/arith/canonical_simplify.cc
+++ b/src/arith/canonical_simplify.cc
@@ -133,15 +133,15 @@ class SplitExprNode : public CanonicalExprNode {
return IntImm(dtype, 0);
}
if (this->upper_factor != SplitExprNode::kPosInf) {
- res = ModImpl(res, MakeConst(dtype, this->upper_factor), div_mode);
+ res = ModImpl(res, IntImm(dtype, this->upper_factor), div_mode);
}
if (this->lower_factor != 1) {
- res = DivImpl(res, MakeConst(dtype, this->lower_factor), div_mode);
+ res = DivImpl(res, IntImm(dtype, this->lower_factor), div_mode);
}
sscale *= this->scale;
if (sscale != 1) {
TVM_FFI_ICHECK(dtype.code() != DLDataTypeCode::kDLUInt || sscale > 0);
- res = res * MakeConst(dtype, sscale);
+ res = res * IntImm(dtype, sscale);
}
return res;
}
@@ -172,20 +172,20 @@ class SplitExprNode : public CanonicalExprNode {
return false;
}
if (this->upper_factor != SplitExprNode::kPosInf) {
- res = ModImpl(res, MakeConst(this->ty(), this->upper_factor), div_mode);
+ res = ModImpl(res, IntImm(this->ty(), this->upper_factor), div_mode);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
if (this->lower_factor != 1) {
- res = DivImpl(res, MakeConst(this->ty(), this->lower_factor), div_mode);
+ res = DivImpl(res, IntImm(this->ty(), this->lower_factor), div_mode);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
if (this->scale != 1) {
TVM_FFI_ICHECK(this->ty().code() != DLDataTypeCode::kDLUInt ||
this->scale > 0);
- res = res * MakeConst(this->ty(), this->scale);
+ res = res * IntImm(this->ty(), this->scale);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
@@ -252,7 +252,7 @@ class SumExprNode : public CanonicalExprNode {
PrimExpr Normalize() const final {
// quick path 1.
if (this->args.size() == 0) {
- return MakeConst(this->ty(), this->base);
+ return IntImm(this->ty(), this->base);
}
return Normalize_(this->ty(), SimplifySplitExprs(args), base);
}
@@ -354,7 +354,7 @@ class SumExprNode : public CanonicalExprNode {
}
}
if (base > 0 || is_min_value) {
- res = res + MakeConst(dtype, base);
+ res = res + IntImm(dtype, base);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
@@ -369,7 +369,7 @@ class SumExprNode : public CanonicalExprNode {
}
}
if (base < 0 && !is_min_value) {
- res = res - MakeConst(dtype, -base);
+ res = res - IntImm(dtype, -base);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
@@ -507,7 +507,7 @@ class SumExprNode : public CanonicalExprNode {
}
}
if (base > 0 || is_min_value) {
- res = res + MakeConst(dtype, base);
+ res = res + IntImm(dtype, base);
}
// negative scales follows using sub.
for (size_t i = 0; i < args.size(); ++i) {
@@ -516,7 +516,7 @@ class SumExprNode : public CanonicalExprNode {
}
}
if (base < 0 && !is_min_value) {
- res = res - MakeConst(dtype, -base);
+ res = res - IntImm(dtype, -base);
}
return res;
}
@@ -837,8 +837,7 @@ SplitExpr
CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval,
return ToSplitExpr(IntImm(lhs.ty(), 0));
} else {
// move the upper_factor modular into index.
- lhs.CopyOnWrite()->index =
- ModImpl(lhs->index, MakeConst(lhs.ty(), lhs->upper_factor),
div_mode);
+ lhs.CopyOnWrite()->index = ModImpl(lhs->index, IntImm(lhs.ty(),
lhs->upper_factor), div_mode);
lhs.CopyOnWrite()->upper_factor = SplitExprNode::kPosInf;
lhs.CopyOnWrite()->scale = 1;
lhs.CopyOnWrite()->lower_factor *= scaled_cval;
@@ -863,8 +862,8 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr*
plhs, PrimExpr* prhs,
// collect lhs products and try to eliminate by matching them to prod in rhs
ffi::Array<ffi::Optional<PrimExpr>> lhs_prods;
PrimType rhs_ty = prhs->ty();
- PrimExpr new_rhs = MakeConst(rhs_ty, 1);
- PrimExpr new_common_scale = MakeConst(rhs_ty, 1);
+ PrimExpr new_rhs = IntImm(rhs_ty, 1);
+ PrimExpr new_common_scale = IntImm(rhs_ty, 1);
int64_t lhs_cscale = 1, rhs_cscale = 1;
int num_elimination = 0;
@@ -907,13 +906,13 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr*
plhs, PrimExpr* prhs,
// construct prod via canonical form
PrimType lhs_ty = plhs->ty();
- PrimExpr new_lhs = MakeConst(lhs_ty, 1);
+ PrimExpr new_lhs = IntImm(lhs_ty, 1);
for (ffi::Optional<PrimExpr> val : lhs_prods) {
if (val.defined()) new_lhs = new_lhs * val.value();
}
- *plhs = new_lhs * MakeConst(lhs_ty, lhs_cscale);
- *prhs = new_rhs * MakeConst(rhs_ty, rhs_cscale);
- *common_scale = new_common_scale * MakeConst(rhs_ty, cscale_gcd);
+ *plhs = new_lhs * IntImm(lhs_ty, lhs_cscale);
+ *prhs = new_rhs * IntImm(rhs_ty, rhs_cscale);
+ *common_scale = new_common_scale * IntImm(rhs_ty, cscale_gcd);
return true;
}
@@ -1051,7 +1050,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const
FloorDivNode* op) {
}
// Apply floormod(floordiv_result, m) to complete the identity
PrimExpr div_result = Normalize(lhs);
- return this->VisitExpr(floormod(div_result, MakeConst(a.ty(),
new_mod)));
+ return this->VisitExpr(floormod(div_result, IntImm(a.ty(),
new_mod)));
}
}
}
@@ -1098,7 +1097,7 @@ SplitExpr
CanonicalSimplifier::Impl::SplitModConst(SplitExpr lhs, int64_t cval,
// Do a recursive call to simplify the mod with the new factor.
if (new_upper_factor < lhs->upper_factor && lhs->upper_factor !=
SplitExprNode::kPosInf) {
auto updated = ToSplitExpr(
- this->VisitExpr(ModImpl(lhs->index, MakeConst(lhs.ty(),
new_upper_factor), div_mode)));
+ this->VisitExpr(ModImpl(lhs->index, IntImm(lhs.ty(),
new_upper_factor), div_mode)));
// re-apply the lower_factor
if (lhs->lower_factor != 1) {
auto ret = SplitDivConst(updated, lhs->lower_factor, div_mode);
@@ -1416,7 +1415,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const
LTNode* op) {
PrimType dtype = divisible->ty();
TVM_FFI_ICHECK(extra->ty() == dtype);
PrimExpr normal_extra = extra->Normalize();
- if (this->analyzer_->CanProve(normal_extra < MakeConst(dtype, gcd)) &&
+ if (this->analyzer_->CanProve(normal_extra < IntImm(dtype, gcd)) &&
this->analyzer_->CanProve(normal_extra >= IntImm(dtype, 0))) {
// Case 1. 0 <= xn < d
divisible.CopyOnWrite()->DivideBy(gcd);
diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h
index 4793538316..0f1bd4d8a6 100644
--- a/src/arith/const_fold.h
+++ b/src/arith/const_fold.h
@@ -272,8 +272,7 @@ inline ffi::Optional<PrimExpr>
TryConstFold<tirx::Mod>(PrimExpr a, PrimExpr b) {
if (pa->value == 0) return a;
}
if (pb) {
- // MakeConst can handle both vector and scalar types.
- if (pb->value == 1) return tirx::MakeConst(result_ty, 0);
+ if (pb->value == 1) return IntImm(result_ty, 0);
TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero";
}
});
@@ -329,8 +328,7 @@ inline ffi::Optional<PrimExpr>
TryConstFold<tirx::FloorMod>(PrimExpr a, PrimExpr
if (pa->value == 0) return a;
}
if (pb) {
- // MakeConst can handle both vector and scalar types.
- if (pb->value == 1) return tirx::MakeConst(result_ty, 0);
+ if (pb->value == 1) return IntImm(result_ty, 0);
TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero";
}
});