electriclilies commented on a change in pull request #6316:
URL: https://github.com/apache/incubator-tvm/pull/6316#discussion_r474312010



##########
File path: python/tvm/relay/op/transform.py
##########
@@ -827,13 +828,17 @@ def strided_slice(data, begin, end, strides=None, 
slice_mode="end"):
     ret : relay.Expr
         The computed result.
     """
-    strides = strides or const([1], dtype="int32")
-    if isinstance(begin, (tuple, list)):
-        begin = const(list(begin))
-    if isinstance(end, (tuple, list)):
-        end = const(list(end))
-    if isinstance(strides, (tuple, list)):
-        strides = const(list(strides))
+    strides = strides or [1]
+    if (isinstance(begin, Expr) or isinstance(end, Expr) or 
isinstance(strides, Expr)):
+        if isinstance(begin, (tuple, list)):
+            begin = const(list(begin))
+        if isinstance(end, (tuple, list)):
+            end = const(list(end))
+        if isinstance(strides, (tuple, list)):
+            strides = const(list(strides))
+        begin = _make.where(begin < cast_like(const(0), begin),

Review comment:
       Can you rename this begin for clarity?

##########
File path: python/tvm/relay/op/dyn/_transform.py
##########
@@ -145,3 +146,53 @@ def one_hot_shape_func(attrs, inputs, _):
     """
     axis = len(inputs[0].shape) if attrs.axis == -1 else attrs.axis
     return [_onehot_shape_func(inputs[0].shape, inputs[3], convert(axis))]
