icemelon9 commented on a change in pull request #6289:
URL: https://github.com/apache/incubator-tvm/pull/6289#discussion_r471814828
##########
File path: include/tvm/topi/transform.h
##########
@@ -481,26 +481,29 @@ inline Tensor stack(const Array<Tensor>& inputs, int axis
= 0, std::string name
*
* \return A Tensor whose op member is the split operation
*/
-inline Array<Tensor> split(const Tensor& x, Array<Integer> split_indices, int
axis,
+inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int
axis,
std::string name = "T_split", std::string tag =
kInjective) {
if (axis < 0) {
axis += static_cast<int>(x->shape.size());
}
CHECK_LT(axis, x->shape.size()) << "axis out of bounds";
- auto src_axis_size = static_cast<int>(GetConstInt(x->shape[axis]));
- std::vector<int> begin_ids;
+ auto src_axis_size = x->shape[axis];
+ std::vector<PrimExpr> begin_ids;
begin_ids.push_back(0);
- for (Integer idx : split_indices) {
- int val = static_cast<int>(idx->value);
- CHECK_GT(val, begin_ids.back()) << "split_indices must be sorted";
Review comment:
Could you add this check to the shape function?
##########
File path: include/tvm/topi/transform.h
##########
@@ -668,15 +671,18 @@ inline Array<Tensor> split_sections(const Tensor& x, int
num_sections, int axis,
}
CHECK_LT(axis, x->shape.size()) << "axis out of bounds";
- auto src_axis_size = static_cast<int>(GetConstInt(x->shape[axis]));
+ auto src_axis_size = x->shape[axis];
CHECK_GT(num_sections, 0) << "Slice count must be > 0";
- CHECK_EQ(src_axis_size % num_sections, 0)
Review comment:
same here
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]