This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e90cd3443d [TIR] Construct proven scalar integer constants directly
(#19934)
e90cd3443d is described below
commit e90cd3443d1c5f9ef0e300046d9aa733ca9c9747
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Jul 4 22:39:39 2026 +0800
[TIR] Construct proven scalar integer constants directly (#19934)
Replace MakeConst with direct IntImm construction at call sites whose
static contracts guarantee scalar integer constants.
The repository-wide audit converts 80 call sites across 32 files.
Generic construction remains where runtime dtype, vector behavior,
unsigned range, overload resolution, or invalid-input diagnostics
require it. Eleven initially proposed conversions were reverted after
compilation and focused tests exposed false positives.
Validation:
- LLVM-enabled compiler and runtime library build
- Focused arithmetic, TIR, S-TIR, TOPI, TE, LLVM-codegen, reflection,
and printing suites
- Changed-file pre-commit and git diff checks
---
include/tvm/topi/nn/pooling.h | 2 +-
include/tvm/topi/transform.h | 4 ++--
src/arith/ir_mutator_with_analyzer.cc | 6 +++---
src/arith/iter_affine_map.cc | 16 ++++++++--------
src/arith/product_normal_form.h | 4 ++--
src/arith/rewrite_simplify.cc | 10 +++++-----
src/arith/solve_linear_equation.cc | 4 ++--
src/arith/solve_linear_inequality.cc | 6 +++---
src/s_tir/schedule/analysis/layout.cc | 2 +-
src/s_tir/schedule/primitive/blockize_tensorize.cc | 2 +-
src/s_tir/schedule/primitive/cache_index.cc | 2 +-
src/s_tir/schedule/primitive/cache_read_write.cc | 12 ++++++------
src/s_tir/schedule/primitive/reduction.cc | 2 +-
src/s_tir/transform/canonicalize_loop.cc | 2 +-
src/s_tir/transform/compact_buffer_region.cc | 8 ++++----
src/s_tir/transform/inject_double_buffer.cc | 8 ++++----
src/s_tir/transform/inject_virtual_thread.cc | 5 ++---
src/s_tir/transform/lower_thread_allreduce.cc | 6 +++---
src/target/intrin_rule.cc | 4 ++--
src/target/llvm/codegen_cpu.cc | 4 ++--
src/target/llvm/codegen_llvm.cc | 2 +-
src/te/operation/create_primfunc.cc | 2 +-
src/tirx/ir/buffer.cc | 4 ++--
src/tirx/ir/layout/tile_slice.cc | 4 ++--
src/tirx/ir/layout/utils.cc | 2 +-
src/tirx/ir/stmt.cc | 6 +++---
src/tirx/transform/ir_utils.h | 4 ++--
src/tirx/transform/lower_intrin.cc | 13 ++++++-------
src/tirx/transform/lower_tvm_builtin.cc | 2 +-
src/tirx/transform/lower_warp_memory.cc | 6 +++---
src/tirx/transform/tvm_ffi_binder.cc | 6 +++---
src/tirx/transform/unroll_loop.cc | 2 +-
32 files changed, 80 insertions(+), 82 deletions(-)
diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h
index 91b10e7d8d..fc67aff781 100644
--- a/include/tvm/topi/nn/pooling.h
+++ b/include/tvm/topi/nn/pooling.h
@@ -194,7 +194,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const
Tensor& x,
PrimExpr w_end = min(w_start + kernel_width, width);
h_start = max(h_start, IntImm(h_start.ty(), 0));
w_start = max(w_start, IntImm(w_start.ty(), 0));
- divide_factor = max((h_end - h_start) * (w_end - w_start),
MakeConst(h_end.ty(), 1));
+ divide_factor = max((h_end - h_start) * (w_end - w_start),
IntImm(h_end.ty(), 1));
}
return tvm::sum(
tvm::if_then_else(tirx::And(tirx::And(out_idx[height_axis] >=
out_idx_lower_h,
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index 26e3b9a1b7..8c814dbd28 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -843,7 +843,7 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor&
x, const te::Tensor& b
ffi::Array<PrimExpr> begin_expr, end_expr, strides_expr;
for (int64_t i = 0; i < num_dynamic_axes; ++i) {
- auto ind = MakeConst(index_ty, i);
+ auto ind = IntImm(index_ty, i);
begin_expr.push_back(begin(ind));
end_expr.push_back(end(ind));
strides_expr.push_back(strides(ind));
@@ -939,7 +939,7 @@ inline Tensor strided_slice_with_axes(
for (size_t i = 0; i < out_shape.size(); ++i)
real_indices.push_back(indices[i]);
for (size_t i = 0; i < normalized_axes.size(); ++i) {
int64_t ax = normalized_axes[i];
- auto stride = MakeConst(strides[i]->ty.as_or_throw<PrimType>(),
strides_vec[i]);
+ auto stride = IntImm(strides[i]->ty.as_or_throw<PrimType>(),
strides_vec[i]);
PrimExpr ind = indices[ax] * stride + begin_expr[i];
real_indices.Set(ax, ind);
}
diff --git a/src/arith/ir_mutator_with_analyzer.cc
b/src/arith/ir_mutator_with_analyzer.cc
index 932ea34980..cdc459c752 100644
--- a/src/arith/ir_mutator_with_analyzer.cc
+++ b/src/arith/ir_mutator_with_analyzer.cc
@@ -55,10 +55,10 @@ void AppendFloorDivConstraints(const FloorDivNode* div,
int64_t value, CompareKi
if (!TryGetIntImm(div->b, &divisor_value) || divisor_value <= 0) return;
PrimType dtype = div->a.ty();
- PrimExpr divisor = MakeConst(dtype, divisor_value);
- PrimExpr k = MakeConst(dtype, value);
+ PrimExpr divisor = IntImm(dtype, divisor_value);
+ PrimExpr k = IntImm(dtype, value);
PrimExpr lo = k * divisor;
- PrimExpr hi = (k + MakeConst(dtype, 1)) * divisor;
+ PrimExpr hi = (k + IntImm(dtype, 1)) * divisor;
switch (kind) {
case CompareKind::kEQ:
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index ae37276531..7292670823 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -565,7 +565,7 @@ class IterMapRewriter : public ExprMutator {
IterMapLevel check_level) {
std::vector<bool> used(splits.size(), false);
std::vector<IterSplitExpr> iters;
- PrimExpr expected_lower_factor = MakeConst(mark->source.ty(), 1);
+ PrimExpr expected_lower_factor = IntImm(mark->source.ty(), 1);
for (size_t i = 0; i < splits.size(); ++i) {
size_t j = 0;
@@ -790,7 +790,7 @@ class IterMapRewriter : public ExprMutator {
for (IterSplitExpr split : expr->args) {
int64_t symbol_prod_count = 0;
int64_t cscale = 1;
- PrimExpr res = tirx::MakeConst(split.ty(), 1);
+ PrimExpr res = IntImm(split.ty(), 1);
auto fcollect = [&](PrimExpr val) {
if (const auto* intimm = val.as<IntImmNode>()) {
cscale *= intimm->value;
@@ -801,7 +801,7 @@ class IterMapRewriter : public ExprMutator {
};
UnpackReduction<tirx::MulNode>(split->scale, fcollect);
if (cscale != 1) {
- res = res * tirx::MakeConst(res.ty(), cscale);
+ res = res * IntImm(res.ty(), cscale);
}
split.CopyOnWrite()->scale = res;
items.emplace_back(Item{cscale, symbol_prod_count, split});
@@ -1884,12 +1884,12 @@ PrimExpr
IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P
} else if (CanProveDivisible(rhs, lhs->scale) && is_zero(base)) {
// floordiv(x*c1, c1*c2) = floordiv(x, c2), c2=rhs/scale
rhs = floordiv(rhs, lhs->scale);
- lhs.CopyOnWrite()->scale = MakeConst(rhs.ty(), 1);
+ lhs.CopyOnWrite()->scale = IntImm(rhs.ty(), 1);
} else if (CanProveDivisible(rhs, lhs->scale) && CanProveDivisible(base,
lhs->scale)) {
// floordiv(x*c1 + y*c1, c1*c2) = floordiv(x+y, c2), c2=rhs/scale
base = floordiv(base, lhs->scale);
rhs = floordiv(rhs, lhs->scale);
- lhs.CopyOnWrite()->scale = MakeConst(rhs.ty(), 1);
+ lhs.CopyOnWrite()->scale = IntImm(rhs.ty(), 1);
} else {
// mark as unresolved.
ErrorLogger(this) << "Cannot represent as IterMap: the numerator's
scaling factor, "
@@ -1935,7 +1935,7 @@ PrimExpr
IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P
new_split = IterSplitExpr(IterMark(padded, padded->extent),
/* lower_factor = */ rhs,
/* extent = */
analyzer_->Simplify(ceildiv(padded->extent, rhs)),
- /* scale = */ MakeConst(rhs.ty(), 1));
+ /* scale = */ IntImm(rhs.ty(), 1));
}
auto new_base = analyzer_->Simplify(floordiv(base - left_pad, rhs), 6);
@@ -2381,7 +2381,7 @@ class SubspaceDivider {
// args are sorted from inner to outer
static IterMark MarkFromArgsAndBase(const std::vector<IterSplitExpr>& args,
PrimExpr base) {
std::vector<IterSplitExpr> res;
- PrimExpr extent = MakeConst(base.ty(), 1);
+ PrimExpr extent = IntImm(base.ty(), 1);
for (const IterSplitExpr& it : args) {
IterSplitExpr arg = it;
arg.CopyOnWrite()->scale = extent;
@@ -2435,7 +2435,7 @@ class SubspaceDivider {
bool encountered_boundary = mark_division.IsOuter();
std::vector<bool> used(splits.size(), false);
std::vector<IterSplitExpr> inner_iters, outer_iters;
- PrimExpr expected_lower_factor = MakeConst(expr->source->source.ty(), 1);
+ PrimExpr expected_lower_factor = IntImm(expr->source->source.ty(), 1);
// find the boundary of outer and inner, like case 1 above
for (size_t i = 0; i < splits.size(); ++i) {
size_t j = 0;
diff --git a/src/arith/product_normal_form.h b/src/arith/product_normal_form.h
index 79e040287f..40af94914d 100644
--- a/src/arith/product_normal_form.h
+++ b/src/arith/product_normal_form.h
@@ -80,7 +80,7 @@ inline void UnpackSum(const PrimExpr& value, FLeaf fleaf, int
sign = 1) {
inline PrimExpr MulAndNormalize(const PrimExpr& lhs, const PrimExpr& rhs) {
int64_t cscale = 1;
PrimType lhs_ty = lhs.ty();
- PrimExpr res = tirx::MakeConst(lhs_ty, 1);
+ PrimExpr res = IntImm(lhs_ty, 1);
auto fcollect = [&](PrimExpr val) {
if (const auto* intimm = val.as<IntImmNode>()) {
cscale *= intimm->value;
@@ -91,7 +91,7 @@ inline PrimExpr MulAndNormalize(const PrimExpr& lhs, const
PrimExpr& rhs) {
UnpackReduction<tirx::MulNode>(lhs, fcollect);
UnpackReduction<tirx::MulNode>(rhs, fcollect);
if (cscale != 1) {
- res = res * tirx::MakeConst(res.ty(), cscale);
+ res = res * IntImm(res.ty(), cscale);
}
return res;
}
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 0effd573c5..3fef41ce71 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -844,7 +844,7 @@ Expr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op)
{
if (truncdiv(c1, c2).Match(ret)) {
int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value;
- return MakeConst(op->ty.as_or_throw<PrimType>(), truncdiv(c1val, c2val));
+ return IntImm(op->ty.as_or_throw<PrimType>(), truncdiv(c1val, c2val));
}
// while it is always true for trunc div
@@ -1025,7 +1025,7 @@ Expr RewriteSimplifier::Impl::VisitExpr_(const ModNode*
op) {
// NOTE: trunc div required
TVM_TRY_RECURSIVE_REWRITE_IF(
truncmod(x, c1),
- truncmod(x, PConst<PrimExpr>(MakeConst(op->ty.as_or_throw<PrimType>(),
-c1.Eval()->value))),
+ truncmod(x, PConst<PrimExpr>(IntImm(op->ty.as_or_throw<PrimType>(),
-c1.Eval()->value))),
c1.Eval()->value < 0);
// try modular analysis
@@ -2017,9 +2017,9 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT
ret) {
} else if (diff == 1) {
return lhs <= rhs;
} else if (diff < 0 && rhs_offset != 0) {
- return lhs + MakeConst(lhs.ty(), -diff) < rhs;
+ return lhs + IntImm(lhs.ty(), -diff) < rhs;
} else if (diff > 0 && lhs_offset != 0) {
- return lhs < rhs + MakeConst(rhs.ty(), diff);
+ return lhs < rhs + IntImm(rhs.ty(), diff);
}
return std::nullopt;
@@ -2374,7 +2374,7 @@ Expr RewriteSimplifier::Impl::VisitExpr_(const CallNode*
op) {
} else if (op->op.same_as(clz_op)) {
if (const auto* arg_int = op->args[0].as<IntImmNode>()) {
int bits = arg_int->ty.as_or_throw<PrimType>().bits();
- if (arg_int->value == 0) return MakeConst(ret_ty, bits);
+ if (arg_int->value == 0) return IntImm(ret_ty, bits);
for (int i = bits - 1; i >= 0; --i) {
if ((int64_t(1) << i) & arg_int->value) {
return IntImm(ret_ty, bits - i - 1);
diff --git a/src/arith/solve_linear_equation.cc
b/src/arith/solve_linear_equation.cc
index fd507ccdd6..7a11888dd3 100644
--- a/src/arith/solve_linear_equation.cc
+++ b/src/arith/solve_linear_equation.cc
@@ -403,12 +403,12 @@ IntConstraintsTransform SolveLinearEquations(const
IntConstraints& system_to_sol
// The j-th variable is just a single value, don't create a tvm variable
// S^{-1}_{nxm} Uy_{mxn}
if (S[j][j] >= 0) {
- PrimExpr a = tirx::MakeConst(Uy[j].ty(), S[j][j]);
+ PrimExpr a = IntImm(Uy[j].ty(), S[j][j]);
solution_for_V_inv_x.push_back(analyzer_problem->Simplify(floordiv(Uy[j], a)));
} else {
// This is required because some simplifiers
// have problems with dividing by negative numbers
- PrimExpr a = tirx::MakeConst(Uy[j].ty(), -S[j][j]);
+ PrimExpr a = IntImm(Uy[j].ty(), -S[j][j]);
solution_for_V_inv_x.push_back(analyzer_problem->Simplify(floordiv(-Uy[j], a)));
}
}
diff --git a/src/arith/solve_linear_inequality.cc
b/src/arith/solve_linear_inequality.cc
index 4d04fd5265..e6130edae2 100644
--- a/src/arith/solve_linear_inequality.cc
+++ b/src/arith/solve_linear_inequality.cc
@@ -252,7 +252,7 @@ PartialSolvedInequalities SolveLinearInequalities(const
IntConstraints& system_t
auto first_gcd = ExtendedEuclidean(pos.first, -neg.first, &gcd_x,
&gcd_y);
PrimType v_ty = v.ty();
PrimExpr c_pos = MakeConst(v_ty, neg.first / first_gcd);
- PrimExpr c_neg = MakeConst(v_ty, pos.first / first_gcd);
+ PrimExpr c_neg = IntImm(v_ty, pos.first / first_gcd);
// eliminate the current variable
PrimExpr new_lhs = c_neg * neg.second - c_pos * pos.second;
PrimExpr new_ineq = LE(new_lhs, IntImm(pos.second.ty(), 0));
@@ -306,7 +306,7 @@ PartialSolvedInequalities SolveLinearInequalities(const
IntConstraints& system_t
upper_bounds.push_back(bound);
}
for (const auto& neg : coef_neg) {
- PrimExpr bound = MakeConst(v.ty(), -coef_lcm / neg.first) * neg.second;
+ PrimExpr bound = IntImm(v.ty(), -coef_lcm / neg.first) * neg.second;
bound = analyzer->Simplify(bound, kSimplifyRewriteCanonicalRewrite);
// Don't add if any of the existing bounds is better
if (std::any_of(lower_bounds.begin(), lower_bounds.end(),
@@ -334,7 +334,7 @@ PartialSolvedInequalities SolveLinearInequalities(const
IntConstraints& system_t
std::sort(equal_list.begin(), equal_list.end(), ExprLess());
// Write it to the result.
- IntGroupBounds bnds(MakeConst(v.ty(), coef_lcm),
+ IntGroupBounds bnds(IntImm(v.ty(), coef_lcm),
ffi::Array<PrimExpr>(lower_bounds.begin(),
lower_bounds.end()),
ffi::Array<PrimExpr>(equal_list.begin(),
equal_list.end()),
ffi::Array<PrimExpr>(upper_bounds.begin(),
upper_bounds.end()));
diff --git a/src/s_tir/schedule/analysis/layout.cc
b/src/s_tir/schedule/analysis/layout.cc
index 223bd46832..5591204242 100644
--- a/src/s_tir/schedule/analysis/layout.cc
+++ b/src/s_tir/schedule/analysis/layout.cc
@@ -40,7 +40,7 @@ ffi::Array<PrimExpr> GetStrides(const Buffer& buffer) {
return {};
}
ffi::Array<PrimExpr> strides(ndim, PrimExpr{nullptr});
- PrimExpr stride = MakeConst(PrimType(buffer->DefaultIndexType()), 1);
+ PrimExpr stride = IntImm(PrimType(buffer->DefaultIndexType()), 1);
for (int i = ndim - 1; i >= 0; --i) {
strides.Set(i, stride);
stride = stride * buffer->shape[i];
diff --git a/src/s_tir/schedule/primitive/blockize_tensorize.cc
b/src/s_tir/schedule/primitive/blockize_tensorize.cc
index 18e1856409..41c991c5d2 100644
--- a/src/s_tir/schedule/primitive/blockize_tensorize.cc
+++ b/src/s_tir/schedule/primitive/blockize_tensorize.cc
@@ -829,7 +829,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref,
const TensorIntrin& int
new_region.reserve(cur->shape.size());
for (int i = 0; i < offset; i++) {
PrimExpr min = indices_base[i];
- PrimExpr extent = MakeConst(min.ty(), 1);
+ PrimExpr extent = IntImm(min.ty(), 1);
new_region.push_back(Range::FromMinExtent(min, extent));
}
for (int i = 0; i < static_cast<int>(old_region.size()); i++) {
diff --git a/src/s_tir/schedule/primitive/cache_index.cc
b/src/s_tir/schedule/primitive/cache_index.cc
index caf929100d..95cadad2c3 100644
--- a/src/s_tir/schedule/primitive/cache_index.cc
+++ b/src/s_tir/schedule/primitive/cache_index.cc
@@ -304,7 +304,7 @@ ffi::Array<SBlock> MakeIndexCacheStage(IndexInfo* info,
const ffi::String& stora
/*IterVarType=*/kDataPar));
access_indices.push_back(var);
- access_region.push_back(Range::FromMinExtent(var, MakeConst(var.ty(),
1)));
+ access_region.push_back(Range::FromMinExtent(var, IntImm(var.ty(), 1)));
block_var_map.Set(block_var, var);
}
diff --git a/src/s_tir/schedule/primitive/cache_read_write.cc
b/src/s_tir/schedule/primitive/cache_read_write.cc
index 5fa315872e..70c56bd31b 100644
--- a/src/s_tir/schedule/primitive/cache_read_write.cc
+++ b/src/s_tir/schedule/primitive/cache_read_write.cc
@@ -270,8 +270,8 @@ SBlock MakeCacheStage(const BufferRegion& cache_region,
CacheStageInfo* info,
/*IterVarType=*/kDataPar));
read_access_indices.push_back(var);
write_access_indices.push_back(var);
- read_access_region.push_back(Range::FromMinExtent(var,
MakeConst(var.ty(), 1)));
- write_access_region.push_back(Range::FromMinExtent(var,
MakeConst(var.ty(), 1)));
+ read_access_region.push_back(Range::FromMinExtent(var, IntImm(var.ty(),
1)));
+ write_access_region.push_back(Range::FromMinExtent(var, IntImm(var.ty(),
1)));
} else {
block_vars.push_back(IterVar(
/*dom=*/Range::FromMinExtent(IntImm(axis_range->extent.ty(), 0),
axis_range->extent),
@@ -281,16 +281,16 @@ SBlock MakeCacheStage(const BufferRegion& cache_region,
CacheStageInfo* info,
// cache_read
read_access_indices.push_back(axis_range->min + var);
read_access_region.push_back(
- Range::FromMinExtent(axis_range->min + var, MakeConst(var.ty(),
1)));
+ Range::FromMinExtent(axis_range->min + var, IntImm(var.ty(), 1)));
write_access_indices.push_back(var);
- write_access_region.push_back(Range::FromMinExtent(var,
MakeConst(var.ty(), 1)));
+ write_access_region.push_back(Range::FromMinExtent(var,
IntImm(var.ty(), 1)));
} else {
// cache_write
write_access_indices.push_back(axis_range->min + var);
write_access_region.push_back(
- Range::FromMinExtent(axis_range->min + var, MakeConst(var.ty(),
1)));
+ Range::FromMinExtent(axis_range->min + var, IntImm(var.ty(), 1)));
read_access_indices.push_back(var);
- read_access_region.push_back(Range::FromMinExtent(var,
MakeConst(var.ty(), 1)));
+ read_access_region.push_back(Range::FromMinExtent(var,
IntImm(var.ty(), 1)));
}
}
}
diff --git a/src/s_tir/schedule/primitive/reduction.cc
b/src/s_tir/schedule/primitive/reduction.cc
index 169508943b..9c806bd686 100644
--- a/src/s_tir/schedule/primitive/reduction.cc
+++ b/src/s_tir/schedule/primitive/reduction.cc
@@ -932,7 +932,7 @@ class RFactorBlockCreator : public BaseBlockCreator {
ffi::Array<Range> region = write_region->region;
region.insert(
region.begin() + factor_axis_,
- Range::FromMinExtent(additional_iter_->var,
MakeConst(additional_iter_->var.ty(), 1)));
+ Range::FromMinExtent(additional_iter_->var,
IntImm(additional_iter_->var.ty(), 1)));
ffi::Optional<Buffer> rf_buffer = buffer_map.Get(write_region->buffer);
TVM_FFI_ICHECK(rf_buffer.defined());
write_regions_.push_back(BufferRegion(rf_buffer.value(),
Substitute(region, var_map_)));
diff --git a/src/s_tir/transform/canonicalize_loop.cc
b/src/s_tir/transform/canonicalize_loop.cc
index 560ac3735e..f5ce387394 100644
--- a/src/s_tir/transform/canonicalize_loop.cc
+++ b/src/s_tir/transform/canonicalize_loop.cc
@@ -48,7 +48,7 @@ class LoopCanonicalizer : public StmtExprMutator {
}
const auto* loop_var = op->loop_var.get();
PrimType loop_var_ty = loop_var->ty.as_or_throw<PrimType>();
- PrimExpr step = op->step.value_or(MakeConst(loop_var_ty, 1));
+ PrimExpr step = op->step.value_or(IntImm(loop_var_ty, 1));
// report warning for negative step, since it would be a forever loop
if (!analyzer_->CanProveGreaterEqual(step, 1)) {
diff --git a/src/s_tir/transform/compact_buffer_region.cc
b/src/s_tir/transform/compact_buffer_region.cc
index ceb8da21bc..a86eb3d0e9 100644
--- a/src/s_tir/transform/compact_buffer_region.cc
+++ b/src/s_tir/transform/compact_buffer_region.cc
@@ -472,7 +472,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
// try estimate a constant upperbound on region's extent
int64_t upperbound = dom_analyzer_->const_int_bound(extent)->max_value;
if (upperbound != arith::ConstIntBound::kPosInf) {
- extent = MakeConst(extent.ty(), upperbound);
+ extent = IntImm(extent.ty(), upperbound);
} else {
result_region.Set(i, original);
continue;
@@ -701,15 +701,15 @@ ffi::Array<PrimExpr> CalcStrides(const BufferAllocInfo&
alloc_info,
if (alloc_info.dim_aligns.size()) {
TVM_FFI_ICHECK(alloc_info.dim_aligns.size() == shape.size());
strides.resize(shape.size());
- PrimExpr stride = MakeConst(shape[0].ty(), 1);
+ PrimExpr stride = IntImm(shape[0].ty(), 1);
for (size_t i = shape.size(); i != 0; --i) {
size_t dim = i - 1;
DimAlignInfo info = alloc_info.dim_aligns[dim];
int align_factor = info.align_factor;
int align_offset = info.align_offset;
if (align_factor != 0) {
- PrimExpr factor = MakeConst(stride.ty(), align_factor);
- PrimExpr offset = MakeConst(stride.ty(), align_offset);
+ PrimExpr factor = IntImm(stride.ty(), align_factor);
+ PrimExpr offset = IntImm(stride.ty(), align_offset);
stride = stride + indexmod(factor + offset - indexmod(stride, factor),
factor);
}
strides[dim] = stride;
diff --git a/src/s_tir/transform/inject_double_buffer.cc
b/src/s_tir/transform/inject_double_buffer.cc
index fe9f57d827..82a17409c6 100644
--- a/src/s_tir/transform/inject_double_buffer.cc
+++ b/src/s_tir/transform/inject_double_buffer.cc
@@ -164,15 +164,15 @@ class DoubleBufferInjector : public StmtExprMutator {
<< "It is better to split with multiple of 2";
TVM_FFI_ICHECK(is_zero(old_loop->min));
PrimExpr zero = old_loop->min;
- PrimExpr new_ext = old_loop->extent -
MakeConst(old_loop->loop_var.ty(), 1);
- PrimExpr factor = MakeConst(new_ext.ty(), split_loop_);
+ PrimExpr new_ext = old_loop->extent - IntImm(old_loop->loop_var.ty(),
1);
+ PrimExpr factor = IntImm(new_ext.ty(), split_loop_);
PrimExpr outer_ext = new_ext / factor;
PrimExpr tail_base = outer_ext * factor;
Var outer_var(old_loop->loop_var->name_hint + ".outer",
old_loop->loop_var.ty());
std::unordered_map<const VarNode*, PrimExpr> vmap;
std::vector<Stmt> loop_seq;
for (int32_t i = 0; i < split_loop_; ++i) {
- vmap[old_loop->loop_var.get()] = outer_var * factor +
MakeConst(factor.ty(), i);
+ vmap[old_loop->loop_var.get()] = outer_var * factor +
IntImm(factor.ty(), i);
loop_seq.emplace_back(Substitute(old_loop->body, vmap));
}
Stmt loop = For(outer_var, zero, outer_ext, old_loop->kind,
SeqStmt::Flatten(loop_seq));
@@ -180,7 +180,7 @@ class DoubleBufferInjector : public StmtExprMutator {
std::vector<Stmt> tail_seq;
Stmt tail_body = StripDoubleBufferWrite()(old_loop->body);
for (int32_t i = 0; i < split_loop_; ++i) {
- PrimExpr idx = tail_base + MakeConst(tail_base.ty(), i);
+ PrimExpr idx = tail_base + IntImm(tail_base.ty(), i);
vmap[old_loop->loop_var.get()] = idx;
tail_seq.emplace_back(IfThenElse(idx < old_loop->extent,
Substitute(tail_body, vmap)));
}
diff --git a/src/s_tir/transform/inject_virtual_thread.cc
b/src/s_tir/transform/inject_virtual_thread.cc
index df0408cf71..f0eeec3763 100644
--- a/src/s_tir/transform/inject_virtual_thread.cc
+++ b/src/s_tir/transform/inject_virtual_thread.cc
@@ -468,15 +468,14 @@ class VTInjector : public arith::IRMutatorWithAnalyzer {
// do unrolling if it is inside innermost content.
ffi::Array<Stmt> seq;
for (int i = 0; i < num_threads_; ++i) {
- seq.push_back(Substitute(stmt, {{var_, MakeConst(var_.ty(), i)}}));
+ seq.push_back(Substitute(stmt, {{var_, IntImm(var_.ty(), i)}}));
}
return SeqStmt::Flatten(seq);
} else {
// insert a for loop
Var idx(var_->name_hint + ".s", var_.ty());
stmt = Substitute(stmt, {{var_, idx}});
- return For(idx, IntImm(idx.ty(), 0), MakeConst(idx.ty(), num_threads_),
ForKind::kSerial,
- stmt);
+ return For(idx, IntImm(idx.ty(), 0), IntImm(idx.ty(), num_threads_),
ForKind::kSerial, stmt);
}
}
diff --git a/src/s_tir/transform/lower_thread_allreduce.cc
b/src/s_tir/transform/lower_thread_allreduce.cc
index aa4b369597..1d72716a12 100644
--- a/src/s_tir/transform/lower_thread_allreduce.cc
+++ b/src/s_tir/transform/lower_thread_allreduce.cc
@@ -337,7 +337,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator
{
staging_shared_bufs.reserve(size);
for (size_t i = 0; i < size; ++i) {
Buffer staging_shared_buf = decl_buffer(
- /*shape=*/{MakeConst(reduce_index.ty(), n_warps * group_extent)},
+ /*shape=*/{IntImm(reduce_index.ty(), n_warps * group_extent)},
/*dtype=*/buffers[i]->dtype, /*name=*/"red_buf_staging",
/*storage_scope=*/"shared");
staging_shared_bufs.push_back(staging_shared_buf);
new_alloc_bufs.push_back(staging_shared_buf);
@@ -371,7 +371,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator
{
}
std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
values, dtypes, combiner, reduce_index, n_warps, group_index, mask,
- /*predicate=*/reduce_index < MakeConst(reduce_index.ty(),
n_warps), &seq);
+ /*predicate=*/reduce_index < IntImm(reduce_index.ty(), n_warps),
&seq);
new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(),
local_bufs.end());
// 5. Create shared memory buffer(s) of `group_extent` elements,
storing
@@ -381,7 +381,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator
{
for (size_t i = 0; i < size; ++i) {
new_alloc_bufs.push_back(reduce_results[i].as_or_throw<BufferLoad>()->buffer);
Buffer broadcast_shared_buf = decl_buffer(
- /*shape=*/{MakeConst(reduce_index.ty(), group_extent)},
+ /*shape=*/{IntImm(reduce_index.ty(), group_extent)},
/*dtype=*/buffers[i]->dtype, /*name=*/"red_result",
/*storage_scope=*/"shared");
write_result.push_back(
BufferStore(broadcast_shared_buf, reduce_results[i],
{group_index}));
diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc
index 4f7ec4e939..67ceff9db6 100644
--- a/src/target/intrin_rule.cc
+++ b/src/target/intrin_rule.cc
@@ -134,8 +134,8 @@ TVM_REGISTER_OP("tirx.tvm_access_ptr")
TVM_FFI_ICHECK(call->ty.as_or_throw<PrimType>().IsHandle());
if (dtype.lanes() != 1) {
PrimType offset_ty = offset.ty();
- offset = offset * MakeConst(offset_ty, dtype.lanes());
- offset = Ramp(offset, MakeConst(offset_ty, 1), dtype.lanes());
+ offset = offset * IntImm(offset_ty, dtype.lanes());
+ offset = Ramp(offset, IntImm(offset_ty, 1), dtype.lanes());
}
Buffer dummy_buf(buffer_var, dtype.WithLanes(1), {offset + 1}, {}, 0,
buffer_var->name_hint,
0, 0, kDefault);
diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc
index 0ece07dcc8..a3e0f9030b 100644
--- a/src/target/llvm/codegen_cpu.cc
+++ b/src/target/llvm/codegen_cpu.cc
@@ -1194,9 +1194,9 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
CreateSerialFor(MakeValue(task_id), MakeValue(end),
MakeValue(num_task), op->loop_var,
op->body);
} else {
- PrimExpr step = (op->extent + num_task - MakeConst(t, 1)) / num_task;
+ PrimExpr step = (op->extent + num_task - IntImm(t, 1)) / num_task;
PrimExpr begin = min(task_id * step, op->extent);
- end = min((task_id + MakeConst(t, 1)) * step, end);
+ end = min((task_id + IntImm(t, 1)) * step, end);
CreateSerialFor(MakeValue(begin), MakeValue(end),
llvm::ConstantInt::getSigned(GetLLVMType(end), 1),
op->loop_var, op->body);
}
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 56facba47f..44077024c1 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -2037,7 +2037,7 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) {
} else {
TVM_FFI_ICHECK(op->kind == ForKind::kSerial);
}
- PrimExpr step = op->step.value_or(MakeConst(op->extent.ty(), 1));
+ PrimExpr step = op->step.value_or(IntImm(op->extent.ty(), 1));
PrimExpr end = is_zero(op->min) ? op->extent : analyzer_->Simplify(op->min +
op->extent);
llvm::Value* begin_value = MakeValue(op->min);
llvm::Value* end_value = MakeValue(end);
diff --git a/src/te/operation/create_primfunc.cc
b/src/te/operation/create_primfunc.cc
index 0a166a2406..5a400d5eae 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -517,7 +517,7 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp&
compute_op, CreateFuncInfo* in
TVM_FFI_ICHECK(scopes[i - 1].axes_remap.count(axis->var));
PrimExpr prev_binding = scopes[i - 1].axes_remap.at(axis->var);
Var block_var("v_" + axis->var->name_hint, index_type);
- Range dom = Range::FromMinExtent(prev_binding, MakeConst(index_type,
1));
+ Range dom = Range::FromMinExtent(prev_binding, IntImm(index_type, 1));
IterVar new_block_iter(dom, block_var, axis->iter_type,
axis->thread_tag, axis->span);
cur_scope.AddBlockIter(axis, new_block_iter, prev_binding);
}
diff --git a/src/tirx/ir/buffer.cc b/src/tirx/ir/buffer.cc
index 64dc7a716f..052793c270 100644
--- a/src/tirx/ir/buffer.cc
+++ b/src/tirx/ir/buffer.cc
@@ -493,7 +493,7 @@ Buffer Buffer::MakeStrideView() const {
const BufferNode* self = operator->();
TVM_FFI_ICHECK(self != nullptr);
auto n = ffi::make_object<BufferNode>(*self);
- PrimExpr acc = MakeConst(PrimType(n->DefaultIndexType()), 1);
+ PrimExpr acc = IntImm(PrimType(n->DefaultIndexType()), 1);
for (size_t i = n->shape.size(); i != 0; --i) {
temp.push_back(acc);
acc = acc * n->shape[i - 1];
@@ -553,7 +553,7 @@ PrimExpr Buffer::access_ptr(int access_mask, PrimType
ptr_type, int content_lane
PrimExpr e_dtype;
PrimExpr extent;
if (self->shape.size() == 0) {
- extent = MakeConst(PrimType(self->DefaultIndexType()), 1);
+ extent = IntImm(PrimType(self->DefaultIndexType()), 1);
} else if (self->strides.size() == self->shape.size()) {
int highest_dim = 0;
extent = self->strides[highest_dim] * self->shape[highest_dim] - offset;
diff --git a/src/tirx/ir/layout/tile_slice.cc b/src/tirx/ir/layout/tile_slice.cc
index b172f7fec0..7e9a957177 100644
--- a/src/tirx/ir/layout/tile_slice.cc
+++ b/src/tirx/ir/layout/tile_slice.cc
@@ -118,7 +118,7 @@ ffi::Optional<TileLayout> SlicePerGroup(TileLayout layout,
PrimExpr begin, PrimE
return TileLayout(new_shard, layout->replica, new_offset);
}
- PrimExpr two = MakeConst(rem.ty(), 2);
+ PrimExpr two = IntImm(rem.ty(), 2);
PrimExpr c = analyzer->Simplify(floordiv(rem, two));
bool even = analyzer->CanProveEqual(floormod(rem, two), 0);
bool mid = analyzer->CanProveEqual(analyzer->Simplify(d0[pivot] + c), Ek);
@@ -131,7 +131,7 @@ ffi::Optional<TileLayout> SlicePerGroup(TileLayout layout,
PrimExpr begin, PrimE
PrimExpr delta =
analyzer->Simplify((pivot > 0 ? shard[pivot - 1]->stride :
PrimExpr(0)) - (Ek - c) * Sk);
std::vector<Iter> new_shard;
- new_shard.push_back(Iter(MakeConst(c.ty(), 2), delta, ak));
+ new_shard.push_back(Iter(IntImm(c.ty(), 2), delta, ak));
new_shard.push_back(Iter(c, Sk, ak));
new_shard.insert(new_shard.end(), peeled_rev.rbegin(),
peeled_rev.rend());
return TileLayout(new_shard, layout->replica, new_offset);
diff --git a/src/tirx/ir/layout/utils.cc b/src/tirx/ir/layout/utils.cc
index 05828a6600..abd114ddf9 100644
--- a/src/tirx/ir/layout/utils.cc
+++ b/src/tirx/ir/layout/utils.cc
@@ -73,7 +73,7 @@ std::vector<PrimExpr> GetDefaultStrides(const
ffi::Array<PrimExpr>& data, PrimEx
// get int32 strides and structurally differ from parser output.
PrimExpr current_stride = initial_stride;
if (const auto* imm = current_stride.as<IntImmNode>()) {
- current_stride = MakeConst(data[0].ty(), imm->value);
+ current_stride = IntImm(data[0].ty(), imm->value);
}
for (int i = static_cast<int>(n) - 1; i >= 0; --i) {
strides[i] = current_stride;
diff --git a/src/tirx/ir/stmt.cc b/src/tirx/ir/stmt.cc
index 3a2696000c..306eea63ef 100644
--- a/src/tirx/ir/stmt.cc
+++ b/src/tirx/ir/stmt.cc
@@ -153,7 +153,7 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent,
ForKind kind, Stmt body,
// When extent, min or step is an IntImm but has narrower dtype than loop_var
// we directly promote them without raising errors.
- auto try_promote_imm_dtype = [&](const PrimExpr& e) {
+ auto try_promote_imm_dtype = [&](const PrimExpr& e) -> PrimExpr {
PrimType e_ty = e.ty();
PrimType loop_var_ty = loop_var.ty();
TVM_FFI_ICHECK(e_ty.bits() <= loop_var_ty.bits())
@@ -161,7 +161,7 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent,
ForKind kind, Stmt body,
<< ") is narrower than that of `min` or `extent` (" << e_ty << ")";
const IntImmNode* a = e.as<IntImmNode>();
if (a && e_ty.bits() < loop_var_ty.bits()) {
- return MakeConst(loop_var_ty, a->value);
+ return IntImm(loop_var_ty, a->value);
} else {
return e;
}
@@ -490,7 +490,7 @@ PrimExpr BufferRegionNode::ToPrimExpr() const {
if (tvm::tirx::is_one(r->extent)) {
indices.push_back(r->min);
} else if (r->extent.as<IntImmNode>()) {
- indices.push_back(tirx::Ramp(r->min, tvm::tirx::MakeConst(r->min.ty(),
1), r->extent));
+ indices.push_back(tirx::Ramp(r->min, IntImm(r->min.ty(), 1), r->extent));
} else {
TVM_FFI_THROW(ValueError) << "Cannot convert to BufferLoad: "
<< ffi::GetRef<BufferRegion>(this);
diff --git a/src/tirx/transform/ir_utils.h b/src/tirx/transform/ir_utils.h
index 3ab4709de2..1cc1faefb7 100644
--- a/src/tirx/transform/ir_utils.h
+++ b/src/tirx/transform/ir_utils.h
@@ -126,8 +126,8 @@ inline PrimExpr AddressOffset(Var handle, PrimType dtype,
int offset) {
inline PrimExpr AddressOffset(Var handle, PrimType dtype, PrimExpr offset) {
if (dtype.lanes() != 1) {
PrimType offset_ty = offset.ty();
- offset = offset * MakeConst(offset_ty, dtype.lanes());
- offset = Ramp(offset, MakeConst(offset_ty, 1), dtype.lanes());
+ offset = offset * IntImm(offset_ty, dtype.lanes());
+ offset = Ramp(offset, IntImm(offset_ty, 1), dtype.lanes());
}
ffi::Array<PrimExpr> shape = {offset + 1};
diff --git a/src/tirx/transform/lower_intrin.cc
b/src/tirx/transform/lower_intrin.cc
index 2839ef6f1f..3757588a6e 100644
--- a/src/tirx/transform/lower_intrin.cc
+++ b/src/tirx/transform/lower_intrin.cc
@@ -122,7 +122,7 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) {
// lower to right shift if possible.
- return op->a >> MakeConst(dtype, shift);
+ return op->a >> IntImm(dtype, shift);
}
if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
@@ -135,8 +135,7 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
if (auto opt_c_value = TryFindShiftCoefficientForPositiveRange(op->a,
b_value)) {
int64_t c_value = *opt_c_value;
// now we can safely lower to truncdiv
- return truncdiv(op->a + MakeConst(dtype, b_value * c_value), op->b) -
- MakeConst(dtype, c_value);
+ return truncdiv(op->a + IntImm(dtype, b_value * c_value), op->b) -
IntImm(dtype, c_value);
}
}
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident";
@@ -147,7 +146,7 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
// So we need to correct these cases.
if ((dtype == PrimType::Int(32) || dtype == PrimType::Int(64)) &&
support_bitwise_op_) {
// equivalent to rdiv + (rmod >= 0 ? 0: -1);
- return rdiv + (rmod >> MakeConst(dtype, dtype.bits() - 1));
+ return rdiv + (rmod >> IntImm(dtype, dtype.bits() - 1));
} else {
return tirx::Select(rmod >= 0, rdiv, rdiv - MakeConst(dtype, 1));
}
@@ -184,7 +183,7 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) {
// lower to masking if possible.
int64_t mask = (static_cast<int64_t>(1) << static_cast<int64_t>(shift))
- 1;
- return op->a & MakeConst(dtype, mask);
+ return op->a & IntImm(dtype, mask);
}
if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
@@ -197,7 +196,7 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
if (auto opt_c_value = TryFindShiftCoefficientForPositiveRange(op->a,
b_value)) {
int64_t c_value = *opt_c_value;
// floormod(a, b) == floormod(a + b*c, b) == truncmod(a + b*c, b)
- return truncmod(op->a + MakeConst(dtype, c_value * b_value), op->b);
+ return truncmod(op->a + IntImm(dtype, c_value * b_value), op->b);
}
}
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident";
@@ -209,7 +208,7 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
// (rmod >> shift) & b
// -> (rmod >= 0 ? 0: -1) & b
// -> rmod >= 0 ? 0 : b
- return rmod + (op->b & (rmod >> MakeConst(dtype, dtype.bits() - 1)));
+ return rmod + (op->b & (rmod >> IntImm(dtype, dtype.bits() - 1)));
} else {
return tirx::Select(rmod >= 0, rmod, rmod + op->b);
}
diff --git a/src/tirx/transform/lower_tvm_builtin.cc
b/src/tirx/transform/lower_tvm_builtin.cc
index 5e7b1c9d95..fb871de56b 100644
--- a/src/tirx/transform/lower_tvm_builtin.cc
+++ b/src/tirx/transform/lower_tvm_builtin.cc
@@ -527,7 +527,7 @@ class BuiltinLower : public StmtExprMutator {
PrimExpr elem_offset = op->args[5].as_or_throw<PrimExpr>();
PrimExpr byte_offset;
if (!is_zero(elem_offset)) {
- byte_offset = elem_offset * MakeConst(elem_offset.ty(), data_bytes);
+ byte_offset = elem_offset * IntImm(elem_offset.ty(), data_bytes);
} else {
byte_offset = elem_offset;
}
diff --git a/src/tirx/transform/lower_warp_memory.cc
b/src/tirx/transform/lower_warp_memory.cc
index b22570d00c..6c22ee2afd 100644
--- a/src/tirx/transform/lower_warp_memory.cc
+++ b/src/tirx/transform/lower_warp_memory.cc
@@ -410,10 +410,10 @@ class WarpAccessRewriter : protected StmtExprMutator {
TVM_FFI_ICHECK(arith::ramp(base, 1, index_ty.lanes()).Match(index));
auto [local_index, group] = SplitIndexByGroup(base.Eval());
- local_index = Ramp(local_index, MakeConst(local_index.ty(), 1),
index_ty.lanes());
+ local_index = Ramp(local_index, IntImm(local_index.ty(), 1),
index_ty.lanes());
return std::make_pair(local_index, group);
}
- PrimExpr m = MakeConst(index_ty, warp_coeff_);
+ PrimExpr m = IntImm(index_ty, warp_coeff_);
// simple case, warp index is on the highest.
if (warp_group_ == 1) {
@@ -424,7 +424,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m));
PrimExpr y = index / MakeConst(index_ty, warp_coeff_ * width_);
y = y * m + x;
- PrimExpr z = indexdiv(indexmod(index, MakeConst(index_ty, warp_coeff_ *
width_)), m);
+ PrimExpr z = indexdiv(indexmod(index, IntImm(index_ty, warp_coeff_ *
width_)), m);
return std::make_pair(analyzer_->canonical_simplify(y),
analyzer_->canonical_simplify(z));
}
}
diff --git a/src/tirx/transform/tvm_ffi_binder.cc
b/src/tirx/transform/tvm_ffi_binder.cc
index a05bd4f8d0..28a9c7216c 100644
--- a/src/tirx/transform/tvm_ffi_binder.cc
+++ b/src/tirx/transform/tvm_ffi_binder.cc
@@ -581,7 +581,7 @@ void TVMFFIABIBuilder::BindCompactStrides(const Buffer&
buffer, const Var& strid
const PrimExpr& v_strides_is_null,
const ffi::reflection::AccessPath&
param_path) {
PrimType stype(buffer->DefaultIndexType());
- PrimExpr expect_stride = MakeConst(stype, 1);
+ PrimExpr expect_stride = IntImm(stype, 1);
ffi::Array<PrimExpr> conds;
for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1;
@@ -608,7 +608,7 @@ void TVMFFIABIBuilder::BindAutoBroadcastStrides(const
Buffer& buffer, const Var&
const PrimExpr&
v_strides_is_null,
const
ffi::reflection::AccessPath& param_path) {
PrimType stype(buffer->DefaultIndexType());
- PrimExpr stride = MakeConst(stype, 1);
+ PrimExpr stride = IntImm(stype, 1);
for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1;
PrimExpr value = cast(buffer->shape[k].ty(),
LoadInt64ArrayElem(strides_ptr, k));
@@ -657,7 +657,7 @@ void TVMFFIABIBuilder::DecodeParamDLTensor(const Buffer&
buffer, const PrimExpr&
// ── Section: ndim ────────────────────────────────────────────
PrimExpr v_ndim = TVMStructGet(tvm_ndim_type, handle, 0,
builtin::kDLTensorNDim);
- PrimExpr a_ndim = MakeConst(tvm_ndim_type,
static_cast<int64_t>(buffer->shape.size()));
+ PrimExpr a_ndim = IntImm(tvm_ndim_type,
static_cast<int64_t>(buffer->shape.size()));
EmitAssert(a_ndim == v_ndim, "ValueError", //
"Mismatched ", buf_name, ".ndim on argument #",
std::to_string(param_index),
when_calling_imm_, sig_imm_, "`,\n expected ",
std::to_string(buffer->shape.size()));
diff --git a/src/tirx/transform/unroll_loop.cc
b/src/tirx/transform/unroll_loop.cc
index 5800764699..b63ebb128a 100644
--- a/src/tirx/transform/unroll_loop.cc
+++ b/src/tirx/transform/unroll_loop.cc
@@ -225,7 +225,7 @@ class LoopUnroller : public StmtExprMutator {
ffi::Map<Var, PrimExpr> vmap;
ffi::Array<Stmt> unrolled;
for (int i = 0; i < value; ++i) {
- vmap.Set(op->loop_var, op->min + MakeConst(op->loop_var.ty(), i));
+ vmap.Set(op->loop_var, op->min + IntImm(op->loop_var.ty(), i));
Stmt step = Substitute(body, vmap);
unrolled.push_back(step);
}