| Issue |
172981
|
| Summary |
[mlir] FoldTensorCastOfOutputIntoForallOp: Reverses scf.forall return itarg write order (semantic mismatch after canonicalization)
|
| Labels |
mlir
|
| Assignees |
|
| Reporter |
hesse-x
|
## Summary
The `FoldTensorCastOfOutputIntoForallOp` canonicalization pattern incorrectly reverses the write order of `scf.forall` return itargs (shared outputs), leading to a **semantic mismatch** between the IR before and after canonicalization. This breaks the core invariant of canonicalization (semantic equivalence) for `scf.forall` ops with multiple shared outputs.
## Steps to Reproduce
### Original IR (before canonicalization)
```mlir
func.func @forall_test(%arg0: tensor<8x32xf32>, %arg1: tensor<8x32xf32>) -> (tensor<?x32xf32>, tensor<?x32xf32>) {
%c8 = arith.constant 8 : index
%c32 = arith.constant 32 : index
%init = tensor.empty(%c32) : tensor<?x32xf32>
%0:2 = scf.forall (%tidx) in (4) shared_outs(%arg2 = %init, %arg3 = %init) -> (tensor<?x32xf32>, tensor<?x32xf32>) {
%pos = arith.muli %c8, %tidx : index
scf.forall.in_parallel {
// Write %arg0 to %arg3 (second shared output)
tensor.parallel_insert_slice %arg0 into %arg3[%pos, 0] [8, 32] [1, 1] : tensor<8x32xf32> into tensor<?x32xf32>
// Write %arg1 to %arg2 (first shared output)
tensor.parallel_insert_slice %arg1 into %arg2[%pos, 0] [8, 32] [1, 1] : tensor<8x32xf32> into tensor<?x32xf32>
}
}
// Return %0#0 (arg2: %arg1 data) and %0#1 (arg3: %arg0 data)
return %0#0, %0#1 : tensor<?x32xf32>, tensor<?x32xf32>
}
```
### IR after canonicalize pass
```mlir
func.func @forall_test(%arg0: tensor<8x32xf32>, %arg1: tensor<8x32xf32>) -> (tensor<?x32xf32>, tensor<?x32xf32>) {
%c8 = arith.constant 8 : index
%0 = tensor.empty() : tensor<32x32xf32>
%1:2 = scf.forall (%arg2) in (4) shared_outs(%arg3 = %0, %arg4 = %0) -> (tensor<32x32xf32>, tensor<32x32xf32>) {
%2 = arith.muli %arg2, %c8 : index
scf.forall.in_parallel {
// Writing %arg0 to %arg3 (first shared output now)
tensor.parallel_insert_slice %arg0 into %arg3[%2, 0] [8, 32] [1, 1] : tensor<8x32xf32> into tensor<32x32xf32>
// Writing %arg1 to %arg4 (second shared output now)
tensor.parallel_insert_slice %arg1 into %arg4[%2, 0] [8, 32] [1, 1] : tensor<8x32xf32> into tensor<32x32xf32>
}
}
// Return %1#0 (arg3: %arg0 data) and %1#1 (arg4: %arg1 data) — ORDER IS REVERSED!
%cast = tensor.cast %1#0 : tensor<32x32xf32> to tensor<?x32xf32>
%cast_0 = tensor.cast %1#1 : tensor<32x32xf32> to tensor<?x32xf32>
return %cast, %cast_0 : tensor<?x32xf32>, tensor<?x32xf32>
}
```
## Expected Behavior
Canonicalization should preserve semantic equivalence:
- The first return value of scf.forall should still map to %arg2 (data from %arg1), and the second to %arg3 (data from %arg0).
- The final returned tensors should match the original IR’s logic (no reversal of %arg0/%arg1 data).
_______________________________________________
llvm-bugs mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs