swjng opened a new pull request, #19511:
URL: https://github.com/apache/tvm/pull/19511
## Motivation
PyTorch's `torch.flip(x, dims=[...])` reverses every listed axis. The
Relax converter `_flip` (`base_fx_graph_translator.py`) instead coerces
the list to a single integer:
```python
if isinstance(dims, list | tuple) and len(dims) > 0:
dims = dims[0]
```
Only the first axis is forwarded to `relax.op.flip`, which is itself
single-axis. The remaining axes are silently dropped.
Minimal repro (vs PyTorch eager) on a `(3, 4)` input with
`dims=[-1, -2]`:
```
ref: [11, 10, 9, 8, 7, 6, 5, 4, ...] # both axes flipped
tvm: [ 3, 2, 1, 0, 7, 6, 5, 4, ...] # only last axis flipped
```
max_abs_diff = 8.0. Both the `torch.export` and legacy fx paths share
this converter, so both are affected.
## Fix
Iterate over `dims` in the converter and emit one `relax.op.flip` per
axis (flips along distinct axes commute, so the order is irrelevant).
A scalar `dims` is wrapped to a single-element list; non-int /
non-sequence arguments still raise `TypeError`.
`relax.op.flip` itself is unchanged: it is used elsewhere as a
single-axis op, and widening its signature would expand the scope of
this fix beyond the PyTorch frontend.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]