yongwww commented on a change in pull request #4312:
URL: https://github.com/apache/incubator-tvm/pull/4312#discussion_r434764491



##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -1772,73 +1788,161 @@ Array<Array<Layout>> 
StridedSliceInferCorrectLayout(const Attrs& attrs,
   }
 
   CHECK(old_in_layouts.defined());
-  CHECK_EQ(old_in_layouts.size(), 1);
+  CHECK_GE(old_in_layouts.size(), 1);
   CHECK(old_in_shapes.defined());
-  CHECK_EQ(old_in_shapes.size(), 1);
+  CHECK_GE(old_in_shapes.size(), 1);
 
   auto layout = old_in_layouts[0];
   if (layout.defined() && new_in_layouts.defined()) {
-    CHECK_EQ(new_in_layouts.size(), 1);
+    CHECK_GE(new_in_layouts.size(), 1);
     auto new_layout = new_in_layouts[0];
     auto shape = old_in_shapes[0];
 
     // NOTE: Discard "const" qualifier here.
     auto* params = 
const_cast<StridedSliceAttrs*>(attrs.as<StridedSliceAttrs>());
+    CHECK(params != nullptr);
+    Array<Integer> begin, end, strides;
+    if (params->begin && params->end && params->strides) {
+      for (Integer i : params->strides.value()) {
+        CHECK(i.defined());
+        strides.push_back(params->slice_mode ? 1 : i->value);
+      }
+
+      for (Integer i : params->begin.value()) {
+        CHECK(i.defined());
+        begin.push_back(i->value);
+      }
+      for (Integer i : params->end.value()) {
+        CHECK(i.defined());
+        end.push_back(i->value);
+      }
+    }
 
     Array<Integer> new_begin, new_end;
 
-    for (size_t i = 0; i < params->begin.size(); i++) {
+    for (size_t i = 0; i < begin.size(); i++) {
       const LayoutAxis& axis = layout[i];
       if (!axis.IsPrimal()) {
         // original layout that contains splitted axes is not supported
         return {{Layout::Undef()}, {Layout::Undef()}};
       }
       auto factor = new_layout.FactorOf(axis);
       if (factor == -1) {
-        new_begin.push_back(params->begin[i]);
-        new_end.push_back(params->end[i]);
+        new_begin.push_back(begin[i]);
+        new_end.push_back(end[i]);
       } else {
-        if (params->strides.defined() && i < params->strides.size()) {
-          auto stride = params->strides[i];
+        if (strides.defined() && i < strides.size()) {
+          auto stride = strides[i];
           // arbitrary stride is not supported
           if (stride.defined() && stride->value != 1) {
             return {{Layout::Undef()}, {Layout::Undef()}};
           }
         }
-        int64_t begin = params->begin[i].defined() ? params->begin[i]->value : 
0;
-        int64_t end =
-            params->end[i].defined() ? params->end[i]->value : 
shape[i].as<IntImmNode>()->value;
-        if (begin % factor || end % factor) {
+        int64_t bg = begin[i].defined() ? begin[i]->value : 0;
+        int64_t ed;
+        if (!end[i].defined()) {
+          ed = shape[i].as<IntImmNode>()->value;
+        } else if (params->slice_mode) {
+          if (end[i]->value < 0) {
+            ed = shape[i].as<IntImmNode>()->value;
+          } else {
+            ed = bg + end[i]->value;
+          }
+        } else {
+          ed = end[i]->value;
+        }
+
+        if (bg % factor || ed % factor) {
           // transform to original layout
           return {{Layout::Undef()}, {Layout::Undef()}};
         }
-        new_begin.push_back(tvm::Integer(begin / factor));
-        new_end.push_back(tvm::Integer(end / factor));
+        new_begin.push_back(tvm::Integer(bg / factor));
+        new_end.push_back(tvm::Integer(ed / factor));
       }
     }
+
     layout = new_layout;
     params->begin = new_begin;
     params->end = new_end;
   }
-  return {{layout}, {layout}};
+  return {{layout, Layout("C"), Layout("C"), Layout("C")}, {layout}};
 }
 
-// Positional relay function to create StridedSlice operator used by frontend 
FFI.
-Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, 
Array<Integer> strides) {
-  auto attrs = make_object<StridedSliceAttrs>();
-  attrs->begin = std::move(begin);
-  attrs->end = std::move(end);
-  attrs->strides = std::move(strides);
-  static const Op& op = Op::Get("strided_slice");
-  return Call(op, {data}, Attrs(attrs), {});
+inline te::Tensor DynamicStridedSlice(const te::Tensor& input, const 
te::Tensor& begin,
+                                      const te::Tensor& end, const te::Tensor& 
strides,
+                                      std::string name = 
"T_strided_slice_dynamic",
+                                      std::string tag = topi::kInjective) {
+  int64_t src_tensor_dim = input->shape.size();
+  Array<IndexExpr> out_shape;
+  for (int64_t i = 0; i < src_tensor_dim; ++i) {
+    out_shape.push_back(tvm::tir::Var("dim"));
+  }
+  // TODO(yongwww): move the compute into topi
+  return te::compute(
+      out_shape,
+      [&](const Array<tvm::tir::Var>& indices) {
+        Array<IndexExpr> real_indices;
+        for (int32_t i = 0; i < src_tensor_dim; ++i) {
+          real_indices.push_back(indices[i] * strides(i) + begin(i));
+        }
+        return input(real_indices);
+      },
+      name, tag);
 }
 
 Array<te::Tensor> StridedSliceCompute(const Attrs& attrs, const 
Array<te::Tensor>& inputs,
                                       const Type& out_type) {
   const StridedSliceAttrs* param = attrs.as<StridedSliceAttrs>();
   CHECK(param != nullptr);
-  return Array<te::Tensor>{
-      topi::strided_slice(inputs[0], param->begin, param->end, 
param->strides)};
+  if (param->begin && param->end && param->strides) {
+    Array<Integer> begin, end, strides;
+    begin = param->begin.value();
+    end = param->end.value();
+    strides = param->strides.value();
+    return Array<te::Tensor>{
+        topi::strided_slice(inputs[0], begin, end, strides, 
param->slice_mode)};
+  } else {
+    te::Tensor data = inputs[0];
+    te::Tensor begin = inputs[1];
+    te::Tensor end = inputs[2];
+    te::Tensor strides = inputs[3];
+    // Dynamic computation
+    int64_t attr_size = data->shape.size();
+    CHECK(begin->shape[0].as<IntImmNode>()->value == attr_size &&
+          end->shape[0].as<IntImmNode>()->value == attr_size &&
+          strides->shape[0].as<IntImmNode>()->value == attr_size)
+        << "begin, end, and strides are required to have the same length"
+        << " if they are non-constant.";
+    return Array<te::Tensor>{DynamicStridedSlice(data, begin, end, strides)};
+  }
+}
+
+// Positional relay function to create StridedSlice operator used by frontend 
FFI.
+Expr MakeStridedSlice(Expr data, Expr begin, Expr end, Expr strides, bool 
slice_mode) {
+  auto attrs = make_object<StridedSliceAttrs>();
+  const ConstantNode *cbegin, *cend, *cstrides;
+  if ((cbegin = begin.as<ConstantNode>()) && (cend = end.as<ConstantNode>()) &&
+      (cstrides = strides.as<ConstantNode>())) {
+    CHECK_EQ(cbegin->data->ndim, 1);
+    CHECK_EQ(cend->data->ndim, 1);
+    CHECK_EQ(cstrides->data->ndim, 1);
+    Array<Integer> begin, end, strides;
+    for (int i = 0; i < cbegin->data->shape[0]; i++) {
+      begin.push_back(Integer(static_cast<int>(ToScalar(cbegin->data, i))));

Review comment:
       updated




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to