================ @@ -1871,115 +1872,98 @@ static VectorType getCollapsedVecType(VectorType type, return VectorType::get(newShape, type.getElementType(), newScalableFlags); } -/// Vectorize a `linalg::UnPackOp` to these 4 Ops: -/// Vector::TransferReadOp - Reads a vector from the source tensor -/// vector::TransposeOp - Transpose the Source tensor -/// ShapeCastOp - Reshape the data based on the target. -/// vector::TransferWriteOp. - Write the result vector back to the destination -/// tensor. -/// If the vector sizes are not provided: -/// * the vector sizes are determined by the input operand and attributes, -/// * update the inBounds attribute instead of masking. +/// Vectorize `linalg.unpack` as: +/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write +/// +/// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes +/// for the xfer_read operation). This is sufficient to infer the other vector +/// sizes required here. +/// +/// If the vector sizes are not provided: +/// * the vector sizes are determined by the operands, +/// * the inBounds attribute is used instead of masking. +/// +/// EXAMPLE (no vector sizes): +/// ``` +/// %unpack = linalg.unpack %src +/// inner_dims_pos = [0, 1] +/// inner_tiles = [8, 8] +/// into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32> +/// ``` +/// is vectorized as: +/// ``` +/// vector.transfer_write %sc into %dest : vector<8x8xf32>, tensor<8x8xf32> +/// ``` static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef<int64_t> inputVectorSizes, + ArrayRef<bool> inputScalableVecDims, SmallVectorImpl<Value> &newResults) { + if (!inputVectorSizes.empty()) { + assert(inputVectorSizes.size() == unpackOp.getSourceRank() && + "Invalid number of input vector sizes!"); + assert(inputVectorSizes.size() == inputScalableVecDims.size() && + "Incompatible number of vector sizes and vector scalable flags!"); + } // TODO: Introduce a parent class that will handle the insertion point update. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(unpackOp); RankedTensorType unpackTensorType = unpackOp.getSourceType(); - ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos(); - ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles(); ArrayRef<int64_t> sourceShape = unpackTensorType.getShape(); + ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape(); bool useInBoundsInsteadOfMasking = false; - ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm(); - - auto destSize = unpackOp.getDestRank(); - - if (!inputVectorSizes.empty()) - assert(inputVectorSizes.size() == destSize && - "Incorrect number of input vector sizes"); - - // vectorSizes is the shape of the vector that will be used to do final - // write on the destination tensor. It is set like this: Let's say the - // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M. - // Thus: - // 1. vectorSizes = sourceShape.take_front(N) - // 2. if outer_dims_perms is present: do that permutation on vectorSizes. - // 3. multiply all the locations in vectorSize pointed by innerDimPos by the - // innerTiles attribute value. - SmallVector<int64_t> vectorSizes(inputVectorSizes); - if (vectorSizes.empty()) { - llvm::append_range(vectorSizes, sourceShape.take_front(destSize)); - if (!outerDimsPerm.empty()) - applyPermutationToVector(vectorSizes, outerDimsPerm); - for (auto [i, pos] : llvm::enumerate(innerDimPos)) - vectorSizes[pos] *= innerTiles[i]; - useInBoundsInsteadOfMasking = true; - } + Location loc = unpackOp->getLoc(); - // readVectorSizes is the size of tensor used to read and apply mask. It is - // set like this: Let's say the vectorSize (VS) array is size 'N' and - // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of - // size M-N - // Thus: - // - initially: readVectorSizes = vectorInputSizes - // - Divide all the readMaskShape locations pointed by innerDimPos - // by the innerTileSize attribute value. - // - if outer_dims_perms is present: do that permutation on readVectorSizes. - // - Append the remaining shape from SS - // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16> - // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512, - // 128] and outer_dims_perm is [1, 0] then read shape is: - // ReadVectorSizes(initial): [512, 128] - // Final Value(after innerDim Adjustment): [512/32, 128/16] - // = [16, 8] - // After applying outer_dims_perm: [8, 16] - // After appending the rest of the sourceShape: [8, 16, 32, 16] - - SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end()); - - for (auto [index, size] : enumerate(innerTiles)) { - readVectorSizes[innerDimPos[index]] = - llvm::divideCeil(readVectorSizes[innerDimPos[index]], size); - } - if (!outerDimsPerm.empty()) { - applyPermutationToVector(readVectorSizes, outerDimsPerm); - } - readVectorSizes.append(sourceShape.begin() + vectorSizes.size(), - sourceShape.end()); + // 1. Obtain vector sizes for the read and write operations. + SmallVector<int64_t> readVectorSizes; + SmallVector<bool> readScalableVectorFlags; - Location loc = unpackOp->getLoc(); + if (!inputVectorSizes.empty()) { + // CASE 1.1: Vector sizes are user-specified. + readVectorSizes.assign(inputVectorSizes.begin(), + inputVectorSizes.begin() + sourceShape.size()); + readScalableVectorFlags.assign(inputScalableVecDims.begin(), + inputScalableVecDims.begin() + + sourceShape.size()); + } else { + // CASE 1.2: Vector sizes are inferred from the static input tensor + // shapes. + if (ShapedType::isDynamicShape(destShape) || + ShapedType::isDynamicShape(sourceShape)) + return failure(); + + readVectorSizes.assign(sourceShape.begin(), sourceShape.end()); + useInBoundsInsteadOfMasking = true; + } ---------------- hanhanW wrote:
I think now we can assign the whole content directly. ```cpp readVectorSizes.assign(inputVectorSizes.begin(), inputVectorSizes.begin()); readScalableVectorFlags.assign(inputScalableVecDims.begin(), inputScalableVecDims.begin()); ``` I'd do it this way now: ```cpp SmallVector<int64_t> readVectorSizes(inputVectorSizes); SmallVector<bool> readScalableVectorFlags(inputScalableVecDims); // Vector sizes are inferred from the static input tensor shapes. if (readVectorSizes.empty()) { // ... the inference } ``` https://github.com/llvm/llvm-project/pull/149293 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits