This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new de89da6b18 [IR] Rename Call annotations to attrs (#19618)
de89da6b18 is described below

commit de89da6b18bd516c40007d15dc1d012d9506eb62
Author: Shushi Hong <[email protected]>
AuthorDate: Wed May 27 06:55:13 2026 -0400

    [IR] Rename Call annotations to attrs (#19618)
    
    This PR renames `tirx::CallNode::annotations` to `attrs`, matching the
    existing Relax `CallNode::attrs` convention.
    Previously, TIRX Call metadata was stored in a `Map<String, Any>` field
    named `annotations`. This PR makes it a first-class `Attrs` field
    instead, so call-level metadata follows the same representation and
    naming style as Relax calls.
---
 include/tvm/tirx/expr.h                          | 16 ++++------
 python/tvm/tirx/expr.py                          | 14 +++++----
 python/tvm/tirx/expr_functor.py                  |  2 +-
 python/tvm/tirx/op.py                            | 12 +++----
 src/arith/ir_mutator_with_analyzer.cc            |  2 +-
 src/arith/rewrite_simplify.cc                    |  4 +--
 src/s_tir/transform/inject_software_pipeline.cc  |  6 ++--
 src/tirx/analysis/deep_equal.cc                  |  4 ++-
 src/tirx/ir/data_type_rewriter.cc                |  5 ++-
 src/tirx/ir/expr.cc                              | 18 ++++++-----
 src/tirx/ir/expr_functor.cc                      |  2 +-
 src/tirx/script/printer/expr.cc                  |  4 +--
 src/tirx/transform/lower_warp_memory.cc          |  2 +-
 src/tirx/transform/storage_rewrite.cc            |  2 +-
 src/tirx/transform/tile_primitive_dispatch.cc    |  2 +-
 src/tirx/transform/unsupported_dtype_legalize.cc |  5 ++-
 src/tirx/transform/vectorize_loop.cc             | 18 +++++------
 tests/python/tirx-base/test_tir_constructor.py   | 40 ++++++++++++++----------
 18 files changed, 81 insertions(+), 77 deletions(-)

diff --git a/include/tvm/tirx/expr.h b/include/tvm/tirx/expr.h
index 68cfcbd361..e2c1c3f33d 100644
--- a/include/tvm/tirx/expr.h
+++ b/include/tvm/tirx/expr.h
@@ -28,6 +28,7 @@
 #include <tvm/ffi/container/array.h>
 #include <tvm/ffi/container/map.h>
 #include <tvm/ffi/string.h>
+#include <tvm/ir/attrs.h>
 #include <tvm/ir/cow.h>
 #include <tvm/ir/expr.h>
 #include <tvm/ir/node_functor.h>
@@ -731,20 +732,15 @@ class CallNode : public PrimExprNode {
   /*! \brief The arguments. */
   ffi::Array<PrimExpr> args;
 
-  /*!
-   * \brief Additional annotations about the call.
-   *
-   * These annotations can be used to carry target-specific metadata through
-   * TIRX transformations and codegen.
-   */
-  ffi::Map<ffi::String, ffi::Any> annotations;
+  /*! \brief The additional attributes. */
+  Attrs attrs;
 
   static void RegisterReflection() {
     namespace refl = tvm::ffi::reflection;
     refl::ObjectDef<CallNode>()
         .def_ro("op", &CallNode::op)
         .def_ro("args", &CallNode::args)
-        .def_ro("annotations", &CallNode::annotations);
+        .def_ro("attrs", &CallNode::attrs);
   }
   TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Call", CallNode, PrimExprNode);
 };
@@ -755,9 +751,9 @@ class CallNode : public PrimExprNode {
  */
 class Call : public PrimExpr {
  public:
-  TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args,
-               ffi::Map<ffi::String, ffi::Any> annotations = 
ffi::Map<ffi::String, ffi::Any>(),
+  TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args, Attrs 
attrs = Attrs(),
                Span span = Span());
+  TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args, Span 
span);
   TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, PrimExpr, CallNode);
   TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode);
 };
diff --git a/python/tvm/tirx/expr.py b/python/tvm/tirx/expr.py
index bf04b3ae7a..a97171e436 100644
--- a/python/tvm/tirx/expr.py
+++ b/python/tvm/tirx/expr.py
@@ -1303,20 +1303,20 @@ class Call(PrimExprWithOp):
     span : Optional[Span]
         The location of this expression in the source code.
 
-    annotations : Optional[dict]
-        Additional metadata attached to the call.
+    attrs : Optional[tvm.ir.Attrs or dict]
+        Attributes attached to the call.
     """
 
     op: Op
     args: list[PrimExpr]
-    annotations: dict
+    attrs: ir.Attrs | None
 
     def __init__(
         self,
         dtype: str,
         op: Op | str,
         args: list[PrimExpr],
-        annotations: dict | None = None,
+        attrs: ir.Attrs | dict | None = None,
         span: Span | None = None,
     ) -> None:
         if isinstance(op, str):
@@ -1330,9 +1330,11 @@ class Call(PrimExprWithOp):
                     % op
                 )
             op = Op.get(op)
-        if annotations:
+        if isinstance(attrs, dict):
+            attrs = ir.make_node("ir.DictAttrs", **attrs)
+        if attrs:
             self.__init_handle_by_constructor__(  # type: ignore
-                _ffi_api.CallWithAnnotations, dtype, op, args, annotations, 
span
+                _ffi_api.CallWithAttrs, dtype, op, args, attrs, span
             )
         else:
             self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op, 
args, span)  # type: ignore
diff --git a/python/tvm/tirx/expr_functor.py b/python/tvm/tirx/expr_functor.py
index b09606602a..8e86a361b6 100644
--- a/python/tvm/tirx/expr_functor.py
+++ b/python/tvm/tirx/expr_functor.py
@@ -495,7 +495,7 @@ class ExprMutator(ExprFunctor):
         if all(old_arg is new_arg for old_arg, new_arg in zip(op.args, args)):
             return op
         else:
-            return tvm.tirx.Call(op.dtype, op.op, args, 
annotations=op.annotations, span=op.span)
+            return tvm.tirx.Call(op.dtype, op.op, args, attrs=op.attrs, 
span=op.span)
 
     def _mutate_binary_op(self, op_cls, op):
         """Helper to mutate binary operators."""
diff --git a/python/tvm/tirx/op.py b/python/tvm/tirx/op.py
index 2f227195b3..0ec7605abb 100644
--- a/python/tvm/tirx/op.py
+++ b/python/tvm/tirx/op.py
@@ -190,7 +190,7 @@ def call_cpacked(*args, span=None):
     return Call("int32", Op.get("tirx.tvm_call_cpacked"), call_args, span=span)
 
 
-def call_intrin(dtype, func_name, *args, annotations=None, span=None):
+def call_intrin(dtype, func_name, *args, attrs=None, span=None):
     """Build expression by calling an intrinsic function.
 
     Intrinsics can be overloaded with multiple data types via
@@ -207,8 +207,8 @@ def call_intrin(dtype, func_name, *args, annotations=None, 
span=None):
     args : list
         Positional arguments.
 
-    annotations : Optional[Dict[str, Object]]
-        Additional annotations about the call.
+    attrs : Optional[tvm.ir.Attrs or Dict[str, Object]]
+        Additional attributes for the call.
 
     span : Optional[Span]
         The location of this operator in the source code.
@@ -218,11 +218,7 @@ def call_intrin(dtype, func_name, *args, annotations=None, 
span=None):
     call : PrimExpr
         The call expression.
     """
-    if annotations is not None:
-        annotations = {
-            k: const(v) if isinstance(v, int | bool) else v for k, v in 
annotations.items()
-        }
-    return Call(dtype, func_name, args, annotations=annotations, span=span)
+    return Call(dtype, func_name, args, attrs=attrs, span=span)
 
 
 def call_pure_extern(dtype, func_name, *args, span=None):
diff --git a/src/arith/ir_mutator_with_analyzer.cc 
b/src/arith/ir_mutator_with_analyzer.cc
index e902d32aba..39b7faad84 100644
--- a/src/arith/ir_mutator_with_analyzer.cc
+++ b/src/arith/ir_mutator_with_analyzer.cc
@@ -310,7 +310,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* 
op) {
         false_value.same_as(op->args[2])) {
       return ffi::GetRef<PrimExpr>(op);
     } else {
-      return Call(op->dtype, op->op, {cond, true_value, false_value}, 
op->annotations, op->span);
+      return Call(op->dtype, op->op, {cond, true_value, false_value}, 
op->attrs, op->span);
     }
   }
   return StmtExprMutator::VisitExpr_(op);
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 1765a6b04a..804cb3cd97 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -2343,8 +2343,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
CallNode* op) {
       // Only check constant cases to avoid recursion
       if (is_const_number(inner_else_expr) && is_const_number(else_expr) &&
           analyzer_->CanProve(inner_else_expr == else_expr)) {
-        return Call(op->dtype, op->op, {cond && inner_cond, inner_then_expr, 
else_expr},
-                    op->annotations, op->span);
+        return Call(op->dtype, op->op, {cond && inner_cond, inner_then_expr, 
else_expr}, op->attrs,
+                    op->span);
       }
     }
   }
diff --git a/src/s_tir/transform/inject_software_pipeline.cc 
b/src/s_tir/transform/inject_software_pipeline.cc
index 86fc6028e1..ba6c3bf666 100644
--- a/src/s_tir/transform/inject_software_pipeline.cc
+++ b/src/s_tir/transform/inject_software_pipeline.cc
@@ -119,7 +119,7 @@ class PipelineOpaqueAccessRewriter {
         ffi::Array<PrimExpr> new_args = call->args;
         const Buffer& new_buffer = (*it).second;
         new_args.Set(4, RewriteWmmaFragmentIndex(buffer, new_buffer, 
call->args[4]));
-        return Call(call->dtype, call->op, new_args, call->annotations, 
call->span);
+        return Call(call->dtype, call->op, new_args, call->attrs, call->span);
       }
     } else if (call->op.same_as(mma_sync)) {
       ffi::Array<PrimExpr> new_args = call->args;
@@ -133,7 +133,7 @@ class PipelineOpaqueAccessRewriter {
           new_args.Set(i * 2 + 1, new_index);
         }
       }
-      return Call(call->dtype, call->op, new_args, call->annotations, 
call->span);
+      return Call(call->dtype, call->op, new_args, call->attrs, call->span);
     } else if (call->op.same_as(access_ptr)) {
       return RewriteBufferAccess(call, {1});
     } else if (call->op.same_as(ptx_mma)) {
@@ -196,7 +196,7 @@ class PipelineOpaqueAccessRewriter {
         new_args.Set(i + 1, new_index);
       }
     }
-    return Call(call->dtype, call->op, new_args, call->annotations, 
call->span);
+    return Call(call->dtype, call->op, new_args, call->attrs, call->span);
   }
 
   const ffi::Map<Var, Buffer>& buffer_data_to_buffer_;
diff --git a/src/tirx/analysis/deep_equal.cc b/src/tirx/analysis/deep_equal.cc
index f164ba427c..53700a85a9 100644
--- a/src/tirx/analysis/deep_equal.cc
+++ b/src/tirx/analysis/deep_equal.cc
@@ -21,6 +21,7 @@
  * \file tirx/analysis/deep_equal.cc
  * \brief Deep equality checking.
  */
+#include <tvm/ffi/extra/structural_equal.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/tirx/analysis.h>
@@ -124,7 +125,8 @@ class ExprDeepEqualChecker : private ExprFunctor<bool(const 
PrimExpr&, const Pri
   bool VisitExpr_(const CallNode* plhs, const PrimExpr& rhs) final {
     const auto* prhs = rhs.as<CallNode>();
     return plhs->dtype == prhs->dtype && plhs->op.same_as(prhs->op) &&
-           ArrayDeepEqual(plhs->args, prhs->args);
+           ArrayDeepEqual(plhs->args, prhs->args) &&
+           ffi::StructuralEqual()(plhs->attrs, prhs->attrs);
   }
 
   bool VisitExpr_(const ReduceNode* plhs, const PrimExpr& rhs) final {
diff --git a/src/tirx/ir/data_type_rewriter.cc 
b/src/tirx/ir/data_type_rewriter.cc
index 6fab0e3e09..cc4c2d5f78 100644
--- a/src/tirx/ir/data_type_rewriter.cc
+++ b/src/tirx/ir/data_type_rewriter.cc
@@ -248,8 +248,7 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) {
   } else if (op->op.same_as(builtin_pow_)) {
     return pow(op->args[0], op->args[1]);
   } else if (op->op.same_as(builtin::if_then_else())) {
-    return Call(op->dtype, op->op, {op->args[0], op->args[1], op->args[2]}, 
op->annotations,
-                op->span);
+    return Call(op->dtype, op->op, {op->args[0], op->args[1], op->args[2]}, 
op->attrs, op->span);
   } else if (op->op.same_as(Op::Get("tirx.clz"))) {
     DataType before_dtype = before->args[0]->dtype;
     DataType after_dtype = op->args[0]->dtype;
@@ -566,7 +565,7 @@ PrimExpr IndexDataTypeRewriter::VisitExpr_(const CallNode* 
op) {
     PrimExpr cond = VisitExpr(op->args[0]);
     is_condition_ = is_condition;
     return Call(op->dtype, op->op, {cond, VisitExpr(op->args[1]), 
VisitExpr(op->args[2])},
-                op->annotations, op->span);
+                op->attrs, op->span);
   }
   return Parent::VisitExpr_(op);
 }
diff --git a/src/tirx/ir/expr.cc b/src/tirx/ir/expr.cc
index 5faa7c24bd..d458c09180 100644
--- a/src/tirx/ir/expr.cc
+++ b/src/tirx/ir/expr.cc
@@ -621,8 +621,7 @@ static ffi::Array<PrimExpr> 
ConvertCallArgs(ffi::Array<CallArg> args) {
   return prim_expr_args;
 }
 
-Call::Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args,
-           ffi::Map<ffi::String, ffi::Any> annotations, Span span) {
+Call::Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args, Attrs 
attrs, Span span) {
   for (size_t i = 0; i < args.size(); ++i) {
     TVM_FFI_ICHECK(args[i].defined()) << "arg " << i << " is not defined()";
   }
@@ -631,24 +630,27 @@ Call::Call(DataType dtype, RelaxExpr op, 
ffi::Array<PrimExpr> args,
   node->dtype = dtype;
   node->op = std::move(op);
   node->args = std::move(args);
-  node->annotations = std::move(annotations);
+  node->attrs = std::move(attrs);
   node->span = std::move(span);
   data_ = std::move(node);
 }
 
+Call::Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args, Span span)
+    : Call(dtype, std::move(op), std::move(args), Attrs(), std::move(span)) {}
+
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
   refl::GlobalDef()
       .def("tirx.Call",
            [](ffi::Optional<DataType> dtype, RelaxExpr op, ffi::Array<CallArg> 
args, Span span) {
-             return Call(dtype.value_or(DataType::Void()), op, 
ConvertCallArgs(args),
-                         ffi::Map<ffi::String, ffi::Any>(), span);
+             return Call(dtype.value_or(DataType::Void()), op, 
ConvertCallArgs(args), Attrs(),
+                         span);
            })
-      .def("tirx.CallWithAnnotations",
+      .def("tirx.CallWithAttrs",
            [](ffi::Optional<DataType> dtype, RelaxExpr op, ffi::Array<CallArg> 
args,
-              ffi::Optional<ffi::Map<ffi::String, ffi::Any>> annotations, Span 
span) {
+              ffi::Optional<Attrs> attrs, Span span) {
              return Call(dtype.value_or(DataType::Void()), op, 
ConvertCallArgs(args),
-                         annotations.value_or(ffi::Map<ffi::String, 
ffi::Any>()), span);
+                         attrs.value_or(Attrs()), span);
            });
 }
 
diff --git a/src/tirx/ir/expr_functor.cc b/src/tirx/ir/expr_functor.cc
index 921ce45623..aba96aae8c 100644
--- a/src/tirx/ir/expr_functor.cc
+++ b/src/tirx/ir/expr_functor.cc
@@ -155,7 +155,7 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) {
   if (args.same_as(op->args)) {
     return ffi::GetRef<PrimExpr>(op);
   } else {
-    return Call(op->dtype, op->op, args, op->annotations, op->span);
+    return Call(op->dtype, op->op, args, op->attrs, op->span);
   }
 }
 
diff --git a/src/tirx/script/printer/expr.cc b/src/tirx/script/printer/expr.cc
index 8aff24f7d8..cd33f59e3c 100644
--- a/src/tirx/script/printer/expr.cc
+++ b/src/tirx/script/printer/expr.cc
@@ -258,7 +258,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
 
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<tirx::Call>("", [](tirx::Call call, AccessPath call_p, 
IRDocsifier d) -> Doc {
-      if (!call->annotations.empty()) {
+      if (call->attrs.defined()) {
         ffi::Array<ExprDoc> call_args;
         int n_args = call->args.size();
         call_args.reserve(n_args);
@@ -270,7 +270,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
                              : d->AsDoc<ExprDoc>(call->op, call_p->Attr("op"));
         return TIR(d, "Call")->Call(
             {LiteralDoc::DataType(call->dtype, call_p->Attr("dtype")), op_doc, 
ListDoc(call_args)},
-            {"annotations"}, {d->AsDoc<DictDoc>(call->annotations, 
call_p->Attr("annotations"))});
+            {"attrs"}, {d->AsDoc<ExprDoc>(call->attrs, 
call_p->Attr("attrs"))});
       }
       static const OpAttrMap<tirx::TScriptPrinterName>& op_names =
           Op::GetAttrMap<tirx::TScriptPrinterName>("TScriptPrinterName");
diff --git a/src/tirx/transform/lower_warp_memory.cc 
b/src/tirx/transform/lower_warp_memory.cc
index 9c80ed599d..99c815bf66 100644
--- a/src/tirx/transform/lower_warp_memory.cc
+++ b/src/tirx/transform/lower_warp_memory.cc
@@ -291,7 +291,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
         new_args.Set(i + 1, local_index);
       }
     }
-    return Call(op->dtype, op->op, new_args, op->annotations, op->span);
+    return Call(op->dtype, op->op, new_args, op->attrs, op->span);
   }
 
   PrimExpr VisitExpr_(const CallNode* op) override {
diff --git a/src/tirx/transform/storage_rewrite.cc 
b/src/tirx/transform/storage_rewrite.cc
index c7e8caaf52..66d04ca899 100644
--- a/src/tirx/transform/storage_rewrite.cc
+++ b/src/tirx/transform/storage_rewrite.cc
@@ -497,7 +497,7 @@ class StoragePlanRewriter : public StmtExprMutator {
         offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + 
offset;
       }
       return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset, 
extent, op->args[4]},
-                  op->annotations, op->span);
+                  op->attrs, op->span);
     } else {
       return StmtExprMutator::VisitExpr_(op);
     }
diff --git a/src/tirx/transform/tile_primitive_dispatch.cc 
b/src/tirx/transform/tile_primitive_dispatch.cc
index fbc7786d92..de01ee5db6 100644
--- a/src/tirx/transform/tile_primitive_dispatch.cc
+++ b/src/tirx/transform/tile_primitive_dispatch.cc
@@ -1160,7 +1160,7 @@ class TilePrimitiveDispatcher : public StmtExprMutator {
         args.push_back(new_arg);
       }
       if (changed) {
-        return tirx::Call(call->dtype, call->op, args, call->annotations, 
call->span);
+        return tirx::Call(call->dtype, call->op, args, call->attrs, 
call->span);
       }
     }
     return pred;
diff --git a/src/tirx/transform/unsupported_dtype_legalize.cc 
b/src/tirx/transform/unsupported_dtype_legalize.cc
index c2934c2d86..558a3ca437 100644
--- a/src/tirx/transform/unsupported_dtype_legalize.cc
+++ b/src/tirx/transform/unsupported_dtype_legalize.cc
@@ -238,13 +238,12 @@ class ComputeLegalizer : public StmtExprMutator {
     auto fmutate = [this](const PrimExpr& e) { return 
PromoteToTarget(this->VisitExpr(e)); };
     ffi::Array<PrimExpr> args = op->args.Map(fmutate);
     if (MatchDType(op->dtype)) {
-      return Call(promote_dtype_.with_lanes(op->dtype.lanes()), op->op, args, 
op->annotations,
-                  op->span);
+      return Call(promote_dtype_.with_lanes(op->dtype.lanes()), op->op, args, 
op->attrs, op->span);
     }
     if (args.same_as(op->args)) {
       return ffi::GetRef<PrimExpr>(op);
     } else {
-      return Call(op->dtype, op->op, args, op->annotations, op->span);
+      return Call(op->dtype, op->op, args, op->attrs, op->span);
     }
   }
 
diff --git a/src/tirx/transform/vectorize_loop.cc 
b/src/tirx/transform/vectorize_loop.cc
index f444c17822..0ac9680d0a 100644
--- a/src/tirx/transform/vectorize_loop.cc
+++ b/src/tirx/transform/vectorize_loop.cc
@@ -491,10 +491,10 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
       t = BroadcastTo(t, lanes, is_scalable);
       f = BroadcastTo(f, lanes, is_scalable);
       if (is_scalable) {
-        return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, 
{cond, t, f},
-                    op->annotations, op->span);
+        return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, 
{cond, t, f}, op->attrs,
+                    op->span);
       } else {
-        return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}, 
op->annotations, op->span);
+        return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}, 
op->attrs, op->span);
       }
     }
   }
@@ -507,14 +507,14 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
     } else {
       int lanes = value.dtype().get_lanes_or_vscale_factor();
       if (value.dtype().is_scalable_vector()) {
-        return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, 
{value}, op->annotations,
+        return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, 
{value}, op->attrs,
                     op->span);
       } else {
         int new_lanes = (op->dtype != DataType::Float4E2M1FN() &&
                          op->args[0].dtype() != DataType::Float4E2M1FN())
                             ? (value.dtype().bits() * value.dtype().lanes()) / 
op->dtype.bits()
                             : value.dtype().lanes();
-        return Call(op->dtype.with_lanes(new_lanes), op->op, {value}, 
op->annotations, op->span);
+        return Call(op->dtype.with_lanes(new_lanes), op->op, {value}, 
op->attrs, op->span);
       }
     }
   }
@@ -536,7 +536,7 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
       auto new_args = op->args;
       new_args.pop_back();
       new_args.push_back(fcd[0]);
-      return Call(op->dtype.with_lanes(lane), op->op, new_args, 
op->annotations, op->span);
+      return Call(op->dtype.with_lanes(lane), op->op, new_args, op->attrs, 
op->span);
     } else if (op->op.same_as(builtin::texture2d_store())) {
       int lane = 0;
       // Vectorize the value to store
@@ -551,7 +551,7 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
           << "Expected Data to be Written equal to Texture Store length";
       ffi::Array<PrimExpr> new_args{op->args[0], op->args[1], op->args[2],
                                     op->args[3], op->args[4], 
mutated_value[0]};
-      return Call(op->dtype.with_lanes(lane), op->op, new_args, 
op->annotations, op->span);
+      return Call(op->dtype.with_lanes(lane), op->op, new_args, op->attrs, 
op->span);
     } else if (op->op.same_as(builtin::reinterpret())) {
       return MutateReinterpretExpr_(op);
     }
@@ -573,7 +573,7 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
       if (op->args.same_as(new_args)) {
         return ffi::GetRef<PrimExpr>(op);
       } else {
-        return Call(op->dtype, op->op, new_args, op->annotations, op->span);
+        return Call(op->dtype, op->op, new_args, op->attrs, op->span);
       }
     } else {
       int lane = 0;
@@ -599,7 +599,7 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
       if (op->args.same_as(new_args)) {
         return ffi::GetRef<PrimExpr>(op);
       } else {
-        return Call(op->dtype.with_lanes(lane), op->op, new_args, 
op->annotations, op->span);
+        return Call(op->dtype.with_lanes(lane), op->op, new_args, op->attrs, 
op->span);
       }
     }
   }
diff --git a/tests/python/tirx-base/test_tir_constructor.py 
b/tests/python/tirx-base/test_tir_constructor.py
index d084fe2b25..eda7fd9ebf 100644
--- a/tests/python/tirx-base/test_tir_constructor.py
+++ b/tests/python/tirx-base/test_tir_constructor.py
@@ -19,6 +19,7 @@ import pytest
 
 import tvm
 from tvm import te, topi
+from tvm.tirx.analysis import expr_deep_equal
 from tvm.tirx.expr_functor import ExprMutator
 
 
@@ -133,32 +134,39 @@ def test_expr_constructor():
     assert x.dtype == "float32"
     assert x.op.name == "tirx.call_extern"
     assert x.args[1] == a
-    assert len(x.annotations) == 0
+    assert x.attrs is None
 
-    annotated_arg = tvm.tirx.Var("annotated_arg", "float32")
-    x_with_annotations = tvm.tirx.Call(
+    attr_arg = tvm.tirx.Var("attr_arg", "float32")
+    x_with_attrs = tvm.tirx.Call(
         "float32",
         "tirx.call_extern",
-        [tvm.tirx.StringImm("xyz"), annotated_arg],
-        annotations={"disable_tma": True},
+        [tvm.tirx.StringImm("xyz"), attr_arg],
+        attrs={"disable_tma": True},
     )
-    assert bool(x_with_annotations.annotations["disable_tma"])
-    assert not tvm.ir.structural_equal(x, x_with_annotations)
-    script = tvm.tirx.Evaluate(x_with_annotations).script()
-    assert "annotations" in script
+    assert x_with_attrs.attrs["disable_tma"] is True
+    assert not tvm.ir.structural_equal(x, x_with_attrs)
+    script = tvm.tirx.Evaluate(x_with_attrs).script()
+    assert "attrs" in script
     assert "disable_tma" in script
-    func = tvm.tirx.PrimFunc([], tvm.tirx.Evaluate(x_with_annotations))
+    func = tvm.tirx.PrimFunc([], tvm.tirx.Evaluate(x_with_attrs))
     assert tvm.script.from_source(func.script()).script() == func.script()
 
     y = tvm.tirx.Var("y", "float32")
-    mutated = ReplaceVar(annotated_arg, y)(x_with_annotations)
-    assert bool(mutated.annotations["disable_tma"])
+    mutated = ReplaceVar(attr_arg, y)(x_with_attrs)
+    assert mutated.attrs["disable_tma"] is True
     assert mutated.args[1].same_as(y)
 
     x_from_intrin = tvm.tirx.call_intrin(
-        "float32", "tirx.call_extern", tvm.tirx.StringImm("xyz"), 
annotations={"disable_tma": True}
+        "float32", "tirx.call_extern", tvm.tirx.StringImm("xyz"), 
attrs={"disable_tma": True}
     )
-    assert int(x_from_intrin.annotations["disable_tma"]) == 1
+    assert x_from_intrin.attrs["disable_tma"] is True
+    x_with_other_attrs = tvm.tirx.Call(
+        "float32",
+        "tirx.call_extern",
+        [tvm.tirx.StringImm("xyz"), attr_arg],
+        attrs={"disable_tma": False},
+    )
+    assert not expr_deep_equal(x_with_attrs, x_with_other_attrs)
 
     cond0 = tvm.tirx.Var("cond0", "bool")
     cond1 = tvm.tirx.Var("cond1", "bool")
@@ -171,12 +179,12 @@ def test_expr_constructor():
         "int32",
         "tirx.if_then_else",
         [cond0, inner_if, tvm.tirx.IntImm("int32", 0)],
-        annotations={"keep": True},
+        attrs={"keep": True},
     )
     simplified = tvm.tirx.transform.StmtSimplify()(
         tvm.IRModule({"main": tvm.tirx.PrimFunc([], 
tvm.tirx.Evaluate(outer_if))})
     )["main"].body.value
-    assert bool(simplified.annotations["keep"])
+    assert simplified.attrs["keep"] is True
 
     v = tvm.tirx.Var("aa", "int32")
     x = tvm.tirx.Let(v, 1, v)

Reply via email to