sunggg commented on code in PR #14548:
URL: https://github.com/apache/tvm/pull/14548#discussion_r1162234152


##########
src/relax/op/tensor/index.cc:
##########
@@ -239,5 +239,78 @@ TVM_REGISTER_OP("relax.strided_slice")
     .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStridedSlice)
     .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
+/* relax.dynamic_strided_slice */
+Expr dynamic_strided_slice(Expr x,      //
+                           Expr begin,  //
+                           Expr end,    //
+                           Expr strides) {
+  static const Op& op = Op::Get("relax.dynamic_strided_slice");
+  return Call(op, {std::move(x), std::move(begin), std::move(end), 
std::move(strides)}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.dynamic_strided_slice").set_body_typed(dynamic_strided_slice);
+
+StructInfo InferStructInfoDynStridedSlice(const Call& call, const 
BlockBuilder& ctx) {
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* begin_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+  const auto* end_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
+  const auto* strides_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[3]);
+
+  ICHECK(data_sinfo);
+  if (data_sinfo->IsUnknownNdim()) {
+    LOG(WARNING) << "When data rank is unknown, dynamic strided slice assumes 
begin/end/stride "
+                    "tensors are well-formed. It could produce runtime error 
when this assumption "
+                    "turns out to be wrong.";
+    return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
+  }
+  if (data_sinfo->IsUnknownDtype()) {
+    LOG(WARNING) << "When data type is unknown, dynamic strided slice assumes 
to have a valid "
+                    "dtype. It could produce runtime error when this 
assumption "
+                    "turns out to be wrong.";
+  }
+
+  int n_axis = data_sinfo->ndim;
+  auto diag_def = [&](const TensorStructInfoNode* sinfo, String name) {
+    ICHECK(sinfo) << "Dynamic strided slice requires the input " << name
+                  << " to be have the struct info. Please try normalizing the 
inputs.";
+    CHECK_EQ(sinfo->ndim, 1) << "Dynamic strided slice requires " << name
+                             << " to be 1d tensor (list of values).";
+    const auto* shape = sinfo->shape.as<ShapeExprNode>();
+    ICHECK(shape) << "Dynamic strided slice requires the input " << name
+                  << " to have well-defined shape.";
+    // NOTE(tvm-team): This strong restriction seems necessary for now until 
we have a generic
+    // solution in converting 1d Tensor with unknown num_elem to 
Array<PrimExpr>.
+    const auto* num_elem = shape->values[0].as<IntImmNode>();
+    ICHECK(num_elem) << "Dynamic strided slice requires the input " << name
+                     << " to have a known integer shape value.";
+    CHECK_EQ(num_elem->value, n_axis) << "Dynamic strided slice requires the 
number of indices in "
+                                      << name << " to equal the number of 
axes.";
+    if (sinfo->IsUnknownDtype()) {
+      LOG(WARNING) << "Dynamic strided slice assumes " << name
+                   << " to be int64 when it is not specified.";
+    } else {
+      CHECK(sinfo->dtype == DataType::Int(64))
+          << "Dynamic strided_slice expects the input " << name
+          << "values to be all int64. However, " << name << " has dtype " << 
sinfo->dtype << ".";
+    }
+  };
+  diag_def(begin_sinfo, "begin");
+  diag_def(end_sinfo, "end");
+  diag_def(strides_sinfo, "stride");
+
+  // The output shape will depend on the runtime value in begin/end/stride 
tensors.
+  // TODO(tvm-team): Extract more compile-time info when those tensors are 
constants.
+  return TensorStructInfo(data_sinfo->dtype, n_axis);

Review Comment:
   Interesting. Is there any example code that I can try? 
   Also, if we can create the `te::Tensor` with symbolic variables, may I ask 
why we cannot extract those values? Can't we access the value by using the 
index in the above example? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to