mbrookhart commented on a change in pull request #8165:
URL: https://github.com/apache/tvm/pull/8165#discussion_r643532442
##########
File path: include/tvm/topi/transform.h
##########
@@ -594,137 +608,152 @@ inline te::Tensor dynamic_strided_slice(const
te::Tensor& x, const te::Tensor& b
}
/*!
- * \brief strided_slice of a tensor
+ * \brief strided_slice of a tensor with dynamic begin/end/stride
*
* \param x The input tensor
* \param begin The indices to begin with in the slicing
* \param end Indicies indicating end of the slice
* \param strides Specifies the stride values, it can be negative
* in that case, the input tensor will be reversed in that particular axis
- * \param slice_mode Specifies the slice mode
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the split operation
*/
-inline Tensor strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
- const Array<PrimExpr>& end, const Array<PrimExpr>&
strides,
- std::string slice_mode = "end", std::string name =
"T_strided_slice",
- std::string tag = kInjective) {
- size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
- // Quick path for dynamic shape strided slice.
- // This is for ease of use to dynamice strided slice in topi.
- bool is_static = IsConstIntArray(x->shape);
- is_static &= IsConstIntArray(begin);
- is_static &= IsConstIntArray(end);
- is_static &= IsConstIntArray(strides);
-
- Array<PrimExpr> out_shape;
- if (!is_static) {
- ICHECK_EQ(strides.size(), src_tensor_dim);
- for (size_t i = 0; i < src_tensor_dim; ++i) {
- out_shape.push_back(indexdiv(end[i] - begin[i], strides[i]));
- }
- return te::compute(
- out_shape,
- [&](const Array<tvm::tir::Var>& indices) {
- Array<PrimExpr> real_indices;
- for (size_t i = 0; i < src_tensor_dim; ++i) {
- real_indices.push_back(indices[i] * strides[i] + begin[i]);
- }
- return x(real_indices);
- },
- name, tag);
- }
-
- // Setup the ranges.
- // NOTE: this code duplicates the shape inference logic relay.op
- // Consider to refactor in the future.
- std::vector<int64_t> stride_vec(src_tensor_dim, 1);
- for (size_t i = 0; i < strides.size(); ++i) {
- ICHECK(strides[i].defined());
- stride_vec[i] = GetConstInt(strides[i]);
- }
-
- const int64_t max_range = std::numeric_limits<int64_t>::max();
+inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor&
begin,
+ const te::Tensor& end, const
te::Tensor& strides,
+ std::string name =
"T_strided_slice_dynamic",
+ std::string tag = topi::kInjective) {
+ const int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value;
+ ICHECK_EQ(end->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
+ ICHECK_EQ(strides->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
- std::vector<int64_t> begin_vec;
- for (size_t i = 0; i < begin.size(); ++i) {
- if (!begin[i].defined()) {
- // value=None
- begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
- } else {
- begin_vec.push_back(GetConstInt(begin[i]));
- }
- }
- for (size_t i = begin_vec.size(); i < src_tensor_dim; ++i) {
- begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
+ Array<PrimExpr> begin_expr, end_expr, strides_expr;
+ for (int64_t i = 0; i < num_dynamic_axes; ++i) {
+ auto i64_ind = IntImm(DataType::Int(64), i);
+ begin_expr.push_back(begin(i64_ind));
+ end_expr.push_back(end(i64_ind));
+ strides_expr.push_back(strides(i64_ind));
}
+ return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name,
tag);
+}
- std::vector<int64_t> end_vec;
- for (size_t i = 0; i < end.size(); ++i) {
- // allow end to be None
+/*!
+ * \brief Calcluate the output shape of strided_slice, the entry point for
Relay type relation
+ *
+ * \param ishape The input tensor shape
+ * \param begin The indices to begin with in the slicing
+ * \param end Indicies indicating end of the slice
+ * \param strides Specifies the stride values, it can be negative
+ * in that case, the input tensor will be reversed in that particular axis
+ * \param axes Axes along which slicing is applied. When it is specified, the
length of begin, end,
+ * strides, and axes argument must be equal
+ * \param slice_mode Specifies the slice mode
+ *
+ * \return The output shape of strided_slice using the arguments above
+ */
+inline Array<PrimExpr> StridedSliceOutputShape(
+ const Array<PrimExpr>& ishape, const Array<Integer>& begin, const
Array<Integer>& end,
+ const Array<Integer>& strides, const Array<Integer>& axes, const
std::string& slice_mode) {
+ ICHECK(axes.size() == begin.size() && axes.size() == end.size() &&
axes.size() == strides.size());
+ std::vector<int64_t> begin_vec, end_vec, strides_vec;
+ std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end,
strides, slice_mode);
+ auto begin_canonicalized = StridedSliceCanonicalizeBegin(ishape, begin_vec,
strides_vec, axes,
+ begin[0]->dtype,
slice_mode);
+ return StridedSliceOutputShape(ishape, begin_vec, end_vec, strides_vec,
axes, slice_mode,
+ begin_canonicalized, true);
+}
- if (!end[i].defined()) {
- end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
- } else if (slice_mode == "size") {
- int64_t end_val = GetConstInt(end[i]);
- if (end_val < 0) {
- end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
- } else {
- end_vec.push_back(begin_vec[i] + end_val);
- }
- } else {
- end_vec.push_back(GetConstInt(end[i]));
- }
- }
- for (size_t i = end_vec.size(); i < src_tensor_dim; ++i) {
- end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
- }
- // Compute
- Array<PrimExpr> begin_expr;
- Array<PrimExpr> strides_expr;
-
- for (size_t i = 0; i < src_tensor_dim; ++i) {
- int64_t begin_range = stride_vec[i] < 0 ? -1 : 0;
- int64_t dim_i = GetConstInt(x->shape[i]);
- int64_t end_range = stride_vec[i] < 0 ? dim_i - 1 : dim_i;
- // transform negative indices to positive value, clips on the correct range
- auto index_canonicalization = [dim_i, begin_range, end_range](int64_t
index) {
- if (index < 0) {
- index += dim_i;
- }
- return std::min(std::max(index, begin_range), end_range);
- };
-
- int64_t begin_i = index_canonicalization(begin_vec[i]);
- int64_t end_i = index_canonicalization(end_vec[i]);
-
- int interval = std::abs(end_i - begin_i);
- int slice_size =
- static_cast<int>((interval + std::abs(stride_vec[i]) - 1) /
std::abs(stride_vec[i]));
- ICHECK(stride_vec[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
- << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i]
- << "] is invalid for axis=" << i;
-
- begin_expr.push_back(make_const(begin[0].dtype(), begin_i));
- strides_expr.push_back(
- make_const((strides.size() != 0 ? strides[0].dtype() :
begin[0].dtype()), stride_vec[i]));
- out_shape.push_back(slice_size);
- }
+/*!
+ * \brief strided_slice of a tensor
+ *
+ * \param x The input tensor
+ * \param begin The indices to begin with in the slicing
+ * \param end Indicies indicating end of the slice
+ * \param strides Specifies the stride values, it can be negative
+ * in that case, the input tensor will be reversed in that particular axis
+ * \param axes Axes along which slicing is applied. When it is specified, the
length of begin, end,
+ * strides, and axes argument must be equal
+ * \param slice_mode Specifies the slice mode
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the split operation
Review comment:
split->slice?
##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -2445,99 +2445,40 @@ bool StridedSliceRel(const Array<Type>& types, int
num_inputs, const Attrs& attr
return false;
}
- auto dshape = data->shape;
- int64_t num_axis = dshape.size();
-
- // calculate output shape
- std::vector<IndexExpr> oshape(num_axis);
- if (param->begin && param->end && param->strides) {
- // stride will be set as 1 if slice mode is enabled
- std::vector<int64_t> stride_vec(num_axis, 1);
- if (param->slice_mode == "end") {
- for (size_t i = 0; i < param->strides.value().size(); ++i) {
- ICHECK(param->strides.value()[i].defined());
- stride_vec[i] = param->strides.value()[i]->value;
- }
- }
- const int64_t max_range = std::numeric_limits<int64_t>::max();
- std::vector<int64_t> begin_vec;
- for (size_t i = 0; i < param->begin.value().size(); ++i) {
- if (!param->begin.value()[i].defined()) {
- begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
- } else {
- begin_vec.push_back(param->begin.value()[i]->value);
- }
- }
- for (int64_t i = begin_vec.size(); i < num_axis; ++i) {
- begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
- }
+ ICHECK(param->begin) << "strided_slice recieved invalid begin " <<
param->begin;
+ ICHECK(param->end) << "strided_slice recieved invalid end " << param->end;
+ ICHECK(param->strides) << "strided_slice recieved invalid strides " <<
param->strides;
+
+ auto begin = param->begin.value();
+ auto end = param->end.value();
+ auto strides = param->strides.value();
+
+ const size_t src_tensor_dim = static_cast<size_t>(data->shape.size());
+ Array<Integer> axes;
+ if (param->axes) {
+ axes = param->axes.value();
+ ICHECK(axes.size() == begin.size() && axes.size() == end.size() &&
+ axes.size() == strides.size())
+ << "axes, begin, end, and strides must have the same length";
+ } else {
+ for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i);
- std::vector<int64_t> end_vec;
- for (size_t i = 0; i < param->end.value().size(); ++i) {
- // allow end to be None
- if (!param->end.value()[i].defined()) {
- end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
- } else if (param->slice_mode == "size") {
- if (param->end.value()[i]->value < 0) {
- end_vec.push_back(max_range);
- } else {
- end_vec.push_back(begin_vec[i] + param->end.value()[i]->value);
- }
- } else if (param->slice_mode == "end") {
- end_vec.push_back(param->end.value()[i]->value);
- } else {
- LOG(FATAL) << "Unsupported slice mode: " << param->slice_mode;
- }
+ const IntImm one = IntImm(DataType::Int(64), 1);
+ const IntImm zero = IntImm(DataType::Int(64), 0);
+ const IntImm max_range = IntImm(DataType::Int(64),
std::numeric_limits<int64_t>::max());
+
+ for (size_t i = strides.size(); i < src_tensor_dim; ++i) {
+ strides.push_back(one);
}
- for (int64_t i = end_vec.size(); i < num_axis; ++i) {
- end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
+ for (size_t i = begin.size(); i < src_tensor_dim; ++i) {
+ begin.push_back(topi::GetConstInt(strides[i]) > 0 ? zero : max_range);
}
-
- for (int64_t i = 0; i < num_axis; ++i) {
- int64_t stride_v = stride_vec[i];
- int64_t begin_v = begin_vec[i];
- int64_t end_v = end_vec[i];
-
- if ((stride_v == 1 && begin_v == 0 && end_v == max_range) ||
- (stride_v == -1 && begin_v == max_range && end_v == 0)) {
- // Quick path, do not slice this dimension.
- oshape[i] = dshape[i];
- continue;
- }
- // Normal path, require the shape to be concrete integer.
- // Require concrete integer as symbolic inference of min/max
- // can get complicated and not very helpful.
- const int64_t* p_dim_size = tir::as_const_int(dshape[i]);
- if (!p_dim_size) {
- oshape[i] = dshape[i];
- continue;
- }
- int64_t dim_size = p_dim_size[0];
- begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v;
- end_v = (end_v < 0) ? dim_size + end_v : end_v;
-
- int64_t slice_range, step;
- if (stride_v < 0) {
- if (end_v < -1) end_v = -1;
- ICHECK_LE(end_v, begin_v) << "strided_slice get empty slice at axis "
<< i;
- begin_v = std::min(dim_size - 1, begin_v);
- slice_range = begin_v - end_v;
- step = -stride_v;
- } else {
- if (begin_v < 0) begin_v = 0;
- ICHECK_GE(stride_v, 0);
- ICHECK_LE(begin_v, end_v) << "strided_slice get invalid slice at axis
" << i;
- end_v = std::min(dim_size, end_v);
- slice_range = end_v - begin_v;
- step = stride_v;
- }
- oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1)
/ step);
+ for (size_t i = end.size(); i < src_tensor_dim; ++i) {
+ end.push_back(topi::GetConstInt(strides[i]) < 0 ? zero : max_range);
}
- } else {
- ICHECK(param->begin) << "strided_slice recieved invalid begin " <<
param->begin;
- ICHECK(param->end) << "strided_slice recieved invalid end " << param->end;
- ICHECK(param->strides) << "strided_slice recieved invalid strides " <<
param->strides;
}
+ auto oshape =
+ topi::StridedSliceOutputShape(data->shape, begin, end, strides, axes,
param->slice_mode);
Review comment:
Thank you for moving this to a common utility :bowing_man:
##########
File path: python/tvm/relay/op/transform.py
##########
@@ -917,7 +922,7 @@ def strided_slice(data, begin, end, strides=None,
slice_mode="end"):
begin = _make.where(begin < cast_like(const(0), begin), begin +
ishape_slice, begin)
begin = _make.where(begin >= ishape_slice, ishape_slice, begin)
return _dyn_make.strided_slice(data, begin, end, strides, slice_mode)
Review comment:
Do we not support axes with dynamic begin?
##########
File path: src/topi/transform.cc
##########
@@ -174,11 +174,26 @@ TVM_REGISTER_GLOBAL("topi.einsum").set_body([](TVMArgs
args, TVMRetValue* rv) {
});
TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args,
TVMRetValue* rv) {
Review comment:
Why not allow axes arguments here?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]