kparzysz-quic commented on code in PR #15112:
URL: https://github.com/apache/tvm/pull/15112#discussion_r1232195776
##########
src/relax/transform/rewrite_dataflow_reshape.cc:
##########
@@ -89,8 +95,32 @@ class DataflowReshapeRewriter : public ExprMutator {
auto arg = arg_tuple[used_arg_indices[0]];
- TensorStructInfo res_sinfo =
Downcast<TensorStructInfo>(call->struct_info_);
- ICHECK(res_sinfo->shape.defined());
+ // The reshape operator expects that the number of elements in the source
is the same
+ // as the number of elements in the result. There are operators that could
have a reshape
+ // pattern that don't meet this requirement (e.g. strided_slice), and they
should not be
+ // converted to reshape.
+ ICHECK(arg->struct_info_.defined() && call->struct_info_.defined());
+ TensorStructInfo arg_sinfo =
Downcast<TensorStructInfo>(arg->struct_info_.value());
+ TensorStructInfo res_sinfo =
Downcast<TensorStructInfo>(call->struct_info_.value());
+
+ if (arg_sinfo->IsUnknownDtype() || arg_sinfo->dtype != res_sinfo->dtype) {
+ return Unchanged;
+ }
+ ICHECK(arg_sinfo->shape.defined() && res_sinfo->shape.defined());
+ if (arg_sinfo->IsUnknownNdim() || res_sinfo->IsUnknownNdim()) {
+ return Unchanged;
+ }
+ auto product = [](Array<PrimExpr> args) -> PrimExpr {
+ ICHECK(!args.empty());
+ return std::reduce(args.begin(), args.end(), PrimExpr(1),
+ [](auto a, auto b) { return a * b; });
+ };
+ auto arg_count = product(arg_sinfo->GetShape().value());
+ auto res_count = product(res_sinfo->GetShape().value());
+ if (!arith::Analyzer().CanProveEqual(arg_count, res_count)) {
+ return Unchanged;
+ }
Review Comment:
This code needs the input arg, so I'll move the call to
`IsCallingTIRReshape` until the arg is known. I hope that's ok.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]