kevinthesun commented on a change in pull request #4312: [TOPI][Relay][OP] 
Dynamic NMS and strided_slice
URL: https://github.com/apache/incubator-tvm/pull/4312#discussion_r410751616
 
 

 ##########
 File path: src/relay/op/tensor/transform.cc
 ##########
 @@ -1775,105 +1776,165 @@ Array<Integer> GetIntArray(Array<IndexExpr> arr) {
   return Downcast<Array<Integer> >(arr);
 }
 
-
 // strided_slice
 TVM_REGISTER_NODE_TYPE(StridedSliceAttrs);
+
+int64_t* ToVector(const runtime::NDArray& array) {
+  size_t len = array.Shape().front();
+  int64_t* rel_vec = new int64_t[len];
+  if (array->dtype.code == kDLInt) {
+    if (array->dtype.bits == 8) {
+      int8_t* init_array = reinterpret_cast<int8_t*>(array->data);
+      for (size_t i = 0; i < len; ++i) {
+        rel_vec[i] = int64_t(init_array[i]);
+      }
+      return rel_vec;
+    } else if (array->dtype.bits == 16) {
+      int16_t* init_array = reinterpret_cast<int16_t*>(array->data);
+      for (size_t i = 0; i < len; ++i) {
+        rel_vec[i] = int64_t(init_array[i]);
+      }
+      return rel_vec;
+    } else if (array->dtype.bits == 32) {
+      int32_t* init_array = reinterpret_cast<int32_t*>(array->data);
+      for (size_t i = 0; i < len; ++i) {
+        rel_vec[i] = int64_t(init_array[i]);
+      }
+      return rel_vec;
+    } else if (array->dtype.bits == 64) {
+      int64_t* init_array = reinterpret_cast<int64_t*>(array->data);
+      for (size_t i = 0; i < len; ++i) {
+        rel_vec[i] = int64_t(init_array[i]);
+      }
+      return rel_vec;
+    }
+  } else if (array->dtype.code == kDLUInt) {
+    if (array->dtype.bits == 8) {
+      uint8_t* init_array = reinterpret_cast<uint8_t*>(array->data);
+      for (size_t i = 0; i < len; ++i) {
+        rel_vec[i] = int64_t(init_array[i]);
+      }
+      return rel_vec;
+    } else if (array->dtype.bits == 16) {
+      uint16_t* init_array = reinterpret_cast<uint16_t*>(array->data);
+      for (size_t i = 0; i < len; ++i) {
+        rel_vec[i] = int64_t(init_array[i]);
+      }
+      return rel_vec;
+    } else if (array->dtype.bits == 32) {
+      uint32_t* init_array = reinterpret_cast<uint32_t*>(array->data);
+      for (size_t i = 0; i < len; ++i) {
+        rel_vec[i] = int64_t(init_array[i]);
+      }
+      return rel_vec;
+    } else if (array->dtype.bits == 64) {
+      uint64_t* init_array = reinterpret_cast<uint64_t*>(array->data);
+      for (size_t i = 0; i < len; ++i) {
+        rel_vec[i] = int64_t(init_array[i]);
+      }
+      return rel_vec;
+    }
+  }
+  LOG(FATAL) << "Unknown data type: " << 
tvm::runtime::DLDataType2String(array->dtype);
+  return rel_vec;
+}
+
 bool StridedSliceRel(const Array<Type>& types,
                      int num_inputs,
                      const Attrs& attrs,
                      const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 2);
-  const auto* data = types[0].as<TensorTypeNode>();
-  if (data == nullptr) return false;
-
-  const StridedSliceAttrs *param = attrs.as<StridedSliceAttrs>();
+  CHECK_EQ(types.size(), 5);
+  const StridedSliceAttrs* param = attrs.as<StridedSliceAttrs>();
   CHECK(param != nullptr);
-
+  const auto* data = types[0].as<TensorTypeNode>();
+  CHECK(data != nullptr);
   auto dshape = data->shape;
-  auto num_axis = dshape.size();
-
-  std::vector<int64_t> stride_vec;
-  for (Integer i : param->strides) {
-    CHECK(i.defined());
-    stride_vec.push_back(i->value);
-  }
-  for (size_t i = stride_vec.size(); i < num_axis; ++i) {
-    stride_vec.push_back(1);
-  }
-  const int64_t max_range = std::numeric_limits<int64_t>::max();
-
-  std::vector<int64_t> begin_vec;
-  for (size_t i = 0; i < param->begin.size(); ++i) {
-    if (!param->begin[i].defined()) {
-      // value=None
+  int64_t num_axis = dshape.size();
+
+  // calculate output shape
+  std::vector<IndexExpr> oshape(num_axis);
+  const ConstantNode *cbegin, *cend, *cstrides;
+  if ((cbegin = param->begin.as<ConstantNode>()) &&
+      (cend = param->end.as<ConstantNode>()) &&
+      (cstrides = param->strides.as<ConstantNode>())) {
+    std::vector<int64_t> stride_vec;
+    int64_t* strides_val = ToVector(cstrides->data);
+    for (int64_t i = 0; i < cstrides->data.Shape().front(); ++i) {
+      stride_vec.push_back(strides_val[i]);
+    }
+    for (int64_t i = stride_vec.size(); i < num_axis; ++i) {
+      stride_vec.push_back(1);
+    }
+    const int64_t max_range = std::numeric_limits<int64_t>::max();
+    std::vector<int64_t> begin_vec;
+    int64_t* begin_val = ToVector(cbegin->data);
+    for (int64_t i = 0; i < cbegin->data.Shape().front(); ++i) {
+      begin_vec.push_back(begin_val[i]);
+    }
+    for (int64_t i = begin_vec.size(); i < num_axis; ++i) {
       begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
-    } else {
-      begin_vec.push_back(param->begin[i]->value);
     }
-  }
-  for (size_t i = begin_vec.size(); i < num_axis; ++i) {
-    begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
-  }
-
-  std::vector<int64_t> end_vec;
-  for (size_t i = 0; i < param->end.size(); ++i) {
-    // allow end to be None
-    if (!param->end[i].defined()) {
+    std::vector<int64_t> end_vec;
+    int64_t* end_val = ToVector(cend->data);
+    for (int64_t i = 0; i < cend->data.Shape().front(); ++i) {
+      end_vec.push_back(end_val[i]);
+    }
+    for (int64_t i = end_vec.size(); i < num_axis; ++i) {
       end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
-    } else {
-      end_vec.push_back(param->end[i]->value);
     }
-  }
-  for (size_t i = end_vec.size(); i < num_axis; ++i) {
-    end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
-  }
-
-  std::vector<IndexExpr> oshape(dshape.size());
-  for (size_t i = 0; i < num_axis; ++i) {
-    int64_t stride_v = stride_vec[i];
-    int64_t begin_v = begin_vec[i];
-    int64_t end_v = end_vec[i];
-
-    if ((stride_v == 1 &&
-         begin_v == 0 &&
-         end_v == max_range) ||
-        (stride_v == -1 &&
-         begin_v == max_range &&
-         end_v == 0)) {
-      // Quick path, do not slice this dimension.
-      oshape[i] = dshape[i];
-      continue;
+
+    for (int64_t i = 0; i < num_axis; ++i) {
+      int64_t stride_v = stride_vec[i];
+      int64_t begin_v = begin_vec[i];
+      int64_t end_v = end_vec[i];
+
+      if ((stride_v == 1 &&
+           begin_v == 0 &&
+           end_v == max_range) ||
+          (stride_v == -1 &&
+           begin_v == max_range &&
+           end_v == 0)) {
+        // Quick path, do not slice this dimension.
+        oshape[i] = dshape[i];
+        continue;
+      }
+      // Normal path, require the shape to be concrete integer.
+      // Require concrete integer as symbolic inference of min/max
+      // can get complicated and not very helpful.
+      const int64_t* p_dim_size = tir::as_const_int(dshape[i]);
+      if (!p_dim_size) {
+        oshape[i] = dshape[i];
+        continue;
+      }
+      int64_t dim_size = p_dim_size[0];
+      begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v;
+      end_v = (end_v < 0) ? dim_size + end_v : end_v;
+
+      int64_t slice_range, step;
+      if (stride_v < 0) {
+        if (end_v < -1) end_v = -1;
+        CHECK_LT(end_v, begin_v)
 
 Review comment:
   ```suggestion
           CHECK_LT(end_v, begin_v) -> CHECK_LE(end_v, begin_v)
   ```
   Since we should allow empty slice.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to