+
+
+@script
+def _strided_slice_shape_func_input_data(data, begin, end, strides,

Review comment:
       What's the difference between `_strided_slice_shape_func_input_shape` 
and `_strided_slice_shape_func_input_data`?

##########
File path: tests/python/relay/test_op_level4.py
##########
@@ -343,7 +337,7 @@ def verify(dshape, begin, end, strides, output, 
slice_mode="end",
         text = func.astext()
         assert "begin=" in text
         assert "end=" in text
-
+        

Review comment:
       white space!!

##########
File path: src/relay/op/dyn/tensor/transform.cc
##########
@@ -430,6 +434,114 @@ RELAY_REGISTER_OP("dyn.full")
     .set_attr<FTVMCompute>("FTVMCompute", FullCompute)
     .set_attr<TOpPattern>("TOpPattern", kElemWise);
 
+bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                     const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 5);

Review comment:
       It would be nice to add a comment saying what each of the input types 
are (ie types = [type1_description, ... ret_type]

##########
File path: src/relay/op/dyn/tensor/transform.cc
##########
@@ -430,6 +434,114 @@ RELAY_REGISTER_OP("dyn.full")
     .set_attr<FTVMCompute>("FTVMCompute", FullCompute)
     .set_attr<TOpPattern>("TOpPattern", kElemWise);
 
+bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                     const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 5);
+  const StridedSliceAttrs* param = attrs.as<StridedSliceAttrs>();
+  if (param == nullptr) {
+    return false;
+  }
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) {
+    return false;
+  }
+  auto dshape = data->shape;
+  int64_t num_axis = dshape.size();
+
+  // calculate output shape
+  std::vector<IndexExpr> oshape(num_axis);
+  for (int64_t i = 0; i < num_axis; ++i) {
+    oshape[i] = Any();
+  }
+
+  reporter->Assign(types[4], TensorType(oshape, data->dtype));
+  return true;
+}
+
+inline te::Tensor DynamicStridedSlice(const te::Tensor& input, const 
te::Tensor& begin,
+                                      const te::Tensor& end, const te::Tensor& 
strides,
+                                      std::string name = 
"T_strided_slice_dynamic",
+                                      std::string tag = topi::kInjective) {
+  int64_t src_tensor_dim = input->shape.size();
+  Array<IndexExpr> out_shape;
+  for (int64_t i = 0; i < src_tensor_dim; ++i) {
+    out_shape.push_back(tvm::tir::Var("dim"));
+  }
+  // TODO(yongwww): move the compute into topi
+  return te::compute(
+      out_shape,
+      [&](const Array<tvm::tir::Var>& indices) {
+        Array<IndexExpr> real_indices;
+        for (int32_t i = 0; i < src_tensor_dim; ++i) {
+          real_indices.push_back(indices[i] * strides(i) + begin(i));
+        }
+        return input(real_indices);
+      },
+      name, tag);
+}
+
+Array<te::Tensor> StridedSliceCompute(const Attrs& attrs, const 
Array<te::Tensor>& inputs,
+                                      const Type& out_type) {
+  te::Tensor data = inputs[0];
+  te::Tensor begin = inputs[1];
+  te::Tensor end = inputs[2];
+  te::Tensor strides = inputs[3];
+  // Dynamic computation
+  int64_t attr_size = data->shape.size();

Review comment:
       Does this stand for attribute size? If so, the name seems a bit 
inaccurate

##########
File path: src/relay/op/dyn/tensor/transform.cc
##########
@@ -430,6 +434,114 @@ RELAY_REGISTER_OP("dyn.full")
     .set_attr<FTVMCompute>("FTVMCompute", FullCompute)
     .set_attr<TOpPattern>("TOpPattern", kElemWise);
 
+bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                     const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 5);
+  const StridedSliceAttrs* param = attrs.as<StridedSliceAttrs>();
+  if (param == nullptr) {
+    return false;
+  }
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) {
+    return false;
+  }
+  auto dshape = data->shape;
+  int64_t num_axis = dshape.size();
+
+  // calculate output shape
+  std::vector<IndexExpr> oshape(num_axis);
+  for (int64_t i = 0; i < num_axis; ++i) {
+    oshape[i] = Any();
+  }
+
+  reporter->Assign(types[4], TensorType(oshape, data->dtype));
+  return true;
+}
+
+inline te::Tensor DynamicStridedSlice(const te::Tensor& input, const 
te::Tensor& begin,
+                                      const te::Tensor& end, const te::Tensor& 
strides,
+                                      std::string name = 
"T_strided_slice_dynamic",
+                                      std::string tag = topi::kInjective) {
+  int64_t src_tensor_dim = input->shape.size();
+  Array<IndexExpr> out_shape;
+  for (int64_t i = 0; i < src_tensor_dim; ++i) {
+    out_shape.push_back(tvm::tir::Var("dim"));
+  }
+  // TODO(yongwww): move the compute into topi
+  return te::compute(
+      out_shape,
+      [&](const Array<tvm::tir::Var>& indices) {
+        Array<IndexExpr> real_indices;
+        for (int32_t i = 0; i < src_tensor_dim; ++i) {
+          real_indices.push_back(indices[i] * strides(i) + begin(i));
+        }
+        return input(real_indices);
+      },
+      name, tag);
+}
+
+Array<te::Tensor> StridedSliceCompute(const Attrs& attrs, const 
Array<te::Tensor>& inputs,
+                                      const Type& out_type) {
+  te::Tensor data = inputs[0];
+  te::Tensor begin = inputs[1];
+  te::Tensor end = inputs[2];
+  te::Tensor strides = inputs[3];
+  // Dynamic computation
+  int64_t attr_size = data->shape.size();
+  CHECK(begin->shape[0].as<IntImmNode>()->value == attr_size &&
+        end->shape[0].as<IntImmNode>()->value == attr_size &&
+        strides->shape[0].as<IntImmNode>()->value == attr_size)
+      << "begin, end, and strides are required to have the same length"
+      << " if they are non-constant.";

Review comment:
       The wording of this error is a bit confusing; "begin, end, and strides 
are required to have the same length or must all be constants" might be better

##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -2069,12 +2070,9 @@ bool StridedSliceRel(const Array<Type>& types, int 
num_inputs, const Attrs& attr
       oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) 
/ step);
     }
   } else {
-    for (int64_t i = 0; i < num_axis; ++i) {
-      oshape[i] = Any();
-    }
+    CHECK(false) << "strided_slice recieved invalid params";

Review comment:
       You could state in this error that strided_slice received an incorrect 
beginning, end, or strides tensor. 

##########
File path: src/relay/transforms/dynamic_to_static.cc
##########
@@ -139,6 +139,24 @@ class DynamicToStaticMutator : public MixedModeMutator {
            }
            return Expr(nullptr);
          }},
+        {Op::Get("dyn.strided_slice"),
+         [](const CallNode* call_node) {
+           if (const ConstantNode* begin = 
call_node->args[1].as<ConstantNode>()) {
+             if (const ConstantNode* end = 
call_node->args[2].as<ConstantNode>()) {

Review comment:
       It would be cleaner to pull these definitions out of the if statements, 
and then check whether they are null or not in one if statement, though 
potentially slower




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to