kevinthesun commented on a change in pull request #4312:
URL: https://github.com/apache/incubator-tvm/pull/4312#discussion_r414959052
##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -1891,81 +1952,163 @@ Array<Array<Layout> > StridedSliceInferCorrectLayout(
}
CHECK(old_in_layouts.defined());
- CHECK_EQ(old_in_layouts.size(), 1);
+ CHECK_GE(old_in_layouts.size(), 1);
CHECK(old_in_shapes.defined());
- CHECK_EQ(old_in_shapes.size(), 1);
+ CHECK_GE(old_in_shapes.size(), 1);
auto layout = old_in_layouts[0];
if (layout.defined() && new_in_layouts.defined()) {
- CHECK_EQ(new_in_layouts.size(), 1);
+ CHECK_GE(new_in_layouts.size(), 1);
auto new_layout = new_in_layouts[0];
auto shape = old_in_shapes[0];
// NOTE: Discard "const" qualifier here.
auto *params =
const_cast<StridedSliceAttrs*>(attrs.as<StridedSliceAttrs>());
+ CHECK(params != nullptr);
+ Array<Integer> begin, end, strides;
+ const ConstantNode *cbegin, *cend, *cstrides;
+ if ((cbegin = params->begin.as<ConstantNode>()) &&
+ (cend = params->end.as<ConstantNode>()) &&
+ (cstrides = params->strides.as<ConstantNode>())) {
+ int64_t* strides_val = ToVector(cstrides->data);
+ for (int64_t i = 0; i < cstrides->data.Shape().front(); ++i) {
+ strides.push_back(strides_val[i]);
+ }
+ int64_t* begin_val = ToVector(cbegin->data);
+ for (int64_t i = 0; i < cbegin->data.Shape().front(); ++i) {
+ begin.push_back(begin_val[i]);
+ }
+ int64_t* end_val = ToVector(cend->data);
+ for (int64_t i = 0; i < cend->data.Shape().front(); ++i) {
+ end.push_back(end_val[i]);
+ }
+ }
Array<Integer> new_begin, new_end;
- for (size_t i = 0; i < params->begin.size(); i++) {
+ for (size_t i = 0; i < begin.size(); i++) {
const LayoutAxis& axis = layout[i];
if (!axis.IsPrimal()) {
// original layout that contains splitted axes is not supported
return {{Layout::Undef()}, {Layout::Undef()}};
}
auto factor = new_layout.FactorOf(axis);
if (factor == -1) {
- new_begin.push_back(params->begin[i]);
- new_end.push_back(params->end[i]);
+ new_begin.push_back(begin[i]);
+ new_end.push_back(end[i]);
} else {
- if (params->strides.defined() && i < params->strides.size()) {
- auto stride = params->strides[i];
+ if (strides.defined() && i < strides.size()) {
+ auto stride = strides[i];
// arbitrary stride is not supported
if (stride.defined() && stride->value != 1) {
return {{Layout::Undef()}, {Layout::Undef()}};
}
}
- int64_t begin = params->begin[i].defined() ? params->begin[i]->value :
0;
- int64_t end = params->end[i].defined() ? params->end[i]->value :
+ int64_t bg = begin[i].defined() ? begin[i]->value : 0;
+ int64_t ed = end[i].defined() ? end[i]->value :
shape[i].as<IntImmNode>()->value;
- if (begin % factor || end % factor) {
+ if (bg % factor || ed % factor) {
// transform to original layout
return {{Layout::Undef()}, {Layout::Undef()}};
}
- new_begin.push_back(tvm::Integer(begin / factor));
- new_end.push_back(tvm::Integer(end / factor));
+ new_begin.push_back(tvm::Integer(bg / factor));
+ new_end.push_back(tvm::Integer(ed / factor));
}
}
- layout = new_layout;
- params->begin = new_begin;
- params->end = new_end;
- }
- return {{layout}, {layout}};
-}
+ layout = new_layout;
-// Positional relay function to create StridedSlice operator used by frontend
FFI.
-Expr MakeStridedSlice(Expr data,
- Array<Integer> begin,
- Array<Integer> end,
- Array<Integer> strides) {
- auto attrs = make_object<StridedSliceAttrs>();
- attrs->begin = std::move(begin);
- attrs->end = std::move(end);
- attrs->strides = std::move(strides);
- static const Op& op = Op::Get("strided_slice");
- return Call(op, {data}, Attrs(attrs), {});
+ DLContext ctx;
+ ctx.device_type = kDLCPU;
+ ctx.device_id = 0;
+ auto begin_ndarray = runtime::NDArray::Empty({int64_t(new_begin.size())},
+ DataType::Int(64), ctx);
+ auto end_ndarray = runtime::NDArray::Empty({int64_t(new_begin.size())},
+ DataType::Int(64), ctx);
+ auto strides_ndarray = runtime::NDArray::Empty({int64_t(new_begin.size())},
+ DataType::Int(64), ctx);
+ int64_t* begin_data = static_cast<int64_t*>(begin_ndarray->data);
+ int64_t* end_data = static_cast<int64_t*>(end_ndarray->data);
+ for (size_t i = 0; i < new_begin.size(); ++i) {
+ begin_data[i] = new_begin[i];
+ end_data[i] = new_end[i];
+ }
+ params->begin = Constant(begin_ndarray);
+ params->end = Constant(end_ndarray);
+ }
+ return {{layout, Layout("C"), Layout("C"), Layout("C")}, {layout}};
+}
+
+inline te::Tensor DynamicStridedSlice(const te::Tensor& input,
+ const te::Tensor& begin,
+ const te::Tensor& end,
+ const te::Tensor& strides,
+ std::string name = "T_strided_slice_dynamic",
+ std::string tag = topi::kInjective) {
+ int64_t src_tensor_dim = input->shape.size();
+ Array<IndexExpr> out_shape;
+ for (int64_t i = 0; i < src_tensor_dim; ++i) {
+ out_shape.push_back(tvm::tir::Var("dim"));
+ }
+ // TODO(yongwww): move the compute into topi
+ return te::compute(out_shape, [&](const Array<tvm::tir::Var>& indices) {
+ Array<IndexExpr> real_indices;
+ for (int32_t i = 0; i < src_tensor_dim; ++i) {
+ real_indices.push_back(indices[i] * strides(i) + begin(i));
+ }
+ return input(real_indices);
+ }, name, tag);
}
Array<te::Tensor> StridedSliceCompute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type) {
const StridedSliceAttrs *param = attrs.as<StridedSliceAttrs>();
CHECK(param != nullptr);
- return Array<te::Tensor>{
- topi::strided_slice(inputs[0], param->begin, param->end, param->strides)
- };
+ const ConstantNode *cbegin, *cend, *cstrides;
+ if ((cbegin = param->begin.as<ConstantNode>()) &&
+ (cend = param->end.as<ConstantNode>()) &&
+ (cstrides = param->strides.as<ConstantNode>())) {
+ Array<Integer> begin, end, strides;
+ int64_t* strides_val = ToVector(cstrides->data);
+ for (int64_t i = 0; i < cstrides->data.Shape().front(); ++i) {
+ strides.push_back(strides_val[i]);
+ }
+ int64_t* begin_val = ToVector(cbegin->data);
+ for (int64_t i = 0; i < cbegin->data.Shape().front(); ++i) {
+ begin.push_back(begin_val[i]);
+ }
+ int64_t* end_val = ToVector(cend->data);
+ for (int64_t i = 0; i < cend->data.Shape().front(); ++i) {
+ end.push_back(end_val[i]);
+ }
+ return Array<te::Tensor>{
+ topi::strided_slice(inputs[0], begin, end, strides)
+ };
+ } else {
+ te::Tensor data = inputs[0];
+ te::Tensor begin = inputs[1];
+ te::Tensor end = inputs[2];
+ te::Tensor strides = inputs[3];
+ // Dynamic computation
Review comment:
We might want to enforce user to provide full begin, end and strides for
symbolic attr case, since dealing with these inside topi would be not ideal.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]