| Issue |
114855
|
| Summary |
[mlir][sparse] mlir-opt crash when lowering softmax with sparse tensors
|
| Labels |
mlir
|
| Assignees |
|
| Reporter |
vmiheer
|
Here's the example mlir performing softmax on sparse tensors. The softmax expansion itself is performed by softmax decomposition in (upstream) mlir.
<details>
<summary>
input.mlir
</summary>
```mlir
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#sparse = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense) }>
module {
func.func @softmax(%arg0: tensor<?x?x?xf32, #sparse>, %arg1: !llvm.ptr) -> tensor<?x?x?xf32, #sparse> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c1_i8 = arith.constant 1 : i8
%c2 = arith.constant 2 : index
%cst = arith.constant 0.000000e+00 : f32
%dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32, #sparse>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32, #sparse>
%dim_1 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32, #sparse>
%0 = tensor.empty(%dim, %dim_0, %dim_1) : tensor<?x?x?xf32, #sparse>
%c0_2 = arith.constant 0 : index
%dim_3 = tensor.dim %arg0, %c0_2 : tensor<?x?x?xf32, #sparse>
%c1_4 = arith.constant 1 : index
%dim_5 = tensor.dim %arg0, %c1_4 : tensor<?x?x?xf32, #sparse>
%c2_6 = arith.constant 2 : index
%dim_7 = tensor.dim %arg0, %c2_6 : tensor<?x?x?xf32, #sparse>
%1 = tensor.empty(%dim_3, %dim_7) : tensor<?x?xf32>
%cst_8 = arith.constant -3.40282347E+38 : f32
%2 = linalg.fill ins(%cst_8 : f32) outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
%3 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<?x?x?xf32, #sparse>) outs(%2 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
%8 = arith.maxnumf %in, %out : f32
linalg.yield %8 : f32
} -> tensor<?x?xf32>
%4 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %3 : tensor<?x?x?xf32, #sparse>, tensor<?x?xf32>) outs(%0 : tensor<?x?x?xf32, #sparse>) {
^bb0(%in: f32, %in_10: f32, %out: f32):
%8 = arith.subf %in, %in_10 : f32
%9 = math.exp %8 : f32
linalg.yield %9 : f32
} -> tensor<?x?x?xf32, #sparse>
%cst_9 = arith.constant 0.000000e+00 : f32
%5 = linalg.fill ins(%cst_9 : f32) outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
%6 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%4 : tensor<?x?x?xf32, #sparse>) outs(%5 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
%8 = arith.addf %in, %out : f32
linalg.yield %8 : f32
} -> tensor<?x?xf32>
%7 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4, %6 : tensor<?x?x?xf32, #sparse>, tensor<?x?xf32>) outs(%0 : tensor<?x?x?xf32, #sparse>) {
^bb0(%in: f32, %in_10: f32, %out: f32):
%8 = arith.divf %in, %in_10 : f32
linalg.yield %8 : f32
} -> tensor<?x?x?xf32, #sparse>
return %7 : tensor<?x?x?xf32, #sparse>
}
}
```
</details>
Commandline: `mlir-opt --sparsifier input.mlir`
Git sha: 33363521ca24f912cc25530f6cecbca53acce8a3
Discourse discussion: https://discourse.llvm.org/t/sparsifier-crash-while-lowering-softmax/82721
Quick reproduction using Compiler Explorer: https://godbolt.org/z/G845EEjMo
Possible resolutions:
1. Add failure in sparsifier for the case specifying features which are not supported.
2. One possible lowering:
<details>
<summary>softmax_sparse</summary>
```
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#csrv = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense) }>
#dense = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : dense) }>
#csr = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
module {
func.func @softmax(%arg0: tensor<?x?x?xf32, #csrv>, %arg1: !llvm.ptr)
-> tensor<?x?x?xf32, #csrv>
// -> tensor<?x?xf32>
{
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c1_i8 = arith.constant 1 : i8
%c2 = arith.constant 2 : index
%cst = arith.constant 0.000000e+00 : f32
%dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32, #csrv>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32, #csrv>
%dim_1 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32, #csrv>
%0 = tensor.empty(%dim, %dim_0, %dim_1) : tensor<?x?x?xf32, #csrv>
%c0_2 = arith.constant 0 : index
%dim_3 = tensor.dim %arg0, %c0_2 : tensor<?x?x?xf32, #csrv>
%c1_4 = arith.constant 1 : index
%dim_5 = tensor.dim %arg0, %c1_4 : tensor<?x?x?xf32, #csrv>
%c2_6 = arith.constant 2 : index
%dim_7 = tensor.dim %arg0, %c2_6 : tensor<?x?x?xf32, #csrv>
%11 = tensor.empty(%dim_3, %dim_7) : tensor<?x?xf32>
%minus_inf = arith.constant -3.40282347E+38 : f32
%21 = linalg.fill ins(%minus_inf : f32) outs(%11 : tensor<?x?xf32>) -> tensor<?x?xf32>
%31 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<?x?x?xf32, #csrv>) outs(%21 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
%res = sparse_tensor.reduce %in, %out, %minus_inf : f32 {
^bb0(%x0: f32, %x1: f32):
%00 = arith.maxnumf %x0, %x1 : f32
sparse_tensor.yield %00: f32
}
linalg.yield %res : f32
} -> tensor<?x?xf32>
%3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<?x?x?xf32, #csrv>) outs(%arg0 : tensor<?x?x?xf32, #csrv>) {
^bb0(%in: f32, %out: f32):
%x = linalg.index 0: index
%y = linalg.index 1: index
%z = linalg.index 2: index
%result = sparse_tensor.unary %in : f32 to f32
present={
^bb0(%in1: f32):
%maxel = tensor.extract %31[%x, %z]: tensor<?x?xf32>
%8 = arith.subf %in1, %maxel : f32
%ret = math.exp %8 : f32
sparse_tensor.yield %ret : f32
}
absent={}
linalg.yield %result : f32
} -> tensor<?x?x?xf32, #csrv>
%1 = tensor.empty(%dim_3, %dim_7) : tensor<?x?xf32>
%cst_8 = arith.constant 0. : f32
%2 = linalg.fill ins(%cst_8 : f32) outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
%4 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<?x?x?xf32, #csrv>) outs(%2 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
%res = sparse_tensor.reduce %in, %out, %cst_8 : f32 {
^bb0(%x0: f32, %x1: f32):
%00 = arith.addf %x0, %x1 : f32
sparse_tensor.yield %00: f32
}
linalg.yield %res : f32
} -> tensor<?x?xf32>
%5 = linalg.generic {indexing_maps = [#map],
iterator_types = ["parallel", "parallel", "parallel"]}
outs(%3: tensor<?x?x?xf32, #csrv>) {
^bb0(%in: f32):
%x = linalg.index 0: index
%z = linalg.index 2: index
%result = sparse_tensor.unary %in : f32 to f32
present={
^bb0(%in1: f32):
%denom = tensor.extract %4[%x, %z]: tensor<?x?xf32>
%ret = arith.divf %in1, %denom : f32
sparse_tensor.yield %ret : f32
}
absent={}
linalg.yield %result : f32
} -> tensor<?x?x?xf32, #csrv>
// return %3: tensor<?x?x?xf32, #csrv>
return %5:tensor<?x?x?xf32, #csrv>
}
}
```
</details>
_______________________________________________
llvm-bugs mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs