| Issue |
177593
|
| Summary |
[mlir][linalg] bitcast is not preserved during matmul specialisation from generic
|
| Labels |
mlir
|
| Assignees |
|
| Reporter |
meshtag
|
```
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
func.func @op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
%B: tensor<8x32xi32>,
%Out: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.generic
{indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
ins(%A, %B : tensor<16x8xi32>, tensor<8x32xi32>) outs(%Out : tensor<16x32xf32>) {
^bb0(%in: i32, %in_0: i32, %out: f32):
%1 = arith.bitcast %in : i32 to f32
%2 = arith.bitcast %in_0 : i32 to f32
%3 = arith.mulf %1, %2 : f32
%4 = arith.addf %out, %3 : f32
linalg.yield %4 : f32
} -> tensor<16x32xf32>
return %0 : tensor<16x32xf32>
}
```
specialising the above linalg.generic like this
```
./mlir-opt --linalg-specialize-generic-ops c1.mlir -o c2.mlir
```
leads to
```
module {
func.func @op_matmul_bitcast_int_to_float(%arg0: tensor<16x8xi32>, %arg1: tensor<8x32xi32>, %arg2: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<16x8xi32>, tensor<8x32xi32>) outs(%arg2 : tensor<16x32xf32>) -> tensor<16x32xf32>
return %0 : tensor<16x32xf32>
}
}
```
and converting the above to linalg.generic again
```
./mlir-opt --linalg-generalize-named-ops c2.mlir -o c3.mlir
```
leads to the following
```
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
func.func @op_matmul_bitcast_int_to_float(%arg0: tensor<16x8xi32>, %arg1: tensor<8x32xi32>, %arg2: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x8xi32>, tensor<8x32xi32>) outs(%arg2 : tensor<16x32xf32>) {
^bb0(%in: i32, %in_0: i32, %out: f32):
%1 = arith.sitofp %in : i32 to f32
%2 = arith.sitofp %in_0 : i32 to f32
%3 = arith.mulf %1, %2 : f32
%4 = arith.addf %out, %3 : f32
linalg.yield %4 : f32
} -> tensor<16x32xf32>
return %0 : tensor<16x32xf32>
}
}
```
Notice that the `arith.bitcast` op was not preserved during the roundtrip and we instead see `arith.sitofop`. Ideally, we should not be loosing any information here.
_______________________________________________
llvm-bugs mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs