https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/66563
>From afd923169445f8800365859145c8abd0823c5ef7 Mon Sep 17 00:00:00 2001 From: Aart Bik <ajc...@google.com> Date: Fri, 15 Sep 2023 17:22:34 -0700 Subject: [PATCH] [mlir][sparse] refine sparse fusion with empty tensors materialization This is a minor step towards deprecating bufferization.alloc_tensor(). It replaces the examples with tensor.empty() and adjusts the underlying rewriting logic to prepare for this upcoming change. --- .../Transforms/SparseTensorRewriting.cpp | 28 +++++----- .../Dialect/SparseTensor/sparse_sddmm.mlir | 54 +++++++++---------- 2 files changed, 41 insertions(+), 41 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 38e6621d54b331d..08482de5879ded7 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -50,8 +50,8 @@ static bool isSparseTensor(Value v) { } static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); } -// Helper method to find zero/uninitialized allocation. -static bool isAlloc(OpOperand *op, bool isZero) { +// Helper method to find zero/uninitialized tensor materialization. +static bool isMaterializing(OpOperand *op, bool isZero) { Value val = op->get(); // Check allocation, with zero alloc when required. if (auto alloc = val.getDefiningOp<AllocTensorOp>()) { @@ -60,6 +60,9 @@ static bool isAlloc(OpOperand *op, bool isZero) { return copy && isZeroValue(copy); return !copy; } + // Check for empty tensor materialization. + if (auto empty = val.getDefiningOp<tensor::EmptyOp>()) + return !isZero; // Last resort for zero alloc: the whole value is zero. return isZero && isZeroValue(val); } @@ -219,24 +222,22 @@ struct FoldInvariantYield : public OpRewritePattern<GenericOp> { LogicalResult matchAndRewrite(GenericOp op, PatternRewriter &rewriter) const override { if (!op.hasTensorSemantics() || op.getNumResults() != 1 || - !isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) || + !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) || !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse()) return failure(); auto outputType = getRankedTensorType(op.getResult(0)); - // Yielding zero on newly allocated (all-zero) sparse tensors can be - // optimized out directly (regardless of dynamic or static size). + // Yielding zero on newly materialized sparse tensor can be + // optimized directly (regardless of dynamic or static size). if (getSparseTensorEncoding(outputType)) { rewriter.replaceOp(op, op.getDpsInitOperand(0)->get()); return success(); } - // Incorporate zero value into allocation copy. + // Use static zero value directly instead of materialization. if (!outputType.hasStaticShape()) return failure(); - Value zero = constantZero(rewriter, op.getLoc(), op.getResult(0).getType()); - AllocTensorOp a = - op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>(); - rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(zero); }); - rewriter.replaceOp(op, op.getDpsInitOperand(0)->get()); + Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp(); + rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType)); + rewriter.eraseOp(def); return success(); } }; @@ -286,8 +287,8 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> { !prod.getResult(0).hasOneUse()) return failure(); // Sampling consumer and sum of multiplication chain producer. - if (!isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) || - !isAlloc(prod.getDpsInitOperand(0), /*isZero=*/true) || + if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) || + !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) || !isSampling(op) || !isSumOfMul(prod)) return failure(); // Modify operand structure of producer and consumer. @@ -327,6 +328,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> { last = rewriter.clone(*acc, mapper)->getResult(0); rewriter.create<linalg::YieldOp>(loc, last); // Force initial value on merged allocation for dense outputs. + // TODO: deal with non alloc tensor here one day if (!getSparseTensorEncoding(op.getResult(0).getType())) { Value init = prod.getDpsInitOperand(0) ->get() diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir index 610ff30a48c4a4f..707648e42cbd849 100755 --- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir @@ -21,13 +21,12 @@ } // CHECK-LABEL: func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> { -// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf64> -// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() copy(%[[VAL_0]]) {bufferization.escape = [false]} : tensor<1024x1024xf64> -// CHECK: return %[[VAL_1]] : tensor<1024x1024xf64> +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf64> +// CHECK: return %[[C0]] : tensor<1024x1024xf64> // CHECK: } func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> { %cst = arith.constant 0.000000e+00 : f64 - %0 = bufferization.alloc_tensor() : tensor<1024x1024xf64> + %0 = tensor.empty() : tensor<1024x1024xf64> %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} @@ -40,13 +39,12 @@ func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> { } // CHECK-LABEL: func.func @fold_yield_direct_zero() -> tensor<32xf64> { -// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf64> -// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() copy(%[[VAL_0]]) {bufferization.escape = [false]} : tensor<32xf64> -// CHECK: return %[[VAL_1]] : tensor<32xf64> +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf64> +// CHECK: return %[[C0]] : tensor<32xf64> // CHECK: } func.func @fold_yield_direct_zero() -> tensor<32xf64> { %cst = arith.constant 0.000000e+00 : f64 - %0 = bufferization.alloc_tensor() : tensor<32xf64> + %0 = tensor.empty() : tensor<32xf64> %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} outs(%0 : tensor<32xf64>) { @@ -92,9 +90,9 @@ func.func @fold_yield_direct_zero() -> tensor<32xf64> { // CHECK: %[[VAL_32:.*]] = arith.mulf %[[VAL_30]], %[[VAL_31]] : f64 // CHECK: %[[VAL_33:.*]] = arith.addf %[[VAL_28]], %[[VAL_32]] : f64 // CHECK: memref.store %[[VAL_33]], %[[VAL_16]]{{\[}}%[[VAL_20]], %[[VAL_27]]] : memref<8x8xf64> -// CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: } +// CHECK: } +// CHECK: } // CHECK: %[[VAL_34:.*]] = bufferization.to_tensor %[[VAL_16]] : memref<8x8xf64> // CHECK: return %[[VAL_34]] : tensor<8x8xf64> // CHECK: } @@ -123,9 +121,9 @@ func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>, } // CHECK-LABEL: func.func @sparse_sampled_dd_unfused( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>, +// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf64>, -// CHECK-SAME: %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> { +// CHECK-SAME: %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> { // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index @@ -133,19 +131,19 @@ func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>, // CHECK-DAG: %[[VAL_7:.*]] = arith.constant true // CHECK-DAG: %[[VAL_8:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64> // CHECK-DAG: %[[VAL_9:.*]] = bufferization.alloc_tensor() copy(%[[VAL_8]]) {bufferization.escape = [false]} : tensor<8x8xf64> -// CHECK-DAG: %[[VAL_10:.*]] = bufferization.alloc_tensor() {bufferization.escape = [false]} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> +// CHECK-DAG: %[[VAL_10:.*]] = tensor.empty() : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> // CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64> // CHECK-DAG: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64> -// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex> -// CHECK-DAG: %[[VAL_14:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex> -// CHECK-DAG: %[[VAL_15:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex> -// CHECK-DAG: %[[VAL_16:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex> -// CHECK-DAG: %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xf64> +// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex> +// CHECK-DAG: %[[VAL_14:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex> +// CHECK-DAG: %[[VAL_15:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex> +// CHECK-DAG: %[[VAL_16:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex> +// CHECK-DAG: %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64> // CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_4]]] : memref<?xindex> // CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_5]]] : memref<?xindex> -// CHECK: %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] iter_args(%[[VAL_22:.*]] = %[[VAL_10]]) -> (tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>) { +// CHECK: %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] iter_args(%[[VAL_22:.*]] = %[[VAL_10]]) -> (tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>) { // CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_21]]] : memref<?xindex> -// CHECK: %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.expand %[[VAL_10]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xf64>, memref<?xi1>, memref<?xindex> +// CHECK: %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.expand %[[VAL_10]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>, memref<?xi1>, memref<?xindex> // CHECK: %[[VAL_28:.*]] = scf.for %[[VAL_29:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_30:.*]] = %[[VAL_27]]) -> (index) { // CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]], %[[VAL_29]]] : memref<8x8xf64> // CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_21]]] : memref<?xindex> @@ -170,15 +168,15 @@ func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>, // CHECK: scf.yield %[[VAL_37]] : index // CHECK: } // CHECK: memref.store %[[VAL_44]], %[[VAL_24]]{{\[}}%[[VAL_38]]] : memref<?xf64> -// CHECK: scf.yield %[[VAL_49:.*]] : index +// CHECK: scf.yield %[[VAL_47]] : index // CHECK: } -// CHECK: scf.yield %[[VAL_50:.*]] : index +// CHECK: scf.yield %[[VAL_35]] : index // CHECK: } -// CHECK: %[[VAL_51:.*]] = sparse_tensor.compress %[[VAL_24]], %[[VAL_25]], %[[VAL_26]], %[[VAL_52:.*]] into %[[VAL_22]]{{\[}}%[[VAL_23]]] : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> -// CHECK: scf.yield %[[VAL_51]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> +// CHECK: %[[VAL_49:.*]] = sparse_tensor.compress %[[VAL_24]], %[[VAL_25]], %[[VAL_26]], %[[VAL_28]] into %[[VAL_22]]{{\[}}%[[VAL_23]]] : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: scf.yield %[[VAL_49]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> // CHECK: } -// CHECK: %[[VAL_53:.*]] = sparse_tensor.load %[[VAL_54:.*]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> -// CHECK: return %[[VAL_53]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> +// CHECK: %[[VAL_50:.*]] = sparse_tensor.load %[[VAL_20]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: return %[[VAL_50]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> // CHECK: } func.func @sparse_sampled_dd_unfused(%args: tensor<8x8xf64, #SM>, %arga: tensor<8x8xf64>, @@ -194,7 +192,7 @@ func.func @sparse_sampled_dd_unfused(%args: tensor<8x8xf64, #SM>, linalg.yield %q : f64 } -> tensor<8x8xf64> // Sample the result with elements-wise multiplication with sparse matrix. - %3 = bufferization.alloc_tensor() : tensor<8x8xf64, #SM> + %3 = tensor.empty() : tensor<8x8xf64, #SM> %4 = linalg.generic #trait_scale ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>) outs(%3 : tensor<8x8xf64, #SM>) { _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits