Issue |
135389
|
Summary |
[mlir][tosa] tosa-to-tensor pass for tosa.slice fails with known output shape but dynamic size attribute
|
Labels |
mlir
|
Assignees |
|
Reporter |
sahas3
|
Consider the input IR:
```
func.func @slice(%arg0 : tensor<2x60x59x5xf32>) -> tensor<?x60x58x5xf32> {
%0 = "tosa.cast"(%arg0) : (tensor<2x60x59x5xf32>) -> tensor<?x60x59x5xf32>
%1 = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
%2 = tosa.const_shape {values = dense<[-1, 60, 58, 5]> : tensor<4xindex>} : () -> !tosa.shape<4>
%3 = tosa.slice %0, %1, %2 : (tensor<?x60x59x5xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x60x58x5xf32>
return %3 : tensor<?x60x58x5xf32>
}
```
running `mlir-opt --tosa-to-tensor` succeeds:
```
func.func @slice(%arg0: tensor<2x60x59x5xf32>) -> tensor<?x60x58x5xf32> {
%0 = tosa.cast %arg0 : (tensor<2x60x59x5xf32>) -> tensor<?x60x59x5xf32>
%c0 = arith.constant 0 : index
%dim = tensor.dim %0, %c0 : tensor<?x60x59x5xf32>
%c0_0 = arith.constant 0 : index
%1 = arith.subi %dim, %c0_0 : index
%extracted_slice = tensor.extract_slice %0[0, 0, 0, 0] [%1, 60, 58, 5] [1, 1, 1, 1] : tensor<?x60x59x5xf32> to tensor<?x60x58x5xf32>
return %extracted_slice : tensor<?x60x58x5xf32>
}
```
but running `mlir-opt --tosa-infer-shapes --tosa-to-tensor` fails:
```
// -----// IR Dump After TosaInferShapesPass (tosa-infer-shapes) //----- //
func.func @slice(%arg0: tensor<2x60x59x5xf32>) -> tensor<?x60x58x5xf32> {
%0 = tosa.cast %arg0 : (tensor<2x60x59x5xf32>) -> tensor<2x60x59x5xf32>
%1 = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
%2 = tosa.const_shape {values = dense<[-1, 60, 58, 5]> : tensor<4xindex>} : () -> !tosa.shape<4>
%3 = tosa.slice %0, %1, %2 : (tensor<2x60x59x5xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x60x58x5xf32>
%cast = tensor.cast %3 : tensor<2x60x58x5xf32> to tensor<?x60x58x5xf32>
return %cast : tensor<?x60x58x5xf32>
}
error: expected type to be 'tensor<?x60x58x5xf32>' or a rank-reduced version. (size mismatch)
%3 = tosa.slice %0, %1, %2 : (tensor<?x60x59x5xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x60x58x5xf32>
^
note: see current operation: %4 = "tensor.extract_slice"(%arg0, %3) <{operandSegmentSizes = array<i32: 1, 0, 1, 0>, static_offsets = array<i64: 0, 0, 0, 0>, static_sizes = array<i64: -9223372036854775808, 60, 58, 5>, static_strides = array<i64: 1, 1, 1, 1>}> : (tensor<2x60x59x5xf32>, index) -> tensor<2x60x58x5xf32>
// -----// IR Dump After TosaToTensorPass Failed (tosa-to-tensor) //----- //
"builtin.module"() ({
"func.func"() <{function_type = (tensor<2x60x59x5xf32>) -> tensor<?x60x58x5xf32>, sym_name = "slice"}> ({
^bb0(%arg0: tensor<2x60x59x5xf32>):
%0 = "arith.constant"() <{value = 0 : index}> : () -> index
%1 = "tensor.dim"(%arg0, %0) : (tensor<2x60x59x5xf32>, index) -> index
%2 = "arith.constant"() <{value = 0 : index}> : () -> index
%3 = "arith.subi"(%1, %2) <{overflowFlags = #arith.overflow<none>}> : (index, index) -> index
%4 = "tensor.extract_slice"(%arg0, %3) <{operandSegmentSizes = array<i32: 1, 0, 1, 0>, static_offsets = array<i64: 0, 0, 0, 0>, static_sizes = array<i64: -9223372036854775808, 60, 58, 5>, static_strides = array<i64: 1, 1, 1, 1>}> : (tensor<2x60x59x5xf32>, index) -> tensor<2x60x58x5xf32>
%5 = "tensor.cast"(%4) : (tensor<2x60x58x5xf32>) -> tensor<?x60x58x5xf32>
"func.return"(%5) : (tensor<?x60x58x5xf32>) -> ()
}) : () -> ()
}) : () -> ()
```
I think `--tosa-infer-shapes` does the correct thing by updating the dynamic dim value in the input and output to `2` but the `size` attribute of `tosa.slice` op isn't updated which is correct too as `tosa-infer-shapes` only updates IO shapes as per my understanding.
I think there are a couple of fixes:
1. Enhance `tosa-to-tensor` to not consider the dynamic size attribute if output shape is known to be static when creating the `tensor.extract_slice` op.
2. Alternatively, as part of SliceOp canonicalization, update the size attribute to not have `-1` when output shape is known.
Any suggestions @sjarus , @eric-k256, @Tai78641 ?
Thanks!
_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs