This is an automated email from the ASF dual-hosted git repository.
spectrometerHBH pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new cae6cb89b7 [IR] Add annotations to Call nodes (#19597)
cae6cb89b7 is described below
commit cae6cb89b732fd7f874b8fabc5fdba95edb41339
Author: Shushi Hong <[email protected]>
AuthorDate: Sun May 24 18:57:37 2026 -0400
[IR] Add annotations to Call nodes (#19597)
This PR adds annotation support to `tirx.Call` so downstream codegen
users can attach call-level metadata and preserve it through TIRX
transforms.
What changed:
- Add `CallNode::annotations` and expose it through reflection.
- Add Python `tvm.tirx.Call(..., annotations=...)` support.
- Preserve call annotations in C++ and Python expression mutators.
- Preserve annotations across TIRX/arith passes that rebuild equivalent
calls.
- Print annotated calls as `Tx.Call(..., annotations={...})` and support
script roundtrip.
- Add regression coverage for annotated calls, mutator preservation,
script roundtrip, and simplify preservation.
This pr also cleans some stuff that #19596 didn't clean completely
---
include/tvm/tirx/expr.h | 17 ++++-
include/tvm/tirx/op.h | 28 ++++-----
python/tvm/rpc/tracker.py | 4 +-
python/tvm/tirx/expr.py | 18 +++++-
python/tvm/tirx/expr_functor.py | 2 +-
python/tvm/tirx/op.py | 29 ++++++---
src/arith/ir_mutator_with_analyzer.cc | 2 +-
src/arith/rewrite_simplify.cc | 3 +-
src/s_tir/transform/inject_software_pipeline.cc | 6 +-
src/tirx/ir/data_type_rewriter.cc | 6 +-
src/tirx/ir/expr.cc | 79 ++++++++++++++----------
src/tirx/ir/expr_functor.cc | 2 +-
src/tirx/ir/stmt.cc | 2 +-
src/tirx/op/op.cc | 51 +++++++--------
src/tirx/script/printer/expr.cc | 14 +++++
src/tirx/transform/lower_warp_memory.cc | 2 +-
src/tirx/transform/storage_rewrite.cc | 3 +-
src/tirx/transform/tile_primitive_dispatch.cc | 2 +-
src/tirx/transform/unsupported_dtype_legalize.cc | 5 +-
src/tirx/transform/vectorize_loop.cc | 18 +++---
tests/python/contrib/test_rpc_tracker.py | 4 +-
tests/python/tirx-base/test_tir_constructor.py | 57 +++++++++++++++++
22 files changed, 239 insertions(+), 115 deletions(-)
diff --git a/include/tvm/tirx/expr.h b/include/tvm/tirx/expr.h
index db54984c82..68cfcbd361 100644
--- a/include/tvm/tirx/expr.h
+++ b/include/tvm/tirx/expr.h
@@ -731,9 +731,20 @@ class CallNode : public PrimExprNode {
/*! \brief The arguments. */
ffi::Array<PrimExpr> args;
+ /*!
+ * \brief Additional annotations about the call.
+ *
+ * These annotations can be used to carry target-specific metadata through
+ * TIRX transformations and codegen.
+ */
+ ffi::Map<ffi::String, ffi::Any> annotations;
+
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
- refl::ObjectDef<CallNode>().def_ro("op", &CallNode::op).def_ro("args",
&CallNode::args);
+ refl::ObjectDef<CallNode>()
+ .def_ro("op", &CallNode::op)
+ .def_ro("args", &CallNode::args)
+ .def_ro("annotations", &CallNode::annotations);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Call", CallNode, PrimExprNode);
};
@@ -744,7 +755,9 @@ class CallNode : public PrimExprNode {
*/
class Call : public PrimExpr {
public:
- TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args, Span
span = Span());
+ TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args,
+ ffi::Map<ffi::String, ffi::Any> annotations =
ffi::Map<ffi::String, ffi::Any>(),
+ Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, PrimExpr, CallNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode);
};
diff --git a/include/tvm/tirx/op.h b/include/tvm/tirx/op.h
index 9093c2c453..549aab4df8 100644
--- a/include/tvm/tirx/op.h
+++ b/include/tvm/tirx/op.h
@@ -736,19 +736,19 @@ inline void CheckMathUnaryOpInputDType(const char*
op_name, DataType dtype) {
}
// Intrinsic operators
-#define TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckInputDType) \
- inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
- static const Op& op = Op::Get("tirx." #OpName); \
- CheckInputDType(#OpName, x.dtype()); \
- if (x.dtype().is_bfloat16()) { \
- DataType bf16_dtype = x.dtype(); \
- DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \
- PrimExpr x_fp32 = tirx::Cast(fp32_dtype, {x}, span); \
- PrimExpr result_fp32 = tirx::Call(fp32_dtype, op, {x_fp32}, span); \
- return tirx::Cast(bf16_dtype, {result_fp32}, span); \
- } else { \
- return tirx::Call(x.dtype(), op, {x}, span); \
- } \
+#define TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckInputDType) \
+ inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
+ static const Op& op = Op::Get("tirx." #OpName); \
+ CheckInputDType(#OpName, x.dtype()); \
+ if (x.dtype().is_bfloat16()) { \
+ DataType bf16_dtype = x.dtype(); \
+ DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \
+ PrimExpr x_fp32 = tirx::Cast(fp32_dtype, {x}, span); \
+ PrimExpr result_fp32 = tirx::Call(fp32_dtype, op, {x_fp32}, {}, span); \
+ return tirx::Cast(bf16_dtype, {result_fp32}, span); \
+ } else { \
+ return tirx::Call(x.dtype(), op, {x}, {}, span); \
+ } \
}
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
@@ -786,7 +786,7 @@ TVM_DECLARE_INTRIN_UNARY(clz);
#define TVM_DECLARE_INTRIN_BINARY(OpName) \
inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \
static const Op& op = Op::Get("tirx." #OpName); \
- return tirx::Call(x.dtype(), op, {x, y}, span); \
+ return tirx::Call(x.dtype(), op, {x, y}, {}, span); \
}
TVM_DECLARE_INTRIN_BINARY(atan2);
diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py
index 0714c64fc9..1af2a26985 100644
--- a/python/tvm/rpc/tracker.py
+++ b/python/tvm/rpc/tracker.py
@@ -248,9 +248,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
try:
self.call_handler(json.loads(msg))
except Exception: # pylint: disable=broad-except
- logger.warning(
- "Error handling message from %s", self.name(),
exc_info=True
- )
+ logger.warning("Error handling message from %s",
self.name(), exc_info=True)
self.close()
return
else:
diff --git a/python/tvm/tirx/expr.py b/python/tvm/tirx/expr.py
index e3c341c4e9..bf04b3ae7a 100644
--- a/python/tvm/tirx/expr.py
+++ b/python/tvm/tirx/expr.py
@@ -1302,13 +1302,22 @@ class Call(PrimExprWithOp):
span : Optional[Span]
The location of this expression in the source code.
+
+ annotations : Optional[dict]
+ Additional metadata attached to the call.
"""
op: Op
args: list[PrimExpr]
+ annotations: dict
def __init__(
- self, dtype: str, op: Op | str, args: list[PrimExpr], span: Span |
None = None
+ self,
+ dtype: str,
+ op: Op | str,
+ args: list[PrimExpr],
+ annotations: dict | None = None,
+ span: Span | None = None,
) -> None:
if isinstance(op, str):
if not op.startswith("tirx."):
@@ -1321,7 +1330,12 @@ class Call(PrimExprWithOp):
% op
)
op = Op.get(op)
- self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op, args,
span) # type: ignore
+ if annotations:
+ self.__init_handle_by_constructor__( # type: ignore
+ _ffi_api.CallWithAnnotations, dtype, op, args, annotations,
span
+ )
+ else:
+ self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op,
args, span) # type: ignore
@tvm_ffi.register_object("tirx.Let")
diff --git a/python/tvm/tirx/expr_functor.py b/python/tvm/tirx/expr_functor.py
index e89ed19c1e..b09606602a 100644
--- a/python/tvm/tirx/expr_functor.py
+++ b/python/tvm/tirx/expr_functor.py
@@ -495,7 +495,7 @@ class ExprMutator(ExprFunctor):
if all(old_arg is new_arg for old_arg, new_arg in zip(op.args, args)):
return op
else:
- return tvm.tirx.Call(op.dtype, op.op, args)
+ return tvm.tirx.Call(op.dtype, op.op, args,
annotations=op.annotations, span=op.span)
def _mutate_binary_op(self, op_cls, op):
"""Helper to mutate binary operators."""
diff --git a/python/tvm/tirx/op.py b/python/tvm/tirx/op.py
index 924bec91dc..2f227195b3 100644
--- a/python/tvm/tirx/op.py
+++ b/python/tvm/tirx/op.py
@@ -62,8 +62,10 @@ tir = tirx # alias for backward compat with upstream
tir.convert() calls
def _pack_buffer(buf, span=None):
"""Build intrinsics that packs the buffer."""
- shape = Call("handle", "tirx.tvm_stack_make_shape", buf.shape, span)
- strides = Call("handle", "tirx.tvm_stack_make_shape", buf.strides, span)
if buf.strides else 0
+ shape = Call("handle", "tirx.tvm_stack_make_shape", buf.shape, span=span)
+ strides = (
+ Call("handle", "tirx.tvm_stack_make_shape", buf.strides, span=span) if
buf.strides else 0
+ )
pack_args = [
buf.data,
shape,
@@ -72,7 +74,7 @@ def _pack_buffer(buf, span=None):
const(0, dtype=buf.dtype),
buf.elem_offset,
]
- return Call("handle", Op.get("tirx.tvm_stack_make_array"), pack_args, span)
+ return Call("handle", Op.get("tirx.tvm_stack_make_array"), pack_args,
span=span)
def call_packed_lowered(*args, span=None):
@@ -101,7 +103,7 @@ def call_packed_lowered(*args, span=None):
te.extern : Create tensor with extern function call.
"""
call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
- return Call("int32", Op.get("tirx.tvm_call_packed_lowered"), call_args,
span)
+ return Call("int32", Op.get("tirx.tvm_call_packed_lowered"), call_args,
span=span)
def call_cpacked_lowered(*args, span=None):
@@ -127,7 +129,7 @@ def call_cpacked_lowered(*args, span=None):
te.extern : Create tensor with extern function call.
"""
call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
- return Call("int32", Op.get("tirx.tvm_call_cpacked_lowered"), call_args,
span)
+ return Call("int32", Op.get("tirx.tvm_call_cpacked_lowered"), call_args,
span=span)
def call_packed(*args, span=None):
@@ -158,7 +160,7 @@ def call_packed(*args, span=None):
te.extern : Create tensor with extern function call.
"""
call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
- return Call("int32", Op.get("tirx.tvm_call_packed"), call_args, span)
+ return Call("int32", Op.get("tirx.tvm_call_packed"), call_args, span=span)
def call_cpacked(*args, span=None):
@@ -185,10 +187,10 @@ def call_cpacked(*args, span=None):
te.extern : Create tensor with extern function call.
"""
call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args]
- return Call("int32", Op.get("tirx.tvm_call_cpacked"), call_args, span)
+ return Call("int32", Op.get("tirx.tvm_call_cpacked"), call_args, span=span)
-def call_intrin(dtype, func_name, *args, span=None):
+def call_intrin(dtype, func_name, *args, annotations=None, span=None):
"""Build expression by calling an intrinsic function.
Intrinsics can be overloaded with multiple data types via
@@ -205,6 +207,9 @@ def call_intrin(dtype, func_name, *args, span=None):
args : list
Positional arguments.
+ annotations : Optional[Dict[str, Object]]
+ Additional annotations about the call.
+
span : Optional[Span]
The location of this operator in the source code.
@@ -213,7 +218,11 @@ def call_intrin(dtype, func_name, *args, span=None):
call : PrimExpr
The call expression.
"""
- return Call(dtype, func_name, args, span)
+ if annotations is not None:
+ annotations = {
+ k: const(v) if isinstance(v, int | bool) else v for k, v in
annotations.items()
+ }
+ return Call(dtype, func_name, args, annotations=annotations, span=span)
def call_pure_extern(dtype, func_name, *args, span=None):
@@ -238,7 +247,7 @@ def call_pure_extern(dtype, func_name, *args, span=None):
call : PrimExpr
The call expression.
"""
- return Call(dtype, Op.get("tirx.call_pure_extern"), [func_name, *args],
span)
+ return Call(dtype, Op.get("tirx.call_pure_extern"), [func_name, *args],
span=span)
def call_extern(dtype, func_name, *args, span=None):
diff --git a/src/arith/ir_mutator_with_analyzer.cc
b/src/arith/ir_mutator_with_analyzer.cc
index 1d35da952f..e902d32aba 100644
--- a/src/arith/ir_mutator_with_analyzer.cc
+++ b/src/arith/ir_mutator_with_analyzer.cc
@@ -310,7 +310,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode*
op) {
false_value.same_as(op->args[2])) {
return ffi::GetRef<PrimExpr>(op);
} else {
- return Call(op->dtype, op->op, {cond, true_value, false_value});
+ return Call(op->dtype, op->op, {cond, true_value, false_value},
op->annotations, op->span);
}
}
return StmtExprMutator::VisitExpr_(op);
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index c9fec7f599..1765a6b04a 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -2343,7 +2343,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
CallNode* op) {
// Only check constant cases to avoid recursion
if (is_const_number(inner_else_expr) && is_const_number(else_expr) &&
analyzer_->CanProve(inner_else_expr == else_expr)) {
- return if_then_else(cond && inner_cond, inner_then_expr, else_expr);
+ return Call(op->dtype, op->op, {cond && inner_cond, inner_then_expr,
else_expr},
+ op->annotations, op->span);
}
}
}
diff --git a/src/s_tir/transform/inject_software_pipeline.cc
b/src/s_tir/transform/inject_software_pipeline.cc
index 1512644052..717b9b7dc8 100644
--- a/src/s_tir/transform/inject_software_pipeline.cc
+++ b/src/s_tir/transform/inject_software_pipeline.cc
@@ -119,7 +119,7 @@ class PipelineOpaqueAccessRewriter {
ffi::Array<PrimExpr> new_args = call->args;
const Buffer& new_buffer = (*it).second;
new_args.Set(4, RewriteWmmaFragmentIndex(buffer, new_buffer,
call->args[4]));
- return Call(call->dtype, call->op, new_args, call->span);
+ return Call(call->dtype, call->op, new_args, call->annotations,
call->span);
}
} else if (call->op.same_as(mma_sync)) {
ffi::Array<PrimExpr> new_args = call->args;
@@ -133,7 +133,7 @@ class PipelineOpaqueAccessRewriter {
new_args.Set(i * 2 + 1, new_index);
}
}
- return Call(call->dtype, call->op, new_args, call->span);
+ return Call(call->dtype, call->op, new_args, call->annotations,
call->span);
} else if (call->op.same_as(access_ptr)) {
return RewriteBufferAccess(call, {1});
} else if (call->op.same_as(ptx_mma)) {
@@ -196,7 +196,7 @@ class PipelineOpaqueAccessRewriter {
new_args.Set(i + 1, new_index);
}
}
- return Call(call->dtype, call->op, new_args, call->span);
+ return Call(call->dtype, call->op, new_args, call->annotations,
call->span);
}
const ffi::Map<Var, Buffer>& buffer_data_to_buffer_;
diff --git a/src/tirx/ir/data_type_rewriter.cc
b/src/tirx/ir/data_type_rewriter.cc
index 901d18e5c4..6fab0e3e09 100644
--- a/src/tirx/ir/data_type_rewriter.cc
+++ b/src/tirx/ir/data_type_rewriter.cc
@@ -248,7 +248,8 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) {
} else if (op->op.same_as(builtin_pow_)) {
return pow(op->args[0], op->args[1]);
} else if (op->op.same_as(builtin::if_then_else())) {
- return if_then_else(op->args[0], op->args[1], op->args[2]);
+ return Call(op->dtype, op->op, {op->args[0], op->args[1], op->args[2]},
op->annotations,
+ op->span);
} else if (op->op.same_as(Op::Get("tirx.clz"))) {
DataType before_dtype = before->args[0]->dtype;
DataType after_dtype = op->args[0]->dtype;
@@ -564,7 +565,8 @@ PrimExpr IndexDataTypeRewriter::VisitExpr_(const CallNode*
op) {
is_condition_ = true;
PrimExpr cond = VisitExpr(op->args[0]);
is_condition_ = is_condition;
- return if_then_else(cond, VisitExpr(op->args[1]), VisitExpr(op->args[2]));
+ return Call(op->dtype, op->op, {cond, VisitExpr(op->args[1]),
VisitExpr(op->args[2])},
+ op->annotations, op->span);
}
return Parent::VisitExpr_(op);
}
diff --git a/src/tirx/ir/expr.cc b/src/tirx/ir/expr.cc
index 3248c009b4..5faa7c24bd 100644
--- a/src/tirx/ir/expr.cc
+++ b/src/tirx/ir/expr.cc
@@ -590,7 +590,39 @@ TVM_FFI_STATIC_INIT_BLOCK() {
}
// Call
-Call::Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args, Span span)
{
+using CallArg = ffi::Variant<ffi::String, DLDataType, IterVar, BufferRegion,
PrimExpr>;
+
+static ffi::Array<PrimExpr> ConvertCallArgs(ffi::Array<CallArg> args) {
+ ffi::Array<PrimExpr> prim_expr_args;
+ for (const auto& it : args) {
+ if (auto opt_str = it.as<ffi::String>()) {
+ prim_expr_args.push_back(StringImm(opt_str.value()));
+ } else if (auto opt_dtype = it.as<DLDataType>()) {
+
prim_expr_args.push_back(StringImm(ffi::DLDataTypeToString(opt_dtype.value())));
+ } else if (const auto* iter_var = it.as<IterVarNode>()) {
+ prim_expr_args.push_back(iter_var->var);
+ } else if (const auto* br = it.as<BufferRegionNode>()) {
+ ffi::Array<PrimExpr> indices;
+ for (Range r : br->region) {
+ if (is_one(r->extent)) {
+ indices.push_back(r->min);
+ } else if (r->extent.as<IntImmNode>()) {
+ indices.push_back(tirx::Ramp(r->min, make_const(r->min->dtype, 1),
r->extent));
+ } else {
+ TVM_FFI_THROW(ValueError)
+ << "Cannot convert to BufferLoad: " <<
ffi::GetRef<BufferRegion>(br);
+ }
+ }
+ prim_expr_args.push_back(BufferLoad(br->buffer, indices));
+ } else {
+ prim_expr_args.push_back(Downcast<PrimExpr>(it));
+ }
+ }
+ return prim_expr_args;
+}
+
+Call::Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args,
+ ffi::Map<ffi::String, ffi::Any> annotations, Span span) {
for (size_t i = 0; i < args.size(); ++i) {
TVM_FFI_ICHECK(args[i].defined()) << "arg " << i << " is not defined()";
}
@@ -599,44 +631,25 @@ Call::Call(DataType dtype, RelaxExpr op,
ffi::Array<PrimExpr> args, Span span) {
node->dtype = dtype;
node->op = std::move(op);
node->args = std::move(args);
+ node->annotations = std::move(annotations);
node->span = std::move(span);
data_ = std::move(node);
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def(
- "tirx.Call",
- [](ffi::Optional<DataType> dtype, RelaxExpr op,
- ffi::Array<ffi::Variant<ffi::String, DLDataType, IterVar,
BufferRegion, PrimExpr>> args,
- Span span) {
- ffi::Array<PrimExpr> prim_expr_args;
- for (const auto& it : args) {
- if (auto opt_str = it.as<ffi::String>()) {
- prim_expr_args.push_back(StringImm(opt_str.value()));
- } else if (auto opt_dtype = it.as<DLDataType>()) {
-
prim_expr_args.push_back(StringImm(ffi::DLDataTypeToString(opt_dtype.value())));
- } else if (const auto* iter_var = it.as<IterVarNode>()) {
- prim_expr_args.push_back(iter_var->var);
- } else if (const auto* br = it.as<BufferRegionNode>()) {
- ffi::Array<PrimExpr> indices;
- for (Range r : br->region) {
- if (is_one(r->extent)) {
- indices.push_back(r->min);
- } else if (r->extent.as<IntImmNode>()) {
- indices.push_back(tirx::Ramp(r->min, make_const(r->min->dtype,
1), r->extent));
- } else {
- TVM_FFI_THROW(ValueError)
- << "Cannot convert to BufferLoad: " <<
ffi::GetRef<BufferRegion>(br);
- }
- }
- prim_expr_args.push_back(BufferLoad(br->buffer, indices));
- } else {
- prim_expr_args.push_back(Downcast<PrimExpr>(it));
- }
- }
- return Call(dtype.value_or(DataType::Void()), op, prim_expr_args,
span);
- });
+ refl::GlobalDef()
+ .def("tirx.Call",
+ [](ffi::Optional<DataType> dtype, RelaxExpr op, ffi::Array<CallArg>
args, Span span) {
+ return Call(dtype.value_or(DataType::Void()), op,
ConvertCallArgs(args),
+ ffi::Map<ffi::String, ffi::Any>(), span);
+ })
+ .def("tirx.CallWithAnnotations",
+ [](ffi::Optional<DataType> dtype, RelaxExpr op, ffi::Array<CallArg>
args,
+ ffi::Optional<ffi::Map<ffi::String, ffi::Any>> annotations, Span
span) {
+ return Call(dtype.value_or(DataType::Void()), op,
ConvertCallArgs(args),
+ annotations.value_or(ffi::Map<ffi::String,
ffi::Any>()), span);
+ });
}
// Shuffle
diff --git a/src/tirx/ir/expr_functor.cc b/src/tirx/ir/expr_functor.cc
index dc9913060e..921ce45623 100644
--- a/src/tirx/ir/expr_functor.cc
+++ b/src/tirx/ir/expr_functor.cc
@@ -155,7 +155,7 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) {
if (args.same_as(op->args)) {
return ffi::GetRef<PrimExpr>(op);
} else {
- return Call(op->dtype, op->op, args);
+ return Call(op->dtype, op->op, args, op->annotations, op->span);
}
}
diff --git a/src/tirx/ir/stmt.cc b/src/tirx/ir/stmt.cc
index 1a9abe6ca8..1943135926 100644
--- a/src/tirx/ir/stmt.cc
+++ b/src/tirx/ir/stmt.cc
@@ -674,7 +674,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
PrimExpr TypeAnnotation(DataType dtype, Span span) {
static auto op = Op::Get("tirx.type_annotation");
- return tirx::Call(dtype, op, {}, span);
+ return tirx::Call(dtype, op, {}, {}, span);
}
TVM_TIRX_REGISTER_OP("type_annotation")
diff --git a/src/tirx/op/op.cc b/src/tirx/op/op.cc
index c2772ad69f..bc500a54cc 100644
--- a/src/tirx/op/op.cc
+++ b/src/tirx/op/op.cc
@@ -128,14 +128,14 @@ Type GetTypeFromRuntimeDataType(const DataType& dtype) {
PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high, Span span) {
return tirx::Call(
t, tirx::builtin::large_uint_imm(),
- {make_const(DataType::UInt(32), low, span),
make_const(DataType::UInt(32), high, span)},
+ {make_const(DataType::UInt(32), low, span),
make_const(DataType::UInt(32), high, span)}, {},
span);
}
// Q-multiplication
PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span
span) {
return tirx::Call(DataType::Int(32, x.dtype().lanes()),
tirx::builtin::q_multiply_shift(),
- {x, y, q, s}, span);
+ {x, y, q, s}, {}, span);
}
void BroadcastToMatchLanes(PrimExpr& op_a, PrimExpr& op_b) { // NOLINT(*)
@@ -263,19 +263,19 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs,
Span span) { // NOLINT(*)
PrimExpr ret(PrimExpr value, Span span) {
TVM_FFI_ICHECK(value.defined());
- return tirx::Call(value.dtype(), tirx::builtin::ret(), {value}, span);
+ return tirx::Call(value.dtype(), tirx::builtin::ret(), {value}, {}, span);
}
PrimExpr thread_return(Span span) {
- return tirx::Call(DataType::Void(), tirx::builtin::thread_return(), {},
span);
+ return tirx::Call(DataType::Void(), tirx::builtin::thread_return(), {}, {},
span);
}
PrimExpr continue_loop(Span span) {
- return tirx::Call(DataType::Void(), tirx::builtin::continue_loop(), {},
span);
+ return tirx::Call(DataType::Void(), tirx::builtin::continue_loop(), {}, {},
span);
}
PrimExpr break_loop(Span span) {
- return tirx::Call(DataType::Void(), tirx::builtin::break_loop(), {}, span);
+ return tirx::Call(DataType::Void(), tirx::builtin::break_loop(), {}, {},
span);
}
TVM_FFI_STATIC_INIT_BLOCK() {
@@ -512,7 +512,7 @@ PrimExpr reinterpret(const DataType& t, PrimExpr value,
Span span) {
value.dtype().bytes() * value.dtype().lanes() == t.bytes()
* t.lanes()))
<< "Reinterpret requires size match " << t << " vs " << value.dtype();
}
- return tirx::Call(t, tirx::builtin::reinterpret(), {value}, span);
+ return tirx::Call(t, tirx::builtin::reinterpret(), {value}, {}, span);
}
// operator+
@@ -654,13 +654,13 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value,
PrimExpr false_value,
}
return tirx::Call(true_value.dtype(), tirx::builtin::if_then_else(),
- {cond, true_value, false_value}, span);
+ {cond, true_value, false_value}, {}, span);
}
// likely
PrimExpr likely(PrimExpr cond, Span span) {
if (is_const_int(cond)) return cond;
- return tirx::Call(cond.dtype(), tirx::builtin::likely(), {cond}, span);
+ return tirx::Call(cond.dtype(), tirx::builtin::likely(), {cond}, {}, span);
}
// operator>
@@ -786,7 +786,7 @@ PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span) {
}
});
- return tirx::Call(a.dtype(), tirx::builtin::shift_right(), {a, b}, span);
+ return tirx::Call(a.dtype(), tirx::builtin::shift_right(), {a, b}, {}, span);
}
// shift left
@@ -805,7 +805,7 @@ PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span) {
if (pb->value == 0) return a;
}
});
- return tirx::Call(a.dtype(), tirx::builtin::shift_left(), {a, b}, span);
+ return tirx::Call(a.dtype(), tirx::builtin::shift_left(), {a, b}, {}, span);
}
// bitwise and
@@ -817,7 +817,7 @@ PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) {
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, (pa->value & pb->value), span);
});
- return tirx::Call(a.dtype(), tirx::builtin::bitwise_and(), {a, b}, span);
+ return tirx::Call(a.dtype(), tirx::builtin::bitwise_and(), {a, b}, {}, span);
}
// bitwise_or
@@ -829,7 +829,7 @@ PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) {
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, (pa->value | pb->value), span);
});
- return tirx::Call(a.dtype(), tirx::builtin::bitwise_or(), {a, b}, span);
+ return tirx::Call(a.dtype(), tirx::builtin::bitwise_or(), {a, b}, {}, span);
}
// bitwise_xor
@@ -841,7 +841,7 @@ PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) {
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value), span);
});
- return tirx::Call(a.dtype(), tirx::builtin::bitwise_xor(), {a, b}, span);
+ return tirx::Call(a.dtype(), tirx::builtin::bitwise_xor(), {a, b}, {}, span);
}
// bitwise_not
@@ -849,7 +849,7 @@ PrimExpr operator~(PrimExpr a) { return bitwise_neg(a); }
PrimExpr bitwise_neg(PrimExpr a, Span span) {
type_check_int_or_bool_args(a, "~ operator (bitwise NOT)");
- return tirx::Call(a.dtype(), tirx::builtin::bitwise_not(), {a}, span);
+ return tirx::Call(a.dtype(), tirx::builtin::bitwise_not(), {a}, {}, span);
}
TVM_FFI_STATIC_INIT_BLOCK() {
@@ -889,7 +889,7 @@ PrimExpr pow(PrimExpr x, PrimExpr y, Span span) {
}
static auto op = Op::Get("tirx.pow");
- return tirx::Call(x.dtype(), op, {x, y}, span);
+ return tirx::Call(x.dtype(), op, {x, y}, {}, span);
}
TVM_TIR_REGISTER_PURE_BINARY_OP("pow").set_attr<TVectorizable>("TVectorizable",
true);
@@ -910,7 +910,7 @@ PrimExpr abs(PrimExpr x, Span span) {
return FloatImm(x.dtype(), std::fabs(fx->value), fx->span);
}
static auto op = Op::Get("tirx.fabs");
- return tirx::Call(x.dtype(), op, {x}, span);
+ return tirx::Call(x.dtype(), op, {x}, {}, span);
} else if (x.dtype().is_uint()) {
return x;
} else {
@@ -935,9 +935,10 @@ PrimExpr isnan(PrimExpr x, Span span) {
}
static auto op = Op::Get("tirx.isnan");
if (x.dtype().bits() == 16) {
- return tirx::Call(t, op, {cast(DataType::Float(32, t.lanes()),
std::move(x), span)}, span);
+ return tirx::Call(t, op, {cast(DataType::Float(32, t.lanes()),
std::move(x), span)}, {},
+ span);
} else {
- return tirx::Call(t, op, {x}, span);
+ return tirx::Call(t, op, {x}, {}, span);
}
} else {
TVM_FFI_THROW(InternalError) << "Data type " << x.dtype()
@@ -1025,7 +1026,7 @@ PrimExpr fmod(PrimExpr x, PrimExpr y, Span span) {
BinaryOpMatchTypes(x, y, span);
TVM_FFI_ICHECK(x.dtype().is_float()) << "fmod only applies to float";
static auto op = Op::Get("tirx.fmod");
- return tirx::Call(x.dtype(), op, {x, y}, span);
+ return tirx::Call(x.dtype(), op, {x, y}, {}, span);
}
TVM_TIR_REGISTER_PURE_UNARY_OP("fmod");
@@ -1039,7 +1040,7 @@ PrimExpr floor(PrimExpr x, Span span) {
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::floor(fx->value), fx->span);
static auto op = Op::Get("tirx.floor");
- return tirx::Call(x.dtype(), op, {x}, span);
+ return tirx::Call(x.dtype(), op, {x}, {}, span);
}
TVM_TIR_REGISTER_PURE_UNARY_OP("floor").set_attr<TVectorizable>("TVectorizable",
true);
@@ -1053,7 +1054,7 @@ PrimExpr ceil(PrimExpr x, Span span) {
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::ceil(fx->value), fx->span);
static auto op = Op::Get("tirx.ceil");
- return tirx::Call(x.dtype(), op, {x}, span);
+ return tirx::Call(x.dtype(), op, {x}, {}, span);
}
TVM_TIR_REGISTER_PURE_UNARY_OP("ceil").set_attr<TVectorizable>("TVectorizable",
true);
@@ -1067,7 +1068,7 @@ PrimExpr round(PrimExpr x, Span span) {
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value), fx->span);
static auto op = Op::Get("tirx.round");
- return tirx::Call(x.dtype(), op, {x}, span);
+ return tirx::Call(x.dtype(), op, {x}, {}, span);
}
TVM_TIR_REGISTER_PURE_UNARY_OP("round").set_attr<TVectorizable>("TVectorizable",
true);
@@ -1081,7 +1082,7 @@ PrimExpr nearbyint(PrimExpr x, Span span) {
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value), fx->span);
static auto op = Op::Get("tirx.nearbyint");
- return tirx::Call(x.dtype(), op, {x}, span);
+ return tirx::Call(x.dtype(), op, {x}, {}, span);
}
TVM_TIR_REGISTER_PURE_UNARY_OP("nearbyint");
@@ -1098,7 +1099,7 @@ PrimExpr trunc(PrimExpr x, Span span) {
fx->span);
}
static auto op = Op::Get("tirx.trunc");
- return tirx::Call(x.dtype(), op, {x}, span);
+ return tirx::Call(x.dtype(), op, {x}, {}, span);
}
TVM_TIR_REGISTER_PURE_UNARY_OP("trunc").set_attr<TVectorizable>("TVectorizable",
true);
diff --git a/src/tirx/script/printer/expr.cc b/src/tirx/script/printer/expr.cc
index 4b852cd4fa..2c9f5daed3 100644
--- a/src/tirx/script/printer/expr.cc
+++ b/src/tirx/script/printer/expr.cc
@@ -258,6 +258,20 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tirx::Call>("", [](tirx::Call call, AccessPath call_p,
IRDocsifier d) -> Doc {
+ if (!call->annotations.empty()) {
+ ffi::Array<ExprDoc> call_args;
+ int n_args = call->args.size();
+ call_args.reserve(n_args);
+ for (int i = 0; i < n_args; ++i) {
+ call_args.push_back(d->AsDoc<ExprDoc>(call->args[i],
call_p->Attr("args")->ArrayItem(i)));
+ }
+ ExprDoc op_doc = call->op.as<Op>()
+ ?
LiteralDoc::Str(call->op.as<Op>().value()->name, call_p->Attr("op"))
+ : d->AsDoc<ExprDoc>(call->op, call_p->Attr("op"));
+ return TIR(d, "Call")->Call(
+ {LiteralDoc::DataType(call->dtype, call_p->Attr("dtype")), op_doc,
ListDoc(call_args)},
+ {"annotations"}, {d->AsDoc<DictDoc>(call->annotations,
call_p->Attr("annotations"))});
+ }
static const OpAttrMap<tirx::TScriptPrinterName>& op_names =
Op::GetAttrMap<tirx::TScriptPrinterName>("TScriptPrinterName");
static const OpAttrMap<tirx::TScriptDtypePrintLocation> dtype_locations =
diff --git a/src/tirx/transform/lower_warp_memory.cc
b/src/tirx/transform/lower_warp_memory.cc
index ed98c5dfe6..2c3d84fad6 100644
--- a/src/tirx/transform/lower_warp_memory.cc
+++ b/src/tirx/transform/lower_warp_memory.cc
@@ -291,7 +291,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
new_args.Set(i + 1, local_index);
}
}
- return Call(op->dtype, op->op, new_args);
+ return Call(op->dtype, op->op, new_args, op->annotations, op->span);
}
PrimExpr VisitExpr_(const CallNode* op) override {
diff --git a/src/tirx/transform/storage_rewrite.cc
b/src/tirx/transform/storage_rewrite.cc
index da31b2f9f5..c0c4243bd3 100644
--- a/src/tirx/transform/storage_rewrite.cc
+++ b/src/tirx/transform/storage_rewrite.cc
@@ -496,7 +496,8 @@ class StoragePlanRewriter : public StmtExprMutator {
if (se->bits_offset != 0) {
offset = make_const(offset.dtype(), se->bits_offset / elem_bits) +
offset;
}
- return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset,
extent, op->args[4]});
+ return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset,
extent, op->args[4]},
+ op->annotations, op->span);
} else {
return StmtExprMutator::VisitExpr_(op);
}
diff --git a/src/tirx/transform/tile_primitive_dispatch.cc
b/src/tirx/transform/tile_primitive_dispatch.cc
index 70509bd3e0..fbc7786d92 100644
--- a/src/tirx/transform/tile_primitive_dispatch.cc
+++ b/src/tirx/transform/tile_primitive_dispatch.cc
@@ -1160,7 +1160,7 @@ class TilePrimitiveDispatcher : public StmtExprMutator {
args.push_back(new_arg);
}
if (changed) {
- return tirx::Call(call->dtype, call->op, args, call->span);
+ return tirx::Call(call->dtype, call->op, args, call->annotations,
call->span);
}
}
return pred;
diff --git a/src/tirx/transform/unsupported_dtype_legalize.cc
b/src/tirx/transform/unsupported_dtype_legalize.cc
index 15f5876075..c2934c2d86 100644
--- a/src/tirx/transform/unsupported_dtype_legalize.cc
+++ b/src/tirx/transform/unsupported_dtype_legalize.cc
@@ -238,12 +238,13 @@ class ComputeLegalizer : public StmtExprMutator {
auto fmutate = [this](const PrimExpr& e) { return
PromoteToTarget(this->VisitExpr(e)); };
ffi::Array<PrimExpr> args = op->args.Map(fmutate);
if (MatchDType(op->dtype)) {
- return Call(promote_dtype_.with_lanes(op->dtype.lanes()), op->op, args);
+ return Call(promote_dtype_.with_lanes(op->dtype.lanes()), op->op, args,
op->annotations,
+ op->span);
}
if (args.same_as(op->args)) {
return ffi::GetRef<PrimExpr>(op);
} else {
- return Call(op->dtype, op->op, args);
+ return Call(op->dtype, op->op, args, op->annotations, op->span);
}
}
diff --git a/src/tirx/transform/vectorize_loop.cc
b/src/tirx/transform/vectorize_loop.cc
index cdf0bddf4d..da90338956 100644
--- a/src/tirx/transform/vectorize_loop.cc
+++ b/src/tirx/transform/vectorize_loop.cc
@@ -491,9 +491,10 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
t = BroadcastTo(t, lanes, is_scalable);
f = BroadcastTo(f, lanes, is_scalable);
if (is_scalable) {
- return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
{cond, t, f});
+ return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
{cond, t, f},
+ op->annotations, op->span);
} else {
- return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f});
+ return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f},
op->annotations, op->span);
}
}
}
@@ -506,13 +507,14 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
} else {
int lanes = value.dtype().get_lanes_or_vscale_factor();
if (value.dtype().is_scalable_vector()) {
- return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
{value});
+ return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
{value}, op->annotations,
+ op->span);
} else {
int new_lanes = (op->dtype != DataType::Float4E2M1FN() &&
op->args[0].dtype() != DataType::Float4E2M1FN())
? (value.dtype().bits() * value.dtype().lanes()) /
op->dtype.bits()
: value.dtype().lanes();
- return Call(op->dtype.with_lanes(new_lanes), op->op, {value});
+ return Call(op->dtype.with_lanes(new_lanes), op->op, {value},
op->annotations, op->span);
}
}
}
@@ -534,7 +536,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
auto new_args = op->args;
new_args.pop_back();
new_args.push_back(fcd[0]);
- return Call(op->dtype.with_lanes(lane), op->op, new_args);
+ return Call(op->dtype.with_lanes(lane), op->op, new_args,
op->annotations, op->span);
} else if (op->op.same_as(builtin::texture2d_store())) {
int lane = 0;
// Vectorize the value to store
@@ -549,7 +551,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
<< "Expected Data to be Written equal to Texture Store length";
ffi::Array<PrimExpr> new_args{op->args[0], op->args[1], op->args[2],
op->args[3], op->args[4],
mutated_value[0]};
- return Call(op->dtype.with_lanes(lane), op->op, new_args);
+ return Call(op->dtype.with_lanes(lane), op->op, new_args,
op->annotations, op->span);
} else if (op->op.same_as(builtin::reinterpret())) {
return MutateReinterpretExpr_(op);
}
@@ -571,7 +573,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
if (op->args.same_as(new_args)) {
return ffi::GetRef<PrimExpr>(op);
} else {
- return Call(op->dtype, op->op, new_args);
+ return Call(op->dtype, op->op, new_args, op->annotations, op->span);
}
} else {
int lane = 0;
@@ -597,7 +599,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
if (op->args.same_as(new_args)) {
return ffi::GetRef<PrimExpr>(op);
} else {
- return Call(op->dtype.with_lanes(lane), op->op, new_args);
+ return Call(op->dtype.with_lanes(lane), op->op, new_args,
op->annotations, op->span);
}
}
}
diff --git a/tests/python/contrib/test_rpc_tracker.py
b/tests/python/contrib/test_rpc_tracker.py
index 486d5abce4..a5351b62a6 100644
--- a/tests/python/contrib/test_rpc_tracker.py
+++ b/tests/python/contrib/test_rpc_tracker.py
@@ -139,9 +139,7 @@ def check_tracker_rejects_oversized_msg_size():
break
time.sleep(0.05)
else:
- raise AssertionError(
- "tracker did not close connection after oversized msg_size"
- )
+ raise AssertionError("tracker did not close connection after
oversized msg_size")
finally:
tserver.terminate()
except ImportError:
diff --git a/tests/python/tirx-base/test_tir_constructor.py
b/tests/python/tirx-base/test_tir_constructor.py
index 16f85f9625..00cd63fa85 100644
--- a/tests/python/tirx-base/test_tir_constructor.py
+++ b/tests/python/tirx-base/test_tir_constructor.py
@@ -19,6 +19,19 @@ import pytest
import tvm
from tvm import te, topi
+from tvm.tirx.expr_functor import ExprMutator
+
+
+class ReplaceVar(ExprMutator):
+ def __init__(self, old_var, new_var):
+ super().__init__()
+ self.old_var = old_var
+ self.new_var = new_var
+
+ def visit_var_(self, op):
+ if op.same_as(self.old_var):
+ return self.new_var
+ return op
def test_expr_constructor():
@@ -120,6 +133,50 @@ def test_expr_constructor():
assert x.dtype == "float32"
assert x.op.name == "tirx.call_extern"
assert x.args[1] == a
+ assert len(x.annotations) == 0
+
+ annotated_arg = tvm.tirx.Var("annotated_arg", "float32")
+ x_with_annotations = tvm.tirx.Call(
+ "float32",
+ "tirx.call_extern",
+ [tvm.tirx.StringImm("xyz"), annotated_arg],
+ annotations={"disable_tma": True},
+ )
+ assert bool(x_with_annotations.annotations["disable_tma"])
+ assert not tvm.ir.structural_equal(x, x_with_annotations)
+ script = tvm.tirx.Evaluate(x_with_annotations).script()
+ assert "annotations" in script
+ assert "disable_tma" in script
+ func = tvm.tirx.PrimFunc([], tvm.tirx.Evaluate(x_with_annotations))
+ assert tvm.script.from_source(func.script()).script() == func.script()
+
+ y = tvm.tirx.Var("y", "float32")
+ mutated = ReplaceVar(annotated_arg, y)(x_with_annotations)
+ assert bool(mutated.annotations["disable_tma"])
+ assert mutated.args[1].same_as(y)
+
+ x_from_intrin = tvm.tirx.call_intrin(
+ "float32", "tirx.call_extern", tvm.tirx.StringImm("xyz"),
annotations={"disable_tma": True}
+ )
+ assert int(x_from_intrin.annotations["disable_tma"]) == 1
+
+ cond0 = tvm.tirx.Var("cond0", "bool")
+ cond1 = tvm.tirx.Var("cond1", "bool")
+ inner_if = tvm.tirx.Call(
+ "int32",
+ "tirx.if_then_else",
+ [cond1, tvm.tirx.IntImm("int32", 1), tvm.tirx.IntImm("int32", 0)],
+ )
+ outer_if = tvm.tirx.Call(
+ "int32",
+ "tirx.if_then_else",
+ [cond0, inner_if, tvm.tirx.IntImm("int32", 0)],
+ annotations={"keep": True},
+ )
+ simplified = tvm.tirx.transform.Simplify()(
+ tvm.IRModule({"main": tvm.tirx.PrimFunc([],
tvm.tirx.Evaluate(outer_if))})
+ )["main"].body.value
+ assert bool(simplified.annotations["keep"])
v = tvm.tirx.Var("aa", "int32")
x = tvm.tirx.Let(v, 1, v)