https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/166618
>From 186a5f9dd5545db6e3ccb228174e9f6edbce95d5 Mon Sep 17 00:00:00 2001 From: makslevental <[email protected]> Date: Wed, 5 Nov 2025 11:13:09 -0800 Subject: [PATCH 1/5] check float cast --- mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index 632e1a7f02602..99d181f6262cd 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -583,9 +583,11 @@ struct FancyAddFLowering : public ConvertOpToLLVMPattern<arith::AddFOp> { auto parent = op->getParentOfType<ModuleOp>(); if (!parent) return failure(); + auto floatTy = dyn_cast<FloatType>(op.getType()); + if (!floatTy) + return failure(); FailureOr<Operation *> adder = LLVM::lookupOrCreateApFloatAddFFn(rewriter, parent); - auto floatTy = cast<FloatType>(op.getType()); // Cast operands to 64-bit integers. Location loc = op.getLoc(); >From 45b3830b7cc440bc62b975d169837159201e0f3c Mon Sep 17 00:00:00 2001 From: makslevental <[email protected]> Date: Wed, 5 Nov 2025 13:26:59 -0800 Subject: [PATCH 2/5] fix creates --- mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index 99d181f6262cd..6fe4c22178c03 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -591,16 +591,16 @@ struct FancyAddFLowering : public ConvertOpToLLVMPattern<arith::AddFOp> { // Cast operands to 64-bit integers. Location loc = op.getLoc(); - Value lhsBits = rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI64Type(), - adaptor.getLhs()); - Value rhsBits = rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI64Type(), - adaptor.getRhs()); + Value lhsBits = LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), + adaptor.getLhs()); + Value rhsBits = LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), + adaptor.getRhs()); // Call software implementation of floating point addition. int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); - Value semValue = rewriter.create<LLVM::ConstantOp>( - loc, rewriter.getI32Type(), + Value semValue = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI32Type(), rewriter.getIntegerAttr(rewriter.getI32Type(), sem)); SmallVector<Value> params = {semValue, lhsBits, rhsBits}; auto resultOp = @@ -608,8 +608,8 @@ struct FancyAddFLowering : public ConvertOpToLLVMPattern<arith::AddFOp> { SymbolRefAttr::get(*adder), params); // Truncate result to the original width. - Value truncatedBits = rewriter.create<LLVM::TruncOp>( - loc, rewriter.getIntegerType(floatTy.getWidth()), + Value truncatedBits = LLVM::TruncOp::create( + rewriter, loc, rewriter.getIntegerType(floatTy.getWidth()), resultOp->getResult(0)); rewriter.replaceOp(op, truncatedBits); return success(); >From dfad041fe5146c7c32ef043750a2247564bd7d29 Mon Sep 17 00:00:00 2001 From: makslevental <[email protected]> Date: Wed, 5 Nov 2025 13:40:18 -0800 Subject: [PATCH 3/5] check fp8 types --- mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index 6fe4c22178c03..370e048bb3af5 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -583,9 +583,13 @@ struct FancyAddFLowering : public ConvertOpToLLVMPattern<arith::AddFOp> { auto parent = op->getParentOfType<ModuleOp>(); if (!parent) return failure(); - auto floatTy = dyn_cast<FloatType>(op.getType()); - if (!floatTy) + if (!llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, + Float8E5M2FNUZType, Float8E4M3FNUZType, + Float8E4M3B11FNUZType, Float8E3M4Type, Float4E2M1FNType, + Float6E2M3FNType, Float6E3M2FNType, Float8E8M0FNUType>( + op.getType())) return failure(); + auto floatTy = cast<FloatType>(op.getType()); FailureOr<Operation *> adder = LLVM::lookupOrCreateApFloatAddFFn(rewriter, parent); @@ -630,7 +634,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns( // clang-format off patterns.add< - //AddFOpLowering, + AddFOpLowering, FancyAddFLowering, AddIOpLowering, AndIOpLowering, >From 3680b1119cad1577f01bc952e74e675f5af46488 Mon Sep 17 00:00:00 2001 From: makslevental <[email protected]> Date: Wed, 5 Nov 2025 19:25:10 -0800 Subject: [PATCH 4/5] add X-macros --- .../mlir/Dialect/LLVMIR/FunctionCallUtils.h | 20 ++++- .../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 33 +++++--- mlir/lib/ExecutionEngine/APFloatWrappers.cpp | 75 ++++++++++++++++--- 3 files changed, 104 insertions(+), 24 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h index 8564d0f4205cf..01f7f75c210ef 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -55,9 +55,23 @@ lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, FailureOr<LLVM::LLVMFuncOp> lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables = nullptr); -FailureOr<LLVM::LLVMFuncOp> -lookupOrCreateApFloatAddFFn(OpBuilder &b, Operation *moduleOp, - SymbolTableCollection *symbolTables = nullptr); + +#define APFLOAT_BIN_OPS(X) \ + X(add) \ + X(subtract) \ + X(multiply) \ + X(divide) \ + X(remainder) \ + X(mod) + +#define LOOKUP_OR_CREATE_APFLOAT_FN_DECL(OP) \ + FailureOr<LLVM::LLVMFuncOp> lookupOrCreateApFloat##OP##Fn( \ + OpBuilder &b, Operation *moduleOp, \ + SymbolTableCollection *symbolTables = nullptr); + +APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DECL) + +#undef LOOKUP_OR_CREATE_APFLOAT_FN_DECL /// Declares a function to print a C-string. /// If a custom runtime function is defined via `runtimeFunctionName`, it must diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index 8ee039be60568..cb6ee76f8cbfb 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -31,7 +31,14 @@ static constexpr llvm::StringRef kPrintBF16 = "printBF16"; static constexpr llvm::StringRef kPrintF32 = "printF32"; static constexpr llvm::StringRef kPrintF64 = "printF64"; static constexpr llvm::StringRef kPrintApFloat = "printApFloat"; -static constexpr llvm::StringRef kApFloatAddF = "APFloat_add"; + +#define APFLOAT_EXTERN_K(OP) kApFloat_##OP + +#define APFLOAT_EXTERN_NAME(OP) \ + static constexpr llvm::StringRef APFLOAT_EXTERN_K(OP) = "APFloat_" #OP; + +APFLOAT_BIN_OPS(APFLOAT_EXTERN_NAME) + static constexpr llvm::StringRef kPrintString = "printString"; static constexpr llvm::StringRef kPrintOpen = "printOpen"; static constexpr llvm::StringRef kPrintClose = "printClose"; @@ -172,16 +179,20 @@ mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } -FailureOr<LLVM::LLVMFuncOp> -mlir::LLVM::lookupOrCreateApFloatAddFFn(OpBuilder &b, Operation *moduleOp, - SymbolTableCollection *symbolTables) { - return lookupOrCreateReservedFn( - b, moduleOp, kApFloatAddF, - {IntegerType::get(moduleOp->getContext(), 32), - IntegerType::get(moduleOp->getContext(), 64), - IntegerType::get(moduleOp->getContext(), 64)}, - IntegerType::get(moduleOp->getContext(), 64), symbolTables); -} +#define LOOKUP_OR_CREATE_APFLOAT_FN_DEFN(OP) \ + FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateApFloat##OP##Fn( \ + OpBuilder &b, Operation *moduleOp, \ + SymbolTableCollection *symbolTables) { \ + return lookupOrCreateReservedFn( \ + b, moduleOp, APFLOAT_EXTERN_K(OP), \ + {IntegerType::get(moduleOp->getContext(), 32), \ + IntegerType::get(moduleOp->getContext(), 64), \ + IntegerType::get(moduleOp->getContext(), 64)}, \ + IntegerType::get(moduleOp->getContext(), 64), symbolTables); \ + } + +APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DEFN) +#undef LOOKUP_OR_CREATE_APFLOAT_FN_DEFN static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) { return LLVM::LLVMPointerType::get(context); diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp index 7879c75803355..8d2848bd7cf77 100644 --- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp +++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp @@ -7,26 +7,81 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/APFloat.h" +#include "llvm/Support/Debug.h" + #include <iostream> +#define DEBUG_TYPE "mlir-apfloat-wrapper" + #if (defined(_WIN32) || defined(__CYGWIN__)) #define MLIR_APFLOAT_WRAPPERS_EXPORTED __declspec(dllexport) #else #define MLIR_APFLOAT_WRAPPERS_EXPORTED __attribute__((visibility("default"))) #endif +static std::string_view +apFloatOpStatusToStr(llvm::APFloatBase::opStatus status) { + switch (status) { + case llvm::APFloatBase::opOK: + return "OK"; + case llvm::APFloatBase::opInvalidOp: + return "InvalidOp"; + case llvm::APFloatBase::opDivByZero: + return "DivByZero"; + case llvm::APFloatBase::opOverflow: + return "Overflow"; + case llvm::APFloatBase::opUnderflow: + return "Underflow"; + case llvm::APFloatBase::opInexact: + return "Inexact"; + } + llvm::report_fatal_error("unhandled llvm::APFloatBase::opStatus variant"); +} + +#define APFLOAT_BINARY_OP(OP) \ + int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_##OP( \ + int32_t semantics, uint64_t a, uint64_t b) { \ + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \ + static_cast<llvm::APFloatBase::Semantics>(semantics)); \ + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \ + llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \ + llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \ + llvm::APFloatBase::opStatus status = lhs.OP(rhs); \ + assert(status == llvm::APFloatBase::opOK && "expected " #OP \ + " opstatus to be OK"); \ + return lhs.bitcastToAPInt().getZExtValue(); \ + } + +#define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE) \ + int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_##OP( \ + int32_t semantics, uint64_t a, uint64_t b) { \ + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \ + static_cast<llvm::APFloatBase::Semantics>(semantics)); \ + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \ + llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \ + llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \ + llvm::APFloatBase::opStatus status = lhs.OP(rhs, ROUNDING_MODE); \ + assert(status == llvm::APFloatBase::opOK && "expected " #OP \ + " opstatus to be OK"); \ + return lhs.bitcastToAPInt().getZExtValue(); \ + } + extern "C" { -int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_add(int32_t semantics, - uint64_t a, uint64_t b) { - const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( - static_cast<llvm::APFloatBase::Semantics>(semantics)); - unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); - llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); - llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); - auto status = lhs.add(rhs, llvm::RoundingMode::NearestTiesToEven); - return lhs.bitcastToAPInt().getZExtValue(); -} +#define BIN_OPS_WITH_ROUNDING(X) \ + X(add, llvm::RoundingMode::NearestTiesToEven) \ + X(subtract, llvm::RoundingMode::NearestTiesToEven) \ + X(multiply, llvm::RoundingMode::NearestTiesToEven) \ + X(divide, llvm::RoundingMode::NearestTiesToEven) + +BIN_OPS_WITH_ROUNDING(APFLOAT_BINARY_OP_ROUNDING_MODE) +#undef BIN_OPS_WITH_ROUNDING +#undef APFLOAT_BINARY_OP_ROUNDING_MODE + +APFLOAT_BINARY_OP(remainder) +APFLOAT_BINARY_OP(mod) + +#undef APFLOAT_BINARY_OP void MLIR_APFLOAT_WRAPPERS_EXPORTED printApFloat(int32_t semantics, uint64_t a) { >From c42f471126db73c9063818f7cacb67377813ea35 Mon Sep 17 00:00:00 2001 From: makslevental <[email protected]> Date: Thu, 6 Nov 2025 15:14:09 -0800 Subject: [PATCH 5/5] add arith-to-apfloat --- .../ArithToAPFloat/ArithToAPFloat.h | 28 ++++ mlir/include/mlir/Conversion/Passes.h | 1 + mlir/include/mlir/Conversion/Passes.td | 13 ++ mlir/include/mlir/Dialect/Func/Utils/Utils.h | 8 ++ .../mlir/Dialect/LLVMIR/FunctionCallUtils.h | 17 --- .../ArithToAPFloat/ArithToAPFloat.cpp | 136 ++++++++++++++++++ .../Conversion/ArithToAPFloat/CMakeLists.txt | 17 +++ .../Conversion/ArithToLLVM/ArithToLLVM.cpp | 48 ------- mlir/lib/Conversion/CMakeLists.txt | 1 + mlir/lib/Dialect/Func/Utils/Utils.cpp | 42 ++++++ .../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 23 --- mlir/lib/ExecutionEngine/APFloatWrappers.cpp | 22 --- .../Arith/CPU/test-apfloat-emulation.mlir | 21 ++- 13 files changed, 266 insertions(+), 111 deletions(-) create mode 100644 mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h create mode 100644 mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp create mode 100644 mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt diff --git a/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h new file mode 100644 index 0000000000000..a5df4647f1acc --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h @@ -0,0 +1,28 @@ +//===- ArithToAPFloat.h - Arith to APFloat impl conversion ---*- C++ ----*-===// +// +// Part of the APFloat Project, under the Apache License v2.0 with APFloat +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H +#define MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H + +#include <memory> + +namespace mlir { + +class DialectRegistry; +class RewritePatternSet; +class Pass; + +#define GEN_PASS_DECL_ARITHTOAPFLOATCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" + +namespace arith { +void populateArithToAPFloatConversionPatterns(RewritePatternSet &patterns); +} // namespace arith +} // namespace mlir + +#endif // MLIR_CONVERSION_ARITHTOAPFloat_ARITHTOAPFloat_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 40d866ec7bf10..82bdfd02661a6 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -12,6 +12,7 @@ #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" +#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h" #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h" #include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 70e3e45c225db..2bcd2870949f3 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -186,6 +186,19 @@ def ArithToLLVMConversionPass : Pass<"convert-arith-to-llvm"> { ]; } +//===----------------------------------------------------------------------===// +// ArithToAPFloat +//===----------------------------------------------------------------------===// + +def ArithToAPFloatConversionPass : Pass<"convert-arith-to-apfloat"> { + let summary = "Convert Arith dialect ops on FP8 types to APFloat lib calls"; + let description = [{ + This pass converts supported Arith ops which manipulate FP8 typed values to APFloat lib calls. + }]; + let dependentDialects = ["func::FuncDialect"]; + let options = []; +} + //===----------------------------------------------------------------------===// // ArithToSPIRV //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h index 3576126a487ac..9c9973cf84368 100644 --- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h @@ -60,6 +60,14 @@ mlir::FailureOr<std::pair<mlir::func::FuncOp, mlir::func::CallOp>> deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, mlir::ModuleOp moduleOp); +/// Create a FuncOp with signature `resultTypes`(`paramTypes`)` and name `name`. +/// Return a failure if the FuncOp found has unexpected signature. +FailureOr<FuncOp> +lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, + ArrayRef<Type> paramTypes = {}, + ArrayRef<Type> resultTypes = {}, bool setPrivate = false, + SymbolTableCollection *symbolTables = nullptr); + } // namespace func } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h index 01f7f75c210ef..b09d32022e348 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -56,23 +56,6 @@ FailureOr<LLVM::LLVMFuncOp> lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables = nullptr); -#define APFLOAT_BIN_OPS(X) \ - X(add) \ - X(subtract) \ - X(multiply) \ - X(divide) \ - X(remainder) \ - X(mod) - -#define LOOKUP_OR_CREATE_APFLOAT_FN_DECL(OP) \ - FailureOr<LLVM::LLVMFuncOp> lookupOrCreateApFloat##OP##Fn( \ - OpBuilder &b, Operation *moduleOp, \ - SymbolTableCollection *symbolTables = nullptr); - -APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DECL) - -#undef LOOKUP_OR_CREATE_APFLOAT_FN_DECL - /// Declares a function to print a C-string. /// If a custom runtime function is defined via `runtimeFunctionName`, it must /// have the signature void(char const*). The default function is `printString`. diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp new file mode 100644 index 0000000000000..bc451e88eb3bd --- /dev/null +++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp @@ -0,0 +1,136 @@ +//===- ArithToAPFloat.cpp - Arithmetic to APFloat impl conversion ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Utils/Utils.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::func; + +#define APFLOAT_BIN_OPS(X) \ + X(add) \ + X(subtract) \ + X(multiply) \ + X(divide) \ + X(remainder) \ + X(mod) + +#define APFLOAT_EXTERN_K(OP) kApFloat_##OP + +#define APFLOAT_EXTERN_NAME(OP) \ + static constexpr llvm::StringRef APFLOAT_EXTERN_K(OP) = "_mlir_" \ + "apfloat_" #OP; + +namespace mlir::func { +#define LOOKUP_OR_CREATE_APFLOAT_FN_DECL(OP) \ + FailureOr<FuncOp> lookupOrCreateApFloat##OP##Fn( \ + OpBuilder &b, Operation *moduleOp, \ + SymbolTableCollection *symbolTables = nullptr); + +APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DECL) + +#undef LOOKUP_OR_CREATE_APFLOAT_FN_DECL + +APFLOAT_BIN_OPS(APFLOAT_EXTERN_NAME) + +#define LOOKUP_OR_CREATE_APFLOAT_FN_DEFN(OP) \ + FailureOr<FuncOp> lookupOrCreateApFloat##OP##Fn( \ + OpBuilder &b, Operation *moduleOp, \ + SymbolTableCollection *symbolTables) { \ + return lookupOrCreateFn(b, moduleOp, APFLOAT_EXTERN_K(OP), \ + {IntegerType::get(moduleOp->getContext(), 32), \ + IntegerType::get(moduleOp->getContext(), 64), \ + IntegerType::get(moduleOp->getContext(), 64)}, \ + {IntegerType::get(moduleOp->getContext(), 64)}, \ + /*setPrivate*/ true, symbolTables); \ + } + +APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DEFN) +#undef LOOKUP_OR_CREATE_APFLOAT_FN_DEFN +} // namespace mlir::func + +struct FancyAddFLowering : OpRewritePattern<arith::AddFOp> { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::AddFOp op, + PatternRewriter &rewriter) const override { + // Get APFloat adder function from runtime library. + auto parent = op->getParentOfType<ModuleOp>(); + if (!parent) + return failure(); + if (!llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, + Float8E5M2FNUZType, Float8E4M3FNUZType, + Float8E4M3B11FNUZType, Float8E3M4Type, Float4E2M1FNType, + Float6E2M3FNType, Float6E3M2FNType, Float8E8M0FNUType>( + op.getType())) + return failure(); + FailureOr<Operation *> adder = lookupOrCreateApFloataddFn(rewriter, parent); + + // Cast operands to 64-bit integers. + Location loc = op.getLoc(); + auto floatTy = cast<FloatType>(op.getType()); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + auto int64Type = rewriter.getI64Type(); + Value lhsBits = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs())); + Value rhsBits = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs())); + + // Call software implementation of floating point addition. + int32_t sem = + llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); + Value semValue = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32Type(), + rewriter.getIntegerAttr(rewriter.getI32Type(), sem)); + SmallVector<Value> params = {semValue, lhsBits, rhsBits}; + auto resultOp = + func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*adder), params); + + // Truncate result to the original width. + Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType, + resultOp->getResult(0)); + rewriter.replaceAllUsesWith( + op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits)); + return success(); + } +}; + +void arith::populateArithToAPFloatConversionPatterns( + RewritePatternSet &patterns) { + patterns.add<FancyAddFLowering>(patterns.getContext()); +} + +namespace { +struct ArithToAPFloatConversionPass final + : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> { + using impl::ArithToAPFloatConversionPassBase< + ArithToAPFloatConversionPass>::ArithToAPFloatConversionPassBase; + + void runOnOperation() override { + Operation *op = getOperation(); + RewritePatternSet patterns(op->getContext()); + arith::populateArithToAPFloatConversionPatterns(patterns); + if (failed(applyPatternsGreedily(op, std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace diff --git a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt new file mode 100644 index 0000000000000..b0d1e46b3655f --- /dev/null +++ b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_conversion_library(MLIRArithToAPFloat + ArithToAPFloat.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRArithTransforms + MLIRFuncDialect + ) diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index 370e048bb3af5..5972c8cc9987e 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -573,53 +573,6 @@ void mlir::arith::registerConvertArithToLLVMInterface( }); } -struct FancyAddFLowering : public ConvertOpToLLVMPattern<arith::AddFOp> { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(arith::AddFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Get APFloat adder function from runtime library. - auto parent = op->getParentOfType<ModuleOp>(); - if (!parent) - return failure(); - if (!llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, - Float8E5M2FNUZType, Float8E4M3FNUZType, - Float8E4M3B11FNUZType, Float8E3M4Type, Float4E2M1FNType, - Float6E2M3FNType, Float6E3M2FNType, Float8E8M0FNUType>( - op.getType())) - return failure(); - auto floatTy = cast<FloatType>(op.getType()); - FailureOr<Operation *> adder = - LLVM::lookupOrCreateApFloatAddFFn(rewriter, parent); - - // Cast operands to 64-bit integers. - Location loc = op.getLoc(); - Value lhsBits = LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), - adaptor.getLhs()); - Value rhsBits = LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), - adaptor.getRhs()); - - // Call software implementation of floating point addition. - int32_t sem = - llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); - Value semValue = LLVM::ConstantOp::create( - rewriter, loc, rewriter.getI32Type(), - rewriter.getIntegerAttr(rewriter.getI32Type(), sem)); - SmallVector<Value> params = {semValue, lhsBits, rhsBits}; - auto resultOp = - LLVM::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), - SymbolRefAttr::get(*adder), params); - - // Truncate result to the original width. - Value truncatedBits = LLVM::TruncOp::create( - rewriter, loc, rewriter.getIntegerType(floatTy.getWidth()), - resultOp->getResult(0)); - rewriter.replaceOp(op, truncatedBits); - return success(); - } -}; - //===----------------------------------------------------------------------===// // Pattern Population //===----------------------------------------------------------------------===// @@ -635,7 +588,6 @@ void mlir::arith::populateArithToLLVMConversionPatterns( // clang-format off patterns.add< AddFOpLowering, - FancyAddFLowering, AddIOpLowering, AndIOpLowering, AddUIExtendedOpLowering, diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index bebf1b8fff3f9..613dc6d242ceb 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard) add_subdirectory(AMDGPUToROCDL) add_subdirectory(ArithCommon) add_subdirectory(ArithToAMDGPU) +add_subdirectory(ArithToAPFloat) add_subdirectory(ArithToArmSME) add_subdirectory(ArithToEmitC) add_subdirectory(ArithToLLVM) diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp index b4cb0932ef631..e187e62cf6555 100644 --- a/mlir/lib/Dialect/Func/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp @@ -254,3 +254,45 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp, return std::make_pair(*newFuncOpOrFailure, newCallOp); } + +FailureOr<func::FuncOp> +func::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, + ArrayRef<Type> paramTypes, ArrayRef<Type> resultTypes, + bool setPrivate, SymbolTableCollection *symbolTables) { + assert(moduleOp->hasTrait<OpTrait::SymbolTable>() && + "expected SymbolTable operation"); + + FuncOp func; + if (symbolTables) { + func = symbolTables->lookupSymbolIn<FuncOp>( + moduleOp, StringAttr::get(moduleOp->getContext(), name)); + } else { + func = llvm::dyn_cast_or_null<FuncOp>( + SymbolTable::lookupSymbolIn(moduleOp, name)); + } + + FunctionType funcT = + FunctionType::get(b.getContext(), paramTypes, resultTypes); + // Assert the signature of the found function is same as expected + if (func) { + if (funcT != func.getFunctionType()) { + func.emitError("redefinition of function '") + << name << "' of different type " << funcT << " is prohibited"; + return failure(); + } + return func; + } + + OpBuilder::InsertionGuard g(b); + assert(!moduleOp->getRegion(0).empty() && "expected non-empty region"); + b.setInsertionPointToStart(&moduleOp->getRegion(0).front()); + FuncOp funcOp = FuncOp::create(b, moduleOp->getLoc(), name, funcT); + if (setPrivate) + funcOp.setPrivate(); + if (symbolTables) { + SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp); + symbolTable.insert(funcOp, moduleOp->getRegion(0).front().begin()); + } + + return funcOp; +} diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index cb6ee76f8cbfb..160b6ae89215c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -31,14 +31,6 @@ static constexpr llvm::StringRef kPrintBF16 = "printBF16"; static constexpr llvm::StringRef kPrintF32 = "printF32"; static constexpr llvm::StringRef kPrintF64 = "printF64"; static constexpr llvm::StringRef kPrintApFloat = "printApFloat"; - -#define APFLOAT_EXTERN_K(OP) kApFloat_##OP - -#define APFLOAT_EXTERN_NAME(OP) \ - static constexpr llvm::StringRef APFLOAT_EXTERN_K(OP) = "APFloat_" #OP; - -APFLOAT_BIN_OPS(APFLOAT_EXTERN_NAME) - static constexpr llvm::StringRef kPrintString = "printString"; static constexpr llvm::StringRef kPrintOpen = "printOpen"; static constexpr llvm::StringRef kPrintClose = "printClose"; @@ -179,21 +171,6 @@ mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } -#define LOOKUP_OR_CREATE_APFLOAT_FN_DEFN(OP) \ - FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateApFloat##OP##Fn( \ - OpBuilder &b, Operation *moduleOp, \ - SymbolTableCollection *symbolTables) { \ - return lookupOrCreateReservedFn( \ - b, moduleOp, APFLOAT_EXTERN_K(OP), \ - {IntegerType::get(moduleOp->getContext(), 32), \ - IntegerType::get(moduleOp->getContext(), 64), \ - IntegerType::get(moduleOp->getContext(), 64)}, \ - IntegerType::get(moduleOp->getContext(), 64), symbolTables); \ - } - -APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DEFN) -#undef LOOKUP_OR_CREATE_APFLOAT_FN_DEFN - static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) { return LLVM::LLVMPointerType::get(context); } diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp index 8d2848bd7cf77..a5049436d03c1 100644 --- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp +++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp @@ -7,37 +7,15 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/APFloat.h" -#include "llvm/Support/Debug.h" #include <iostream> -#define DEBUG_TYPE "mlir-apfloat-wrapper" - #if (defined(_WIN32) || defined(__CYGWIN__)) #define MLIR_APFLOAT_WRAPPERS_EXPORTED __declspec(dllexport) #else #define MLIR_APFLOAT_WRAPPERS_EXPORTED __attribute__((visibility("default"))) #endif -static std::string_view -apFloatOpStatusToStr(llvm::APFloatBase::opStatus status) { - switch (status) { - case llvm::APFloatBase::opOK: - return "OK"; - case llvm::APFloatBase::opInvalidOp: - return "InvalidOp"; - case llvm::APFloatBase::opDivByZero: - return "DivByZero"; - case llvm::APFloatBase::opOverflow: - return "Overflow"; - case llvm::APFloatBase::opUnderflow: - return "Underflow"; - case llvm::APFloatBase::opInexact: - return "Inexact"; - } - llvm::report_fatal_error("unhandled llvm::APFloatBase::opStatus variant"); -} - #define APFLOAT_BINARY_OP(OP) \ int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_##OP( \ int32_t semantics, uint64_t a, uint64_t b) { \ diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir index 5cd83688d1710..d4c2394474b15 100644 --- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir @@ -1,7 +1,7 @@ // Check that the ceildivsi lowering is correct. // We do not check any poison or UB values, as it is not possible to catch them. -// RUN: mlir-opt %s --convert-to-llvm +// RUN: mlir-opt %s --convert-arith-to-apfloat // Put rhs into separate function so that it won't be constant-folded. func.func @foo() -> f4E2M1FN { @@ -17,3 +17,22 @@ func.func @entry() { return } +// CHECK-LABEL: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64 + +// CHECK-LABEL: func.func @foo() -> f4E2M1FN { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 4.000000e+00 : f4E2M1FN +// CHECK: return %[[CONSTANT_0]] : f4E2M1FN +// CHECK: } + +// CHECK-LABEL: func.func @entry() { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 18 : i32 +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 6 : i64 +// CHECK: %[[VAL_0:.*]] = call @foo() : () -> f4E2M1FN +// CHECK: %[[BITCAST_0:.*]] = arith.bitcast %[[VAL_0]] : f4E2M1FN to i4 +// CHECK: %[[EXTUI_0:.*]] = arith.extui %[[BITCAST_0]] : i4 to i64 +// CHECK: %[[VAL_1:.*]] = call @_mlir_apfloat_add(%[[CONSTANT_0]], %[[EXTUI_0]], %[[CONSTANT_1]]) : (i32, i64, i64) -> i64 +// CHECK: %[[TRUNCI_0:.*]] = arith.trunci %[[VAL_1]] : i64 to i4 +// CHECK: %[[BITCAST_1:.*]] = arith.bitcast %[[TRUNCI_0]] : i4 to f4E2M1FN +// CHECK: vector.print %[[BITCAST_1]] : f4E2M1FN +// CHECK: return +// CHECK: } \ No newline at end of file _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
