yongwww commented on a change in pull request #4312:
URL: https://github.com/apache/incubator-tvm/pull/4312#discussion_r427533545
##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -1947,57 +1948,81 @@ Array<Array<Layout> > StridedSliceInferCorrectLayout(
// NOTE: Discard "const" qualifier here.
auto *params =
const_cast<StridedSliceAttrs*>(attrs.as<StridedSliceAttrs>());
+ CHECK(params != nullptr);
+ Array<Integer> begin, end, strides;
+ const ConstantNode *cbegin, *cend, *cstrides;
+ if ((cbegin = params->begin.as<ConstantNode>()) &&
+ (cend = params->end.as<ConstantNode>()) &&
+ (cstrides = params->strides.as<ConstantNode>())) {
+
+ int32_t* strides_val = reinterpret_cast<int32_t*>(cstrides->data->data);
+ for (size_t i = 0; i < cstrides->data.Shape().front(); ++i){
+ strides.push_back(strides_val[i]);
+ }
+ int32_t* begin_val = reinterpret_cast<int32_t*>(cbegin->data->data);
+ for (size_t i = 0; i < cbegin->data.Shape().front(); ++i){
+ begin.push_back(begin_val[i]);
+ }
+ int32_t* end_val = reinterpret_cast<int32_t*>(cend->data->data);
+ for (size_t i = 0; i < cend->data.Shape().front(); ++i){
+ end.push_back(end_val[i]);
+ }
+ }
Array<Integer> new_begin, new_end;
- for (size_t i = 0; i < params->begin.size(); i++) {
+ for (size_t i = 0; i < begin.size(); i++) {
const LayoutAxis& axis = layout[i];
if (!axis.IsPrimal()) {
// original layout that contains splitted axes is not supported
return {{Layout::Undef()}, {Layout::Undef()}};
}
auto factor = new_layout.FactorOf(axis);
if (factor == -1) {
- new_begin.push_back(params->begin[i]);
- new_end.push_back(params->end[i]);
+ new_begin.push_back(begin[i]);
+ new_end.push_back(end[i]);
} else {
- if (params->strides.defined() && i < params->strides.size()) {
- auto stride = params->strides[i];
+ if (strides.defined() && i < strides.size()) {
+ auto stride = strides[i];
// arbitrary stride is not supported
if (stride.defined() && stride->value != 1) {
return {{Layout::Undef()}, {Layout::Undef()}};
}
}
- int64_t begin = params->begin[i].defined() ? params->begin[i]->value :
0;
- int64_t end = params->end[i].defined() ? params->end[i]->value :
+ int64_t bg = begin[i].defined() ? begin[i]->value : 0;
Review comment:
be consistent with that in topi
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]