Issue 141667
Summary [MLIR] Bufferization of `tensor.generate` does not take into account repetitions
Labels mlir
Assignees
Reporter erick-xanadu
    Hello,

I have what I believe is an error in the bufferization of `tensor.generate`. When `tensor.generate` is being bufferized, it will bufferize the function body with the same rules as outside the body. In the following example, we see that `circuit_0` is called with a different value in the argument each time it is called. This value is obtained through the extraction, addition, and insertion into a tensor obtained from the context above.

```mlir
 func.func private @circuit_0.finitediff0(%arg0: tensor<2xf64>) -> tensor<2x2xf64> {
    %cst = arith.constant 3.000000e-01 : f64
    %cst_0 = arith.constant dense<3.000000e-01> : tensor<2x2xf64>
    %0 = call @circuit_0(%arg0) : (tensor<2xf64>) -> tensor<2xf64>
    %generated = tensor.generate  {
    ^bb0(%arg1: index, %arg2: index):

      // important bit
      %extracted = tensor.extract %arg0[%arg2] : tensor<2xf64>
      %2 = arith.addf %extracted, %cst : f64
      %inserted = tensor.insert %2 into %arg0[%arg2] : tensor<2xf64>
      // new value being passed here each time we loop trhough tensor.generate
      %3 = func.call @circuit_0(%inserted) : (tensor<2xf64>) -> tensor<2xf64>


 %4 = arith.subf %3, %0 : tensor<2xf64>
      %extracted_1 = tensor.extract %4[%arg1] : tensor<2xf64>
      tensor.yield %extracted_1 : f64
    } : tensor<2x2xf64>
    %1 = arith.divf %generated, %cst_0 : tensor<2x2xf64>
 return %1 : tensor<2x2xf64>
  }
```

However, after bufferization, we see the following code:

```mlir
  func.func private @circuit_0.finitediff0(%arg0: memref<2xf64>) -> memref<2x2xf64> {
    %cst = arith.constant 3.000000e-01 : f64 
    %0 = memref.get_global @__constant_2x2xf64 : memref<2x2xf64>                                         
    %1 = call @circuit_0(%arg0) : (memref<2xf64>) -> memref<2xf64> 
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x2xf64>                                     
    linalg.map outs(%alloc : memref<2x2xf64>) 
      () { 


        %2 = linalg.index 0 : index
 %3 = linalg.index 1 : index 
        %4 = memref.load %arg0[%3] : memref<2xf64> 
        %5 = arith.addf %4, %cst : f64
        memref.store %5, %arg0[%3] : memref<2xf64>

        // value of arg0 changes
        // with each iteration
        %6 = func.call @circuit_0(%arg0) : (memref<2xf64>) -> memref<2xf64>

 
        linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%6, %1 : memref<2xf64>, memref<2xf64>) outs(%6 : memref<2xf64>) {
        ^bb0(%in: f64, %in_0: f64, %out: f64): 
          %8 = arith.subf %in, %in_0 : f64                                                               
 linalg.yield %8 : f64
        } 
        %7 = memref.load %6[%2] : memref<2xf64> 
        linalg.yield %7 : f64 
      }
    linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%alloc, %0 : memref<2x2xf64>, memref<2x2xf64>) outs(%alloc : memref<2x2xf64>) {                                                  
 ^bb0(%in: f64, %in_0: f64, %out: f64): 
      %2 = arith.divf %in, %in_0 : f64 
      linalg.yield %2 : f64 
    } 
    return %alloc : memref<2x2xf64>
 }
```

It looks like this may stem from the lack of bufferization of the `linalg.map` op, but I am not entirely sure. 
_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs

Reply via email to