This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 41c9c3b91a [REFACTOR][TIR] remove legacy tir::any (#17783)
41c9c3b91a is described below
commit 41c9c3b91ab7be4496d32bdbe81b1c37189173c9
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Mar 26 14:43:17 2025 -0400
[REFACTOR][TIR] remove legacy tir::any (#17783)
This PR removes legacy tir::any which was used to represent unknown
shape in relay. As we move toward first class symbolic shape, we no longer
need the ? shape in the system.
---
include/tvm/tir/expr.h | 36 ----------------------
include/tvm/tir/expr_functor.h | 4 ---
include/tvm/topi/detail/strided_slice.h | 3 +-
.../msc/framework/tensorrt/transform/pattern.py | 4 +--
python/tvm/contrib/msc/plugin/codegen/sources.py | 6 +---
python/tvm/tir/__init__.py | 2 +-
python/tvm/tir/expr.py | 12 --------
python/tvm/topi/nn/pad.py | 5 +--
python/tvm/topi/utils.py | 6 ++--
src/contrib/msc/core/utils.h | 6 +---
src/script/printer/legacy_repr.cc | 3 --
src/script/printer/tir/expr.cc | 6 ----
src/tir/analysis/deep_equal.cc | 3 --
src/tir/ir/expr.cc | 12 --------
src/tir/ir/expr_functor.cc | 4 ---
src/tir/ir/tir_visitor_with_path.cc | 2 --
src/tir/ir/tir_visitor_with_path.h | 1 -
tests/python/arith/test_arith_rewrite_simplify.py | 3 --
.../python/tvmscript/test_tvmscript_printer_tir.py | 10 ------
19 files changed, 9 insertions(+), 119 deletions(-)
diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h
index a157516f53..06ee75070c 100644
--- a/include/tvm/tir/expr.h
+++ b/include/tvm/tir/expr.h
@@ -1100,42 +1100,6 @@ class Reduce : public PrimExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode);
};
-/*! \brief Any shape. */
-class AnyNode : public PrimExprNode {
- public:
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &dtype);
- v->Visit("span", &span);
- }
-
- bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const {
- return equal(dtype, other->dtype);
- }
-
- void SHashReduce(SHashReducer hash_reduce) const {}
-
- /*! \brief Convert to var. */
- Var ToVar() const { return Var("any_dim", DataType::Int(32)); }
-
- /*! \brief Convert to SizeVar. */
- SizeVar ToSizeVar() const { return SizeVar("any_dim", DataType::Int(32)); }
-
- static constexpr const char* _type_key = "tir.Any";
- TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode);
-};
-
-/*!
- * \brief Managed reference to AnyNode
- * \sa AnyNode
- */
-class Any : public PrimExpr {
- public:
- TVM_DLL Any(Span span = Span());
-
- TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Any, PrimExpr, AnyNode);
- TVM_DEFINE_OBJECT_REF_COW_METHOD(AnyNode);
-};
-
/*
* \brief Template function to convert Map to unordered_map
* Sometimes useful for API gluing when internal uses unordered_map
diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tir/expr_functor.h
index 7a9cf91a65..dfa9d7e1e3 100644
--- a/include/tvm/tir/expr_functor.h
+++ b/include/tvm/tir/expr_functor.h
@@ -149,7 +149,6 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
virtual R VisitExpr_(const IntImmNode* op, Args... args)
EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloatImmNode* op, Args... args)
EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const StringImmNode* op, Args... args)
EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const AnyNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
}
@@ -192,7 +191,6 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
IR_EXPR_FUNCTOR_DISPATCH(IntImmNode);
IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode);
IR_EXPR_FUNCTOR_DISPATCH(StringImmNode);
- IR_EXPR_FUNCTOR_DISPATCH(AnyNode);
vtable.Finalize();
return vtable;
}
@@ -244,7 +242,6 @@ class TVM_DLL ExprVisitor : public ExprFunctor<void(const
PrimExpr&)> {
void VisitExpr_(const IntImmNode* op) override;
void VisitExpr_(const FloatImmNode* op) override;
void VisitExpr_(const StringImmNode* op) override;
- void VisitExpr_(const AnyNode* op) override;
};
/*!
@@ -290,7 +287,6 @@ class TVM_DLL ExprMutator : protected
ExprFunctor<PrimExpr(const PrimExpr&)> {
PrimExpr VisitExpr_(const IntImmNode* op) override;
PrimExpr VisitExpr_(const FloatImmNode* op) override;
PrimExpr VisitExpr_(const StringImmNode* op) override;
- PrimExpr VisitExpr_(const AnyNode* op) override;
};
} // namespace tir
diff --git a/include/tvm/topi/detail/strided_slice.h
b/include/tvm/topi/detail/strided_slice.h
index a69f8f99ae..f2e021ed98 100644
--- a/include/tvm/topi/detail/strided_slice.h
+++ b/include/tvm/topi/detail/strided_slice.h
@@ -122,6 +122,7 @@ inline Array<PrimExpr> StridedSliceOutputShape(const
Array<PrimExpr>& ishape,
const Array<Integer>& axes,
std::string slice_mode,
const Array<PrimExpr>&
begin_canonicalized,
bool use_any = false) {
+ ICHECK(!use_any) << "StridedSliceOutputShape does not legacy use_any";
const size_t src_tensor_dim = ishape.size();
Array<PrimExpr> out_shape;
for (size_t i = 0; i < src_tensor_dim; ++i) {
@@ -140,8 +141,6 @@ inline Array<PrimExpr> StridedSliceOutputShape(const
Array<PrimExpr>& ishape,
ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
<< ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is
invalid for axis=" << i;
out_shape.Set(axes[i].IntValue(), cast(out_shape[i].dtype(),
PrimExpr(slice_size)));
- } else if (use_any) {
- out_shape.Set(axes[i].IntValue(), tvm::tir::Any());
} else {
out_shape.Set(axes[i].IntValue(), tvm::tir::Var("dim",
out_shape[i]->dtype));
}
diff --git a/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py
b/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py
index 17aee690e3..cd12b336ab 100644
--- a/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py
+++ b/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py
@@ -143,9 +143,7 @@ def _check_expr(expr: relax.Expr, dtypes: Tuple[str] =
None) -> bool:
return False
unknown_dim = 0
for s in sinfo.shape.values:
- if isinstance(s, (tvm.tir.Var, tvm.tir.Any)):
- unknown_dim += 1
- elif isinstance(s, tvm.tir.IntImm) and s < 0:
+ if isinstance(s, tvm.tir.IntImm) and s < 0:
unknown_dim += 1
return unknown_dim <= 1
diff --git a/python/tvm/contrib/msc/plugin/codegen/sources.py
b/python/tvm/contrib/msc/plugin/codegen/sources.py
index 1ea95a958f..3806dabd0e 100644
--- a/python/tvm/contrib/msc/plugin/codegen/sources.py
+++ b/python/tvm/contrib/msc/plugin/codegen/sources.py
@@ -642,11 +642,7 @@ class TVMUtils {
Array<tvm::PrimExpr> tvm_shape;
for (size_t i = 0; i < meta_shape.ndim(); i++) {
auto dim = meta_shape.DimAt(i);
- if (dim == -1) {
- tvm_shape.push_back(tir::Any());
- } else {
- tvm_shape.push_back(Integer(dim));
- }
+ tvm_shape.push_back(Integer(dim));
}
return tvm_shape;
}
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 568e05351a..4f56ec3c15 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -26,7 +26,7 @@ from .expr import Var, SizeVar, Reduce, FloatImm, IntImm,
StringImm, Cast
from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
from .expr import Select, BufferLoad, ProducerLoad, Ramp, Broadcast, Shuffle
-from .expr import Call, CallEffectKind, Let, IterVar, CommReducer, Any
+from .expr import Call, CallEffectKind, Let, IterVar, CommReducer
from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While
from .stmt import (
diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index 37976394f8..6cd4302133 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -1308,15 +1308,3 @@ class Let(PrimExprWithOp):
self, var: Var, value: PrimExpr, body: PrimExpr, span: Optional[Span]
= None
) -> None:
self.__init_handle_by_constructor__(_ffi_api.Let, var, value, body,
span) # type: ignore
-
-
-@tvm._ffi.register_object("tir.Any")
-class Any(PrimExprWithOp):
- """Any node.
-
- span : Optional[Span]
- The location of this expression in the source code.
- """
-
- def __init__(self, span: Optional[Span] = None) -> None:
- self.__init_handle_by_constructor__(_ffi_api.Any, span) # type: ignore
diff --git a/python/tvm/topi/nn/pad.py b/python/tvm/topi/nn/pad.py
index 7bd2b7632b..8833ef38d6 100644
--- a/python/tvm/topi/nn/pad.py
+++ b/python/tvm/topi/nn/pad.py
@@ -59,10 +59,7 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0,
name="PadInput", attrs=
ana = tvm.arith.Analyzer()
dshape = []
for dim in data.shape:
- if isinstance(dim, tvm.tir.Any):
- dshape.append(tvm.te.size_var("dim"))
- else:
- dshape.append(dim)
+ dshape.append(dim)
out_shape = tuple(ana.simplify(dshape[i] + pad_before[i] + pad_after[i])
for i in range(n))
pad_value = (
pad_value
diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py
index 71599ad74a..3a0441ef84 100644
--- a/python/tvm/topi/utils.py
+++ b/python/tvm/topi/utils.py
@@ -23,7 +23,7 @@ from numbers import Integral
import numpy as np
import tvm
from tvm import te
-from tvm.tir import Any, SizeVar, bijective_layout, layout
+from tvm.tir import SizeVar, bijective_layout, layout
from . import cpp, tag
@@ -187,7 +187,7 @@ def get_const_tuple(in_tuple):
ret = []
ana = None
for elem in in_tuple:
- if isinstance(elem, (tvm.tir.Var, tvm.tir.expr.Any)):
+ if isinstance(elem, tvm.tir.Var):
ret.append(elem)
elif not isinstance(elem, (tvm.tir.IntImm, int)):
ana = tvm.arith.Analyzer() if ana is None else ana
@@ -525,4 +525,4 @@ def is_target(names):
def is_dynamic_shape(shape):
"""Checks if any part of a shape is dynamic"""
- return any([isinstance(x, (Any, SizeVar)) for x in shape])
+ return any([isinstance(x, SizeVar) for x in shape])
diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h
index 4156688303..84e3c66741 100644
--- a/src/contrib/msc/core/utils.h
+++ b/src/contrib/msc/core/utils.h
@@ -221,11 +221,7 @@ class ArrayUtils {
TVM_DLL static const Array<T> Cast(const Array<PrimExpr>& src_array) {
Array<T> new_array;
for (const auto& s : src_array) {
- if (s->IsInstance<tvm::tir::AnyNode>()) {
- new_array.push_back(T(-1));
- } else {
- new_array.push_back(Downcast<T>(s));
- }
+ new_array.push_back(Downcast<T>(s));
}
return new_array;
}
diff --git a/src/script/printer/legacy_repr.cc
b/src/script/printer/legacy_repr.cc
index 084a86a6f5..b047657b3d 100644
--- a/src/script/printer/legacy_repr.cc
+++ b/src/script/printer/legacy_repr.cc
@@ -521,9 +521,6 @@ TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
(*p) << ")";
});
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<AnyNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
(*p) << "?"; });
-
TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
.set_dispatch<BufferLoadNode>([](const ObjectRef& node, ReprLegacyPrinter*
p) {
auto* op = static_cast<const BufferLoadNode*>(node.get());
diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc
index 8268e6b35e..8ac0931496 100644
--- a/src/script/printer/tir/expr.cc
+++ b/src/script/printer/tir/expr.cc
@@ -296,11 +296,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return prefix->Call(args);
});
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- .set_dispatch<tir::Any>("", [](tir::Any any, ObjectPath p, IRDocsifier d)
-> Doc {
- return TIR(d, "Any")->Call({});
- });
-
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Reduce>("", [](tir::Reduce r, ObjectPath p, IRDocsifier
d) -> Doc {
ExprDoc combiner = d->AsDoc<ExprDoc>(r->combiner, p->Attr("combiner"));
@@ -415,7 +410,6 @@ TVM_SCRIPT_REPR(tir::CallNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::ShuffleNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::CommReducerNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::IndexMapNode, ReprPrintTIR);
-TVM_SCRIPT_REPR(tir::AnyNode, ReprPrintTIR);
TVM_SCRIPT_REPR(tir::ReduceNode, ReprPrintTIR);
} // namespace printer
diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc
index 1ec9fc5522..d4e0284343 100644
--- a/src/tir/analysis/deep_equal.cc
+++ b/src/tir/analysis/deep_equal.cc
@@ -65,9 +65,6 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const
PrimExpr& rhs) const {
auto* prhs = rhs.as<IntImmNode>();
return plhs->dtype == prhs->dtype && plhs->value == prhs->value;
}
- if (lhs.as<AnyNode>()) {
- return false;
- }
return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false, NullOpt);
}
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index ca28520f8f..b52c85df35 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -747,18 +747,6 @@ TVM_REGISTER_GLOBAL("tir.Reduce")
TVM_REGISTER_NODE_TYPE(ReduceNode);
-// Any
-Any::Any(Span span) {
- auto n = make_object<AnyNode>();
- n->dtype = DataType::Int(32);
- n->span = std::move(span);
- data_ = std::move(n);
-}
-
-TVM_REGISTER_GLOBAL("tir.Any").set_body_typed([](Span span) { return
Any(span); });
-
-TVM_REGISTER_NODE_TYPE(AnyNode);
-
// BufferLoad
void BufferLoadNode::LegalizeDType() {
for (int i = 0; i < static_cast<int>(indices.size()) - 1; i++) {
diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc
index 3c117b58a7..05e333b78a 100644
--- a/src/tir/ir/expr_functor.cc
+++ b/src/tir/ir/expr_functor.cc
@@ -32,8 +32,6 @@ void ExprVisitor::VisitExpr_(const SizeVarNode* op) {
this->VisitExpr_(static_cast<const VarNode*>(op));
}
-void ExprVisitor::VisitExpr_(const AnyNode* op) {}
-
void ExprVisitor::VisitExpr_(const BufferLoadNode* op) {
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
}
@@ -119,8 +117,6 @@ PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) {
return this->VisitExpr_(static_cast<const VarNode*>(op));
}
-PrimExpr ExprMutator::VisitExpr_(const AnyNode* op) { return
GetRef<PrimExpr>(op); }
-
PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) {
auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
Array<PrimExpr> indices = op->indices.Map(fmutate);
diff --git a/src/tir/ir/tir_visitor_with_path.cc
b/src/tir/ir/tir_visitor_with_path.cc
index e0318b21be..4f5007aedb 100644
--- a/src/tir/ir/tir_visitor_with_path.cc
+++ b/src/tir/ir/tir_visitor_with_path.cc
@@ -343,8 +343,6 @@ void TIRVisitorWithPath::VisitExpr_(const SizeVarNode* op,
ObjectPath path) {
VisitExpr_(static_cast<const VarNode*>(op), path);
}
-void TIRVisitorWithPath::VisitExpr_(const AnyNode* op, ObjectPath path) {}
-
void TIRVisitorWithPath::VisitExpr_(const BufferLoadNode* op, ObjectPath path)
{
Visit(op->buffer, path->Attr("buffer"));
Visit(op->indices, path->Attr("indices"));
diff --git a/src/tir/ir/tir_visitor_with_path.h
b/src/tir/ir/tir_visitor_with_path.h
index 1ae6df58f7..61441541da 100644
--- a/src/tir/ir/tir_visitor_with_path.h
+++ b/src/tir/ir/tir_visitor_with_path.h
@@ -152,7 +152,6 @@ class TIRVisitorWithPath : protected ExprFunctor<void(const
PrimExpr&, ObjectPat
void VisitExpr_(const IntImmNode* op, ObjectPath path) override;
void VisitExpr_(const FloatImmNode* op, ObjectPath path) override;
void VisitExpr_(const StringImmNode* op, ObjectPath path) override;
- void VisitExpr_(const AnyNode* op, ObjectPath path) override;
// Utility to call EnterDef/ExitDef. Used in the implementation of
// WithDef.
diff --git a/tests/python/arith/test_arith_rewrite_simplify.py
b/tests/python/arith/test_arith_rewrite_simplify.py
index 7fc1862192..ad4abdfe29 100644
--- a/tests/python/arith/test_arith_rewrite_simplify.py
+++ b/tests/python/arith/test_arith_rewrite_simplify.py
@@ -420,7 +420,6 @@ class TestAddIndex(BaseCompare):
class TestSubIndex(BaseCompare):
x, y, z = te.var("x"), te.var("y"), te.var("z")
- a, b = tvm.tir.Any(), tvm.tir.Any()
test_case = tvm.testing.parameter(
TestCase(x + y - y, x),
@@ -437,8 +436,6 @@ class TestSubIndex(BaseCompare):
TestCase(y - tvm.te.max(x, y), tvm.te.min(y - x, 0)),
# mul co-efficient foldng
TestCase(x - x, 0),
- TestCase(a - a, 0),
- TestCase(a - b, a - b),
TestCase(x * y - x, x * (y + (-1))),
TestCase(x * y - 10 * x, x * (y + (-10))),
TestCase(y * x - x * z, x * (y - z)),
diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py
b/tests/python/tvmscript/test_tvmscript_printer_tir.py
index 943ba54060..21d8fc9422 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py
@@ -688,16 +688,6 @@ T.comm_reducer(lambda x, y: x + y, [T.float32(0.0)])
)
-def test_any():
- obj = tir.Any()
- _assert_print(
- obj,
- """
-T.Any()
-""",
- )
-
-
def test_int_imm():
obj = T.int16(1)
_assert_print(