electriclilies commented on a change in pull request #6316:
URL: https://github.com/apache/incubator-tvm/pull/6316#discussion_r474312010
##########
File path: python/tvm/relay/op/transform.py
##########
@@ -827,13 +828,17 @@ def strided_slice(data, begin, end, strides=None,
slice_mode="end"):
ret : relay.Expr
The computed result.
"""
- strides = strides or const([1], dtype="int32")
- if isinstance(begin, (tuple, list)):
- begin = const(list(begin))
- if isinstance(end, (tuple, list)):
- end = const(list(end))
- if isinstance(strides, (tuple, list)):
- strides = const(list(strides))
+ strides = strides or [1]
+ if (isinstance(begin, Expr) or isinstance(end, Expr) or
isinstance(strides, Expr)):
+ if isinstance(begin, (tuple, list)):
+ begin = const(list(begin))
+ if isinstance(end, (tuple, list)):
+ end = const(list(end))
+ if isinstance(strides, (tuple, list)):
+ strides = const(list(strides))
+ begin = _make.where(begin < cast_like(const(0), begin),
Review comment:
Can you rename this begin for clarity?
##########
File path: python/tvm/relay/op/dyn/_transform.py
##########
@@ -145,3 +146,53 @@ def one_hot_shape_func(attrs, inputs, _):
"""
axis = len(inputs[0].shape) if attrs.axis == -1 else attrs.axis
return [_onehot_shape_func(inputs[0].shape, inputs[3], convert(axis))]
+
+
+@script
+def _strided_slice_shape_func_input_data(data, begin, end, strides,
Review comment:
What's the difference between `_strided_slice_shape_func_input_shape`
and `_strided_slice_shape_func_input_data`?
##########
File path: tests/python/relay/test_op_level4.py
##########
@@ -343,7 +337,7 @@ def verify(dshape, begin, end, strides, output,
slice_mode="end",
text = func.astext()
assert "begin=" in text
assert "end=" in text
-
+
Review comment:
white space!!
##########
File path: src/relay/op/dyn/tensor/transform.cc
##########
@@ -430,6 +434,114 @@ RELAY_REGISTER_OP("dyn.full")
.set_attr<FTVMCompute>("FTVMCompute", FullCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise);
+bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs&
attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(types.size(), 5);
Review comment:
It would be nice to add a comment saying what each of the input types
are (ie types = [type1_description, ... ret_type]
##########
File path: src/relay/op/dyn/tensor/transform.cc
##########
@@ -430,6 +434,114 @@ RELAY_REGISTER_OP("dyn.full")
.set_attr<FTVMCompute>("FTVMCompute", FullCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise);
+bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs&
attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(types.size(), 5);
+ const StridedSliceAttrs* param = attrs.as<StridedSliceAttrs>();
+ if (param == nullptr) {
+ return false;
+ }
+ const auto* data = types[0].as<TensorTypeNode>();
+ if (data == nullptr) {
+ return false;
+ }
+ auto dshape = data->shape;
+ int64_t num_axis = dshape.size();
+
+ // calculate output shape
+ std::vector<IndexExpr> oshape(num_axis);
+ for (int64_t i = 0; i < num_axis; ++i) {
+ oshape[i] = Any();
+ }
+
+ reporter->Assign(types[4], TensorType(oshape, data->dtype));
+ return true;
+}
+
+inline te::Tensor DynamicStridedSlice(const te::Tensor& input, const
te::Tensor& begin,
+ const te::Tensor& end, const te::Tensor&
strides,
+ std::string name =
"T_strided_slice_dynamic",
+ std::string tag = topi::kInjective) {
+ int64_t src_tensor_dim = input->shape.size();
+ Array<IndexExpr> out_shape;
+ for (int64_t i = 0; i < src_tensor_dim; ++i) {
+ out_shape.push_back(tvm::tir::Var("dim"));
+ }
+ // TODO(yongwww): move the compute into topi
+ return te::compute(
+ out_shape,
+ [&](const Array<tvm::tir::Var>& indices) {
+ Array<IndexExpr> real_indices;
+ for (int32_t i = 0; i < src_tensor_dim; ++i) {
+ real_indices.push_back(indices[i] * strides(i) + begin(i));
+ }
+ return input(real_indices);
+ },
+ name, tag);
+}
+
+Array<te::Tensor> StridedSliceCompute(const Attrs& attrs, const
Array<te::Tensor>& inputs,
+ const Type& out_type) {
+ te::Tensor data = inputs[0];
+ te::Tensor begin = inputs[1];
+ te::Tensor end = inputs[2];
+ te::Tensor strides = inputs[3];
+ // Dynamic computation
+ int64_t attr_size = data->shape.size();
Review comment:
Does this stand for attribute size? If so, the name seems a bit
inaccurate
##########
File path: src/relay/op/dyn/tensor/transform.cc
##########
@@ -430,6 +434,114 @@ RELAY_REGISTER_OP("dyn.full")
.set_attr<FTVMCompute>("FTVMCompute", FullCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise);
+bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs&
attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(types.size(), 5);
+ const StridedSliceAttrs* param = attrs.as<StridedSliceAttrs>();
+ if (param == nullptr) {
+ return false;
+ }
+ const auto* data = types[0].as<TensorTypeNode>();
+ if (data == nullptr) {
+ return false;
+ }
+ auto dshape = data->shape;
+ int64_t num_axis = dshape.size();
+
+ // calculate output shape
+ std::vector<IndexExpr> oshape(num_axis);
+ for (int64_t i = 0; i < num_axis; ++i) {
+ oshape[i] = Any();
+ }
+
+ reporter->Assign(types[4], TensorType(oshape, data->dtype));
+ return true;
+}
+
+inline te::Tensor DynamicStridedSlice(const te::Tensor& input, const
te::Tensor& begin,
+ const te::Tensor& end, const te::Tensor&
strides,
+ std::string name =
"T_strided_slice_dynamic",
+ std::string tag = topi::kInjective) {
+ int64_t src_tensor_dim = input->shape.size();
+ Array<IndexExpr> out_shape;
+ for (int64_t i = 0; i < src_tensor_dim; ++i) {
+ out_shape.push_back(tvm::tir::Var("dim"));
+ }
+ // TODO(yongwww): move the compute into topi
+ return te::compute(
+ out_shape,
+ [&](const Array<tvm::tir::Var>& indices) {
+ Array<IndexExpr> real_indices;
+ for (int32_t i = 0; i < src_tensor_dim; ++i) {
+ real_indices.push_back(indices[i] * strides(i) + begin(i));
+ }
+ return input(real_indices);
+ },
+ name, tag);
+}
+
+Array<te::Tensor> StridedSliceCompute(const Attrs& attrs, const
Array<te::Tensor>& inputs,
+ const Type& out_type) {
+ te::Tensor data = inputs[0];
+ te::Tensor begin = inputs[1];
+ te::Tensor end = inputs[2];
+ te::Tensor strides = inputs[3];
+ // Dynamic computation
+ int64_t attr_size = data->shape.size();
+ CHECK(begin->shape[0].as<IntImmNode>()->value == attr_size &&
+ end->shape[0].as<IntImmNode>()->value == attr_size &&
+ strides->shape[0].as<IntImmNode>()->value == attr_size)
+ << "begin, end, and strides are required to have the same length"
+ << " if they are non-constant.";
Review comment:
The wording of this error is a bit confusing; "begin, end, and strides
are required to have the same length or must all be constants" might be better
##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -2069,12 +2070,9 @@ bool StridedSliceRel(const Array<Type>& types, int
num_inputs, const Attrs& attr
oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1)
/ step);
}
} else {
- for (int64_t i = 0; i < num_axis; ++i) {
- oshape[i] = Any();
- }
+ CHECK(false) << "strided_slice recieved invalid params";
Review comment:
You could state in this error that strided_slice received an incorrect
beginning, end, or strides tensor.
##########
File path: src/relay/transforms/dynamic_to_static.cc
##########
@@ -139,6 +139,24 @@ class DynamicToStaticMutator : public MixedModeMutator {
}
return Expr(nullptr);
}},
+ {Op::Get("dyn.strided_slice"),
+ [](const CallNode* call_node) {
+ if (const ConstantNode* begin =
call_node->args[1].as<ConstantNode>()) {
+ if (const ConstantNode* end =
call_node->args[2].as<ConstantNode>()) {
Review comment:
It would be cleaner to pull these definitions out of the if statements,
and then check whether they are null or not in one if statement, though
potentially slower
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]