llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) <details> <summary>Changes</summary> - **[[mlir][linalg] Refactor vectorization hooks to improve code reuse** - **[mlir][linalg] Simplify `createWriteOrMaskedWrite` (NFC)** --- Full diff: https://github.com/llvm/llvm-project/pull/141567.diff 1 Files Affected: - (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+40-78) ``````````diff diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 0113ba86a5ae3..2abb2f0ea467c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1590,61 +1590,46 @@ static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes, /// Creates an optionally masked TransferWriteOp /// /// Generates the following operation: -/// %res = vector.transfer_write %vectorToStore into %dest +/// %res = vector.transfer_write %vecToStore into %dest /// -/// If the leading N dimensions of the vector to store do not match -/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)), -/// masking is applied to ensure correctness: +/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness: /// -/// %mask = vector.create_mask(%destShape) : %vectorToStoreShape +/// %mask = vector.create_mask(%destShape) : %vecToStoreShape /// %res = vector.mask %mask { -/// vector.transfer_write %vectorToStore into %dest +/// vector.transfer_write %vecToStore into %dest /// } /// -/// The mask shape is identical to `vectorToStore` (with the element type == +/// The mask shape is identical to `vecToStore` (with the element type == /// i1), and the mask values are based on the shape of the `dest` tensor. /// /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute /// is used instead of masking: /// -/// %write = vector.transfer_write %vectorToStore into %dest +/// %write = vector.transfer_write %vecToStore into %dest /// in_bounds_flags = (...) /// %res = vector.transfer_write %input into %dest /// {in_bounds = in_bounds_flags} /// -/// `writeIndices` specifies the offsets to use. If empty, all indices are set -/// to 0. -/// -/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from -/// `valueToStore`. -/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are -/// already provided in `vectorToStore`. +/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices +/// are set to 0. static Operation * -createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore, - Value dest, - ArrayRef<int64_t> inputVecSizesForLeadingDims, - SmallVector<Value> writeIndices = {}, +createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, + Value dest, SmallVector<Value> writeIndices = {}, bool useInBoundsInsteadOfMasking = false) { ShapedType destType = cast<ShapedType>(dest.getType()); int64_t destRank = destType.getRank(); auto destShape = destType.getShape(); - VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType()); + VectorType vecToStoreType = cast<VectorType>(vecToStore.getType()); int64_t vecToStoreRank = vecToStoreType.getRank(); auto vecToStoreShape = vecToStoreType.getShape(); // Compute the in_bounds attribute SmallVector<bool> inBoundsVal(vecToStoreRank, true); if (useInBoundsInsteadOfMasking) { - // In this case, assume that all the required vector sizes have been - // provided. - assert(inputVecSizesForLeadingDims.size() == - static_cast<size_t>(vecToStoreType.getRank()) && - "Insufficient number of input vector sizes!"); - // Update the inBounds attribute. for (unsigned i = 0; i < destRank; i++) - inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) && + inBoundsVal[i] = (destShape[i] == vecToStoreShape[i]) && !ShapedType::isDynamic(destShape[i]); } @@ -1660,7 +1645,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore, // Generate the xfer_write Op Operation *write = builder.create<vector::TransferWriteOp>(loc, - /*vector=*/vectorToStore, + /*vector=*/vecToStore, /*source=*/dest, /*indices=*/writeIndices, /*inBounds=*/inBoundsVal); @@ -1669,46 +1654,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore, if (useInBoundsInsteadOfMasking) return write; - assert(llvm::none_of( - destShape.drop_front(inputVecSizesForLeadingDims.size()), - [](int64_t size) { return size == ShapedType::kDynamic; }) && - "Only dims aligned with inputVecSizesForLeadingDims may be dynamic"); - - // Check if masking is needed. - bool needMaskForWrite = - !llvm::equal(inputVecSizesForLeadingDims, - destShape.take_front(destRank - vecToStoreRank + - inputVecSizesForLeadingDims.size())); - - // If masking is needed, generate the mask and mask the operation. - if (needMaskForWrite) { - // Get the mask shape + type. Missing mask dimensions are taken from - // `vectorToStore`. - SmallVector<int64_t> writeMaskShape; - writeMaskShape.append(inputVecSizesForLeadingDims.begin(), - inputVecSizesForLeadingDims.end()); - if (vecToStoreRank > - static_cast<int64_t>(inputVecSizesForLeadingDims.size())) - writeMaskShape.append(vecToStoreShape.begin() + - inputVecSizesForLeadingDims.size(), - vecToStoreShape.end()); - auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type()); - - SmallVector<OpFoldResult> destSizes = - tensor::getMixedSizes(builder, loc, dest); - SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(), - destSizes.end()); - - if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape, - writeMaskShape)) - return write; - - Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>( - loc, writeMaskType, maskSizes); - write = mlir::vector::maskOperation(builder, write, maskForWrite); - } + // Check if masking is needed. If not, exit. + if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank))) + return write; + + // Compute the mask and mask the write Op. + auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type()); + + SmallVector<OpFoldResult> destSizes = + tensor::getMixedSizes(builder, loc, dest); + SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank, + destSizes.end()); + + if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape, + vecToStoreShape)) + return write; - return write; + Value maskForWrite = + builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes); + return mlir::vector::maskOperation(builder, write, maskForWrite); } /// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant @@ -1808,10 +1772,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, Value dest = rewriter.create<tensor::EmptyOp>( loc, reifiedReturnShapes[0], transposeOp.getResult().getType().getElementType()); - Operation *write = createWriteOrMaskedWrite( - rewriter, loc, transposeOp.getResult(), dest, - /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{}, - /*useInBoundsInsteadOfMasking=*/false); + Operation *write = + createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest, + /*writeIndices=*/{}, + /*useInBoundsInsteadOfMasking=*/false); newResults.push_back(write->getResult(0)); return success(); } @@ -1949,7 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, shapeCastOp.getResult().getType().getElementType()); Operation *write = createWriteOrMaskedWrite( rewriter, loc, shapeCastOp.getResult(), dest, - /*inputVecSizesForLeadingDims=*/writeVectorSizes, /*writeIndices=*/{}, useInBoundsInsteadOfMasking); newResults.push_back(write->getResult(0)); return success(); @@ -1982,10 +1945,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, // Create Xfer write Op Value dest = rewriter.create<tensor::EmptyOp>( loc, reifiedReturnShapes[0], padOp.getResultType().getElementType()); - Operation *write = createWriteOrMaskedWrite( - rewriter, loc, maskedRead, dest, - /*inputVecSizesForLeadingDims=*/inputVectorSizes, {}, - /*useInBoundsInsteadOfMasking=*/false); + Operation *write = + createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest, {}, + /*useInBoundsInsteadOfMasking=*/false); newResults.push_back(write->getResult(0)); return success(); } @@ -3041,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, // Create write auto writeIndices = getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets()); - Operation *write = createWriteOrMaskedWrite( - rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices); + Operation *write = createWriteOrMaskedWrite(rewriter, loc, read, + sliceOp.getDest(), writeIndices); // 4. Finalize newResults.push_back(write->getResult(0)); `````````` </details> https://github.com/llvm/llvm-project/pull/141567 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits