This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 72f40c673f [Arith] Let IRMutatorWithAnalyzer take a const Analyzer& 
(#19829)
72f40c673f is described below

commit 72f40c673f7c7e87fc8229a398239e55ba652918
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Jun 18 15:21:08 2026 -0400

    [Arith] Let IRMutatorWithAnalyzer take a const Analyzer& (#19829)
    
    This pr is the follow-up pr to #19675. IRMutatorWithAnalyzer and its
    subclasses took AnalyzerObj*, so callers had to pass analyzer.get().
    This adds a const Analyzer& constructor and migrates the family (and
    their Apply() / free-function entry points) to take the handle directly,
    dropping .get() at the call sites.
    
    The original AnalyzerObj* overload is kept because
    RewriteSimplifier::Impl is constructed from the AnalyzerObj's own this
    during the analyzer's construction, where no Analyzer handle exists yet.
    And the same dual-overload pattern already used by
    arith::ConstraintContext.
---
 src/arith/ir_mutator_with_analyzer.h                   | 1 +
 src/backend/trn/transform/lower_trainium_layout.cc     | 4 ++--
 src/s_tir/backend/adreno/inject_texture_alloc.cc       | 4 ++--
 src/s_tir/schedule/primitive/blockize_tensorize.cc     | 2 +-
 src/s_tir/schedule/primitive/layout_transformation.cc  | 6 +++---
 src/s_tir/schedule/transform.h                         | 4 ++--
 src/s_tir/transform/hoist_expression.cc                | 4 ++--
 src/s_tir/transform/inject_permuted_layout.cc          | 4 ++--
 src/s_tir/transform/inject_virtual_thread.cc           | 2 +-
 src/s_tir/transform/lower_async_dma.cc                 | 4 ++--
 src/s_tir/transform/renormalize_split_pattern.cc       | 4 ++--
 src/s_tir/transform/using_assume_to_reduce_branches.cc | 4 ++--
 src/tirx/transform/flatten_buffer.cc                   | 4 ++--
 src/tirx/transform/lower_intrin.cc                     | 7 +++----
 src/tirx/transform/lower_tirx_cleanup.cc               | 4 ++--
 src/tirx/transform/remove_no_op.cc                     | 8 ++++----
 src/tirx/transform/remove_no_op.h                      | 2 +-
 src/tirx/transform/stmt_simplify.cc                    | 8 ++++----
 src/tirx/transform/stmt_simplify.h                     | 2 +-
 19 files changed, 39 insertions(+), 39 deletions(-)

diff --git a/src/arith/ir_mutator_with_analyzer.h 
b/src/arith/ir_mutator_with_analyzer.h
index 5e2fa6ab00..6fee93a16c 100644
--- a/src/arith/ir_mutator_with_analyzer.h
+++ b/src/arith/ir_mutator_with_analyzer.h
@@ -47,6 +47,7 @@ namespace arith {
  */
 class IRMutatorWithAnalyzer : public tirx::StmtExprMutator {
  public:
+  explicit IRMutatorWithAnalyzer(const Analyzer& analyzer) : 
analyzer_(analyzer.get()) {}
   explicit IRMutatorWithAnalyzer(AnalyzerObj* analyzer) : analyzer_(analyzer) 
{}
 
   using StmtExprMutator::VisitExpr_;
diff --git a/src/backend/trn/transform/lower_trainium_layout.cc 
b/src/backend/trn/transform/lower_trainium_layout.cc
index b0fba77eba..3409a372bd 100644
--- a/src/backend/trn/transform/lower_trainium_layout.cc
+++ b/src/backend/trn/transform/lower_trainium_layout.cc
@@ -58,7 +58,7 @@ class TrainiumLayoutApplier : public 
arith::IRMutatorWithAnalyzer {
   static std::pair<Stmt, ffi::Map<Var, Buffer>> Lower(
       const Stmt& stmt, const ffi::Map<tirx::Var, Buffer> buffer_map) {
     arith::Analyzer ana;
-    TrainiumLayoutApplier storage_lower(ana.get());
+    TrainiumLayoutApplier storage_lower(ana);
     std::unordered_map<Var, Buffer> new_buffer_map;
     std::vector<Buffer> param_flattened_buffers;
     for (const auto& kv : buffer_map) {
@@ -83,7 +83,7 @@ class TrainiumLayoutApplier : public 
arith::IRMutatorWithAnalyzer {
   using IRMutatorWithAnalyzer::VisitExpr_;
   using IRMutatorWithAnalyzer::VisitStmt_;
 
-  explicit TrainiumLayoutApplier(arith::AnalyzerObj* analyzer)
+  explicit TrainiumLayoutApplier(const arith::Analyzer& analyzer)
       : arith::IRMutatorWithAnalyzer(analyzer) {}
 
   ffi::Any VisitAny(const ffi::Any& any) {
diff --git a/src/s_tir/backend/adreno/inject_texture_alloc.cc 
b/src/s_tir/backend/adreno/inject_texture_alloc.cc
index 709e7c3336..e4e7c322ef 100644
--- a/src/s_tir/backend/adreno/inject_texture_alloc.cc
+++ b/src/s_tir/backend/adreno/inject_texture_alloc.cc
@@ -46,7 +46,7 @@ class TextureAllocInjector : public 
arith::IRMutatorWithAnalyzer {
  public:
   static PrimFunc Inject(PrimFunc func) {
     arith::Analyzer ana;
-    auto pass = TextureAllocInjector(ana.get());
+    auto pass = TextureAllocInjector(ana);
     auto writer = func.CopyOnWrite();
     pass.MarkBufferMapShapes(func);
     writer->body = pass.VisitStmt(func->body);
@@ -59,7 +59,7 @@ class TextureAllocInjector : public 
arith::IRMutatorWithAnalyzer {
   using IRMutatorWithAnalyzer::VisitStmt;
   using IRMutatorWithAnalyzer::VisitStmt_;
 
-  explicit TextureAllocInjector(arith::AnalyzerObj* ana) : 
IRMutatorWithAnalyzer(ana) {}
+  explicit TextureAllocInjector(const arith::Analyzer& ana) : 
IRMutatorWithAnalyzer(ana) {}
 
   Stmt VisitStmt_(const AllocBufferNode* op) final {
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
diff --git a/src/s_tir/schedule/primitive/blockize_tensorize.cc 
b/src/s_tir/schedule/primitive/blockize_tensorize.cc
index 9c6df3c8b7..87150b0431 100644
--- a/src/s_tir/schedule/primitive/blockize_tensorize.cc
+++ b/src/s_tir/schedule/primitive/blockize_tensorize.cc
@@ -769,7 +769,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, 
const TensorIntrin& int
   }
 
   arith::Analyzer analyzer;
-  PrimFunc intrin_desc = StmtSimplify(intrin->desc, analyzer.get());
+  PrimFunc intrin_desc = StmtSimplify(intrin->desc, analyzer);
   PrimFunc intrin_impl = DeepCopy(intrin->impl);
 
   int index_dtype_bits = -1;
diff --git a/src/s_tir/schedule/primitive/layout_transformation.cc 
b/src/s_tir/schedule/primitive/layout_transformation.cc
index 1964ead616..0ac3802135 100644
--- a/src/s_tir/schedule/primitive/layout_transformation.cc
+++ b/src/s_tir/schedule/primitive/layout_transformation.cc
@@ -765,7 +765,7 @@ class TransformLayoutRewriter : private 
arith::IRMutatorWithAnalyzer {
                                                    pad_value, analyzer.get())
                     : TransformLayoutPlanner::NoPaddingRequired();
 
-    TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan, 
analyzer.get());
+    TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan, 
analyzer);
     SBlock result = Downcast<SBlock>(rewriter(scope_stmt));
     if (auto plan_ptr = 
std::get_if<TransformLayoutPlanner::ProloguePlan>(&plan)) {
       auto write_ptr = result.CopyOnWrite();
@@ -782,7 +782,7 @@ class TransformLayoutRewriter : private 
arith::IRMutatorWithAnalyzer {
   TransformLayoutRewriter(const Buffer& old_buffer, const Buffer& new_buffer,
                           const IndexMap& index_map,
                           const TransformLayoutPlanner::TransformPlan& plan,
-                          arith::AnalyzerObj* analyzer)
+                          const arith::Analyzer& analyzer)
       : IRMutatorWithAnalyzer(analyzer),
         old_buffer_(old_buffer),
         new_buffer_(new_buffer),
@@ -1456,7 +1456,7 @@ void TransformBlockLayout(ScheduleState self, const 
StmtSRef& block_sref,
   SBlock new_block =
       Downcast<SBlock>(Substitute(ffi::GetRef<SBlock>(block_ptr), 
inverse_subst_map));
   new_block.CopyOnWrite()->iter_vars = new_block_iters;
-  new_block = 
Downcast<SBlock>(BlockBufferAccessSimplifier::Simplify(new_block, 
analyzer.get()));
+  new_block = 
Downcast<SBlock>(BlockBufferAccessSimplifier::Simplify(new_block, analyzer));
 
   // Step 5.3: Create outer loops for each new block iter.
 
diff --git a/src/s_tir/schedule/transform.h b/src/s_tir/schedule/transform.h
index 6221cb35de..da6d54a966 100644
--- a/src/s_tir/schedule/transform.h
+++ b/src/s_tir/schedule/transform.h
@@ -236,13 +236,13 @@ class BlockBufferAccessSimplifier : public 
arith::IRMutatorWithAnalyzer {
    * \param analyzer The arithmetic analyzer
    * \return The simplified statement
    */
-  static Stmt Simplify(const Stmt& stmt, arith::AnalyzerObj* analyzer) {
+  static Stmt Simplify(const Stmt& stmt, const arith::Analyzer& analyzer) {
     BlockBufferAccessSimplifier simplifier(analyzer);
     return simplifier(stmt);
   }
 
  private:
-  explicit BlockBufferAccessSimplifier(arith::AnalyzerObj* analyzer)
+  explicit BlockBufferAccessSimplifier(const arith::Analyzer& analyzer)
       : IRMutatorWithAnalyzer(analyzer) {}
 
   using IRMutatorWithAnalyzer::VisitExpr_;
diff --git a/src/s_tir/transform/hoist_expression.cc 
b/src/s_tir/transform/hoist_expression.cc
index ac48593bd2..24cdfd6f5d 100644
--- a/src/s_tir/transform/hoist_expression.cc
+++ b/src/s_tir/transform/hoist_expression.cc
@@ -450,7 +450,7 @@ class ExpressionHoister : public 
arith::IRMutatorWithAnalyzer {
     auto loop_info = HoistInfoCollector::Collect(stmt, config);
 
     arith::Analyzer analyzer;
-    ExpressionHoister hoister(std::move(loop_info), config, analyzer.get());
+    ExpressionHoister hoister(std::move(loop_info), config, analyzer);
     stmt = hoister(std::move(stmt));
     stmt = ConvertSSA(std::move(stmt));
     return stmt;
@@ -462,7 +462,7 @@ class ExpressionHoister : public 
arith::IRMutatorWithAnalyzer {
   using Parent::VisitStmt_;
 
   explicit ExpressionHoister(std::vector<HoistInfoCollector::HoistInfo> 
loop_info,
-                             HoistExpressionConfig config, arith::AnalyzerObj* 
analyzer)
+                             HoistExpressionConfig config, const 
arith::Analyzer& analyzer)
       : Parent(analyzer), config_(config) {
     for (auto& info : loop_info) {
       // Mark let bindings to use if they are enabled on their own.
diff --git a/src/s_tir/transform/inject_permuted_layout.cc 
b/src/s_tir/transform/inject_permuted_layout.cc
index 8ef5051ae0..6cbba28110 100644
--- a/src/s_tir/transform/inject_permuted_layout.cc
+++ b/src/s_tir/transform/inject_permuted_layout.cc
@@ -46,14 +46,14 @@ class PermutedLayoutInjector : private 
IRMutatorWithAnalyzer {
   static PrimFunc Transform(PrimFunc func) {
     Analyzer analyzer;
 
-    auto new_body = PermutedLayoutInjector(func, analyzer.get())(func->body);
+    auto new_body = PermutedLayoutInjector(func, analyzer)(func->body);
     auto func_node = func.CopyOnWrite();
     func_node->body = new_body;
     return func;
   }
 
  private:
-  explicit PermutedLayoutInjector(PrimFunc func, AnalyzerObj* analyzer)
+  explicit PermutedLayoutInjector(PrimFunc func, const Analyzer& analyzer)
       : IRMutatorWithAnalyzer(analyzer) {
     buffer_map_.insert(func->buffer_map.begin(), func->buffer_map.end());
   }
diff --git a/src/s_tir/transform/inject_virtual_thread.cc 
b/src/s_tir/transform/inject_virtual_thread.cc
index 6e50970623..7853bcc049 100644
--- a/src/s_tir/transform/inject_virtual_thread.cc
+++ b/src/s_tir/transform/inject_virtual_thread.cc
@@ -541,7 +541,7 @@ Pass InjectVirtualThread() {
 
     arith::Analyzer analyzer;
 
-    n->body = VirtualThreadInjector(analyzer.get())(std::move(n->body));
+    n->body = VirtualThreadInjector(analyzer)(std::move(n->body));
     n->body = ConvertSSA(std::move(n->body));
     return f;
   };
diff --git a/src/s_tir/transform/lower_async_dma.cc 
b/src/s_tir/transform/lower_async_dma.cc
index 1178c1aa48..89660d4fef 100644
--- a/src/s_tir/transform/lower_async_dma.cc
+++ b/src/s_tir/transform/lower_async_dma.cc
@@ -46,7 +46,7 @@ using namespace tvm::tirx;
 
 class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer {
  public:
-  explicit AsyncDMALowerer(bool dma_bypass_cache, arith::AnalyzerObj* analyzer)
+  explicit AsyncDMALowerer(bool dma_bypass_cache, const arith::Analyzer& 
analyzer)
       : IRMutatorWithAnalyzer(analyzer), dma_bypass_cache_(dma_bypass_cache) {}
 
   // TODO(leiwang1999): split lower async DMA support for CUDA and Hexagon 
Backend
@@ -176,7 +176,7 @@ Pass LowerAsyncDMA() {
     arith::Analyzer analyzer;
     bool dma_bypass_cache =
         ctx->GetConfig<bool>("tirx.experimental_dma_bypass_cache", 
false).value();
-    fptr->body = AsyncDMALowerer(dma_bypass_cache, 
analyzer.get())(std::move(fptr->body));
+    fptr->body = AsyncDMALowerer(dma_bypass_cache, 
analyzer)(std::move(fptr->body));
     return f;
   };
   return CreatePrimFuncPass(pass_func, 0, "s_tir.LowerAsyncDMA", {});
diff --git a/src/s_tir/transform/renormalize_split_pattern.cc 
b/src/s_tir/transform/renormalize_split_pattern.cc
index f185e66a07..2fbadfabd4 100644
--- a/src/s_tir/transform/renormalize_split_pattern.cc
+++ b/src/s_tir/transform/renormalize_split_pattern.cc
@@ -52,7 +52,7 @@ using namespace arith;
 
 class SplitPatternReNormalizer : public IRMutatorWithAnalyzer {
  public:
-  explicit SplitPatternReNormalizer(AnalyzerObj* analyzer) : 
IRMutatorWithAnalyzer(analyzer) {}
+  explicit SplitPatternReNormalizer(const Analyzer& analyzer) : 
IRMutatorWithAnalyzer(analyzer) {}
 
   using IRMutatorWithAnalyzer::VisitExpr_;
 
@@ -201,7 +201,7 @@ Pass RenormalizeSplitPattern() {
   auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
     auto* n = f.CopyOnWrite();
     arith::Analyzer analyzer;
-    n->body = SplitPatternReNormalizer(analyzer.get())(std::move(n->body));
+    n->body = SplitPatternReNormalizer(analyzer)(std::move(n->body));
     return f;
   };
   return CreatePrimFuncPass(pass_func, 0, "s_tir.RenormalizeSplitPattern", {});
diff --git a/src/s_tir/transform/using_assume_to_reduce_branches.cc 
b/src/s_tir/transform/using_assume_to_reduce_branches.cc
index 0ad2ad69f5..74e2c6c7d9 100644
--- a/src/s_tir/transform/using_assume_to_reduce_branches.cc
+++ b/src/s_tir/transform/using_assume_to_reduce_branches.cc
@@ -115,7 +115,7 @@ class ParseAssumeAndOvercompute : public 
IRMutatorWithAnalyzer {
 
  public:
   using Parent = IRMutatorWithAnalyzer;
-  explicit ParseAssumeAndOvercompute(AnalyzerObj* analyzer) : Parent(analyzer) 
{}
+  explicit ParseAssumeAndOvercompute(const Analyzer& analyzer) : 
Parent(analyzer) {}
 
  private:
   using Parent::VisitExpr_;
@@ -380,7 +380,7 @@ Pass UseAssumeToReduceBranches() {
 
           if (assume_checker.has_assume) {
             // Leverage from assume and eliminate the branch
-            ParseAssumeAndOvercompute func_analyzer_mutator(analyzer.get());
+            ParseAssumeAndOvercompute func_analyzer_mutator(analyzer);
             n->body = func_analyzer_mutator(std::move(n->body));
           }
         }
diff --git a/src/tirx/transform/flatten_buffer.cc 
b/src/tirx/transform/flatten_buffer.cc
index 7298c2df20..267b09484e 100644
--- a/src/tirx/transform/flatten_buffer.cc
+++ b/src/tirx/transform/flatten_buffer.cc
@@ -45,7 +45,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
  public:
   static PrimFunc Flatten(PrimFunc func) {
     arith::Analyzer ana;
-    auto pass = BufferFlattener(ana.get());
+    auto pass = BufferFlattener(ana);
     pass.MarkBufferMapShapes(func);
     auto body = pass.VisitStmt(func->body);
 
@@ -78,7 +78,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
   using IRMutatorWithAnalyzer::VisitStmt;
   using IRMutatorWithAnalyzer::VisitStmt_;
 
-  explicit BufferFlattener(arith::AnalyzerObj* ana) : 
IRMutatorWithAnalyzer(ana) {}
+  explicit BufferFlattener(const arith::Analyzer& ana) : 
IRMutatorWithAnalyzer(ana) {}
 
   Stmt VisitStmt_(const SBlockNode* op) final {
     TVM_FFI_ICHECK_EQ(op->match_buffers.size(), 0)
diff --git a/src/tirx/transform/lower_intrin.cc 
b/src/tirx/transform/lower_intrin.cc
index c85d762ab1..94ee4e1044 100644
--- a/src/tirx/transform/lower_intrin.cc
+++ b/src/tirx/transform/lower_intrin.cc
@@ -47,7 +47,7 @@ class IntrinInjecter : public 
tvm::arith::IRMutatorWithAnalyzer {
   using IRMutatorWithAnalyzer::VisitStmt_;
   using FLowerGeneral = ffi::TypedFunction<PrimExpr(PrimExpr)>;
 
-  IntrinInjecter(arith::AnalyzerObj* analyzer, const Target& tgt, bool 
enable_fast_math)
+  IntrinInjecter(const arith::Analyzer& analyzer, const Target& tgt, bool 
enable_fast_math)
       : IRMutatorWithAnalyzer(analyzer) {
     std::string target = tgt->kind->name;
     ffi::String mtriple = tgt->GetAttr<ffi::String>("mtriple").value_or("");
@@ -368,8 +368,7 @@ Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
   arith::Analyzer analyzer;
   bool enable_fast_math =
       
transform::PassContext::Current()->GetConfig<bool>("tirx.enable_fast_math", 
false).value();
-  return IntrinInjecter(analyzer.get(), Target(ffi::String(target)),
-                        enable_fast_math)(std::move(stmt));
+  return IntrinInjecter(analyzer, Target(ffi::String(target)), 
enable_fast_math)(std::move(stmt));
 }
 
 namespace transform {
@@ -381,7 +380,7 @@ Pass LowerIntrin() {
     TVM_FFI_ICHECK(target.defined()) << "LowerIntrin: Require the target 
attribute";
     arith::Analyzer analyzer;
     bool enable_fast_math = ctx->GetConfig<bool>("tirx.enable_fast_math", 
false).value();
-    n->body = IntrinInjecter(analyzer.get(), target.value(), 
enable_fast_math)(std::move(n->body));
+    n->body = IntrinInjecter(analyzer, target.value(), 
enable_fast_math)(std::move(n->body));
     return f;
   };
   return CreatePrimFuncPass(pass_func, 0, "tirx.LowerIntrin", {});
diff --git a/src/tirx/transform/lower_tirx_cleanup.cc 
b/src/tirx/transform/lower_tirx_cleanup.cc
index b20c52248a..433a0365a2 100644
--- a/src/tirx/transform/lower_tirx_cleanup.cc
+++ b/src/tirx/transform/lower_tirx_cleanup.cc
@@ -47,7 +47,7 @@ class LayoutApplier : public arith::IRMutatorWithAnalyzer {
   static std::pair<Stmt, ffi::Map<Var, Buffer>> Flatten(
       const Stmt& stmt, const ffi::Map<tirx::Var, Buffer> buffer_map, const 
Target& target) {
     arith::Analyzer ana;
-    LayoutApplier storage_lower(ana.get(), target);
+    LayoutApplier storage_lower(ana, target);
     std::unordered_map<Var, Buffer> new_buffer_map;
     std::vector<Buffer> param_flattened_buffers;
     for (const auto& kv : buffer_map) {
@@ -72,7 +72,7 @@ class LayoutApplier : public arith::IRMutatorWithAnalyzer {
   using IRMutatorWithAnalyzer::VisitExpr_;
   using IRMutatorWithAnalyzer::VisitStmt_;
 
-  explicit LayoutApplier(arith::AnalyzerObj* analyzer, const Target& target)
+  explicit LayoutApplier(const arith::Analyzer& analyzer, const Target& target)
       : arith::IRMutatorWithAnalyzer(analyzer), target_(target) {}
 
   ffi::Any VisitAny(const ffi::Any& any) {
diff --git a/src/tirx/transform/remove_no_op.cc 
b/src/tirx/transform/remove_no_op.cc
index 6394eb2198..833bf1f45f 100644
--- a/src/tirx/transform/remove_no_op.cc
+++ b/src/tirx/transform/remove_no_op.cc
@@ -75,7 +75,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tirx.RemoveNoOp", 
RemoveNoOpConfig);
 // Mark the statement of each stage.
 class NoOpRemover : public arith::IRMutatorWithAnalyzer {
  public:
-  static Stmt Apply(Stmt stmt, arith::AnalyzerObj* analyzer, bool 
ignore_profiler_call = false) {
+  static Stmt Apply(Stmt stmt, const arith::Analyzer& analyzer, bool 
ignore_profiler_call = false) {
     NoOpRemover visitor(analyzer, ignore_profiler_call);
     return visitor(std::move(stmt));
   }
@@ -85,7 +85,7 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer {
   using Parent::VisitStmt;
   using Parent::VisitStmt_;
 
-  NoOpRemover(arith::AnalyzerObj* analyzer, bool ignore_profiler_call = false)
+  NoOpRemover(const arith::Analyzer& analyzer, bool ignore_profiler_call = 
false)
       : Parent(analyzer), ignore_profiler_call_(ignore_profiler_call) {}
 
   Stmt VisitStmt_(const BindNode* op) final {
@@ -266,7 +266,7 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer {
   bool ignore_profiler_call_{false};
 };
 
-Stmt RemoveNoOp(Stmt stmt, arith::AnalyzerObj* analyzer, bool 
ignore_profiler_call) {
+Stmt RemoveNoOp(Stmt stmt, const arith::Analyzer& analyzer, bool 
ignore_profiler_call) {
   return NoOpRemover::Apply(std::move(stmt), analyzer, ignore_profiler_call);
 }
 
@@ -286,7 +286,7 @@ Pass RemoveNoOp() {
     {
       auto* write_ptr = f.CopyOnWrite();
       write_ptr->body =
-          NoOpRemover::Apply(std::move(write_ptr->body), analyzer.get(), 
ignore_profiler_call);
+          NoOpRemover::Apply(std::move(write_ptr->body), analyzer, 
ignore_profiler_call);
     }
     return f;
   };
diff --git a/src/tirx/transform/remove_no_op.h 
b/src/tirx/transform/remove_no_op.h
index cd9710b617..03a9823aa6 100644
--- a/src/tirx/transform/remove_no_op.h
+++ b/src/tirx/transform/remove_no_op.h
@@ -41,7 +41,7 @@ namespace tirx {
  *
  * \return The modified statement with no-ops removed
  */
-Stmt RemoveNoOp(Stmt stmt, arith::AnalyzerObj* analyzer, bool 
ignore_profiler_call = false);
+Stmt RemoveNoOp(Stmt stmt, const arith::Analyzer& analyzer, bool 
ignore_profiler_call = false);
 
 }  // namespace tirx
 }  // namespace tvm
diff --git a/src/tirx/transform/stmt_simplify.cc 
b/src/tirx/transform/stmt_simplify.cc
index d7dd4599f4..01038e4165 100644
--- a/src/tirx/transform/stmt_simplify.cc
+++ b/src/tirx/transform/stmt_simplify.cc
@@ -98,7 +98,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tirx.StmtSimplify", 
StmtSimplifyConfig);
 
 class StmtSimplifier : public IRMutatorWithAnalyzer {
  public:
-  static PrimFunc Apply(PrimFunc func, AnalyzerObj* analyzer,
+  static PrimFunc Apply(PrimFunc func, const Analyzer& analyzer,
                         ffi::Optional<StmtSimplifyConfig> config_opt = 
std::nullopt) {
     auto config = config_opt.value_or(MakeDefaultStmtSimplifyConfig());
     
analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions());
@@ -110,7 +110,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
   }
 
  private:
-  explicit StmtSimplifier(AnalyzerObj* analyzer, StmtSimplifyConfig config)
+  explicit StmtSimplifier(const Analyzer& analyzer, StmtSimplifyConfig config)
       : IRMutatorWithAnalyzer(analyzer), config_(config) {}
 
   using Parent = IRMutatorWithAnalyzer;
@@ -250,7 +250,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
 
 namespace tirx {
 
-PrimFunc StmtSimplify(PrimFunc func, arith::AnalyzerObj* analyzer) {
+PrimFunc StmtSimplify(PrimFunc func, const arith::Analyzer& analyzer) {
   return arith::StmtSimplifier::Apply(std::move(func), analyzer);
 }
 
@@ -261,7 +261,7 @@ Pass StmtSimplify() {
     arith::Analyzer analyzer;
     auto cfg = ctx->GetConfig<arith::StmtSimplifyConfig>("tirx.StmtSimplify");
 
-    return arith::StmtSimplifier::Apply(f, analyzer.get(), cfg);
+    return arith::StmtSimplifier::Apply(f, analyzer, cfg);
   };
   return CreatePrimFuncPass(pass_func, 0, "tirx.StmtSimplify", {});
 }
diff --git a/src/tirx/transform/stmt_simplify.h 
b/src/tirx/transform/stmt_simplify.h
index 5f10397e83..224df0ed8b 100644
--- a/src/tirx/transform/stmt_simplify.h
+++ b/src/tirx/transform/stmt_simplify.h
@@ -34,7 +34,7 @@ namespace tirx {
  *
  * Applies the same behavior as the tirx.transform.StmtSimplify pass.
  */
-PrimFunc StmtSimplify(PrimFunc func, arith::AnalyzerObj* analyzer);
+PrimFunc StmtSimplify(PrimFunc func, const arith::Analyzer& analyzer);
 
 }  // namespace tirx
 }  // namespace tvm

Reply via email to