Issue |
151795
|
Summary |
[MLIR]Inconsistent Output Between `-convert-math-to-llvm` and `-convert-math-to-spirv`
|
Labels |
mlir
|
Assignees |
|
Reporter |
sweead
|
test commit: [8934a6e](https://github.com/llvm/llvm-project/commit/8934a6e13bd8d2a0ad2609bd62832ca700dab3a7)
## Description:
When converting MLIR code with math operations, different backends (llvm vs spirv) produce inconsistent outputs for the same code
## Steps to Reproduce:
### MLIR program (test.mlir):
```
module {
llvm.func @malloc(i64) -> !llvm.ptr
llvm.mlir.global private constant @__constant_10x1x1xf32(dense<72.4199981> : tensor<10x1x1xf32>) {addr_space = 0 : i32, alignment = 64 : i64} : !llvm.array<10 x array<1 x array<1 x f32>>>
llvm.func @printMemrefF32(i64, !llvm.ptr) attributes {sym_visibility = "private"}
llvm.func @main() {
%0 = llvm.mlir.constant(1 : index) : i64
%1 = llvm.mlir.constant(10 : index) : i64
%2 = llvm.mlir.constant(0 : index) : i64
%3 = llvm.mlir.constant(10 : index) : i64
%4 = llvm.mlir.constant(1 : index) : i64
%5 = llvm.mlir.constant(1 : index) : i64
%6 = llvm.mlir.constant(1 : index) : i64
%7 = llvm.mlir.zero : !llvm.ptr
%8 = llvm.getelementptr %7[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32
%9 = llvm.ptrtoint %8 : !llvm.ptr to i64
%10 = llvm.mlir.addressof @__constant_10x1x1xf32 : !llvm.ptr
%11 = llvm.getelementptr %10[0, 0, 0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x array<1 x array<1 x f32>>>
%12 = llvm.mlir.constant(3735928559 : index) : i64
%13 = llvm.inttoptr %12 : i64 to !llvm.ptr
%14 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%15 = llvm.insertvalue %13, %14[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%16 = llvm.insertvalue %11, %15[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%17 = llvm.mlir.constant(0 : index) : i64
%18 = llvm.insertvalue %17, %16[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%19 = llvm.insertvalue %3, %18[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%20 = llvm.insertvalue %4, %19[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%21 = llvm.insertvalue %5, %20[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%22 = llvm.insertvalue %4, %21[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%23 = llvm.insertvalue %5, %22[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%24 = llvm.insertvalue %6, %23[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%25 = llvm.mlir.constant(10 : index) : i64
%26 = llvm.mlir.constant(1 : index) : i64
%27 = llvm.mlir.constant(1 : index) : i64
%28 = llvm.mlir.constant(1 : index) : i64
%29 = llvm.mlir.zero : !llvm.ptr
%30 = llvm.getelementptr %29[%25] : (!llvm.ptr, i64) -> !llvm.ptr, f32
%31 = llvm.ptrtoint %30 : !llvm.ptr to i64
%32 = llvm.mlir.constant(64 : index) : i64
%33 = llvm.add %31, %32 : i64
%34 = llvm.call @malloc(%33) : (i64) -> !llvm.ptr
%35 = llvm.ptrtoint %34 : !llvm.ptr to i64
%36 = llvm.mlir.constant(1 : index) : i64
%37 = llvm.sub %32, %36 : i64
%38 = llvm.add %35, %37 : i64
%39 = llvm.urem %38, %32 : i64
%40 = llvm.sub %38, %39 : i64
%41 = llvm.inttoptr %40 : i64 to !llvm.ptr
%42 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%43 = llvm.insertvalue %34, %42[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%44 = llvm.insertvalue %41, %43[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%45 = llvm.mlir.constant(0 : index) : i64
%46 = llvm.insertvalue %45, %44[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%47 = llvm.insertvalue %25, %46[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%48 = llvm.insertvalue %26, %47[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%49 = llvm.insertvalue %27, %48[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%50 = llvm.insertvalue %26, %49[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%51 = llvm.insertvalue %27, %50[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%52 = llvm.insertvalue %28, %51[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
llvm.br ^bb1(%2 : i64)
^bb1(%53: i64): // 2 preds: ^bb0, ^bb8
%54 = llvm.icmp "slt" %53, %1 : i64
llvm.cond_br %54, ^bb2, ^bb9
^bb2: // pred: ^bb1
llvm.br ^bb3(%2 : i64)
^bb3(%55: i64): // 2 preds: ^bb2, ^bb7
%56 = llvm.icmp "slt" %55, %0 : i64
llvm.cond_br %56, ^bb4, ^bb8
^bb4: // pred: ^bb3
llvm.br ^bb5(%2 : i64)
^bb5(%57: i64): // 2 preds: ^bb4, ^bb6
%58 = llvm.icmp "slt" %57, %0 : i64
llvm.cond_br %58, ^bb6, ^bb7
^bb6: // pred: ^bb5
%59 = llvm.extractvalue %24[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%60 = llvm.add %53, %55 overflow<nsw, nuw> : i64
%61 = llvm.add %60, %57 overflow<nsw, nuw> : i64
%62 = llvm.getelementptr inbounds|nuw %59[%61] : (!llvm.ptr, i64) -> !llvm.ptr, f32
%63 = llvm.load %62 : !llvm.ptr -> f32
%64 = math.tanh %63 : f32
%65 = llvm.extractvalue %52[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
%66 = llvm.add %53, %55 overflow<nsw, nuw> : i64
%67 = llvm.add %66, %57 overflow<nsw, nuw> : i64
%68 = llvm.getelementptr inbounds|nuw %65[%67] : (!llvm.ptr, i64) -> !llvm.ptr, f32
llvm.store %64, %68 : f32, !llvm.ptr
%69 = llvm.add %57, %0 : i64
llvm.br ^bb5(%69 : i64)
^bb7: // pred: ^bb5
%70 = llvm.add %55, %0 : i64
llvm.br ^bb3(%70 : i64)
^bb8: // pred: ^bb3
%71 = llvm.add %53, %0 : i64
llvm.br ^bb1(%71 : i64)
^bb9: // pred: ^bb1
%72 = llvm.mlir.constant(1 : index) : i64
%73 = llvm.alloca %72 x !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> : (i64) -> !llvm.ptr
llvm.store %52, %73 : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>, !llvm.ptr
%74 = llvm.mlir.constant(3 : index) : i64
%75 = llvm.mlir.poison : !llvm.struct<(i64, ptr)>
%76 = llvm.insertvalue %74, %75[0] : !llvm.struct<(i64, ptr)>
%77 = llvm.insertvalue %73, %76[1] : !llvm.struct<(i64, ptr)>
%78 = llvm.extractvalue %77[0] : !llvm.struct<(i64, ptr)>
%79 = llvm.extractvalue %77[1] : !llvm.struct<(i64, ptr)>
llvm.call @printMemrefF32(%78, %79) : (i64, !llvm.ptr) -> ()
llvm.return
}
}
```
#### Use -convert-math-to-llvm :
```
/home/workdir/llvm-project/build/bin/mlir-opt test.mlir -convert-math-to-llvm \
| /home/workdir/llvm-project/build/bin/mlir-runner -e main -entry-point-result=void -shared-libs=/home/workdir/llvm-project/build/lib/libmlir_runner_utils.so
```
#### Output:
```
[[[1]],[[1]],[[1]],[[1]],[[1]],[[1]],[[1]],[[1]],[[1]],[[1]]]
```
### 2. Use -convert-math-to-spirv:
```
/home/workdir/llvm-project/build/bin/mlir-opt test.mlir -convert-math-to-spirv \
-one-shot-bufferize="bufferize-function-boundaries=1" -convert-spirv-to-llvm\
| /home/workdir/llvm-project/build/bin/mlir-runner -e main -entry-point-result=void -shared-libs=/home/workdir/llvm-project/build/lib/libmlir_runner_utils.so
```
#### Output:
```
[[[-nan]], [[-nan]], [[-nan]], [[-nan]], [[-nan]], [[-nan]], [[-nan]], [[-nan]], [[-nan]], [[-nan]]]
```
_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs