swjng opened a new pull request, #19511:
URL: https://github.com/apache/tvm/pull/19511

   ## Motivation
   
   PyTorch's `torch.flip(x, dims=[...])` reverses every listed axis. The
   Relax converter `_flip` (`base_fx_graph_translator.py`) instead coerces
   the list to a single integer:
   
   ```python
   if isinstance(dims, list | tuple) and len(dims) > 0:
       dims = dims[0]
   ```
   
   Only the first axis is forwarded to `relax.op.flip`, which is itself
   single-axis. The remaining axes are silently dropped.
   
   Minimal repro (vs PyTorch eager) on a `(3, 4)` input with
   `dims=[-1, -2]`:
   
   ```
   ref: [11, 10,  9,  8,  7,  6,  5,  4, ...]   # both axes flipped
   tvm: [ 3,  2,  1,  0,  7,  6,  5,  4, ...]   # only last axis flipped
   ```
   
   max_abs_diff = 8.0. Both the `torch.export` and legacy fx paths share
   this converter, so both are affected.
   
   ## Fix
   
   Iterate over `dims` in the converter and emit one `relax.op.flip` per
   axis (flips along distinct axes commute, so the order is irrelevant).
   A scalar `dims` is wrapped to a single-element list; non-int /
   non-sequence arguments still raise `TypeError`.
   
   `relax.op.flip` itself is unchanged: it is used elsewhere as a
   single-axis op, and widening its signature would expand the scope of
   this fix beyond the PyTorch frontend.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to