This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new bbe5d6b18c [Unity] Added bounds checking on TupleGetItem index (#15024)
bbe5d6b18c is described below

commit bbe5d6b18c3a06c590d012497bda5057acc84130
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Jun 12 11:48:34 2023 -0400

    [Unity] Added bounds checking on TupleGetItem index (#15024)
    
    * [Unity] Added bounds checking on TupleGetItem index
    
    The index provided must be non-negative, and less than the size of the
    tuple.  When the tuple being accessed has `TupleStructInfo`, the upper
    bound can also be checked immediately.
    
    * Lint fix
    
    no-else-return has been disabled since #11327, but it looks like
    no-else-raise is still enabled
    
    * Also initialize checked_type_ for TupleGetItem
    
    * Update VMShapeLowerMutator to avoid duplicate struct info
    
    * Remove unused variable
    
    * Remove UpdateStructInfo where no longer needed
---
 include/tvm/relax/nested_msg.h               |  2 --
 python/tvm/relax/expr.py                     | 12 ++++++++++-
 src/relax/backend/vm/vm_shape_lower.cc       |  6 ++----
 src/relax/ir/expr.cc                         | 11 ++++++++++
 tests/python/relax/test_blockbuilder_core.py | 31 ++++++++++++++++++++++++++++
 5 files changed, 55 insertions(+), 7 deletions(-)

diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h
index 0564c26687..761698f437 100644
--- a/include/tvm/relax/nested_msg.h
+++ b/include/tvm/relax/nested_msg.h
@@ -330,7 +330,6 @@ NestedMsg<T> MapToNestedMsgBySInfo(Expr expr, FType 
fmapleaf) {
         field = expr_tuple->fields[i];
       } else {
         field = TupleGetItem(expr, i);
-        UpdateStructInfo(field, tuple->fields[i]);
       }
       res.push_back(MapToNestedMsgBySInfo<T, FType>(field, fmapleaf));
     }
@@ -513,7 +512,6 @@ Expr TransformTupleLeaf(Expr expr, std::array<NestedMsg<T>, 
N> msgs, FType ftran
         field = expr_tuple->fields[i];
       } else {
         field = TupleGetItem(expr, i);
-        UpdateStructInfo(field, tuple->fields[i]);
       }
       std::array<NestedMsg<T>, N> sub_msgs;
       for (size_t j = 0; j < N; ++j) {
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index 6474db1775..22e5cbcddd 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -229,7 +229,17 @@ class ExprWithOp(Expr, Scriptable):
         result: ExprWithOp
             The result expression.
         """
-        return TupleGetItem(self, index)
+        try:
+            return TupleGetItem(self, index)
+        except tvm.TVMError as err:
+            # For Python objects with __getitem__, but without
+            # __len__, tuple unpacking is done by iterating over
+            # sequential indices until IndexError is raised.
+            # Therefore, convert from TVMError to IndexError for
+            # compatibility.
+            if "Index out of bounds" in err.args[0]:
+                raise IndexError from err
+            raise
 
 
 @tvm._ffi.register_object("relax.expr.Call")
diff --git a/src/relax/backend/vm/vm_shape_lower.cc 
b/src/relax/backend/vm/vm_shape_lower.cc
index 694bcd40d6..97e20a6b86 100644
--- a/src/relax/backend/vm/vm_shape_lower.cc
+++ b/src/relax/backend/vm/vm_shape_lower.cc
@@ -634,11 +634,9 @@ class VMShapeLowerMutator
   Expr MakeTupleGetItem(Expr value, int64_t index) {
     if (auto* tuple_expr = value.as<TupleNode>()) {
       return tuple_expr->fields[index];
-    } else if (auto* tuple_sinfo = 
GetStructInfoAs<TupleStructInfoNode>(value)) {
+    } else if (GetStructInfoAs<TupleStructInfoNode>(value)) {
       // value is tuple type, it is OK to run tuple get item.
-      auto ret = TupleGetItem(value, index);
-      UpdateStructInfo(ret, tuple_sinfo->fields[index]);
-      return ret;
+      return TupleGetItem(value, index);
     } else {
       // call runtime tuple get item, and return a object.
       Call call(builtin_tuple_getitem_, {value, PrimValue::Int64(index)}, 
Attrs(), {object_sinfo_});
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index 7cd356e0ca..3dafc0ddef 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -166,7 +166,18 @@ Tuple WithFields(Tuple tuple, Optional<Array<Expr>> 
opt_fields, Optional<Span> o
 }
 
 TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) {
+  CHECK_GE(index, 0) << "Index out of bounds: Tuple " << tuple
+                     << " cannot be accessed with negative index " << index;
   ObjectPtr<TupleGetItemNode> n = make_object<TupleGetItemNode>();
+
+  if (auto* tuple_info = tuple->struct_info_.as<TupleStructInfoNode>()) {
+    CHECK_LT(index, tuple_info->fields.size())
+        << "Index out of bounds: Tuple " << tuple << " is of size " << 
tuple_info->fields.size()
+        << ", and cannot be accessed with index " << index;
+    auto sinfo = tuple_info->fields[index];
+    n->struct_info_ = sinfo;
+    n->checked_type_ = GetStaticType(sinfo);
+  }
   n->tuple = std::move(tuple);
   n->index = index;
   n->span = std::move(span);
diff --git a/tests/python/relax/test_blockbuilder_core.py 
b/tests/python/relax/test_blockbuilder_core.py
index f0b14933d1..4ba25bdffc 100644
--- a/tests/python/relax/test_blockbuilder_core.py
+++ b/tests/python/relax/test_blockbuilder_core.py
@@ -303,6 +303,37 @@ def test_normalize():
     assert isinstance(tuple_2.struct_info.fields[1].fields[1], 
rx.TensorStructInfo)
 
 
+def test_tuple_indexing():
+    m = tir.Var("m", "int64")
+    n = tir.Var("n", "int64")
+
+    shape_x = rx.TensorStructInfo([m, n], "float16")
+    shape_y = rx.TensorStructInfo([n], "float16")
+    relax_tuple = rx.Var("relax_tuple", rx.TupleStructInfo([shape_x, shape_y]))
+
+    assert isinstance(relax_tuple.struct_info, rx.TupleStructInfo)
+    assert isinstance(relax_tuple.struct_info.fields[0], rx.TensorStructInfo)
+    assert isinstance(relax_tuple.struct_info.fields[1], rx.TensorStructInfo)
+
+    # TupleGetItem will initialize struct info from the
+    # TupleStructInfo, if present.
+    x = relax_tuple[0]
+    tvm.ir.assert_structural_equal(x.struct_info, shape_x)
+
+    y = relax_tuple[1]
+    tvm.ir.assert_structural_equal(y.struct_info, shape_y)
+
+    # Tuple unpacking produces TupleGetItem structs
+    x_unpack, y_unpack = relax_tuple
+    tvm.ir.assert_structural_equal(x, x_unpack)
+    tvm.ir.assert_structural_equal(y, y_unpack)
+
+    # When TupleStructInfo is available, tuple unpacking fails immediately
+    # for incorrect number of arguments.
+    with pytest.raises(ValueError):
+        x_unpack, y_unpack, z_unpack = relax_tuple
+
+
 def test_call_te():
     bb = rx.BlockBuilder()
     n, m = tir.Var("n", "int64"), tir.Var("m", "int64")

Reply via email to