Lunderberg commented on code in PR #16826:
URL: https://github.com/apache/tvm/pull/16826#discussion_r1550061085


##########
src/relax/op/tensor/index.cc:
##########
@@ -102,117 +107,326 @@ TVM_REGISTER_OP("relax.take")
 /* relax.strided_slice */
 TVM_REGISTER_NODE_TYPE(StridedSliceAttrs);
 
-Expr strided_slice(Expr x,                             //
-                   Array<Integer> axes,                //
-                   Array<PrimExpr> begin,              //
-                   Array<PrimExpr> end,                //
-                   Optional<Array<PrimExpr>> strides,  //
+Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, Optional<Expr> 
strides,
                    bool assume_inbound) {
-  int n_axis = axes.size();
-  CHECK_EQ(static_cast<int>(begin.size()), n_axis)
-      << "StridedSlice requires the number of begin indices to equal the 
number of axes.";
-  CHECK_EQ(static_cast<int>(end.size()), n_axis)
-      << "StridedSlice requires the number of end indices to equal the number 
of axes.";
-  if (strides.defined()) {
-    CHECK_EQ(static_cast<int>(strides.value().size()), n_axis)
-        << "StridedSlice requires the number of strides to equal the number of 
axes.";
-  }
-
-  // Todo(relax-team): We are going to support dynamic strided slice, where
-  // begin/end/stride can be not static at compile time. Therefore, 
begin/end/stride
-  // should not be part of StridedSliceAttrs, as we only allow static values to
-  // reside in attributes. However, using ShapeExpr to represent these
-  // arrays is not conceptually right, because they are not describing a
-  // concrete shape. The proper way to support dynamic strided slice is to use
-  // Tuple of PrimValue to represent begin/end/stride. Since at this moment
-  // we have no support for PrimValue, we store begin/end/stride as attribute
-  // fields as a workaround.
-  // Will switch to Tuple of PrimValue after introducing PrimValue.
-  auto f_convert_to_int64 = [](const PrimExpr& value) {
-    if (value->IsInstance<IntImmNode>()) {
-      return cast(DataType::Int(64), value);
+  // Initial validation of the arguments.  A more complete validation
+  // will be done when inferring the StructInfo, but that requires the
+  // StructInfo of all arguments to be populated.
+
+  std::optional<std::tuple<const char*, size_t, Expr>> known_length;
+  auto check_tuple = [&known_length](const char* name, Expr expr) {
+    if (const auto* tuple = expr.as<TupleNode>()) {
+      size_t length = tuple->fields.size();
+      if (known_length.has_value()) {
+        const auto& prev = known_length.value();
+        CHECK_EQ(length, std::get<size_t>(prev))
+            << "The strided_slice operator requires that "
+            << "the axes, begin, end, and strides tuples are all the same 
length.  "
+            << "However, the " << std::get<const char*>(prev) << " argument ("
+            << std::get<Expr>(prev) << ") has " << std::get<size_t>(prev) << " 
elements, while the "
+            << name << " argument (" << expr << ") has " << length << " 
elements.";
+      } else {
+        known_length = std::tuple{name, length, expr};
+      }
     }
-    CHECK(value.dtype() == DataType::Int(64)) << "strided_slice expects the 
input begin/end/stride "
-                                                 "values to be all int64. 
However, the given "
-                                              << value << " has dtype " << 
value->dtype;
-    return value;
   };
+  check_tuple("axes", axes);
+  check_tuple("begin", begin);
+  check_tuple("end", end);
+  if (strides.defined()) check_tuple("strides", strides.value());
 
   ObjectPtr<StridedSliceAttrs> attrs = make_object<StridedSliceAttrs>();
-  attrs->axes = std::move(axes);
-  attrs->begin = begin.Map(f_convert_to_int64);
-  attrs->end = end.Map(f_convert_to_int64);
-  attrs->strides = strides.defined() ? strides.value().Map(f_convert_to_int64) 
: strides;
   attrs->assume_inbound = assume_inbound;
 
+  Array<Expr> args = {x, axes, begin, end};
+  if (strides.defined()) {
+    args.push_back(strides.value());
+  }
+
   static const Op& op = Op::Get("relax.strided_slice");
-  return Call(op, {std::move(x)}, Attrs(attrs), {});
+  auto call = Call(op, args, Attrs(attrs));
+
+  return call;
 }
 
 TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice);
 
-inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, int64_t 
stride,
-                                  bool assume_inbound) {
-  // Same as topi strided slice CanonicalizeIndex function in
-  // include/tvm/topi/detail/strided_slice.h
-  PrimExpr begin_range = stride < 0 ? -1 : 0;
-  PrimExpr end_range = stride < 0 ? extent - 1 : extent;
+inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr 
stride) {
+  // Handle Python-style negative indices
   index = if_then_else(index < 0, index + extent, index);
-  return assume_inbound ? index : min(max(index, begin_range), end_range);  // 
NOLINT
+  // Clamp the result to valid indices
+  PrimExpr lower_bound = tvm::if_then_else(stride < 0, -1, 0);
+  PrimExpr upper_bound = tvm::if_then_else(stride < 0, extent - 1, extent);
+  index = tvm::min(tvm::max(index, lower_bound), upper_bound);
+
+  // PrimExpr bounds_offset = tvm::if_then_else(stride < 0, -1, 0);
+  // index = tvm::min(tvm::max(index, 0 + bounds_offset), extent + 
bounds_offset);

Review Comment:
   Thank you, and deleted!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to