sunggg commented on code in PR #14548:
URL: https://github.com/apache/tvm/pull/14548#discussion_r1162233085


##########
include/tvm/topi/transform.h:
##########
@@ -2035,6 +2034,73 @@ inline Tensor adv_index(const Tensor& data, const 
Array<Tensor>& indices,
       name, tag);
 }
 
+namespace relax {
+// relax dynamic slice
+inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& 
begin,
+                                        const te::Tensor& end, const 
te::Tensor& strides,
+                                        Array<PrimExpr> output_shape,
+                                        std::string name = 
"T_strided_slice_dynamic",
+                                        std::string tag = kInjective) {
+  const size_t num_dynamic_axes = x.ndim();
+  ICHECK_EQ(begin.ndim(), 1);
+  ICHECK_EQ(end.ndim(), 1);
+  ICHECK_EQ(strides.ndim(), 1);
+  const auto* len_begin = begin->shape[0].as<IntImmNode>();
+  const auto* len_end = end->shape[0].as<IntImmNode>();
+  const auto* len_strides = strides->shape[0].as<IntImmNode>();
+  ICHECK(len_begin);
+  ICHECK(len_end);
+  ICHECK(len_strides);
+  ICHECK_EQ(len_begin->value, num_dynamic_axes);
+  ICHECK_EQ(len_end->value, num_dynamic_axes);
+  ICHECK_EQ(len_strides->value, num_dynamic_axes);
+
+  return te::compute(
+      output_shape,
+      [&](const Array<tvm::tir::Var>& indices) {
+        Array<PrimExpr> real_indices;
+        for (size_t i = 0; i < num_dynamic_axes; ++i) {
+          auto ind = make_const(DataType::Int(64), i);
+          real_indices.push_back(indices[i] * strides(ind) + 
tvm::min(begin(ind), x->shape[i] - 1));
+        }
+        return x(real_indices);
+      },
+      name, tag);
+}
+
+inline te::Tensor shape_func_dynamic_strided_slice(
+    const te::Tensor& data, const te::Tensor& begin, const te::Tensor& end,
+    const te::Tensor& strides, std::string name = 
"T_shape_func_strided_slice_dynamic") {
+  return te::compute(
+      {begin->shape[0]},
+      [&](const Array<tvm::tir::Var>& indices) {
+        ICHECK(indices.size() == 1);
+        auto CanonicalizeIndex = [&](PrimExpr index, PrimExpr extent, PrimExpr 
stride) {
+          PrimExpr begin_range = if_then_else(stride < 0, -1, 0);
+          PrimExpr end_range = if_then_else(stride < 0, extent - 1, extent);
+          index = if_then_else(index < 0, index + extent, index);
+          return min(max(index, begin_range), end_range);
+        };
+
+        auto GetLength = [&](PrimExpr begin, PrimExpr end, PrimExpr stride, 
PrimExpr length) {
+          begin = CanonicalizeIndex(begin, length, stride);
+          end = CanonicalizeIndex(end, length, stride);
+          PrimExpr len1 = ceildiv(begin - end, -stride);
+          PrimExpr len2 = ceildiv(end - begin, stride);
+          return if_then_else(stride < 0, len1, len2);
+        };
+        PrimExpr length(-1);
+        int ndim = data.ndim();
+        for (int i = 0; i < ndim; i++) {
+          length = if_then_else(indices[0] == i, data->shape[i], length);
+        }
+        return GetLength(begin(indices), end(indices), strides(indices), 
length);

Review Comment:
   The one of the main goals of https://github.com/apache/tvm/pull/14278 is to 
automate the current repetitive & tedious op registration process including 
legalization, struct_info, etc. Unlike current flow which requires us to look 
at multiple different sites and do the manual job, 
https://github.com/apache/tvm/pull/14278 will allow us to look at the single 
place to put all those information so that the parser will handle the rest. I 
believe legalization function should be there, too. 
   
   > I think the question is whether or not we consider shape funcs a part of 
"op definition". 
   
   You brought up the very good points. I see that shape func is a part of op 
definition just like legalization function for each operator. When 
https://github.com/apache/tvm/pull/14278 lands, shape function can be 
registered together only if necessary so that we don't complicate the op 
registration unnecessarily. 
   
   >  but the current discussion in this PR suggests shape funcs need not be 
implemented under topi.
   
   I forgot to mention earlier, but https://github.com/apache/tvm/pull/14278 
deduce output shape of operator by using the existing shape computation logic 
in TOPI to automatically generate the struct_info logics. Not sure if this is 
the hard requirement tho. @junrushao, would it be okay to put the shape 
computation in elsewhere, for example legalizer? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to