kevinthesun commented on a change in pull request #4312:
URL: https://github.com/apache/incubator-tvm/pull/4312#discussion_r414959052



##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -1891,81 +1952,163 @@ Array<Array<Layout> > StridedSliceInferCorrectLayout(
   }
 
   CHECK(old_in_layouts.defined());
-  CHECK_EQ(old_in_layouts.size(), 1);
+  CHECK_GE(old_in_layouts.size(), 1);
   CHECK(old_in_shapes.defined());
-  CHECK_EQ(old_in_shapes.size(), 1);
+  CHECK_GE(old_in_shapes.size(), 1);
 
   auto layout = old_in_layouts[0];
   if (layout.defined() && new_in_layouts.defined()) {
-    CHECK_EQ(new_in_layouts.size(), 1);
+    CHECK_GE(new_in_layouts.size(), 1);
     auto new_layout = new_in_layouts[0];
     auto shape = old_in_shapes[0];
 
     // NOTE: Discard "const" qualifier here.
     auto *params = 
const_cast<StridedSliceAttrs*>(attrs.as<StridedSliceAttrs>());
+    CHECK(params != nullptr);
+    Array<Integer> begin, end, strides;
+    const ConstantNode *cbegin, *cend, *cstrides;
+    if ((cbegin = params->begin.as<ConstantNode>()) &&
+        (cend = params->end.as<ConstantNode>()) &&
+        (cstrides = params->strides.as<ConstantNode>())) {
+      int64_t* strides_val = ToVector(cstrides->data);
+      for (int64_t i = 0; i < cstrides->data.Shape().front(); ++i) {
+        strides.push_back(strides_val[i]);
+      }
+      int64_t* begin_val = ToVector(cbegin->data);
+      for (int64_t i = 0; i < cbegin->data.Shape().front(); ++i) {
+        begin.push_back(begin_val[i]);
+      }
+      int64_t* end_val = ToVector(cend->data);
+      for (int64_t i = 0; i < cend->data.Shape().front(); ++i) {
+        end.push_back(end_val[i]);
+      }
+    }
 
     Array<Integer> new_begin, new_end;
 
-    for (size_t i = 0; i < params->begin.size(); i++) {
+    for (size_t i = 0; i < begin.size(); i++) {
       const LayoutAxis& axis = layout[i];
       if (!axis.IsPrimal()) {
         // original layout that contains splitted axes is not supported
         return {{Layout::Undef()}, {Layout::Undef()}};
       }
       auto factor = new_layout.FactorOf(axis);
       if (factor == -1) {
-        new_begin.push_back(params->begin[i]);
-        new_end.push_back(params->end[i]);
+        new_begin.push_back(begin[i]);
+        new_end.push_back(end[i]);
       } else {
-        if (params->strides.defined() && i < params->strides.size()) {
-          auto stride = params->strides[i];
+        if (strides.defined() && i < strides.size()) {
+          auto stride = strides[i];
           // arbitrary stride is not supported
           if (stride.defined() && stride->value != 1) {
             return {{Layout::Undef()}, {Layout::Undef()}};
           }
         }
-        int64_t begin = params->begin[i].defined() ? params->begin[i]->value : 
0;
-        int64_t end = params->end[i].defined() ? params->end[i]->value :
+        int64_t bg = begin[i].defined() ? begin[i]->value : 0;
+        int64_t ed = end[i].defined() ? end[i]->value :
             shape[i].as<IntImmNode>()->value;
-        if (begin % factor || end % factor) {
+        if (bg % factor || ed % factor) {
           // transform to original layout
           return {{Layout::Undef()}, {Layout::Undef()}};
         }
-        new_begin.push_back(tvm::Integer(begin / factor));
-        new_end.push_back(tvm::Integer(end / factor));
+        new_begin.push_back(tvm::Integer(bg / factor));
+        new_end.push_back(tvm::Integer(ed / factor));
       }
     }
-    layout = new_layout;
-    params->begin = new_begin;
-    params->end = new_end;
-  }
-  return {{layout}, {layout}};
-}
 
+    layout = new_layout;
 
-// Positional relay function to create StridedSlice operator used by frontend 
FFI.
-Expr MakeStridedSlice(Expr data,
-                      Array<Integer> begin,
-                      Array<Integer> end,
-                      Array<Integer> strides) {
-  auto attrs = make_object<StridedSliceAttrs>();
-  attrs->begin = std::move(begin);
-  attrs->end = std::move(end);
-  attrs->strides = std::move(strides);
-  static const Op& op = Op::Get("strided_slice");
-  return Call(op, {data}, Attrs(attrs), {});
+    DLContext ctx;
+    ctx.device_type = kDLCPU;
+    ctx.device_id = 0;
+    auto begin_ndarray = runtime::NDArray::Empty({int64_t(new_begin.size())},
+                                                 DataType::Int(64), ctx);
+    auto end_ndarray = runtime::NDArray::Empty({int64_t(new_begin.size())},
+                                               DataType::Int(64), ctx);
+    auto strides_ndarray = runtime::NDArray::Empty({int64_t(new_begin.size())},
+                                                   DataType::Int(64), ctx);
+    int64_t* begin_data = static_cast<int64_t*>(begin_ndarray->data);
+    int64_t* end_data = static_cast<int64_t*>(end_ndarray->data);
+    for (size_t i = 0; i < new_begin.size(); ++i) {
+      begin_data[i] = new_begin[i];
+      end_data[i] = new_end[i];
+    }
+    params->begin = Constant(begin_ndarray);
+    params->end = Constant(end_ndarray);
+  }
+  return {{layout, Layout("C"), Layout("C"), Layout("C")}, {layout}};
+}
+
+inline te::Tensor DynamicStridedSlice(const te::Tensor& input,
+                                  const te::Tensor& begin,
+                                  const te::Tensor& end,
+                                  const te::Tensor& strides,
+                                  std::string name = "T_strided_slice_dynamic",
+                                  std::string tag = topi::kInjective) {
+  int64_t src_tensor_dim = input->shape.size();
+  Array<IndexExpr> out_shape;
+  for (int64_t i = 0; i < src_tensor_dim; ++i) {
+    out_shape.push_back(tvm::tir::Var("dim"));
+  }
+  // TODO(yongwww): move the compute into topi
+  return te::compute(out_shape, [&](const Array<tvm::tir::Var>& indices) {
+      Array<IndexExpr> real_indices;
+      for (int32_t i = 0; i < src_tensor_dim; ++i) {
+        real_indices.push_back(indices[i] * strides(i) + begin(i));
+      }
+      return input(real_indices);
+    }, name, tag);
 }
 
 Array<te::Tensor> StridedSliceCompute(const Attrs& attrs,
                                       const Array<te::Tensor>& inputs,
                                       const Type& out_type) {
   const StridedSliceAttrs *param = attrs.as<StridedSliceAttrs>();
   CHECK(param != nullptr);
-  return Array<te::Tensor>{
-    topi::strided_slice(inputs[0], param->begin, param->end, param->strides)
-  };
+  const ConstantNode *cbegin, *cend, *cstrides;
+  if ((cbegin = param->begin.as<ConstantNode>()) &&
+      (cend = param->end.as<ConstantNode>()) &&
+      (cstrides = param->strides.as<ConstantNode>())) {
+    Array<Integer> begin, end, strides;
+    int64_t* strides_val = ToVector(cstrides->data);
+    for (int64_t i = 0; i < cstrides->data.Shape().front(); ++i) {
+      strides.push_back(strides_val[i]);
+    }
+    int64_t* begin_val = ToVector(cbegin->data);
+    for (int64_t i = 0; i < cbegin->data.Shape().front(); ++i) {
+      begin.push_back(begin_val[i]);
+    }
+    int64_t* end_val = ToVector(cend->data);
+    for (int64_t i = 0; i < cend->data.Shape().front(); ++i) {
+      end.push_back(end_val[i]);
+    }
+    return Array<te::Tensor>{
+      topi::strided_slice(inputs[0], begin, end, strides)
+    };
+  } else {
+    te::Tensor data = inputs[0];
+    te::Tensor begin = inputs[1];
+    te::Tensor end = inputs[2];
+    te::Tensor strides = inputs[3];
+    // Dynamic computation

Review comment:
       We might want to enforce user to provide full begin, end and strides for 
symbolic attr case, since dealing with these inside topi would be not ideal.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to