This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 72f40c673f [Arith] Let IRMutatorWithAnalyzer take a const Analyzer&
(#19829)
72f40c673f is described below
commit 72f40c673f7c7e87fc8229a398239e55ba652918
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Jun 18 15:21:08 2026 -0400
[Arith] Let IRMutatorWithAnalyzer take a const Analyzer& (#19829)
This pr is the follow-up pr to #19675. IRMutatorWithAnalyzer and its
subclasses took AnalyzerObj*, so callers had to pass analyzer.get().
This adds a const Analyzer& constructor and migrates the family (and
their Apply() / free-function entry points) to take the handle directly,
dropping .get() at the call sites.
The original AnalyzerObj* overload is kept because
RewriteSimplifier::Impl is constructed from the AnalyzerObj's own this
during the analyzer's construction, where no Analyzer handle exists yet.
And the same dual-overload pattern already used by
arith::ConstraintContext.
---
src/arith/ir_mutator_with_analyzer.h | 1 +
src/backend/trn/transform/lower_trainium_layout.cc | 4 ++--
src/s_tir/backend/adreno/inject_texture_alloc.cc | 4 ++--
src/s_tir/schedule/primitive/blockize_tensorize.cc | 2 +-
src/s_tir/schedule/primitive/layout_transformation.cc | 6 +++---
src/s_tir/schedule/transform.h | 4 ++--
src/s_tir/transform/hoist_expression.cc | 4 ++--
src/s_tir/transform/inject_permuted_layout.cc | 4 ++--
src/s_tir/transform/inject_virtual_thread.cc | 2 +-
src/s_tir/transform/lower_async_dma.cc | 4 ++--
src/s_tir/transform/renormalize_split_pattern.cc | 4 ++--
src/s_tir/transform/using_assume_to_reduce_branches.cc | 4 ++--
src/tirx/transform/flatten_buffer.cc | 4 ++--
src/tirx/transform/lower_intrin.cc | 7 +++----
src/tirx/transform/lower_tirx_cleanup.cc | 4 ++--
src/tirx/transform/remove_no_op.cc | 8 ++++----
src/tirx/transform/remove_no_op.h | 2 +-
src/tirx/transform/stmt_simplify.cc | 8 ++++----
src/tirx/transform/stmt_simplify.h | 2 +-
19 files changed, 39 insertions(+), 39 deletions(-)
diff --git a/src/arith/ir_mutator_with_analyzer.h
b/src/arith/ir_mutator_with_analyzer.h
index 5e2fa6ab00..6fee93a16c 100644
--- a/src/arith/ir_mutator_with_analyzer.h
+++ b/src/arith/ir_mutator_with_analyzer.h
@@ -47,6 +47,7 @@ namespace arith {
*/
class IRMutatorWithAnalyzer : public tirx::StmtExprMutator {
public:
+ explicit IRMutatorWithAnalyzer(const Analyzer& analyzer) :
analyzer_(analyzer.get()) {}
explicit IRMutatorWithAnalyzer(AnalyzerObj* analyzer) : analyzer_(analyzer)
{}
using StmtExprMutator::VisitExpr_;
diff --git a/src/backend/trn/transform/lower_trainium_layout.cc
b/src/backend/trn/transform/lower_trainium_layout.cc
index b0fba77eba..3409a372bd 100644
--- a/src/backend/trn/transform/lower_trainium_layout.cc
+++ b/src/backend/trn/transform/lower_trainium_layout.cc
@@ -58,7 +58,7 @@ class TrainiumLayoutApplier : public
arith::IRMutatorWithAnalyzer {
static std::pair<Stmt, ffi::Map<Var, Buffer>> Lower(
const Stmt& stmt, const ffi::Map<tirx::Var, Buffer> buffer_map) {
arith::Analyzer ana;
- TrainiumLayoutApplier storage_lower(ana.get());
+ TrainiumLayoutApplier storage_lower(ana);
std::unordered_map<Var, Buffer> new_buffer_map;
std::vector<Buffer> param_flattened_buffers;
for (const auto& kv : buffer_map) {
@@ -83,7 +83,7 @@ class TrainiumLayoutApplier : public
arith::IRMutatorWithAnalyzer {
using IRMutatorWithAnalyzer::VisitExpr_;
using IRMutatorWithAnalyzer::VisitStmt_;
- explicit TrainiumLayoutApplier(arith::AnalyzerObj* analyzer)
+ explicit TrainiumLayoutApplier(const arith::Analyzer& analyzer)
: arith::IRMutatorWithAnalyzer(analyzer) {}
ffi::Any VisitAny(const ffi::Any& any) {
diff --git a/src/s_tir/backend/adreno/inject_texture_alloc.cc
b/src/s_tir/backend/adreno/inject_texture_alloc.cc
index 709e7c3336..e4e7c322ef 100644
--- a/src/s_tir/backend/adreno/inject_texture_alloc.cc
+++ b/src/s_tir/backend/adreno/inject_texture_alloc.cc
@@ -46,7 +46,7 @@ class TextureAllocInjector : public
arith::IRMutatorWithAnalyzer {
public:
static PrimFunc Inject(PrimFunc func) {
arith::Analyzer ana;
- auto pass = TextureAllocInjector(ana.get());
+ auto pass = TextureAllocInjector(ana);
auto writer = func.CopyOnWrite();
pass.MarkBufferMapShapes(func);
writer->body = pass.VisitStmt(func->body);
@@ -59,7 +59,7 @@ class TextureAllocInjector : public
arith::IRMutatorWithAnalyzer {
using IRMutatorWithAnalyzer::VisitStmt;
using IRMutatorWithAnalyzer::VisitStmt_;
- explicit TextureAllocInjector(arith::AnalyzerObj* ana) :
IRMutatorWithAnalyzer(ana) {}
+ explicit TextureAllocInjector(const arith::Analyzer& ana) :
IRMutatorWithAnalyzer(ana) {}
Stmt VisitStmt_(const AllocBufferNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
diff --git a/src/s_tir/schedule/primitive/blockize_tensorize.cc
b/src/s_tir/schedule/primitive/blockize_tensorize.cc
index 9c6df3c8b7..87150b0431 100644
--- a/src/s_tir/schedule/primitive/blockize_tensorize.cc
+++ b/src/s_tir/schedule/primitive/blockize_tensorize.cc
@@ -769,7 +769,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref,
const TensorIntrin& int
}
arith::Analyzer analyzer;
- PrimFunc intrin_desc = StmtSimplify(intrin->desc, analyzer.get());
+ PrimFunc intrin_desc = StmtSimplify(intrin->desc, analyzer);
PrimFunc intrin_impl = DeepCopy(intrin->impl);
int index_dtype_bits = -1;
diff --git a/src/s_tir/schedule/primitive/layout_transformation.cc
b/src/s_tir/schedule/primitive/layout_transformation.cc
index 1964ead616..0ac3802135 100644
--- a/src/s_tir/schedule/primitive/layout_transformation.cc
+++ b/src/s_tir/schedule/primitive/layout_transformation.cc
@@ -765,7 +765,7 @@ class TransformLayoutRewriter : private
arith::IRMutatorWithAnalyzer {
pad_value, analyzer.get())
: TransformLayoutPlanner::NoPaddingRequired();
- TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan,
analyzer.get());
+ TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan,
analyzer);
SBlock result = Downcast<SBlock>(rewriter(scope_stmt));
if (auto plan_ptr =
std::get_if<TransformLayoutPlanner::ProloguePlan>(&plan)) {
auto write_ptr = result.CopyOnWrite();
@@ -782,7 +782,7 @@ class TransformLayoutRewriter : private
arith::IRMutatorWithAnalyzer {
TransformLayoutRewriter(const Buffer& old_buffer, const Buffer& new_buffer,
const IndexMap& index_map,
const TransformLayoutPlanner::TransformPlan& plan,
- arith::AnalyzerObj* analyzer)
+ const arith::Analyzer& analyzer)
: IRMutatorWithAnalyzer(analyzer),
old_buffer_(old_buffer),
new_buffer_(new_buffer),
@@ -1456,7 +1456,7 @@ void TransformBlockLayout(ScheduleState self, const
StmtSRef& block_sref,
SBlock new_block =
Downcast<SBlock>(Substitute(ffi::GetRef<SBlock>(block_ptr),
inverse_subst_map));
new_block.CopyOnWrite()->iter_vars = new_block_iters;
- new_block =
Downcast<SBlock>(BlockBufferAccessSimplifier::Simplify(new_block,
analyzer.get()));
+ new_block =
Downcast<SBlock>(BlockBufferAccessSimplifier::Simplify(new_block, analyzer));
// Step 5.3: Create outer loops for each new block iter.
diff --git a/src/s_tir/schedule/transform.h b/src/s_tir/schedule/transform.h
index 6221cb35de..da6d54a966 100644
--- a/src/s_tir/schedule/transform.h
+++ b/src/s_tir/schedule/transform.h
@@ -236,13 +236,13 @@ class BlockBufferAccessSimplifier : public
arith::IRMutatorWithAnalyzer {
* \param analyzer The arithmetic analyzer
* \return The simplified statement
*/
- static Stmt Simplify(const Stmt& stmt, arith::AnalyzerObj* analyzer) {
+ static Stmt Simplify(const Stmt& stmt, const arith::Analyzer& analyzer) {
BlockBufferAccessSimplifier simplifier(analyzer);
return simplifier(stmt);
}
private:
- explicit BlockBufferAccessSimplifier(arith::AnalyzerObj* analyzer)
+ explicit BlockBufferAccessSimplifier(const arith::Analyzer& analyzer)
: IRMutatorWithAnalyzer(analyzer) {}
using IRMutatorWithAnalyzer::VisitExpr_;
diff --git a/src/s_tir/transform/hoist_expression.cc
b/src/s_tir/transform/hoist_expression.cc
index ac48593bd2..24cdfd6f5d 100644
--- a/src/s_tir/transform/hoist_expression.cc
+++ b/src/s_tir/transform/hoist_expression.cc
@@ -450,7 +450,7 @@ class ExpressionHoister : public
arith::IRMutatorWithAnalyzer {
auto loop_info = HoistInfoCollector::Collect(stmt, config);
arith::Analyzer analyzer;
- ExpressionHoister hoister(std::move(loop_info), config, analyzer.get());
+ ExpressionHoister hoister(std::move(loop_info), config, analyzer);
stmt = hoister(std::move(stmt));
stmt = ConvertSSA(std::move(stmt));
return stmt;
@@ -462,7 +462,7 @@ class ExpressionHoister : public
arith::IRMutatorWithAnalyzer {
using Parent::VisitStmt_;
explicit ExpressionHoister(std::vector<HoistInfoCollector::HoistInfo>
loop_info,
- HoistExpressionConfig config, arith::AnalyzerObj*
analyzer)
+ HoistExpressionConfig config, const
arith::Analyzer& analyzer)
: Parent(analyzer), config_(config) {
for (auto& info : loop_info) {
// Mark let bindings to use if they are enabled on their own.
diff --git a/src/s_tir/transform/inject_permuted_layout.cc
b/src/s_tir/transform/inject_permuted_layout.cc
index 8ef5051ae0..6cbba28110 100644
--- a/src/s_tir/transform/inject_permuted_layout.cc
+++ b/src/s_tir/transform/inject_permuted_layout.cc
@@ -46,14 +46,14 @@ class PermutedLayoutInjector : private
IRMutatorWithAnalyzer {
static PrimFunc Transform(PrimFunc func) {
Analyzer analyzer;
- auto new_body = PermutedLayoutInjector(func, analyzer.get())(func->body);
+ auto new_body = PermutedLayoutInjector(func, analyzer)(func->body);
auto func_node = func.CopyOnWrite();
func_node->body = new_body;
return func;
}
private:
- explicit PermutedLayoutInjector(PrimFunc func, AnalyzerObj* analyzer)
+ explicit PermutedLayoutInjector(PrimFunc func, const Analyzer& analyzer)
: IRMutatorWithAnalyzer(analyzer) {
buffer_map_.insert(func->buffer_map.begin(), func->buffer_map.end());
}
diff --git a/src/s_tir/transform/inject_virtual_thread.cc
b/src/s_tir/transform/inject_virtual_thread.cc
index 6e50970623..7853bcc049 100644
--- a/src/s_tir/transform/inject_virtual_thread.cc
+++ b/src/s_tir/transform/inject_virtual_thread.cc
@@ -541,7 +541,7 @@ Pass InjectVirtualThread() {
arith::Analyzer analyzer;
- n->body = VirtualThreadInjector(analyzer.get())(std::move(n->body));
+ n->body = VirtualThreadInjector(analyzer)(std::move(n->body));
n->body = ConvertSSA(std::move(n->body));
return f;
};
diff --git a/src/s_tir/transform/lower_async_dma.cc
b/src/s_tir/transform/lower_async_dma.cc
index 1178c1aa48..89660d4fef 100644
--- a/src/s_tir/transform/lower_async_dma.cc
+++ b/src/s_tir/transform/lower_async_dma.cc
@@ -46,7 +46,7 @@ using namespace tvm::tirx;
class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer {
public:
- explicit AsyncDMALowerer(bool dma_bypass_cache, arith::AnalyzerObj* analyzer)
+ explicit AsyncDMALowerer(bool dma_bypass_cache, const arith::Analyzer&
analyzer)
: IRMutatorWithAnalyzer(analyzer), dma_bypass_cache_(dma_bypass_cache) {}
// TODO(leiwang1999): split lower async DMA support for CUDA and Hexagon
Backend
@@ -176,7 +176,7 @@ Pass LowerAsyncDMA() {
arith::Analyzer analyzer;
bool dma_bypass_cache =
ctx->GetConfig<bool>("tirx.experimental_dma_bypass_cache",
false).value();
- fptr->body = AsyncDMALowerer(dma_bypass_cache,
analyzer.get())(std::move(fptr->body));
+ fptr->body = AsyncDMALowerer(dma_bypass_cache,
analyzer)(std::move(fptr->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "s_tir.LowerAsyncDMA", {});
diff --git a/src/s_tir/transform/renormalize_split_pattern.cc
b/src/s_tir/transform/renormalize_split_pattern.cc
index f185e66a07..2fbadfabd4 100644
--- a/src/s_tir/transform/renormalize_split_pattern.cc
+++ b/src/s_tir/transform/renormalize_split_pattern.cc
@@ -52,7 +52,7 @@ using namespace arith;
class SplitPatternReNormalizer : public IRMutatorWithAnalyzer {
public:
- explicit SplitPatternReNormalizer(AnalyzerObj* analyzer) :
IRMutatorWithAnalyzer(analyzer) {}
+ explicit SplitPatternReNormalizer(const Analyzer& analyzer) :
IRMutatorWithAnalyzer(analyzer) {}
using IRMutatorWithAnalyzer::VisitExpr_;
@@ -201,7 +201,7 @@ Pass RenormalizeSplitPattern() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
arith::Analyzer analyzer;
- n->body = SplitPatternReNormalizer(analyzer.get())(std::move(n->body));
+ n->body = SplitPatternReNormalizer(analyzer)(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "s_tir.RenormalizeSplitPattern", {});
diff --git a/src/s_tir/transform/using_assume_to_reduce_branches.cc
b/src/s_tir/transform/using_assume_to_reduce_branches.cc
index 0ad2ad69f5..74e2c6c7d9 100644
--- a/src/s_tir/transform/using_assume_to_reduce_branches.cc
+++ b/src/s_tir/transform/using_assume_to_reduce_branches.cc
@@ -115,7 +115,7 @@ class ParseAssumeAndOvercompute : public
IRMutatorWithAnalyzer {
public:
using Parent = IRMutatorWithAnalyzer;
- explicit ParseAssumeAndOvercompute(AnalyzerObj* analyzer) : Parent(analyzer)
{}
+ explicit ParseAssumeAndOvercompute(const Analyzer& analyzer) :
Parent(analyzer) {}
private:
using Parent::VisitExpr_;
@@ -380,7 +380,7 @@ Pass UseAssumeToReduceBranches() {
if (assume_checker.has_assume) {
// Leverage from assume and eliminate the branch
- ParseAssumeAndOvercompute func_analyzer_mutator(analyzer.get());
+ ParseAssumeAndOvercompute func_analyzer_mutator(analyzer);
n->body = func_analyzer_mutator(std::move(n->body));
}
}
diff --git a/src/tirx/transform/flatten_buffer.cc
b/src/tirx/transform/flatten_buffer.cc
index 7298c2df20..267b09484e 100644
--- a/src/tirx/transform/flatten_buffer.cc
+++ b/src/tirx/transform/flatten_buffer.cc
@@ -45,7 +45,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
public:
static PrimFunc Flatten(PrimFunc func) {
arith::Analyzer ana;
- auto pass = BufferFlattener(ana.get());
+ auto pass = BufferFlattener(ana);
pass.MarkBufferMapShapes(func);
auto body = pass.VisitStmt(func->body);
@@ -78,7 +78,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
using IRMutatorWithAnalyzer::VisitStmt;
using IRMutatorWithAnalyzer::VisitStmt_;
- explicit BufferFlattener(arith::AnalyzerObj* ana) :
IRMutatorWithAnalyzer(ana) {}
+ explicit BufferFlattener(const arith::Analyzer& ana) :
IRMutatorWithAnalyzer(ana) {}
Stmt VisitStmt_(const SBlockNode* op) final {
TVM_FFI_ICHECK_EQ(op->match_buffers.size(), 0)
diff --git a/src/tirx/transform/lower_intrin.cc
b/src/tirx/transform/lower_intrin.cc
index c85d762ab1..94ee4e1044 100644
--- a/src/tirx/transform/lower_intrin.cc
+++ b/src/tirx/transform/lower_intrin.cc
@@ -47,7 +47,7 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
using IRMutatorWithAnalyzer::VisitStmt_;
using FLowerGeneral = ffi::TypedFunction<PrimExpr(PrimExpr)>;
- IntrinInjecter(arith::AnalyzerObj* analyzer, const Target& tgt, bool
enable_fast_math)
+ IntrinInjecter(const arith::Analyzer& analyzer, const Target& tgt, bool
enable_fast_math)
: IRMutatorWithAnalyzer(analyzer) {
std::string target = tgt->kind->name;
ffi::String mtriple = tgt->GetAttr<ffi::String>("mtriple").value_or("");
@@ -368,8 +368,7 @@ Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
arith::Analyzer analyzer;
bool enable_fast_math =
transform::PassContext::Current()->GetConfig<bool>("tirx.enable_fast_math",
false).value();
- return IntrinInjecter(analyzer.get(), Target(ffi::String(target)),
- enable_fast_math)(std::move(stmt));
+ return IntrinInjecter(analyzer, Target(ffi::String(target)),
enable_fast_math)(std::move(stmt));
}
namespace transform {
@@ -381,7 +380,7 @@ Pass LowerIntrin() {
TVM_FFI_ICHECK(target.defined()) << "LowerIntrin: Require the target
attribute";
arith::Analyzer analyzer;
bool enable_fast_math = ctx->GetConfig<bool>("tirx.enable_fast_math",
false).value();
- n->body = IntrinInjecter(analyzer.get(), target.value(),
enable_fast_math)(std::move(n->body));
+ n->body = IntrinInjecter(analyzer, target.value(),
enable_fast_math)(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tirx.LowerIntrin", {});
diff --git a/src/tirx/transform/lower_tirx_cleanup.cc
b/src/tirx/transform/lower_tirx_cleanup.cc
index b20c52248a..433a0365a2 100644
--- a/src/tirx/transform/lower_tirx_cleanup.cc
+++ b/src/tirx/transform/lower_tirx_cleanup.cc
@@ -47,7 +47,7 @@ class LayoutApplier : public arith::IRMutatorWithAnalyzer {
static std::pair<Stmt, ffi::Map<Var, Buffer>> Flatten(
const Stmt& stmt, const ffi::Map<tirx::Var, Buffer> buffer_map, const
Target& target) {
arith::Analyzer ana;
- LayoutApplier storage_lower(ana.get(), target);
+ LayoutApplier storage_lower(ana, target);
std::unordered_map<Var, Buffer> new_buffer_map;
std::vector<Buffer> param_flattened_buffers;
for (const auto& kv : buffer_map) {
@@ -72,7 +72,7 @@ class LayoutApplier : public arith::IRMutatorWithAnalyzer {
using IRMutatorWithAnalyzer::VisitExpr_;
using IRMutatorWithAnalyzer::VisitStmt_;
- explicit LayoutApplier(arith::AnalyzerObj* analyzer, const Target& target)
+ explicit LayoutApplier(const arith::Analyzer& analyzer, const Target& target)
: arith::IRMutatorWithAnalyzer(analyzer), target_(target) {}
ffi::Any VisitAny(const ffi::Any& any) {
diff --git a/src/tirx/transform/remove_no_op.cc
b/src/tirx/transform/remove_no_op.cc
index 6394eb2198..833bf1f45f 100644
--- a/src/tirx/transform/remove_no_op.cc
+++ b/src/tirx/transform/remove_no_op.cc
@@ -75,7 +75,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tirx.RemoveNoOp",
RemoveNoOpConfig);
// Mark the statement of each stage.
class NoOpRemover : public arith::IRMutatorWithAnalyzer {
public:
- static Stmt Apply(Stmt stmt, arith::AnalyzerObj* analyzer, bool
ignore_profiler_call = false) {
+ static Stmt Apply(Stmt stmt, const arith::Analyzer& analyzer, bool
ignore_profiler_call = false) {
NoOpRemover visitor(analyzer, ignore_profiler_call);
return visitor(std::move(stmt));
}
@@ -85,7 +85,7 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer {
using Parent::VisitStmt;
using Parent::VisitStmt_;
- NoOpRemover(arith::AnalyzerObj* analyzer, bool ignore_profiler_call = false)
+ NoOpRemover(const arith::Analyzer& analyzer, bool ignore_profiler_call =
false)
: Parent(analyzer), ignore_profiler_call_(ignore_profiler_call) {}
Stmt VisitStmt_(const BindNode* op) final {
@@ -266,7 +266,7 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer {
bool ignore_profiler_call_{false};
};
-Stmt RemoveNoOp(Stmt stmt, arith::AnalyzerObj* analyzer, bool
ignore_profiler_call) {
+Stmt RemoveNoOp(Stmt stmt, const arith::Analyzer& analyzer, bool
ignore_profiler_call) {
return NoOpRemover::Apply(std::move(stmt), analyzer, ignore_profiler_call);
}
@@ -286,7 +286,7 @@ Pass RemoveNoOp() {
{
auto* write_ptr = f.CopyOnWrite();
write_ptr->body =
- NoOpRemover::Apply(std::move(write_ptr->body), analyzer.get(),
ignore_profiler_call);
+ NoOpRemover::Apply(std::move(write_ptr->body), analyzer,
ignore_profiler_call);
}
return f;
};
diff --git a/src/tirx/transform/remove_no_op.h
b/src/tirx/transform/remove_no_op.h
index cd9710b617..03a9823aa6 100644
--- a/src/tirx/transform/remove_no_op.h
+++ b/src/tirx/transform/remove_no_op.h
@@ -41,7 +41,7 @@ namespace tirx {
*
* \return The modified statement with no-ops removed
*/
-Stmt RemoveNoOp(Stmt stmt, arith::AnalyzerObj* analyzer, bool
ignore_profiler_call = false);
+Stmt RemoveNoOp(Stmt stmt, const arith::Analyzer& analyzer, bool
ignore_profiler_call = false);
} // namespace tirx
} // namespace tvm
diff --git a/src/tirx/transform/stmt_simplify.cc
b/src/tirx/transform/stmt_simplify.cc
index d7dd4599f4..01038e4165 100644
--- a/src/tirx/transform/stmt_simplify.cc
+++ b/src/tirx/transform/stmt_simplify.cc
@@ -98,7 +98,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tirx.StmtSimplify",
StmtSimplifyConfig);
class StmtSimplifier : public IRMutatorWithAnalyzer {
public:
- static PrimFunc Apply(PrimFunc func, AnalyzerObj* analyzer,
+ static PrimFunc Apply(PrimFunc func, const Analyzer& analyzer,
ffi::Optional<StmtSimplifyConfig> config_opt =
std::nullopt) {
auto config = config_opt.value_or(MakeDefaultStmtSimplifyConfig());
analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions());
@@ -110,7 +110,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
}
private:
- explicit StmtSimplifier(AnalyzerObj* analyzer, StmtSimplifyConfig config)
+ explicit StmtSimplifier(const Analyzer& analyzer, StmtSimplifyConfig config)
: IRMutatorWithAnalyzer(analyzer), config_(config) {}
using Parent = IRMutatorWithAnalyzer;
@@ -250,7 +250,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
namespace tirx {
-PrimFunc StmtSimplify(PrimFunc func, arith::AnalyzerObj* analyzer) {
+PrimFunc StmtSimplify(PrimFunc func, const arith::Analyzer& analyzer) {
return arith::StmtSimplifier::Apply(std::move(func), analyzer);
}
@@ -261,7 +261,7 @@ Pass StmtSimplify() {
arith::Analyzer analyzer;
auto cfg = ctx->GetConfig<arith::StmtSimplifyConfig>("tirx.StmtSimplify");
- return arith::StmtSimplifier::Apply(f, analyzer.get(), cfg);
+ return arith::StmtSimplifier::Apply(f, analyzer, cfg);
};
return CreatePrimFuncPass(pass_func, 0, "tirx.StmtSimplify", {});
}
diff --git a/src/tirx/transform/stmt_simplify.h
b/src/tirx/transform/stmt_simplify.h
index 5f10397e83..224df0ed8b 100644
--- a/src/tirx/transform/stmt_simplify.h
+++ b/src/tirx/transform/stmt_simplify.h
@@ -34,7 +34,7 @@ namespace tirx {
*
* Applies the same behavior as the tirx.transform.StmtSimplify pass.
*/
-PrimFunc StmtSimplify(PrimFunc func, arith::AnalyzerObj* analyzer);
+PrimFunc StmtSimplify(PrimFunc func, const arith::Analyzer& analyzer);
} // namespace tirx
} // namespace tvm