Issue |
141667
|
Summary |
[MLIR] Bufferization of `tensor.generate` does not take into account repetitions
|
Labels |
mlir
|
Assignees |
|
Reporter |
erick-xanadu
|
Hello,
I have what I believe is an error in the bufferization of `tensor.generate`. When `tensor.generate` is being bufferized, it will bufferize the function body with the same rules as outside the body. In the following example, we see that `circuit_0` is called with a different value in the argument each time it is called. This value is obtained through the extraction, addition, and insertion into a tensor obtained from the context above.
```mlir
func.func private @circuit_0.finitediff0(%arg0: tensor<2xf64>) -> tensor<2x2xf64> {
%cst = arith.constant 3.000000e-01 : f64
%cst_0 = arith.constant dense<3.000000e-01> : tensor<2x2xf64>
%0 = call @circuit_0(%arg0) : (tensor<2xf64>) -> tensor<2xf64>
%generated = tensor.generate {
^bb0(%arg1: index, %arg2: index):
// important bit
%extracted = tensor.extract %arg0[%arg2] : tensor<2xf64>
%2 = arith.addf %extracted, %cst : f64
%inserted = tensor.insert %2 into %arg0[%arg2] : tensor<2xf64>
// new value being passed here each time we loop trhough tensor.generate
%3 = func.call @circuit_0(%inserted) : (tensor<2xf64>) -> tensor<2xf64>
%4 = arith.subf %3, %0 : tensor<2xf64>
%extracted_1 = tensor.extract %4[%arg1] : tensor<2xf64>
tensor.yield %extracted_1 : f64
} : tensor<2x2xf64>
%1 = arith.divf %generated, %cst_0 : tensor<2x2xf64>
return %1 : tensor<2x2xf64>
}
```
However, after bufferization, we see the following code:
```mlir
func.func private @circuit_0.finitediff0(%arg0: memref<2xf64>) -> memref<2x2xf64> {
%cst = arith.constant 3.000000e-01 : f64
%0 = memref.get_global @__constant_2x2xf64 : memref<2x2xf64>
%1 = call @circuit_0(%arg0) : (memref<2xf64>) -> memref<2xf64>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<2x2xf64>
linalg.map outs(%alloc : memref<2x2xf64>)
() {
%2 = linalg.index 0 : index
%3 = linalg.index 1 : index
%4 = memref.load %arg0[%3] : memref<2xf64>
%5 = arith.addf %4, %cst : f64
memref.store %5, %arg0[%3] : memref<2xf64>
// value of arg0 changes
// with each iteration
%6 = func.call @circuit_0(%arg0) : (memref<2xf64>) -> memref<2xf64>
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%6, %1 : memref<2xf64>, memref<2xf64>) outs(%6 : memref<2xf64>) {
^bb0(%in: f64, %in_0: f64, %out: f64):
%8 = arith.subf %in, %in_0 : f64
linalg.yield %8 : f64
}
%7 = memref.load %6[%2] : memref<2xf64>
linalg.yield %7 : f64
}
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%alloc, %0 : memref<2x2xf64>, memref<2x2xf64>) outs(%alloc : memref<2x2xf64>) {
^bb0(%in: f64, %in_0: f64, %out: f64):
%2 = arith.divf %in, %in_0 : f64
linalg.yield %2 : f64
}
return %alloc : memref<2x2xf64>
}
```
It looks like this may stem from the lack of bufferization of the `linalg.map` op, but I am not entirely sure.
_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs