lixiaoquan commented on a change in pull request #7835:
URL: https://github.com/apache/tvm/pull/7835#discussion_r612898547
##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -2159,6 +2159,74 @@ Array<te::Tensor> SqueezeCompute(const Attrs& attrs,
const Array<te::Tensor>& in
return {topi::squeeze(inputs[0], param->axis)};
}
+Array<Array<Layout>> SqueezeInferCorrectLayout(const Attrs& attrs,
+ const Array<Layout>&
new_in_layouts,
+ const Array<Layout>&
old_in_layouts,
+ const Array<tvm::relay::Type>&
old_in_types) {
+ // NOTE: Discard "const" qualifier here.
+ SqueezeAttrs* params = const_cast<SqueezeAttrs*>(attrs.as<SqueezeAttrs>());
+
+ Layout inferred_input = new_in_layouts.defined() ? new_in_layouts[0] :
old_in_layouts[0];
+ Layout inferred_output = inferred_input;
+
+ ICHECK(old_in_types[0].as<TensorTypeNode>());
+ const auto& shape = old_in_types[0].as<TensorTypeNode>()->shape;
+
+ // axis to squeeze
+ Array<Integer> axis;
+ if (params->axis.defined()) {
+ axis = params->axis;
+ } else {
+ // if axes is None, squeeze all axes of dimension 1
+ for (size_t i = 0; i < shape.size(); i++) {
+ if (topi::detail::GetConstInt(shape[i]) == 1) {
+ axis.push_back(i);
+ }
+ }
+ }
+
+ if (axis.size() == 0) {
Review comment:
I've removed this check since it can be handled in following logic, and
a case is added to cover `nothing to squeeze` case
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]