This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new bbe5d6b18c [Unity] Added bounds checking on TupleGetItem index (#15024)
bbe5d6b18c is described below
commit bbe5d6b18c3a06c590d012497bda5057acc84130
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Jun 12 11:48:34 2023 -0400
[Unity] Added bounds checking on TupleGetItem index (#15024)
* [Unity] Added bounds checking on TupleGetItem index
The index provided must be non-negative, and less than the size of the
tuple. When the tuple being accessed has `TupleStructInfo`, the upper
bound can also be checked immediately.
* Lint fix
no-else-return has been disabled since #11327, but it looks like
no-else-raise is still enabled
* Also initialize checked_type_ for TupleGetItem
* Update VMShapeLowerMutator to avoid duplicate struct info
* Remove unused variable
* Remove UpdateStructInfo where no longer needed
---
include/tvm/relax/nested_msg.h | 2 --
python/tvm/relax/expr.py | 12 ++++++++++-
src/relax/backend/vm/vm_shape_lower.cc | 6 ++----
src/relax/ir/expr.cc | 11 ++++++++++
tests/python/relax/test_blockbuilder_core.py | 31 ++++++++++++++++++++++++++++
5 files changed, 55 insertions(+), 7 deletions(-)
diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h
index 0564c26687..761698f437 100644
--- a/include/tvm/relax/nested_msg.h
+++ b/include/tvm/relax/nested_msg.h
@@ -330,7 +330,6 @@ NestedMsg<T> MapToNestedMsgBySInfo(Expr expr, FType
fmapleaf) {
field = expr_tuple->fields[i];
} else {
field = TupleGetItem(expr, i);
- UpdateStructInfo(field, tuple->fields[i]);
}
res.push_back(MapToNestedMsgBySInfo<T, FType>(field, fmapleaf));
}
@@ -513,7 +512,6 @@ Expr TransformTupleLeaf(Expr expr, std::array<NestedMsg<T>,
N> msgs, FType ftran
field = expr_tuple->fields[i];
} else {
field = TupleGetItem(expr, i);
- UpdateStructInfo(field, tuple->fields[i]);
}
std::array<NestedMsg<T>, N> sub_msgs;
for (size_t j = 0; j < N; ++j) {
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index 6474db1775..22e5cbcddd 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -229,7 +229,17 @@ class ExprWithOp(Expr, Scriptable):
result: ExprWithOp
The result expression.
"""
- return TupleGetItem(self, index)
+ try:
+ return TupleGetItem(self, index)
+ except tvm.TVMError as err:
+ # For Python objects with __getitem__, but without
+ # __len__, tuple unpacking is done by iterating over
+ # sequential indices until IndexError is raised.
+ # Therefore, convert from TVMError to IndexError for
+ # compatibility.
+ if "Index out of bounds" in err.args[0]:
+ raise IndexError from err
+ raise
@tvm._ffi.register_object("relax.expr.Call")
diff --git a/src/relax/backend/vm/vm_shape_lower.cc
b/src/relax/backend/vm/vm_shape_lower.cc
index 694bcd40d6..97e20a6b86 100644
--- a/src/relax/backend/vm/vm_shape_lower.cc
+++ b/src/relax/backend/vm/vm_shape_lower.cc
@@ -634,11 +634,9 @@ class VMShapeLowerMutator
Expr MakeTupleGetItem(Expr value, int64_t index) {
if (auto* tuple_expr = value.as<TupleNode>()) {
return tuple_expr->fields[index];
- } else if (auto* tuple_sinfo =
GetStructInfoAs<TupleStructInfoNode>(value)) {
+ } else if (GetStructInfoAs<TupleStructInfoNode>(value)) {
// value is tuple type, it is OK to run tuple get item.
- auto ret = TupleGetItem(value, index);
- UpdateStructInfo(ret, tuple_sinfo->fields[index]);
- return ret;
+ return TupleGetItem(value, index);
} else {
// call runtime tuple get item, and return a object.
Call call(builtin_tuple_getitem_, {value, PrimValue::Int64(index)},
Attrs(), {object_sinfo_});
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index 7cd356e0ca..3dafc0ddef 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -166,7 +166,18 @@ Tuple WithFields(Tuple tuple, Optional<Array<Expr>>
opt_fields, Optional<Span> o
}
TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) {
+ CHECK_GE(index, 0) << "Index out of bounds: Tuple " << tuple
+ << " cannot be accessed with negative index " << index;
ObjectPtr<TupleGetItemNode> n = make_object<TupleGetItemNode>();
+
+ if (auto* tuple_info = tuple->struct_info_.as<TupleStructInfoNode>()) {
+ CHECK_LT(index, tuple_info->fields.size())
+ << "Index out of bounds: Tuple " << tuple << " is of size " <<
tuple_info->fields.size()
+ << ", and cannot be accessed with index " << index;
+ auto sinfo = tuple_info->fields[index];
+ n->struct_info_ = sinfo;
+ n->checked_type_ = GetStaticType(sinfo);
+ }
n->tuple = std::move(tuple);
n->index = index;
n->span = std::move(span);
diff --git a/tests/python/relax/test_blockbuilder_core.py
b/tests/python/relax/test_blockbuilder_core.py
index f0b14933d1..4ba25bdffc 100644
--- a/tests/python/relax/test_blockbuilder_core.py
+++ b/tests/python/relax/test_blockbuilder_core.py
@@ -303,6 +303,37 @@ def test_normalize():
assert isinstance(tuple_2.struct_info.fields[1].fields[1],
rx.TensorStructInfo)
+def test_tuple_indexing():
+ m = tir.Var("m", "int64")
+ n = tir.Var("n", "int64")
+
+ shape_x = rx.TensorStructInfo([m, n], "float16")
+ shape_y = rx.TensorStructInfo([n], "float16")
+ relax_tuple = rx.Var("relax_tuple", rx.TupleStructInfo([shape_x, shape_y]))
+
+ assert isinstance(relax_tuple.struct_info, rx.TupleStructInfo)
+ assert isinstance(relax_tuple.struct_info.fields[0], rx.TensorStructInfo)
+ assert isinstance(relax_tuple.struct_info.fields[1], rx.TensorStructInfo)
+
+ # TupleGetItem will initialize struct info from the
+ # TupleStructInfo, if present.
+ x = relax_tuple[0]
+ tvm.ir.assert_structural_equal(x.struct_info, shape_x)
+
+ y = relax_tuple[1]
+ tvm.ir.assert_structural_equal(y.struct_info, shape_y)
+
+ # Tuple unpacking produces TupleGetItem structs
+ x_unpack, y_unpack = relax_tuple
+ tvm.ir.assert_structural_equal(x, x_unpack)
+ tvm.ir.assert_structural_equal(y, y_unpack)
+
+ # When TupleStructInfo is available, tuple unpacking fails immediately
+ # for incorrect number of arguments.
+ with pytest.raises(ValueError):
+ x_unpack, y_unpack, z_unpack = relax_tuple
+
+
def test_call_te():
bb = rx.BlockBuilder()
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")