llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir-llvm Author: Matthias Springer (matthias-springer) <details> <summary>Changes</summary> The LLVM dialect no longer has its own vector types. It uses `mlir::VectorType` everywhere. Remove `LLVM::getFixedVectorType/getScalableVectorType` and use `VectorType::get` instead. This commit addresses a [comment](https://github.com/llvm/llvm-project/pull/133286#discussion_r2022192500) on the PR that deleted the LLVM vector types. Depends on #<!-- -->134981. --- Full diff: https://github.com/llvm/llvm-project/pull/135051.diff 7 Files Affected: - (modified) mlir/docs/Dialects/LLVM.md (-4) - (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h (-8) - (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+16-17) - (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (-12) - (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+14-9) - (modified) mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp (+10-14) - (modified) mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp (+5-4) ``````````diff diff --git a/mlir/docs/Dialects/LLVM.md b/mlir/docs/Dialects/LLVM.md index 468f69c419071..4b5d518ca4eab 100644 --- a/mlir/docs/Dialects/LLVM.md +++ b/mlir/docs/Dialects/LLVM.md @@ -336,10 +336,6 @@ compatible with the LLVM dialect: vector type compatible with the LLVM dialect; - `llvm::ElementCount LLVM::getVectorNumElements(Type)` - returns the number of elements in any vector type compatible with the LLVM dialect; -- `Type LLVM::getFixedVectorType(Type, unsigned)` - gets a fixed vector type - with the given element type and size; the resulting type is either a - built-in or an LLVM dialect vector type depending on which one supports the - given element type. #### Examples of Compatible Vector Types diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h index a2a76c49a2bda..17561f79d135a 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -126,14 +126,6 @@ Type getVectorType(Type elementType, unsigned numElements, /// and length. Type getVectorType(Type elementType, const llvm::ElementCount &numElements); -/// Creates an LLVM dialect-compatible type with the given element type and -/// length. -Type getFixedVectorType(Type elementType, unsigned numElements); - -/// Creates an LLVM dialect-compatible type with the given element type and -/// length. -Type getScalableVectorType(Type elementType, unsigned numElements); - /// Returns the size of the given primitive LLVM dialect-compatible type /// (including vectors) in bits, for example, the size of i16 is 16 and /// the size of vector<4xi16> is 64. Returns 0 for non-primitive diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 51507c6507b69..69fa62c8196e4 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -61,13 +61,13 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) { static Type inferIntrinsicResultType(Type vectorResultType) { MLIRContext *ctx = vectorResultType.getContext(); auto a = cast<LLVM::LLVMArrayType>(vectorResultType); - auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2); + auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx)); auto i32Ty = IntegerType::get(ctx, 32); - auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); + auto i32x2Ty = VectorType::get(2, i32Ty); Type f64Ty = Float64Type::get(ctx); - Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); + Type f64x2Ty = VectorType::get(2, f64Ty); Type f32Ty = Float32Type::get(ctx); - Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); + Type f32x2Ty = VectorType::get(2, f32Ty); if (a.getElementType() == f16x2Ty) { return LLVM::LLVMStructType::getLiteral( ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty)); @@ -85,7 +85,7 @@ static Type inferIntrinsicResultType(Type vectorResultType) { ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty)); } - if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) { + if (a.getElementType() == VectorType::get(1, f32Ty)) { return LLVM::LLVMStructType::getLiteral( ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty)); } @@ -106,11 +106,11 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type i32Ty = rewriter.getI32Type(); Type f32Ty = rewriter.getF32Type(); Type f64Ty = rewriter.getF64Type(); - Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2); - Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); - Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); - Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); - Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); + Type f16x2Ty = VectorType::get(2, rewriter.getF16Type()); + Type i32x2Ty = VectorType::get(2, i32Ty); + Type f64x2Ty = VectorType::get(2, f64Ty); + Type f32x2Ty = VectorType::get(2, f32Ty); + Type f32x1Ty = VectorType::get(1, f32Ty); auto makeConst = [&](int32_t index) -> Value { return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32), @@ -181,9 +181,9 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b, Type f64Ty = b.getF64Type(); Type f32Ty = b.getF32Type(); Type i64Ty = b.getI64Type(); - Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4); - Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8); - Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); + Type i8x4Ty = VectorType::get(4, b.getI8Type()); + Type i4x8Ty = VectorType::get(8, b.getIntegerType(4)); + Type f32x1Ty = VectorType::get(1, f32Ty); auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType()); for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { @@ -268,8 +268,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> { if (!vectorResultType) { return failure(); } - Type innerVectorType = LLVM::getFixedVectorType( - vectorResultType.getElementType(), vectorResultType.getDimSize(1)); + Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1), + vectorResultType.getElementType()); int64_t num32BitRegs = vectorResultType.getDimSize(0); @@ -627,8 +627,7 @@ struct NVGPUMmaSparseSyncLowering // Bitcast the sparse metadata from vector<2xf16> to an i32. Value sparseMetadata = adaptor.getSparseMetadata(); - if (sparseMetadata.getType() != - LLVM::getFixedVectorType(rewriter.getI16Type(), 2)) + if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type())) return op->emitOpError() << "Expected metadata type to be LLVM " "VectorType of 2 i16 elements"; sparseMetadata = diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index b3c2a29309528..29cf38c1fefea 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -851,18 +851,6 @@ Type mlir::LLVM::getVectorType(Type elementType, /*isScalable=*/false); } -Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) { - assert(VectorType::isValidElementType(elementType) && - "incompatible element type"); - return VectorType::get(numElements, elementType); -} - -Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) { - // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as - // scalable/non-scalable. - return VectorType::get(numElements, elementType, /*scalableDims=*/true); -} - llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) { assert(isCompatibleType(type) && "expected a type compatible with the LLVM dialect"); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 09bff6101edd3..b9d6952f67671 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -144,7 +144,7 @@ LogicalResult BulkStoreOp::verify() { std::optional<mlir::NVVM::MMATypes> MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) { auto half2Type = - LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2); + VectorType::get(2, Float16Type::get(operandElType.getContext())); if (operandElType.isF64()) return NVVM::MMATypes::f64; if (operandElType.isF16() || operandElType == half2Type) @@ -243,7 +243,8 @@ void MmaOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames); // Print the types of the operands and result. - p << " : " << "("; + p << " : " + << "("; llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(), frags[1].regs[0].getType(), frags[2].regs[0].getType()}, @@ -404,7 +405,7 @@ LogicalResult MmaOp::verify() { MLIRContext *context = getContext(); auto f16Ty = Float16Type::get(context); auto i32Ty = IntegerType::get(context, 32); - auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2); + auto f16x2Ty = VectorType::get(2, f16Ty); auto f32Ty = Float32Type::get(context); auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); @@ -506,7 +507,7 @@ LogicalResult MmaOp::verify() { expectedA.emplace_back(1, f64Ty); expectedB.emplace_back(1, f64Ty); expectedC.emplace_back(2, f64Ty); - // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2)); + // expectedC.emplace_back(1, VectorType::get(2, f64Ty)); expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral( context, SmallVector<Type>(2, f64Ty))); allowedShapes.push_back({8, 8, 4}); @@ -992,7 +993,9 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() { ss << "},"; // Need to map read/write registers correctly. regCnt = (regCnt * 2); - ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p"; + ss << " $" << (regCnt) << "," + << " $" << (regCnt + 1) << "," + << " p"; if (getTypeD() != WGMMATypes::s32) { ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4); } @@ -1219,7 +1222,7 @@ llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims, : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile) #define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \ - [&]() -> auto { \ + [&]() -> auto{ \ switch (dims) { \ case 1: \ return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \ @@ -1234,7 +1237,8 @@ llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims, default: \ llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \ } \ - }() + } \ + () llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID( int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) { @@ -1364,13 +1368,14 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op, : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1) #define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \ - [&]() -> auto { \ + [&]() -> auto{ \ if (src_fmt == Tcgen05CpSrcFormat::B6x16_P32) \ return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \ if (src_fmt == Tcgen05CpSrcFormat::B4x16_P64) \ return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \ return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \ - }() + } \ + () llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) { auto curOp = cast<NVVM::Tcgen05CpOp>(op); diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp index 39cca7d363e0d..e80360aa08ed5 100644 --- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp +++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp @@ -103,16 +103,15 @@ nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) { Type elType = type.vectorType.getElementType(); if (elType.isF16()) { - return FragmentElementInfo{ - LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32, - inferNumRegistersPerMatrixFragment(type)}; + return FragmentElementInfo{VectorType::get(2, Float16Type::get(ctx)), 2, 32, + inferNumRegistersPerMatrixFragment(type)}; } // f64 operand Type f64Ty = Float64Type::get(ctx); if (elType.isF64()) { return isAccum - ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128, + ? FragmentElementInfo{VectorType::get(2, f64Ty), 2, 128, inferNumRegistersPerMatrixFragment(type)} : FragmentElementInfo{f64Ty, 1, 64, inferNumRegistersPerMatrixFragment(type)}; @@ -120,30 +119,27 @@ nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) { // int8 operand if (elType.isInteger(8)) { - return FragmentElementInfo{ - LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32, - inferNumRegistersPerMatrixFragment(type)}; + return FragmentElementInfo{VectorType::get(4, IntegerType::get(ctx, 8)), 4, + 32, inferNumRegistersPerMatrixFragment(type)}; } // int4 operand if (elType.isInteger(4)) { - return FragmentElementInfo{ - LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32, - inferNumRegistersPerMatrixFragment(type)}; + return FragmentElementInfo{VectorType::get(8, IntegerType::get(ctx, 4)), 8, + 32, inferNumRegistersPerMatrixFragment(type)}; } // Integer 32bit acc operands if (elType.isInteger(32)) { - return FragmentElementInfo{ - LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64, - inferNumRegistersPerMatrixFragment(type)}; + return FragmentElementInfo{VectorType::get(2, IntegerType::get(ctx, 32)), 2, + 64, inferNumRegistersPerMatrixFragment(type)}; } // Floating point 32bit operands if (elType.isF32()) { Type f32Ty = Float32Type::get(ctx); return isAccum - ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64, + ? FragmentElementInfo{VectorType::get(2, f32Ty), 2, 64, inferNumRegistersPerMatrixFragment(type)} : FragmentElementInfo{f32Ty, 1, 32, inferNumRegistersPerMatrixFragment(type)}; diff --git a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp index bc9765fff2953..c46aa3e80d51a 100644 --- a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp +++ b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp @@ -124,14 +124,15 @@ class TypeFromLLVMIRTranslatorImpl { /// Translates the given fixed-vector type. Type translate(llvm::FixedVectorType *type) { - return LLVM::getFixedVectorType(translateType(type->getElementType()), - type->getNumElements()); + return VectorType::get(type->getNumElements(), + translateType(type->getElementType())); } /// Translates the given scalable-vector type. Type translate(llvm::ScalableVectorType *type) { - return LLVM::getScalableVectorType(translateType(type->getElementType()), - type->getMinNumElements()); + return VectorType::get(type->getMinNumElements(), + translateType(type->getElementType()), + /*scalable=*/true); } /// Translates the given target extension type. `````````` </details> https://github.com/llvm/llvm-project/pull/135051 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits