Author: Aart Bik Date: 2021-01-05T15:31:39-08:00 New Revision: 8b124c19f52cb8ed0236b602df56787553e1e1b6
URL: https://github.com/llvm/llvm-project/commit/8b124c19f52cb8ed0236b602df56787553e1e1b6 DIFF: https://github.com/llvm/llvm-project/commit/8b124c19f52cb8ed0236b602df56787553e1e1b6.diff LOG: [mlir][sparse] adjust output shape inference to new tensor abstraction Nicolas changed the tensor abstraction so that every output has its own shape definition. This simplifies the "inference" that was used in the sparse compiler. Reviewed By: penpornk Differential Revision: https://reviews.llvm.org/D94119 Added: Modified: mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp mlir/test/Dialect/Linalg/sparse_2d.mlir Removed: ################################################################################ diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp index a6b7277e47e3..ed81d5e24805 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp @@ -538,15 +538,8 @@ static void genBuffers(Merger &merger, CodeGen &codegen, // Find lower and upper bound in current dimension. Value up; if (shape[d] == TensorType::kDynamicSize) { - // For the output tensor, we may need to infer the upper bound. - // For all others, we look at the incoming argument. - if (t == numInputs && !op.getNumInitTensors()) { - up = codegen.sizes[i]; - assert(up); // TODO: what else? - } else { - Value arg = t < numInputs ? op.getInput(t) : op.getInitTensors()[0]; - up = rewriter.create<DimOp>(loc, arg, d); - } + Value arg = t < numInputs ? op.getInput(t) : op.getOutput(0); + up = rewriter.create<DimOp>(loc, arg, d); args.push_back(up); } else { up = rewriter.create<ConstantIndexOp>(loc, shape[d]); diff --git a/mlir/test/Dialect/Linalg/sparse_2d.mlir b/mlir/test/Dialect/Linalg/sparse_2d.mlir index 6612a723f23d..9bb68ca91089 100644 --- a/mlir/test/Dialect/Linalg/sparse_2d.mlir +++ b/mlir/test/Dialect/Linalg/sparse_2d.mlir @@ -1139,19 +1139,19 @@ func @sum_reduction(%arga: tensor<10x20xf32>, %argx: tensor<f32>) -> tensor<f32> // CHECK: %[[VAL_2:.*]] = constant 999 : index // CHECK: %[[VAL_3:.*]] = constant 0 : index // CHECK: %[[VAL_4:.*]] = constant 1 : index -// CHECK: %[[VAL_5:.*]] = dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?xf64> +// CHECK: %[[VAL_5:.*]] = alloca(%[[VAL_2]]) : memref<?xindex> // CHECK: %[[VAL_6:.*]] = alloca(%[[VAL_2]]) : memref<?xindex> -// CHECK: %[[VAL_7:.*]] = alloca(%[[VAL_2]]) : memref<?xindex> -// CHECK: %[[VAL_8:.*]] = dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf64> -// CHECK: %[[VAL_9:.*]] = alloca(%[[VAL_2]]) : memref<?xf64> -// CHECK: %[[VAL_10:.*]] = alloca(%[[VAL_5]], %[[VAL_8]]) : memref<?x?xf64> -// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] { -// CHECK: %[[VAL_12:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex> +// CHECK: %[[VAL_7:.*]] = alloca(%[[VAL_2]]) : memref<?xf64> +// CHECK: %[[VAL_8:.*]] = dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?xf64> +// CHECK: %[[VAL_9:.*]] = dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf64> +// CHECK: %[[VAL_10:.*]] = alloca(%[[VAL_8]], %[[VAL_9]]) : memref<?x?xf64> +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_8]] step %[[VAL_4]] { +// CHECK: %[[VAL_12:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_11]]] : memref<?xindex> // CHECK: %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_4]] : index -// CHECK: %[[VAL_14:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xindex> +// CHECK: %[[VAL_14:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_13]]] : memref<?xindex> // CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_4]] { -// CHECK: %[[VAL_16:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex> -// CHECK: %[[VAL_17:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref<?xf64> +// CHECK: %[[VAL_16:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref<?xindex> +// CHECK: %[[VAL_17:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xf64> // CHECK: %[[VAL_18:.*]] = mulf %[[VAL_17]], %[[VAL_1]] : f64 // CHECK: store %[[VAL_18]], %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_16]]] : memref<?x?xf64> // CHECK: } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits