Issue 177593
Summary [mlir][linalg] bitcast is not preserved during matmul specialisation from generic
Labels mlir
Assignees
Reporter meshtag
    ```
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>

func.func @op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
 %B: tensor<8x32xi32>,
 %Out: tensor<16x32xf32>) -> tensor<16x32xf32> {
  %0 = linalg.generic
         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
 ins(%A, %B : tensor<16x8xi32>, tensor<8x32xi32>) outs(%Out : tensor<16x32xf32>) {
  ^bb0(%in: i32, %in_0: i32, %out: f32):
    %1 = arith.bitcast %in : i32 to f32
    %2 = arith.bitcast %in_0 : i32 to f32
 %3 = arith.mulf %1, %2 : f32
    %4 = arith.addf %out, %3 : f32
 linalg.yield %4 : f32
  } -> tensor<16x32xf32>
  return %0 : tensor<16x32xf32>
}
```

specialising the above linalg.generic like this
```
./mlir-opt --linalg-specialize-generic-ops c1.mlir -o c2.mlir
```

leads to 
```
module {
  func.func @op_matmul_bitcast_int_to_float(%arg0: tensor<16x8xi32>, %arg1: tensor<8x32xi32>, %arg2: tensor<16x32xf32>) -> tensor<16x32xf32> {
    %0 = linalg.matmul ins(%arg0, %arg1 : tensor<16x8xi32>, tensor<8x32xi32>) outs(%arg2 : tensor<16x32xf32>) -> tensor<16x32xf32>
    return %0 : tensor<16x32xf32>
  }
}
```

and converting the above to linalg.generic again
```
./mlir-opt --linalg-generalize-named-ops c2.mlir -o c3.mlir 
```
leads to the following 
```
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
  func.func @op_matmul_bitcast_int_to_float(%arg0: tensor<16x8xi32>, %arg1: tensor<8x32xi32>, %arg2: tensor<16x32xf32>) -> tensor<16x32xf32> {
    %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x8xi32>, tensor<8x32xi32>) outs(%arg2 : tensor<16x32xf32>) {
 ^bb0(%in: i32, %in_0: i32, %out: f32):
      %1 = arith.sitofp %in : i32 to f32
      %2 = arith.sitofp %in_0 : i32 to f32
      %3 = arith.mulf %1, %2 : f32
      %4 = arith.addf %out, %3 : f32
      linalg.yield %4 : f32
 } -> tensor<16x32xf32>
    return %0 : tensor<16x32xf32>
 }
}
```

Notice that the `arith.bitcast` op was not preserved during the roundtrip and we instead see `arith.sitofop`. Ideally, we should not be loosing any information here.
_______________________________________________
llvm-bugs mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs

Reply via email to