masahi commented on code in PR #14548:
URL: https://github.com/apache/tvm/pull/14548#discussion_r1162238558


##########
src/relax/op/tensor/index.cc:
##########
@@ -239,5 +239,78 @@ TVM_REGISTER_OP("relax.strided_slice")
     .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStridedSlice)
     .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
+/* relax.dynamic_strided_slice */
+Expr dynamic_strided_slice(Expr x,      //
+                           Expr begin,  //
+                           Expr end,    //
+                           Expr strides) {
+  static const Op& op = Op::Get("relax.dynamic_strided_slice");
+  return Call(op, {std::move(x), std::move(begin), std::move(end), 
std::move(strides)}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.dynamic_strided_slice").set_body_typed(dynamic_strided_slice);
+
+StructInfo InferStructInfoDynStridedSlice(const Call& call, const 
BlockBuilder& ctx) {
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* begin_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+  const auto* end_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
+  const auto* strides_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[3]);
+
+  ICHECK(data_sinfo);
+  if (data_sinfo->IsUnknownNdim()) {
+    LOG(WARNING) << "When data rank is unknown, dynamic strided slice assumes 
begin/end/stride "
+                    "tensors are well-formed. It could produce runtime error 
when this assumption "
+                    "turns out to be wrong.";
+    return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
+  }
+  if (data_sinfo->IsUnknownDtype()) {
+    LOG(WARNING) << "When data type is unknown, dynamic strided slice assumes 
to have a valid "
+                    "dtype. It could produce runtime error when this 
assumption "
+                    "turns out to be wrong.";
+  }
+
+  int n_axis = data_sinfo->ndim;
+  auto diag_def = [&](const TensorStructInfoNode* sinfo, String name) {
+    ICHECK(sinfo) << "Dynamic strided slice requires the input " << name
+                  << " to be have the struct info. Please try normalizing the 
inputs.";
+    CHECK_EQ(sinfo->ndim, 1) << "Dynamic strided slice requires " << name
+                             << " to be 1d tensor (list of values).";
+    const auto* shape = sinfo->shape.as<ShapeExprNode>();
+    ICHECK(shape) << "Dynamic strided slice requires the input " << name
+                  << " to have well-defined shape.";
+    // NOTE(tvm-team): This strong restriction seems necessary for now until 
we have a generic
+    // solution in converting 1d Tensor with unknown num_elem to 
Array<PrimExpr>.
+    const auto* num_elem = shape->values[0].as<IntImmNode>();
+    ICHECK(num_elem) << "Dynamic strided slice requires the input " << name
+                     << " to have a known integer shape value.";
+    CHECK_EQ(num_elem->value, n_axis) << "Dynamic strided slice requires the 
number of indices in "
+                                      << name << " to equal the number of 
axes.";
+    if (sinfo->IsUnknownDtype()) {
+      LOG(WARNING) << "Dynamic strided slice assumes " << name
+                   << " to be int64 when it is not specified.";
+    } else {
+      CHECK(sinfo->dtype == DataType::Int(64))
+          << "Dynamic strided_slice expects the input " << name
+          << "values to be all int64. However, " << name << " has dtype " << 
sinfo->dtype << ".";
+    }
+  };
+  diag_def(begin_sinfo, "begin");
+  diag_def(end_sinfo, "end");
+  diag_def(strides_sinfo, "stride");
+
+  // The output shape will depend on the runtime value in begin/end/stride 
tensors.
+  // TODO(tvm-team): Extract more compile-time info when those tensors are 
constants.
+  return TensorStructInfo(data_sinfo->dtype, n_axis);

Review Comment:
   Assuming "symbolic variable" you mentioned is just another `PrimExpr`, I 
expect we can fill in a tensor by that value, just like any other `PrimExpr`.  
Why not try `Can begin(ind) be symbolic in practice?` thing?
   
   > Can't we access the value by using the index in the above example?
   
   What I meant was, as soon as we put a symbolic expression in a TE tensor, we 
lost all symbolic-specific information. You can index it, but what you get is 
an opaque value. So we cannot exploit any symbolic information about it. 
   
   So output shape represented by TE Tensor loses symbolic or constant shape 
information. That's why I think it is better to directly generate shape func 
via TIR. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to