slyubomirsky commented on code in PR #16826:
URL: https://github.com/apache/tvm/pull/16826#discussion_r1548789846
##########
src/relax/op/tensor/index.cc:
##########
@@ -102,117 +107,326 @@ TVM_REGISTER_OP("relax.take")
/* relax.strided_slice */
TVM_REGISTER_NODE_TYPE(StridedSliceAttrs);
-Expr strided_slice(Expr x, //
- Array<Integer> axes, //
- Array<PrimExpr> begin, //
- Array<PrimExpr> end, //
- Optional<Array<PrimExpr>> strides, //
+Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, Optional<Expr>
strides,
bool assume_inbound) {
- int n_axis = axes.size();
- CHECK_EQ(static_cast<int>(begin.size()), n_axis)
- << "StridedSlice requires the number of begin indices to equal the
number of axes.";
- CHECK_EQ(static_cast<int>(end.size()), n_axis)
- << "StridedSlice requires the number of end indices to equal the number
of axes.";
- if (strides.defined()) {
- CHECK_EQ(static_cast<int>(strides.value().size()), n_axis)
- << "StridedSlice requires the number of strides to equal the number of
axes.";
- }
-
- // Todo(relax-team): We are going to support dynamic strided slice, where
- // begin/end/stride can be not static at compile time. Therefore,
begin/end/stride
- // should not be part of StridedSliceAttrs, as we only allow static values to
- // reside in attributes. However, using ShapeExpr to represent these
- // arrays is not conceptually right, because they are not describing a
- // concrete shape. The proper way to support dynamic strided slice is to use
- // Tuple of PrimValue to represent begin/end/stride. Since at this moment
- // we have no support for PrimValue, we store begin/end/stride as attribute
- // fields as a workaround.
- // Will switch to Tuple of PrimValue after introducing PrimValue.
- auto f_convert_to_int64 = [](const PrimExpr& value) {
- if (value->IsInstance<IntImmNode>()) {
- return cast(DataType::Int(64), value);
+ // Initial validation of the arguments. A more complete validation
+ // will be done when inferring the StructInfo, but that requires the
+ // StructInfo of all arguments to be populated.
+
+ std::optional<std::tuple<const char*, size_t, Expr>> known_length;
+ auto check_tuple = [&known_length](const char* name, Expr expr) {
+ if (const auto* tuple = expr.as<TupleNode>()) {
+ size_t length = tuple->fields.size();
+ if (known_length.has_value()) {
+ const auto& prev = known_length.value();
+ CHECK_EQ(length, std::get<size_t>(prev))
+ << "The strided_slice operator requires that "
+ << "the axes, begin, end, and strides tuples are all the same
length. "
+ << "However, the " << std::get<const char*>(prev) << " argument ("
+ << std::get<Expr>(prev) << ") has " << std::get<size_t>(prev) << "
elements, while the "
+ << name << " argument (" << expr << ") has " << length << "
elements.";
+ } else {
+ known_length = std::tuple{name, length, expr};
+ }
}
- CHECK(value.dtype() == DataType::Int(64)) << "strided_slice expects the
input begin/end/stride "
- "values to be all int64.
However, the given "
- << value << " has dtype " <<
value->dtype;
- return value;
};
+ check_tuple("axes", axes);
+ check_tuple("begin", begin);
+ check_tuple("end", end);
+ if (strides.defined()) check_tuple("strides", strides.value());
ObjectPtr<StridedSliceAttrs> attrs = make_object<StridedSliceAttrs>();
- attrs->axes = std::move(axes);
- attrs->begin = begin.Map(f_convert_to_int64);
- attrs->end = end.Map(f_convert_to_int64);
- attrs->strides = strides.defined() ? strides.value().Map(f_convert_to_int64)
: strides;
attrs->assume_inbound = assume_inbound;
+ Array<Expr> args = {x, axes, begin, end};
+ if (strides.defined()) {
+ args.push_back(strides.value());
+ }
+
static const Op& op = Op::Get("relax.strided_slice");
- return Call(op, {std::move(x)}, Attrs(attrs), {});
+ auto call = Call(op, args, Attrs(attrs));
+
+ return call;
}
TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice);
-inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, int64_t
stride,
- bool assume_inbound) {
- // Same as topi strided slice CanonicalizeIndex function in
- // include/tvm/topi/detail/strided_slice.h
- PrimExpr begin_range = stride < 0 ? -1 : 0;
- PrimExpr end_range = stride < 0 ? extent - 1 : extent;
+inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr
stride) {
+ // Handle Python-style negative indices
index = if_then_else(index < 0, index + extent, index);
- return assume_inbound ? index : min(max(index, begin_range), end_range); //
NOLINT
+ // Clamp the result to valid indices
+ PrimExpr lower_bound = tvm::if_then_else(stride < 0, -1, 0);
+ PrimExpr upper_bound = tvm::if_then_else(stride < 0, extent - 1, extent);
+ index = tvm::min(tvm::max(index, lower_bound), upper_bound);
+
+ // PrimExpr bounds_offset = tvm::if_then_else(stride < 0, -1, 0);
+ // index = tvm::min(tvm::max(index, 0 + bounds_offset), extent +
bounds_offset);
Review Comment:
Probably should just be deleted, I assume
##########
src/relax/op/tensor/index.cc:
##########
@@ -102,117 +107,326 @@ TVM_REGISTER_OP("relax.take")
/* relax.strided_slice */
TVM_REGISTER_NODE_TYPE(StridedSliceAttrs);
-Expr strided_slice(Expr x, //
- Array<Integer> axes, //
- Array<PrimExpr> begin, //
- Array<PrimExpr> end, //
- Optional<Array<PrimExpr>> strides, //
+Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, Optional<Expr>
strides,
bool assume_inbound) {
- int n_axis = axes.size();
- CHECK_EQ(static_cast<int>(begin.size()), n_axis)
- << "StridedSlice requires the number of begin indices to equal the
number of axes.";
- CHECK_EQ(static_cast<int>(end.size()), n_axis)
- << "StridedSlice requires the number of end indices to equal the number
of axes.";
- if (strides.defined()) {
- CHECK_EQ(static_cast<int>(strides.value().size()), n_axis)
- << "StridedSlice requires the number of strides to equal the number of
axes.";
- }
-
- // Todo(relax-team): We are going to support dynamic strided slice, where
- // begin/end/stride can be not static at compile time. Therefore,
begin/end/stride
- // should not be part of StridedSliceAttrs, as we only allow static values to
- // reside in attributes. However, using ShapeExpr to represent these
- // arrays is not conceptually right, because they are not describing a
- // concrete shape. The proper way to support dynamic strided slice is to use
- // Tuple of PrimValue to represent begin/end/stride. Since at this moment
- // we have no support for PrimValue, we store begin/end/stride as attribute
- // fields as a workaround.
- // Will switch to Tuple of PrimValue after introducing PrimValue.
- auto f_convert_to_int64 = [](const PrimExpr& value) {
- if (value->IsInstance<IntImmNode>()) {
- return cast(DataType::Int(64), value);
+ // Initial validation of the arguments. A more complete validation
+ // will be done when inferring the StructInfo, but that requires the
+ // StructInfo of all arguments to be populated.
+
+ std::optional<std::tuple<const char*, size_t, Expr>> known_length;
+ auto check_tuple = [&known_length](const char* name, Expr expr) {
+ if (const auto* tuple = expr.as<TupleNode>()) {
+ size_t length = tuple->fields.size();
+ if (known_length.has_value()) {
+ const auto& prev = known_length.value();
+ CHECK_EQ(length, std::get<size_t>(prev))
+ << "The strided_slice operator requires that "
+ << "the axes, begin, end, and strides tuples are all the same
length. "
+ << "However, the " << std::get<const char*>(prev) << " argument ("
+ << std::get<Expr>(prev) << ") has " << std::get<size_t>(prev) << "
elements, while the "
+ << name << " argument (" << expr << ") has " << length << "
elements.";
+ } else {
+ known_length = std::tuple{name, length, expr};
+ }
}
- CHECK(value.dtype() == DataType::Int(64)) << "strided_slice expects the
input begin/end/stride "
- "values to be all int64.
However, the given "
- << value << " has dtype " <<
value->dtype;
- return value;
};
+ check_tuple("axes", axes);
+ check_tuple("begin", begin);
+ check_tuple("end", end);
+ if (strides.defined()) check_tuple("strides", strides.value());
ObjectPtr<StridedSliceAttrs> attrs = make_object<StridedSliceAttrs>();
- attrs->axes = std::move(axes);
- attrs->begin = begin.Map(f_convert_to_int64);
- attrs->end = end.Map(f_convert_to_int64);
- attrs->strides = strides.defined() ? strides.value().Map(f_convert_to_int64)
: strides;
attrs->assume_inbound = assume_inbound;
+ Array<Expr> args = {x, axes, begin, end};
+ if (strides.defined()) {
+ args.push_back(strides.value());
+ }
+
static const Op& op = Op::Get("relax.strided_slice");
- return Call(op, {std::move(x)}, Attrs(attrs), {});
+ auto call = Call(op, args, Attrs(attrs));
+
+ return call;
}
TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice);
-inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, int64_t
stride,
- bool assume_inbound) {
- // Same as topi strided slice CanonicalizeIndex function in
- // include/tvm/topi/detail/strided_slice.h
- PrimExpr begin_range = stride < 0 ? -1 : 0;
- PrimExpr end_range = stride < 0 ? extent - 1 : extent;
+inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr
stride) {
+ // Handle Python-style negative indices
index = if_then_else(index < 0, index + extent, index);
- return assume_inbound ? index : min(max(index, begin_range), end_range); //
NOLINT
+ // Clamp the result to valid indices
+ PrimExpr lower_bound = tvm::if_then_else(stride < 0, -1, 0);
+ PrimExpr upper_bound = tvm::if_then_else(stride < 0, extent - 1, extent);
+ index = tvm::min(tvm::max(index, lower_bound), upper_bound);
+
+ // PrimExpr bounds_offset = tvm::if_then_else(stride < 0, -1, 0);
+ // index = tvm::min(tvm::max(index, 0 + bounds_offset), extent +
bounds_offset);
+
+ return index;
}
-PrimExpr GetLength(PrimExpr begin, PrimExpr end, const int64_t stride, const
PrimExpr& length,
+PrimExpr GetLength(PrimExpr begin, PrimExpr end, PrimExpr stride, PrimExpr
extent,
bool assume_inbound) {
- begin = CanonicalizeIndex(begin, length, stride, assume_inbound);
- end = CanonicalizeIndex(end, length, stride, assume_inbound);
- arith::Analyzer ana;
- if (stride < 0) {
- return ana.Simplify(ceildiv(begin - end, IntImm(DataType::Int(64),
-stride)));
+ if (assume_inbound) {
+ return ceildiv(end - begin, stride);
} else {
- return ana.Simplify(ceildiv(end - begin, IntImm(DataType::Int(64),
stride)));
+ begin = CanonicalizeIndex(begin, extent, stride);
+ end = CanonicalizeIndex(end, extent, stride);
+ return tvm::if_then_else(stride < 0, ceildiv(begin - end, -stride),
+ ceildiv(end - begin, stride));
}
}
-StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder&
ctx) {
- TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
- const auto* attrs = call->attrs.as<StridedSliceAttrs>();
- if (attrs->axes.empty()) {
- return data_sinfo;
- }
+/* \brief Helper function to unpack a relax::Tuple
+ *
+ * A `relax::Tuple` may be provided to an operator as an in-line
+ * expression, as a variable bound to known tuple within the current
+ * function, as a function argument, etc. The StructInfo of the tuple
+ * tracks the known values of any `PrimValue` elements, but it can be
+ * tedious to extract. This utility extracts the `PrimExpr` contents
+ * of a `relax::Tuple`.
+ *
+ * If the StructInfo cannot contain a tuple of the type specified,
+ * this function will throw an exception. (e.g. Attempting to extract
+ * a tuple from a `TensorStructInfo`.)
+ *
+ * \tparam PrimType The subtype of PrimExpr to extract. For example,
+ * extracting an `Array<Integer>`
+ *
+ * \param sinfo The StructInfo to inspect
+ *
+ * \returns An array of the `PrimType`, if it can be extracted.
+ * Otherwise, `NullOpt`.
+ */
+template <typename PrimType = PrimExpr,
+ typename = std::enable_if_t<std::is_base_of_v<PrimExpr, PrimType>>>
+Optional<Array<PrimType>> UnpackTupleOfPrimValue(Optional<StructInfo> sinfo) {
+ if (!sinfo) return NullOpt;
- if (data_sinfo->IsUnknownNdim()) {
- return TensorStructInfo(data_sinfo->dtype, kUnknownNDim,
data_sinfo->vdevice);
+ // An ObjectStructInfo may contain a tuple of the desired type, but
+ // it isn't yet known whether it does. Return early, as we cannot
+ // provide a known `Array<PrimType>` to the caller.
+ if (sinfo.as<ObjectStructInfoNode>()) return NullOpt;
+
+ auto tuple = sinfo.as<TupleStructInfoNode>();
+ CHECK(tuple) << "TypeError: "
+ << "The struct info " << sinfo << " cannot contain a tuple
whose elements are "
+ << PrimType::ContainerType::_type_key;
+
+ Array<PrimType> output;
+ for (size_t i = 0; i < tuple->fields.size(); i++) {
+ auto field = tuple->fields[i];
+
+ if (field.as<ObjectStructInfoNode>()) return NullOpt;
+
+ auto prim_sinfo = field.as<PrimStructInfoNode>();
+ CHECK(prim_sinfo) << "TypeError: "
+ << "The struct info " << sinfo
+ << " cannot contain a tuple whose elements are "
+ << PrimType::ContainerType::_type_key << ", because
element " << i
+ << " has struct info " << field;
+
+ if (!prim_sinfo->value.defined()) return NullOpt;
+
+ Optional<PrimType> element = prim_sinfo->value.as<PrimType>();
+ if (!element) return NullOpt;
+
+ output.push_back(element.value());
}
+ return output;
+}
- std::vector<int> axes = NormalizeAxes(call, ctx, data_sinfo->ndim,
attrs->axes);
- const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
- if (data_shape == nullptr) {
- return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim,
data_sinfo->vdevice);
+/* \brief Helper function to unpack a relax::Tuple
+ *
+ * A `relax::Tuple` may be provided to an operator as an in-line
+ * expression, as a variable bound to known tuple within the current
+ * function, as a function argument, etc. The StructInfo of the tuple
+ * tracks the known values of any `PrimValue` elements, but it can be
+ * tedious to extract. This utility extracts the `PrimExpr` contents
+ * of a `relax::Tuple`.
+ *
+ * If the StructInfo cannot contain a tuple of the type specified,
+ * this function will throw an exception. (e.g. Attempting to extract
+ * a tuple from a `TensorStructInfo`.)
+ *
+ * \tparam PrimType The subtype of PrimExpr to extract. For example,
+ * extracting an `Array<Integer>`
+ *
+ * \param expr The `relax::Expr` to inspect
+ *
+ * \returns An array of the `PrimType`, if it can be extracted.
+ * Otherwise, `NullOpt`.
+ */
+template <typename PrimType = PrimExpr,
+ typename = std::enable_if_t<std::is_base_of_v<PrimExpr, PrimType>>>
+Optional<Array<PrimType>> UnpackTupleOfPrimValue(Optional<Expr> expr) {
+ if (expr) {
+ return UnpackTupleOfPrimValue<PrimType>(GetStructInfo(expr.value()));
+ } else {
+ return NullOpt;
}
+}
- int n_axis = axes.size();
- Array<PrimExpr> strides = attrs->strides.defined()
- ? attrs->strides.value()
- : Array<PrimExpr>(n_axis,
IntImm(DataType::Int(64), 1));
- std::vector<int64_t> int_strides;
- int_strides.reserve(n_axis);
- // Only do output shape inference when all the begin/end/strides values are
integers.
- for (int i = 0; i < n_axis; ++i) {
- const auto* int_stride = strides[i].as<IntImmNode>();
- if (!int_stride) {
- return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim,
data_sinfo->vdevice);
+StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder&
ctx) {
+ size_t n_args = call->args.size();
+ CHECK(4 <= n_args && n_args <= 5)
+ << "Operator " << call->op << " accepts either three arguments (data,
axes, begin, end) "
+ << " or four arguments (data, axes, begin, end, strides), "
+ << "but received " << n_args << " in expression " << call;
+
+ Expr data = call->args[0];
+ Expr axes = call->args[1];
+ Expr begin = call->args[2];
+ Expr end = call->args[3];
+ Optional<Expr> strides = [&]() -> Optional<Expr> {
+ if (n_args > 4) {
+ return call->args[4];
+ } else {
+ return NullOpt;
+ }
+ }();
+
+ auto axes_sinfo = GetStructInfo(call->args[1]);
+ auto begin_sinfo = GetStructInfo(call->args[2]);
+ auto end_sinfo = GetStructInfo(call->args[3]);
+ auto strides_sinfo = [&]() -> Optional<StructInfo> {
+ if (n_args > 4) {
+ return GetStructInfo(call->args[4]);
+ } else {
+ return NullOpt;
}
- int_strides.push_back(int_stride->value);
+ }();
+
+ CHECK(IsBaseOf(relax::TensorStructInfo(DataType::Void(), kUnknownNDim),
GetStructInfo(data)))
+ << "Operator " << call->op << " requires the first argument to be a
tensor. "
+ << "However, in expression " << call << ", the first argument " << data
<< " has struct info "
+ << GetStructInfo(data);
+
+ // TODO(Lunderberg): Implement this check using `IsBaseOf`. Doing
+ // so will require a way to represent a `relax::TupleStructInfo` of
+ // unknown length, where each element has the same `StructInfo`.
Review Comment:
Hm, I think the idea of having a list type has come up before. I think we've
had lists via `Object`s and `PackedFunc`s before, probably for gradient.
Reifying it in the type system could be feasible if it's a common enough use
case.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]