llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

<details>
<summary>Changes</summary>

- **[[mlir][linalg] Refactor vectorization hooks to improve code reuse**
- **[mlir][linalg] Simplify `createWriteOrMaskedWrite` (NFC)**


---
Full diff: https://github.com/llvm/llvm-project/pull/141567.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+40-78) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp 
b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0113ba86a5ae3..2abb2f0ea467c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1590,61 +1590,46 @@ static bool 
isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
 /// Creates an optionally masked TransferWriteOp
 ///
 /// Generates the following operation:
-///   %res = vector.transfer_write %vectorToStore into %dest
+///   %res = vector.transfer_write %vecToStore into %dest
 ///
-/// If the leading N dimensions of the vector to store do not match
-/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
-/// masking is applied to ensure correctness:
+/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
 ///
-///   %mask = vector.create_mask(%destShape) : %vectorToStoreShape
+///   %mask = vector.create_mask(%destShape) : %vecToStoreShape
 ///   %res = vector.mask %mask {
-///     vector.transfer_write %vectorToStore into %dest
+///     vector.transfer_write %vecToStore into %dest
 ///   }
 ///
-/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// The mask shape is identical to `vecToStore` (with the element type ==
 /// i1), and the mask values are based on the shape of the `dest` tensor.
 ///
 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` 
attribute
 /// is used instead of masking:
 ///
-///   %write = vector.transfer_write %vectorToStore into %dest
+///   %write = vector.transfer_write %vecToStore into %dest
 ///   in_bounds_flags = (...)
 ///   %res = vector.transfer_write %input into %dest
 ///       {in_bounds = in_bounds_flags}
 ///
-/// `writeIndices` specifies the offsets to use. If empty, all indices are set
-/// to 0.
-///
-/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
-/// `valueToStore`.
-/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes 
are
-/// already provided in `vectorToStore`.
+/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
+/// are set to 0.
 static Operation *
-createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
-                         Value dest,
-                         ArrayRef<int64_t> inputVecSizesForLeadingDims,
-                         SmallVector<Value> writeIndices = {},
+createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
+                         Value dest, SmallVector<Value> writeIndices = {},
                          bool useInBoundsInsteadOfMasking = false) {
 
   ShapedType destType = cast<ShapedType>(dest.getType());
   int64_t destRank = destType.getRank();
   auto destShape = destType.getShape();
 
-  VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+  VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
   int64_t vecToStoreRank = vecToStoreType.getRank();
   auto vecToStoreShape = vecToStoreType.getShape();
 
   // Compute the in_bounds attribute
   SmallVector<bool> inBoundsVal(vecToStoreRank, true);
   if (useInBoundsInsteadOfMasking) {
-    // In this case, assume that all the required vector sizes have been
-    // provided.
-    assert(inputVecSizesForLeadingDims.size() ==
-               static_cast<size_t>(vecToStoreType.getRank()) &&
-           "Insufficient number of input vector sizes!");
-    // Update the inBounds attribute.
     for (unsigned i = 0; i < destRank; i++)
-      inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
+      inBoundsVal[i] = (destShape[i] == vecToStoreShape[i]) &&
                        !ShapedType::isDynamic(destShape[i]);
   }
 
@@ -1660,7 +1645,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location 
loc, Value vectorToStore,
   // Generate the xfer_write Op
   Operation *write =
       builder.create<vector::TransferWriteOp>(loc,
-                                              /*vector=*/vectorToStore,
+                                              /*vector=*/vecToStore,
                                               /*source=*/dest,
                                               /*indices=*/writeIndices,
                                               /*inBounds=*/inBoundsVal);
@@ -1669,46 +1654,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location 
loc, Value vectorToStore,
   if (useInBoundsInsteadOfMasking)
     return write;
 
-  assert(llvm::none_of(
-             destShape.drop_front(inputVecSizesForLeadingDims.size()),
-             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
-         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
-
-  // Check if masking is needed.
-  bool needMaskForWrite =
-      !llvm::equal(inputVecSizesForLeadingDims,
-                   destShape.take_front(destRank - vecToStoreRank +
-                                        inputVecSizesForLeadingDims.size()));
-
-  // If masking is needed, generate the mask and mask the operation.
-  if (needMaskForWrite) {
-    // Get the mask shape + type. Missing mask dimensions are taken from
-    // `vectorToStore`.
-    SmallVector<int64_t> writeMaskShape;
-    writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
-                          inputVecSizesForLeadingDims.end());
-    if (vecToStoreRank >
-        static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
-      writeMaskShape.append(vecToStoreShape.begin() +
-                                inputVecSizesForLeadingDims.size(),
-                            vecToStoreShape.end());
-    auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-
-    SmallVector<OpFoldResult> destSizes =
-        tensor::getMixedSizes(builder, loc, dest);
-    SmallVector<OpFoldResult> maskSizes(destSizes.end() - 
writeMaskShape.size(),
-                                        destSizes.end());
-
-    if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
-                                writeMaskShape))
-      return write;
-
-    Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
-        loc, writeMaskType, maskSizes);
-    write = mlir::vector::maskOperation(builder, write, maskForWrite);
-  }
+  // Check if masking is needed. If not, exit.
+  if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
+    return write;
+
+  // Compute the mask and mask the write Op.
+  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+
+  SmallVector<OpFoldResult> destSizes =
+      tensor::getMixedSizes(builder, loc, dest);
+  SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
+                                      destSizes.end());
+
+  if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+                              vecToStoreShape))
+    return write;
 
-  return write;
+  Value maskForWrite =
+      builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, 
maskSizes);
+  return mlir::vector::maskOperation(builder, write, maskForWrite);
 }
 
 /// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1808,10 +1772,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, 
linalg::PackOp packOp,
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0],
       transposeOp.getResult().getType().getElementType());
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, transposeOp.getResult(), dest,
-      /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
-      /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
+                               /*writeIndices=*/{},
+                               /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1949,7 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, 
linalg::UnPackOp unpackOp,
       shapeCastOp.getResult().getType().getElementType());
   Operation *write = createWriteOrMaskedWrite(
       rewriter, loc, shapeCastOp.getResult(), dest,
-      /*inputVecSizesForLeadingDims=*/writeVectorSizes,
       /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
   newResults.push_back(write->getResult(0));
   return success();
@@ -1982,10 +1945,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, 
tensor::PadOp padOp,
   // Create Xfer write Op
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, maskedRead, dest,
-      /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
-      /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest, {},
+                               /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -3041,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, 
tensor::InsertSliceOp sliceOp,
   // Create write
   auto writeIndices =
       getValueOrCreateConstantIndexOp(rewriter, loc, 
sliceOp.getMixedOffsets());
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), 
writeIndices);
+  Operation *write = createWriteOrMaskedWrite(rewriter, loc, read,
+                                              sliceOp.getDest(), writeIndices);
 
   // 4. Finalize
   newResults.push_back(write->getResult(0));

``````````

</details>


https://github.com/llvm/llvm-project/pull/141567
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to