This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new de89da6b18 [IR] Rename Call annotations to attrs (#19618)
de89da6b18 is described below
commit de89da6b18bd516c40007d15dc1d012d9506eb62
Author: Shushi Hong <[email protected]>
AuthorDate: Wed May 27 06:55:13 2026 -0400
[IR] Rename Call annotations to attrs (#19618)
This PR renames `tirx::CallNode::annotations` to `attrs`, matching the
existing Relax `CallNode::attrs` convention.
Previously, TIRX Call metadata was stored in a `Map<String, Any>` field
named `annotations`. This PR makes it a first-class `Attrs` field
instead, so call-level metadata follows the same representation and
naming style as Relax calls.
---
include/tvm/tirx/expr.h | 16 ++++------
python/tvm/tirx/expr.py | 14 +++++----
python/tvm/tirx/expr_functor.py | 2 +-
python/tvm/tirx/op.py | 12 +++----
src/arith/ir_mutator_with_analyzer.cc | 2 +-
src/arith/rewrite_simplify.cc | 4 +--
src/s_tir/transform/inject_software_pipeline.cc | 6 ++--
src/tirx/analysis/deep_equal.cc | 4 ++-
src/tirx/ir/data_type_rewriter.cc | 5 ++-
src/tirx/ir/expr.cc | 18 ++++++-----
src/tirx/ir/expr_functor.cc | 2 +-
src/tirx/script/printer/expr.cc | 4 +--
src/tirx/transform/lower_warp_memory.cc | 2 +-
src/tirx/transform/storage_rewrite.cc | 2 +-
src/tirx/transform/tile_primitive_dispatch.cc | 2 +-
src/tirx/transform/unsupported_dtype_legalize.cc | 5 ++-
src/tirx/transform/vectorize_loop.cc | 18 +++++------
tests/python/tirx-base/test_tir_constructor.py | 40 ++++++++++++++----------
18 files changed, 81 insertions(+), 77 deletions(-)
diff --git a/include/tvm/tirx/expr.h b/include/tvm/tirx/expr.h
index 68cfcbd361..e2c1c3f33d 100644
--- a/include/tvm/tirx/expr.h
+++ b/include/tvm/tirx/expr.h
@@ -28,6 +28,7 @@
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/string.h>
+#include <tvm/ir/attrs.h>
#include <tvm/ir/cow.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/node_functor.h>
@@ -731,20 +732,15 @@ class CallNode : public PrimExprNode {
/*! \brief The arguments. */
ffi::Array<PrimExpr> args;
- /*!
- * \brief Additional annotations about the call.
- *
- * These annotations can be used to carry target-specific metadata through
- * TIRX transformations and codegen.
- */
- ffi::Map<ffi::String, ffi::Any> annotations;
+ /*! \brief The additional attributes. */
+ Attrs attrs;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<CallNode>()
.def_ro("op", &CallNode::op)
.def_ro("args", &CallNode::args)
- .def_ro("annotations", &CallNode::annotations);
+ .def_ro("attrs", &CallNode::attrs);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Call", CallNode, PrimExprNode);
};
@@ -755,9 +751,9 @@ class CallNode : public PrimExprNode {
*/
class Call : public PrimExpr {
public:
- TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args,
- ffi::Map<ffi::String, ffi::Any> annotations =
ffi::Map<ffi::String, ffi::Any>(),
+ TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args, Attrs
attrs = Attrs(),
Span span = Span());
+ TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args, Span
span);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, PrimExpr, CallNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode);
};
diff --git a/python/tvm/tirx/expr.py b/python/tvm/tirx/expr.py
index bf04b3ae7a..a97171e436 100644
--- a/python/tvm/tirx/expr.py
+++ b/python/tvm/tirx/expr.py
@@ -1303,20 +1303,20 @@ class Call(PrimExprWithOp):
span : Optional[Span]
The location of this expression in the source code.
- annotations : Optional[dict]
- Additional metadata attached to the call.
+ attrs : Optional[tvm.ir.Attrs or dict]
+ Attributes attached to the call.
"""
op: Op
args: list[PrimExpr]
- annotations: dict
+ attrs: ir.Attrs | None
def __init__(
self,
dtype: str,
op: Op | str,
args: list[PrimExpr],
- annotations: dict | None = None,
+ attrs: ir.Attrs | dict | None = None,
span: Span | None = None,
) -> None:
if isinstance(op, str):
@@ -1330,9 +1330,11 @@ class Call(PrimExprWithOp):
% op
)
op = Op.get(op)
- if annotations:
+ if isinstance(attrs, dict):
+ attrs = ir.make_node("ir.DictAttrs", **attrs)
+ if attrs:
self.__init_handle_by_constructor__( # type: ignore
- _ffi_api.CallWithAnnotations, dtype, op, args, annotations,
span
+ _ffi_api.CallWithAttrs, dtype, op, args, attrs, span
)
else:
self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op,
args, span) # type: ignore
diff --git a/python/tvm/tirx/expr_functor.py b/python/tvm/tirx/expr_functor.py
index b09606602a..8e86a361b6 100644
--- a/python/tvm/tirx/expr_functor.py
+++ b/python/tvm/tirx/expr_functor.py
@@ -495,7 +495,7 @@ class ExprMutator(ExprFunctor):
if all(old_arg is new_arg for old_arg, new_arg in zip(op.args, args)):
return op
else:
- return tvm.tirx.Call(op.dtype, op.op, args,
annotations=op.annotations, span=op.span)
+ return tvm.tirx.Call(op.dtype, op.op, args, attrs=op.attrs,
span=op.span)
def _mutate_binary_op(self, op_cls, op):
"""Helper to mutate binary operators."""
diff --git a/python/tvm/tirx/op.py b/python/tvm/tirx/op.py
index 2f227195b3..0ec7605abb 100644
--- a/python/tvm/tirx/op.py
+++ b/python/tvm/tirx/op.py
@@ -190,7 +190,7 @@ def call_cpacked(*args, span=None):
return Call("int32", Op.get("tirx.tvm_call_cpacked"), call_args, span=span)
-def call_intrin(dtype, func_name, *args, annotations=None, span=None):
+def call_intrin(dtype, func_name, *args, attrs=None, span=None):
"""Build expression by calling an intrinsic function.
Intrinsics can be overloaded with multiple data types via
@@ -207,8 +207,8 @@ def call_intrin(dtype, func_name, *args, annotations=None,
span=None):
args : list
Positional arguments.
- annotations : Optional[Dict[str, Object]]
- Additional annotations about the call.
+ attrs : Optional[tvm.ir.Attrs or Dict[str, Object]]
+ Additional attributes for the call.
span : Optional[Span]
The location of this operator in the source code.
@@ -218,11 +218,7 @@ def call_intrin(dtype, func_name, *args, annotations=None,
span=None):
call : PrimExpr
The call expression.
"""
- if annotations is not None:
- annotations = {
- k: const(v) if isinstance(v, int | bool) else v for k, v in
annotations.items()
- }
- return Call(dtype, func_name, args, annotations=annotations, span=span)
+ return Call(dtype, func_name, args, attrs=attrs, span=span)
def call_pure_extern(dtype, func_name, *args, span=None):
diff --git a/src/arith/ir_mutator_with_analyzer.cc
b/src/arith/ir_mutator_with_analyzer.cc
index e902d32aba..39b7faad84 100644
--- a/src/arith/ir_mutator_with_analyzer.cc
+++ b/src/arith/ir_mutator_with_analyzer.cc
@@ -310,7 +310,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode*
op) {
false_value.same_as(op->args[2])) {
return ffi::GetRef<PrimExpr>(op);
} else {
- return Call(op->dtype, op->op, {cond, true_value, false_value},
op->annotations, op->span);
+ return Call(op->dtype, op->op, {cond, true_value, false_value},
op->attrs, op->span);
}
}
return StmtExprMutator::VisitExpr_(op);
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 1765a6b04a..804cb3cd97 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -2343,8 +2343,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
CallNode* op) {
// Only check constant cases to avoid recursion
if (is_const_number(inner_else_expr) && is_const_number(else_expr) &&
analyzer_->CanProve(inner_else_expr == else_expr)) {
- return Call(op->dtype, op->op, {cond && inner_cond, inner_then_expr,
else_expr},
- op->annotations, op->span);
+ return Call(op->dtype, op->op, {cond && inner_cond, inner_then_expr,
else_expr}, op->attrs,
+ op->span);
}
}
}
diff --git a/src/s_tir/transform/inject_software_pipeline.cc
b/src/s_tir/transform/inject_software_pipeline.cc
index 86fc6028e1..ba6c3bf666 100644
--- a/src/s_tir/transform/inject_software_pipeline.cc
+++ b/src/s_tir/transform/inject_software_pipeline.cc
@@ -119,7 +119,7 @@ class PipelineOpaqueAccessRewriter {
ffi::Array<PrimExpr> new_args = call->args;
const Buffer& new_buffer = (*it).second;
new_args.Set(4, RewriteWmmaFragmentIndex(buffer, new_buffer,
call->args[4]));
- return Call(call->dtype, call->op, new_args, call->annotations,
call->span);
+ return Call(call->dtype, call->op, new_args, call->attrs, call->span);
}
} else if (call->op.same_as(mma_sync)) {
ffi::Array<PrimExpr> new_args = call->args;
@@ -133,7 +133,7 @@ class PipelineOpaqueAccessRewriter {
new_args.Set(i * 2 + 1, new_index);
}
}
- return Call(call->dtype, call->op, new_args, call->annotations,
call->span);
+ return Call(call->dtype, call->op, new_args, call->attrs, call->span);
} else if (call->op.same_as(access_ptr)) {
return RewriteBufferAccess(call, {1});
} else if (call->op.same_as(ptx_mma)) {
@@ -196,7 +196,7 @@ class PipelineOpaqueAccessRewriter {
new_args.Set(i + 1, new_index);
}
}
- return Call(call->dtype, call->op, new_args, call->annotations,
call->span);
+ return Call(call->dtype, call->op, new_args, call->attrs, call->span);
}
const ffi::Map<Var, Buffer>& buffer_data_to_buffer_;
diff --git a/src/tirx/analysis/deep_equal.cc b/src/tirx/analysis/deep_equal.cc
index f164ba427c..53700a85a9 100644
--- a/src/tirx/analysis/deep_equal.cc
+++ b/src/tirx/analysis/deep_equal.cc
@@ -21,6 +21,7 @@
* \file tirx/analysis/deep_equal.cc
* \brief Deep equality checking.
*/
+#include <tvm/ffi/extra/structural_equal.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tirx/analysis.h>
@@ -124,7 +125,8 @@ class ExprDeepEqualChecker : private ExprFunctor<bool(const
PrimExpr&, const Pri
bool VisitExpr_(const CallNode* plhs, const PrimExpr& rhs) final {
const auto* prhs = rhs.as<CallNode>();
return plhs->dtype == prhs->dtype && plhs->op.same_as(prhs->op) &&
- ArrayDeepEqual(plhs->args, prhs->args);
+ ArrayDeepEqual(plhs->args, prhs->args) &&
+ ffi::StructuralEqual()(plhs->attrs, prhs->attrs);
}
bool VisitExpr_(const ReduceNode* plhs, const PrimExpr& rhs) final {
diff --git a/src/tirx/ir/data_type_rewriter.cc
b/src/tirx/ir/data_type_rewriter.cc
index 6fab0e3e09..cc4c2d5f78 100644
--- a/src/tirx/ir/data_type_rewriter.cc
+++ b/src/tirx/ir/data_type_rewriter.cc
@@ -248,8 +248,7 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) {
} else if (op->op.same_as(builtin_pow_)) {
return pow(op->args[0], op->args[1]);
} else if (op->op.same_as(builtin::if_then_else())) {
- return Call(op->dtype, op->op, {op->args[0], op->args[1], op->args[2]},
op->annotations,
- op->span);
+ return Call(op->dtype, op->op, {op->args[0], op->args[1], op->args[2]},
op->attrs, op->span);
} else if (op->op.same_as(Op::Get("tirx.clz"))) {
DataType before_dtype = before->args[0]->dtype;
DataType after_dtype = op->args[0]->dtype;
@@ -566,7 +565,7 @@ PrimExpr IndexDataTypeRewriter::VisitExpr_(const CallNode*
op) {
PrimExpr cond = VisitExpr(op->args[0]);
is_condition_ = is_condition;
return Call(op->dtype, op->op, {cond, VisitExpr(op->args[1]),
VisitExpr(op->args[2])},
- op->annotations, op->span);
+ op->attrs, op->span);
}
return Parent::VisitExpr_(op);
}
diff --git a/src/tirx/ir/expr.cc b/src/tirx/ir/expr.cc
index 5faa7c24bd..d458c09180 100644
--- a/src/tirx/ir/expr.cc
+++ b/src/tirx/ir/expr.cc
@@ -621,8 +621,7 @@ static ffi::Array<PrimExpr>
ConvertCallArgs(ffi::Array<CallArg> args) {
return prim_expr_args;
}
-Call::Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args,
- ffi::Map<ffi::String, ffi::Any> annotations, Span span) {
+Call::Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args, Attrs
attrs, Span span) {
for (size_t i = 0; i < args.size(); ++i) {
TVM_FFI_ICHECK(args[i].defined()) << "arg " << i << " is not defined()";
}
@@ -631,24 +630,27 @@ Call::Call(DataType dtype, RelaxExpr op,
ffi::Array<PrimExpr> args,
node->dtype = dtype;
node->op = std::move(op);
node->args = std::move(args);
- node->annotations = std::move(annotations);
+ node->attrs = std::move(attrs);
node->span = std::move(span);
data_ = std::move(node);
}
+Call::Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args, Span span)
+ : Call(dtype, std::move(op), std::move(args), Attrs(), std::move(span)) {}
+
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("tirx.Call",
[](ffi::Optional<DataType> dtype, RelaxExpr op, ffi::Array<CallArg>
args, Span span) {
- return Call(dtype.value_or(DataType::Void()), op,
ConvertCallArgs(args),
- ffi::Map<ffi::String, ffi::Any>(), span);
+ return Call(dtype.value_or(DataType::Void()), op,
ConvertCallArgs(args), Attrs(),
+ span);
})
- .def("tirx.CallWithAnnotations",
+ .def("tirx.CallWithAttrs",
[](ffi::Optional<DataType> dtype, RelaxExpr op, ffi::Array<CallArg>
args,
- ffi::Optional<ffi::Map<ffi::String, ffi::Any>> annotations, Span
span) {
+ ffi::Optional<Attrs> attrs, Span span) {
return Call(dtype.value_or(DataType::Void()), op,
ConvertCallArgs(args),
- annotations.value_or(ffi::Map<ffi::String,
ffi::Any>()), span);
+ attrs.value_or(Attrs()), span);
});
}
diff --git a/src/tirx/ir/expr_functor.cc b/src/tirx/ir/expr_functor.cc
index 921ce45623..aba96aae8c 100644
--- a/src/tirx/ir/expr_functor.cc
+++ b/src/tirx/ir/expr_functor.cc
@@ -155,7 +155,7 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) {
if (args.same_as(op->args)) {
return ffi::GetRef<PrimExpr>(op);
} else {
- return Call(op->dtype, op->op, args, op->annotations, op->span);
+ return Call(op->dtype, op->op, args, op->attrs, op->span);
}
}
diff --git a/src/tirx/script/printer/expr.cc b/src/tirx/script/printer/expr.cc
index 8aff24f7d8..cd33f59e3c 100644
--- a/src/tirx/script/printer/expr.cc
+++ b/src/tirx/script/printer/expr.cc
@@ -258,7 +258,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tirx::Call>("", [](tirx::Call call, AccessPath call_p,
IRDocsifier d) -> Doc {
- if (!call->annotations.empty()) {
+ if (call->attrs.defined()) {
ffi::Array<ExprDoc> call_args;
int n_args = call->args.size();
call_args.reserve(n_args);
@@ -270,7 +270,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
: d->AsDoc<ExprDoc>(call->op, call_p->Attr("op"));
return TIR(d, "Call")->Call(
{LiteralDoc::DataType(call->dtype, call_p->Attr("dtype")), op_doc,
ListDoc(call_args)},
- {"annotations"}, {d->AsDoc<DictDoc>(call->annotations,
call_p->Attr("annotations"))});
+ {"attrs"}, {d->AsDoc<ExprDoc>(call->attrs,
call_p->Attr("attrs"))});
}
static const OpAttrMap<tirx::TScriptPrinterName>& op_names =
Op::GetAttrMap<tirx::TScriptPrinterName>("TScriptPrinterName");
diff --git a/src/tirx/transform/lower_warp_memory.cc
b/src/tirx/transform/lower_warp_memory.cc
index 9c80ed599d..99c815bf66 100644
--- a/src/tirx/transform/lower_warp_memory.cc
+++ b/src/tirx/transform/lower_warp_memory.cc
@@ -291,7 +291,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
new_args.Set(i + 1, local_index);
}
}
- return Call(op->dtype, op->op, new_args, op->annotations, op->span);
+ return Call(op->dtype, op->op, new_args, op->attrs, op->span);
}
PrimExpr VisitExpr_(const CallNode* op) override {
diff --git a/src/tirx/transform/storage_rewrite.cc
b/src/tirx/transform/storage_rewrite.cc
index c7e8caaf52..66d04ca899 100644
--- a/src/tirx/transform/storage_rewrite.cc
+++ b/src/tirx/transform/storage_rewrite.cc
@@ -497,7 +497,7 @@ class StoragePlanRewriter : public StmtExprMutator {
offset = make_const(offset.dtype(), se->bits_offset / elem_bits) +
offset;
}
return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset,
extent, op->args[4]},
- op->annotations, op->span);
+ op->attrs, op->span);
} else {
return StmtExprMutator::VisitExpr_(op);
}
diff --git a/src/tirx/transform/tile_primitive_dispatch.cc
b/src/tirx/transform/tile_primitive_dispatch.cc
index fbc7786d92..de01ee5db6 100644
--- a/src/tirx/transform/tile_primitive_dispatch.cc
+++ b/src/tirx/transform/tile_primitive_dispatch.cc
@@ -1160,7 +1160,7 @@ class TilePrimitiveDispatcher : public StmtExprMutator {
args.push_back(new_arg);
}
if (changed) {
- return tirx::Call(call->dtype, call->op, args, call->annotations,
call->span);
+ return tirx::Call(call->dtype, call->op, args, call->attrs,
call->span);
}
}
return pred;
diff --git a/src/tirx/transform/unsupported_dtype_legalize.cc
b/src/tirx/transform/unsupported_dtype_legalize.cc
index c2934c2d86..558a3ca437 100644
--- a/src/tirx/transform/unsupported_dtype_legalize.cc
+++ b/src/tirx/transform/unsupported_dtype_legalize.cc
@@ -238,13 +238,12 @@ class ComputeLegalizer : public StmtExprMutator {
auto fmutate = [this](const PrimExpr& e) { return
PromoteToTarget(this->VisitExpr(e)); };
ffi::Array<PrimExpr> args = op->args.Map(fmutate);
if (MatchDType(op->dtype)) {
- return Call(promote_dtype_.with_lanes(op->dtype.lanes()), op->op, args,
op->annotations,
- op->span);
+ return Call(promote_dtype_.with_lanes(op->dtype.lanes()), op->op, args,
op->attrs, op->span);
}
if (args.same_as(op->args)) {
return ffi::GetRef<PrimExpr>(op);
} else {
- return Call(op->dtype, op->op, args, op->annotations, op->span);
+ return Call(op->dtype, op->op, args, op->attrs, op->span);
}
}
diff --git a/src/tirx/transform/vectorize_loop.cc
b/src/tirx/transform/vectorize_loop.cc
index f444c17822..0ac9680d0a 100644
--- a/src/tirx/transform/vectorize_loop.cc
+++ b/src/tirx/transform/vectorize_loop.cc
@@ -491,10 +491,10 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
t = BroadcastTo(t, lanes, is_scalable);
f = BroadcastTo(f, lanes, is_scalable);
if (is_scalable) {
- return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
{cond, t, f},
- op->annotations, op->span);
+ return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
{cond, t, f}, op->attrs,
+ op->span);
} else {
- return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f},
op->annotations, op->span);
+ return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f},
op->attrs, op->span);
}
}
}
@@ -507,14 +507,14 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
} else {
int lanes = value.dtype().get_lanes_or_vscale_factor();
if (value.dtype().is_scalable_vector()) {
- return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
{value}, op->annotations,
+ return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
{value}, op->attrs,
op->span);
} else {
int new_lanes = (op->dtype != DataType::Float4E2M1FN() &&
op->args[0].dtype() != DataType::Float4E2M1FN())
? (value.dtype().bits() * value.dtype().lanes()) /
op->dtype.bits()
: value.dtype().lanes();
- return Call(op->dtype.with_lanes(new_lanes), op->op, {value},
op->annotations, op->span);
+ return Call(op->dtype.with_lanes(new_lanes), op->op, {value},
op->attrs, op->span);
}
}
}
@@ -536,7 +536,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
auto new_args = op->args;
new_args.pop_back();
new_args.push_back(fcd[0]);
- return Call(op->dtype.with_lanes(lane), op->op, new_args,
op->annotations, op->span);
+ return Call(op->dtype.with_lanes(lane), op->op, new_args, op->attrs,
op->span);
} else if (op->op.same_as(builtin::texture2d_store())) {
int lane = 0;
// Vectorize the value to store
@@ -551,7 +551,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
<< "Expected Data to be Written equal to Texture Store length";
ffi::Array<PrimExpr> new_args{op->args[0], op->args[1], op->args[2],
op->args[3], op->args[4],
mutated_value[0]};
- return Call(op->dtype.with_lanes(lane), op->op, new_args,
op->annotations, op->span);
+ return Call(op->dtype.with_lanes(lane), op->op, new_args, op->attrs,
op->span);
} else if (op->op.same_as(builtin::reinterpret())) {
return MutateReinterpretExpr_(op);
}
@@ -573,7 +573,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
if (op->args.same_as(new_args)) {
return ffi::GetRef<PrimExpr>(op);
} else {
- return Call(op->dtype, op->op, new_args, op->annotations, op->span);
+ return Call(op->dtype, op->op, new_args, op->attrs, op->span);
}
} else {
int lane = 0;
@@ -599,7 +599,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
if (op->args.same_as(new_args)) {
return ffi::GetRef<PrimExpr>(op);
} else {
- return Call(op->dtype.with_lanes(lane), op->op, new_args,
op->annotations, op->span);
+ return Call(op->dtype.with_lanes(lane), op->op, new_args, op->attrs,
op->span);
}
}
}
diff --git a/tests/python/tirx-base/test_tir_constructor.py
b/tests/python/tirx-base/test_tir_constructor.py
index d084fe2b25..eda7fd9ebf 100644
--- a/tests/python/tirx-base/test_tir_constructor.py
+++ b/tests/python/tirx-base/test_tir_constructor.py
@@ -19,6 +19,7 @@ import pytest
import tvm
from tvm import te, topi
+from tvm.tirx.analysis import expr_deep_equal
from tvm.tirx.expr_functor import ExprMutator
@@ -133,32 +134,39 @@ def test_expr_constructor():
assert x.dtype == "float32"
assert x.op.name == "tirx.call_extern"
assert x.args[1] == a
- assert len(x.annotations) == 0
+ assert x.attrs is None
- annotated_arg = tvm.tirx.Var("annotated_arg", "float32")
- x_with_annotations = tvm.tirx.Call(
+ attr_arg = tvm.tirx.Var("attr_arg", "float32")
+ x_with_attrs = tvm.tirx.Call(
"float32",
"tirx.call_extern",
- [tvm.tirx.StringImm("xyz"), annotated_arg],
- annotations={"disable_tma": True},
+ [tvm.tirx.StringImm("xyz"), attr_arg],
+ attrs={"disable_tma": True},
)
- assert bool(x_with_annotations.annotations["disable_tma"])
- assert not tvm.ir.structural_equal(x, x_with_annotations)
- script = tvm.tirx.Evaluate(x_with_annotations).script()
- assert "annotations" in script
+ assert x_with_attrs.attrs["disable_tma"] is True
+ assert not tvm.ir.structural_equal(x, x_with_attrs)
+ script = tvm.tirx.Evaluate(x_with_attrs).script()
+ assert "attrs" in script
assert "disable_tma" in script
- func = tvm.tirx.PrimFunc([], tvm.tirx.Evaluate(x_with_annotations))
+ func = tvm.tirx.PrimFunc([], tvm.tirx.Evaluate(x_with_attrs))
assert tvm.script.from_source(func.script()).script() == func.script()
y = tvm.tirx.Var("y", "float32")
- mutated = ReplaceVar(annotated_arg, y)(x_with_annotations)
- assert bool(mutated.annotations["disable_tma"])
+ mutated = ReplaceVar(attr_arg, y)(x_with_attrs)
+ assert mutated.attrs["disable_tma"] is True
assert mutated.args[1].same_as(y)
x_from_intrin = tvm.tirx.call_intrin(
- "float32", "tirx.call_extern", tvm.tirx.StringImm("xyz"),
annotations={"disable_tma": True}
+ "float32", "tirx.call_extern", tvm.tirx.StringImm("xyz"),
attrs={"disable_tma": True}
)
- assert int(x_from_intrin.annotations["disable_tma"]) == 1
+ assert x_from_intrin.attrs["disable_tma"] is True
+ x_with_other_attrs = tvm.tirx.Call(
+ "float32",
+ "tirx.call_extern",
+ [tvm.tirx.StringImm("xyz"), attr_arg],
+ attrs={"disable_tma": False},
+ )
+ assert not expr_deep_equal(x_with_attrs, x_with_other_attrs)
cond0 = tvm.tirx.Var("cond0", "bool")
cond1 = tvm.tirx.Var("cond1", "bool")
@@ -171,12 +179,12 @@ def test_expr_constructor():
"int32",
"tirx.if_then_else",
[cond0, inner_if, tvm.tirx.IntImm("int32", 0)],
- annotations={"keep": True},
+ attrs={"keep": True},
)
simplified = tvm.tirx.transform.StmtSimplify()(
tvm.IRModule({"main": tvm.tirx.PrimFunc([],
tvm.tirx.Evaluate(outer_if))})
)["main"].body.value
- assert bool(simplified.annotations["keep"])
+ assert simplified.attrs["keep"] is True
v = tvm.tirx.Var("aa", "int32")
x = tvm.tirx.Let(v, 1, v)