================ @@ -0,0 +1,304 @@ +//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements lowering patterns from vector.contract to +// SVE I8MM operations. +// +//===--- + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "mlir/Dialect/UB/IR/UBOps.h" + +#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm" + +using namespace mlir; +using namespace mlir::arm_sve; + +namespace { +// Check if the given value is a result of the operation `T` (which must be +// sign- or zero- extend) from i8 to i32. Return the value before the extension. +template <typename T> +inline std::enable_if_t<(std::is_base_of_v<arith::ExtSIOp, T> || + std::is_base_of_v<arith::ExtUIOp, T>), + std::optional<Value>> +extractExtOperand(Value v, Type i8Ty, Type i32Ty) { + auto extOp = dyn_cast_or_null<T>(v.getDefiningOp()); + if (!extOp) + return {}; + + auto inOp = extOp.getIn(); + auto inTy = dyn_cast<VectorType>(inOp.getType()); + if (!inTy || inTy.getElementType() != i8Ty) + return {}; + + auto outTy = dyn_cast<VectorType>(extOp.getType()); + if (!outTy || outTy.getElementType() != i32Ty) + return {}; + + return inOp; +} + +// Designate the operation (resp. instruction) used to do sub-tile matrix +// multiplications. +enum class MMLA { + Signed, // smmla + Unsigned, // ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped +}; + +// Create the matrix multply and accumulate operation according to `op`. +Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, + mlir::VectorType accType, Value acc, Value lhs, Value rhs) { + switch (op) { + case MMLA::Signed: + return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs); + case MMLA::Unsigned: + return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs); + case MMLA::Mixed: + return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs); + case MMLA::MixedSwapped: + // The accumulator comes transposed and the result will be transposed + // later, so all we have to do here is swap the operands. + return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs); + } +} + +class LowerContractionToSVEI8MMPattern + : public OpRewritePattern<vector::ContractionOp> { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + mlir::VectorType lhsType = op.getLhsType(); + mlir::VectorType rhsType = op.getRhsType(); + + // For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we + // eventually expect from MMT4D. M and N dimensions must be even and at + // least 2. + if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() || + rhsType.getRank() != 2) + return failure(); + + if (lhsType.isScalable() || !rhsType.isScalable()) + return failure(); + + // M, N, and K are the conventional names for matrix dimensions in the + // context of matrix multiplication. + auto M = lhsType.getDimSize(0); + auto N = rhsType.getDimSize(0); + auto K = rhsType.getDimSize(1); + + if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 || + N % 2 != 0 || !rhsType.getScalableDims()[0]) + return failure(); + + // Check permutation maps. For now only accept + // lhs: (d0, d1, d2) -> (d0, d2) + // rhs: (d0, d1, d2) -> (d1, d2) + // acc: (d0, d1, d2) -> (d0, d1) + // Note: RHS is transposed. + if (op.getIndexingMapsArray()[0] != + AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u}, + op.getContext()) || + op.getIndexingMapsArray()[1] != + AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u}, + op.getContext()) || + op.getIndexingMapsArray()[2] != + AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u}, + op.getContext())) + return failure(); + + // Check iterator types for matrix multiplication. + auto itTypes = op.getIteratorTypesArray(); + if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || + itTypes[1] != vector::IteratorType::parallel || + itTypes[2] != vector::IteratorType::reduction) + return failure(); + + // Check the combining kind is addition. + if (op.getKind() != vector::CombiningKind::ADD) + return failure(); + + // Check the output is a vector of i32 elements. + auto outTy = dyn_cast<VectorType>(op.getType()); + if (!outTy || outTy.getElementType() != rewriter.getI32Type()) + return failure(); + + // Check inputs are sign-/zero- extensions from i8 to i32. Get the values + // before the extension. All four signed/unsigned combinations for input + // operands are supported, but they are lowered to different operations. + // Determina which is the appropriate operation to lower to. + MMLA mmlaOp = MMLA::Signed; + auto maybeLhs = extractExtOperand<arith::ExtSIOp>( + op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type()); + if (!maybeLhs) { + mmlaOp = MMLA::Unsigned; + maybeLhs = extractExtOperand<arith::ExtUIOp>( + op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type()); + } + if (!maybeLhs) + return failure(); + + auto maybeRhs = extractExtOperand<arith::ExtSIOp>( + op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type()); + if (maybeRhs) { + if (mmlaOp == MMLA::Unsigned) + mmlaOp = MMLA::Mixed; + } else { + if (mmlaOp == MMLA::Signed) + mmlaOp = MMLA::MixedSwapped; + maybeRhs = extractExtOperand<arith::ExtUIOp>( + op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type()); + } + if (!maybeRhs) + return failure(); + + // One-dimensional vector types for arm_sve.*mmla + auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true}); + auto nxv4i32 = VectorType::get(4, rewriter.getI32Type(), {true}); + + // Extract LHS sub-tiles. ---------------- banach-space wrote:
[nit] Could you specify the dims? That would be helpful. https://github.com/llvm/llvm-project/pull/135636 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits