comaniac commented on a change in pull request #7214:
URL: https://github.com/apache/tvm/pull/7214#discussion_r552931908
##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -419,6 +419,77 @@ bool TransposeRel(const Array<Type>& types, int
num_inputs, const Attrs& attrs,
return true;
}
+Array<Array<Layout>> TransposeInferCorrectLayout(const Attrs& attrs,
+ const Array<Layout>&
new_in_layouts,
+ const Array<Layout>&
old_in_layouts,
+ const
Array<tvm::relay::Type>& old_in_types) {
+ // Discard "const" qualifier.
+ auto* params = const_cast<TransposeAttrs*>(attrs.as<TransposeAttrs>());
+ ICHECK(params != nullptr);
+
+ std::string in_layout_str = "";
+ std::string out_layout_str = "";
+
+ // Infer the input layout string and update the axes.
+ if (old_in_layouts.defined()) {
+ ICHECK_EQ(old_in_layouts.size(), 1);
+ auto old_layout = old_in_layouts[0];
+
+ // Deal with default axes.
+ if (!params->axes.defined() || params->axes.size() == 0) {
+ Array<Integer> axes = Array<Integer>();
+ for (int i = old_layout.ndim() - 1; i >= 0; --i) {
+ axes.push_back(i);
+ }
+ params->axes = std::move(axes);
+ }
+
+ if (new_in_layouts.defined()) {
+ ICHECK_EQ(new_in_layouts.size(), 1);
+ auto new_layout = new_in_layouts[0];
+
+ // Update the axes based on the new layout.
+ Array<Integer> new_axes;
+ if (new_layout.ndim() == old_layout.ndim()) {
+ // Make sure old and new layouts have consistent dimensions.
+ // For example, transpose does not support NCHW8c so it cannot be the
new layout.
Review comment:
```suggestion
// For example, transpose does not support NCHW[x]c so it cannot be
the new layout.
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]