https://github.com/makslevental updated 
https://github.com/llvm/llvm-project/pull/166618

>From 186a5f9dd5545db6e3ccb228174e9f6edbce95d5 Mon Sep 17 00:00:00 2001
From: makslevental <[email protected]>
Date: Wed, 5 Nov 2025 11:13:09 -0800
Subject: [PATCH 1/5] check float cast

---
 mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp 
b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 632e1a7f02602..99d181f6262cd 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -583,9 +583,11 @@ struct FancyAddFLowering : public 
ConvertOpToLLVMPattern<arith::AddFOp> {
     auto parent = op->getParentOfType<ModuleOp>();
     if (!parent)
       return failure();
+    auto floatTy = dyn_cast<FloatType>(op.getType());
+    if (!floatTy)
+      return failure();
     FailureOr<Operation *> adder =
         LLVM::lookupOrCreateApFloatAddFFn(rewriter, parent);
-    auto floatTy = cast<FloatType>(op.getType());
 
     // Cast operands to 64-bit integers.
     Location loc = op.getLoc();

>From 45b3830b7cc440bc62b975d169837159201e0f3c Mon Sep 17 00:00:00 2001
From: makslevental <[email protected]>
Date: Wed, 5 Nov 2025 13:26:59 -0800
Subject: [PATCH 2/5] fix creates

---
 mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp 
b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 99d181f6262cd..6fe4c22178c03 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -591,16 +591,16 @@ struct FancyAddFLowering : public 
ConvertOpToLLVMPattern<arith::AddFOp> {
 
     // Cast operands to 64-bit integers.
     Location loc = op.getLoc();
-    Value lhsBits = rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI64Type(),
-                                                  adaptor.getLhs());
-    Value rhsBits = rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI64Type(),
-                                                  adaptor.getRhs());
+    Value lhsBits = LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(),
+                                         adaptor.getLhs());
+    Value rhsBits = LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(),
+                                         adaptor.getRhs());
 
     // Call software implementation of floating point addition.
     int32_t sem =
         llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
-    Value semValue = rewriter.create<LLVM::ConstantOp>(
-        loc, rewriter.getI32Type(),
+    Value semValue = LLVM::ConstantOp::create(
+        rewriter, loc, rewriter.getI32Type(),
         rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
     SmallVector<Value> params = {semValue, lhsBits, rhsBits};
     auto resultOp =
@@ -608,8 +608,8 @@ struct FancyAddFLowering : public 
ConvertOpToLLVMPattern<arith::AddFOp> {
                              SymbolRefAttr::get(*adder), params);
 
     // Truncate result to the original width.
-    Value truncatedBits = rewriter.create<LLVM::TruncOp>(
-        loc, rewriter.getIntegerType(floatTy.getWidth()),
+    Value truncatedBits = LLVM::TruncOp::create(
+        rewriter, loc, rewriter.getIntegerType(floatTy.getWidth()),
         resultOp->getResult(0));
     rewriter.replaceOp(op, truncatedBits);
     return success();

>From dfad041fe5146c7c32ef043750a2247564bd7d29 Mon Sep 17 00:00:00 2001
From: makslevental <[email protected]>
Date: Wed, 5 Nov 2025 13:40:18 -0800
Subject: [PATCH 3/5] check fp8 types

---
 mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 10 +++++++---
 1 file changed, 7 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp 
b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 6fe4c22178c03..370e048bb3af5 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -583,9 +583,13 @@ struct FancyAddFLowering : public 
ConvertOpToLLVMPattern<arith::AddFOp> {
     auto parent = op->getParentOfType<ModuleOp>();
     if (!parent)
       return failure();
-    auto floatTy = dyn_cast<FloatType>(op.getType());
-    if (!floatTy)
+    if (!llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
+                   Float8E5M2FNUZType, Float8E4M3FNUZType,
+                   Float8E4M3B11FNUZType, Float8E3M4Type, Float4E2M1FNType,
+                   Float6E2M3FNType, Float6E3M2FNType, Float8E8M0FNUType>(
+            op.getType()))
       return failure();
+    auto floatTy = cast<FloatType>(op.getType());
     FailureOr<Operation *> adder =
         LLVM::lookupOrCreateApFloatAddFFn(rewriter, parent);
 
@@ -630,7 +634,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
 
   // clang-format off
   patterns.add<
-    //AddFOpLowering,
+    AddFOpLowering,
     FancyAddFLowering,
     AddIOpLowering,
     AndIOpLowering,

>From 3680b1119cad1577f01bc952e74e675f5af46488 Mon Sep 17 00:00:00 2001
From: makslevental <[email protected]>
Date: Wed, 5 Nov 2025 19:25:10 -0800
Subject: [PATCH 4/5] add X-macros

---
 .../mlir/Dialect/LLVMIR/FunctionCallUtils.h   | 20 ++++-
 .../Dialect/LLVMIR/IR/FunctionCallUtils.cpp   | 33 +++++---
 mlir/lib/ExecutionEngine/APFloatWrappers.cpp  | 75 ++++++++++++++++---
 3 files changed, 104 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h 
b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 8564d0f4205cf..01f7f75c210ef 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -55,9 +55,23 @@ lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
 FailureOr<LLVM::LLVMFuncOp>
 lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
                              SymbolTableCollection *symbolTables = nullptr);
-FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateApFloatAddFFn(OpBuilder &b, Operation *moduleOp,
-                            SymbolTableCollection *symbolTables = nullptr);
+
+#define APFLOAT_BIN_OPS(X)                                                     
\
+  X(add)                                                                       
\
+  X(subtract)                                                                  
\
+  X(multiply)                                                                  
\
+  X(divide)                                                                    
\
+  X(remainder)                                                                 
\
+  X(mod)
+
+#define LOOKUP_OR_CREATE_APFLOAT_FN_DECL(OP)                                   
\
+  FailureOr<LLVM::LLVMFuncOp> lookupOrCreateApFloat##OP##Fn(                   
\
+      OpBuilder &b, Operation *moduleOp,                                       
\
+      SymbolTableCollection *symbolTables = nullptr);
+
+APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DECL)
+
+#undef LOOKUP_OR_CREATE_APFLOAT_FN_DECL
 
 /// Declares a function to print a C-string.
 /// If a custom runtime function is defined via `runtimeFunctionName`, it must
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp 
b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 8ee039be60568..cb6ee76f8cbfb 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -31,7 +31,14 @@ static constexpr llvm::StringRef kPrintBF16 = "printBF16";
 static constexpr llvm::StringRef kPrintF32 = "printF32";
 static constexpr llvm::StringRef kPrintF64 = "printF64";
 static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
-static constexpr llvm::StringRef kApFloatAddF = "APFloat_add";
+
+#define APFLOAT_EXTERN_K(OP) kApFloat_##OP
+
+#define APFLOAT_EXTERN_NAME(OP)                                                
\
+  static constexpr llvm::StringRef APFLOAT_EXTERN_K(OP) = "APFloat_" #OP;
+
+APFLOAT_BIN_OPS(APFLOAT_EXTERN_NAME)
+
 static constexpr llvm::StringRef kPrintString = "printString";
 static constexpr llvm::StringRef kPrintOpen = "printOpen";
 static constexpr llvm::StringRef kPrintClose = "printClose";
@@ -172,16 +179,20 @@ mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, 
Operation *moduleOp,
       LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
 }
 
-FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreateApFloatAddFFn(OpBuilder &b, Operation *moduleOp,
-                                        SymbolTableCollection *symbolTables) {
-  return lookupOrCreateReservedFn(
-      b, moduleOp, kApFloatAddF,
-      {IntegerType::get(moduleOp->getContext(), 32),
-       IntegerType::get(moduleOp->getContext(), 64),
-       IntegerType::get(moduleOp->getContext(), 64)},
-      IntegerType::get(moduleOp->getContext(), 64), symbolTables);
-}
+#define LOOKUP_OR_CREATE_APFLOAT_FN_DEFN(OP)                                   
\
+  FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateApFloat##OP##Fn(       
\
+      OpBuilder &b, Operation *moduleOp,                                       
\
+      SymbolTableCollection *symbolTables) {                                   
\
+    return lookupOrCreateReservedFn(                                           
\
+        b, moduleOp, APFLOAT_EXTERN_K(OP),                                     
\
+        {IntegerType::get(moduleOp->getContext(), 32),                         
\
+         IntegerType::get(moduleOp->getContext(), 64),                         
\
+         IntegerType::get(moduleOp->getContext(), 64)},                        
\
+        IntegerType::get(moduleOp->getContext(), 64), symbolTables);           
\
+  }
+
+APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DEFN)
+#undef LOOKUP_OR_CREATE_APFLOAT_FN_DEFN
 
 static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
   return LLVM::LLVMPointerType::get(context);
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp 
b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index 7879c75803355..8d2848bd7cf77 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -7,26 +7,81 @@
 
//===----------------------------------------------------------------------===//
 
 #include "llvm/ADT/APFloat.h"
+#include "llvm/Support/Debug.h"
+
 #include <iostream>
 
+#define DEBUG_TYPE "mlir-apfloat-wrapper"
+
 #if (defined(_WIN32) || defined(__CYGWIN__))
 #define MLIR_APFLOAT_WRAPPERS_EXPORTED __declspec(dllexport)
 #else
 #define MLIR_APFLOAT_WRAPPERS_EXPORTED __attribute__((visibility("default")))
 #endif
 
+static std::string_view
+apFloatOpStatusToStr(llvm::APFloatBase::opStatus status) {
+  switch (status) {
+  case llvm::APFloatBase::opOK:
+    return "OK";
+  case llvm::APFloatBase::opInvalidOp:
+    return "InvalidOp";
+  case llvm::APFloatBase::opDivByZero:
+    return "DivByZero";
+  case llvm::APFloatBase::opOverflow:
+    return "Overflow";
+  case llvm::APFloatBase::opUnderflow:
+    return "Underflow";
+  case llvm::APFloatBase::opInexact:
+    return "Inexact";
+  }
+  llvm::report_fatal_error("unhandled llvm::APFloatBase::opStatus variant");
+}
+
+#define APFLOAT_BINARY_OP(OP)                                                  
\
+  int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_##OP(                         
\
+      int32_t semantics, uint64_t a, uint64_t b) {                             
\
+    const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(        
\
+        static_cast<llvm::APFloatBase::Semantics>(semantics));                 
\
+    unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);           
\
+    llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a));                          
\
+    llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b));                          
\
+    llvm::APFloatBase::opStatus status = lhs.OP(rhs);                          
\
+    assert(status == llvm::APFloatBase::opOK && "expected " #OP                
\
+                                                " opstatus to be OK");         
\
+    return lhs.bitcastToAPInt().getZExtValue();                                
\
+  }
+
+#define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE)                     
\
+  int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_##OP(                         
\
+      int32_t semantics, uint64_t a, uint64_t b) {                             
\
+    const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(        
\
+        static_cast<llvm::APFloatBase::Semantics>(semantics));                 
\
+    unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);           
\
+    llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a));                          
\
+    llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b));                          
\
+    llvm::APFloatBase::opStatus status = lhs.OP(rhs, ROUNDING_MODE);           
\
+    assert(status == llvm::APFloatBase::opOK && "expected " #OP                
\
+                                                " opstatus to be OK");         
\
+    return lhs.bitcastToAPInt().getZExtValue();                                
\
+  }
+
 extern "C" {
 
-int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_add(int32_t semantics,
-                                                   uint64_t a, uint64_t b) {
-  const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
-      static_cast<llvm::APFloatBase::Semantics>(semantics));
-  unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
-  llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a));
-  llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b));
-  auto status = lhs.add(rhs, llvm::RoundingMode::NearestTiesToEven);
-  return lhs.bitcastToAPInt().getZExtValue();
-}
+#define BIN_OPS_WITH_ROUNDING(X)                                               
\
+  X(add, llvm::RoundingMode::NearestTiesToEven)                                
\
+  X(subtract, llvm::RoundingMode::NearestTiesToEven)                           
\
+  X(multiply, llvm::RoundingMode::NearestTiesToEven)                           
\
+  X(divide, llvm::RoundingMode::NearestTiesToEven)
+
+BIN_OPS_WITH_ROUNDING(APFLOAT_BINARY_OP_ROUNDING_MODE)
+#undef BIN_OPS_WITH_ROUNDING
+#undef APFLOAT_BINARY_OP_ROUNDING_MODE
+
+APFLOAT_BINARY_OP(remainder)
+APFLOAT_BINARY_OP(mod)
+
+#undef APFLOAT_BINARY_OP
 
 void MLIR_APFLOAT_WRAPPERS_EXPORTED printApFloat(int32_t semantics,
                                                  uint64_t a) {

>From c42f471126db73c9063818f7cacb67377813ea35 Mon Sep 17 00:00:00 2001
From: makslevental <[email protected]>
Date: Thu, 6 Nov 2025 15:14:09 -0800
Subject: [PATCH 5/5] add arith-to-apfloat

---
 .../ArithToAPFloat/ArithToAPFloat.h           |  28 ++++
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |  13 ++
 mlir/include/mlir/Dialect/Func/Utils/Utils.h  |   8 ++
 .../mlir/Dialect/LLVMIR/FunctionCallUtils.h   |  17 ---
 .../ArithToAPFloat/ArithToAPFloat.cpp         | 136 ++++++++++++++++++
 .../Conversion/ArithToAPFloat/CMakeLists.txt  |  17 +++
 .../Conversion/ArithToLLVM/ArithToLLVM.cpp    |  48 -------
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 mlir/lib/Dialect/Func/Utils/Utils.cpp         |  42 ++++++
 .../Dialect/LLVMIR/IR/FunctionCallUtils.cpp   |  23 ---
 mlir/lib/ExecutionEngine/APFloatWrappers.cpp  |  22 ---
 .../Arith/CPU/test-apfloat-emulation.mlir     |  21 ++-
 13 files changed, 266 insertions(+), 111 deletions(-)
 create mode 100644 mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
 create mode 100644 mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
 create mode 100644 mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt

diff --git a/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h 
b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
new file mode 100644
index 0000000000000..a5df4647f1acc
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h
@@ -0,0 +1,28 @@
+//===- ArithToAPFloat.h - Arith to APFloat impl conversion ---*- C++ 
----*-===//
+//
+// Part of the APFloat Project, under the Apache License v2.0 with APFloat
+// Exceptions. See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
+#define MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
+
+#include <memory>
+
+namespace mlir {
+
+class DialectRegistry;
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_ARITHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace arith {
+void populateArithToAPFloatConversionPatterns(RewritePatternSet &patterns);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOAPFloat_ARITHTOAPFloat_H
diff --git a/mlir/include/mlir/Conversion/Passes.h 
b/mlir/include/mlir/Conversion/Passes.h
index 40d866ec7bf10..82bdfd02661a6 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -12,6 +12,7 @@
 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
 #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
+#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
 #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
 #include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td 
b/mlir/include/mlir/Conversion/Passes.td
index 70e3e45c225db..2bcd2870949f3 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -186,6 +186,19 @@ def ArithToLLVMConversionPass : 
Pass<"convert-arith-to-llvm"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// ArithToAPFloat
+//===----------------------------------------------------------------------===//
+
+def ArithToAPFloatConversionPass : Pass<"convert-arith-to-apfloat"> {
+  let summary = "Convert Arith dialect ops on FP8 types to APFloat lib calls";
+  let description = [{
+    This pass converts supported Arith ops which manipulate FP8 typed values 
to APFloat lib calls.
+  }];
+  let dependentDialects = ["func::FuncDialect"];
+  let options = [];
+}
+
 
//===----------------------------------------------------------------------===//
 // ArithToSPIRV
 
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h 
b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 3576126a487ac..9c9973cf84368 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -60,6 +60,14 @@ mlir::FailureOr<std::pair<mlir::func::FuncOp, 
mlir::func::CallOp>>
 deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp 
funcOp,
                         mlir::ModuleOp moduleOp);
 
+/// Create a FuncOp with signature `resultTypes`(`paramTypes`)` and name 
`name`.
+/// Return a failure if the FuncOp found has unexpected signature.
+FailureOr<FuncOp>
+lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
+                 ArrayRef<Type> paramTypes = {},
+                 ArrayRef<Type> resultTypes = {}, bool setPrivate = false,
+                 SymbolTableCollection *symbolTables = nullptr);
+
 } // namespace func
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h 
b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 01f7f75c210ef..b09d32022e348 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -56,23 +56,6 @@ FailureOr<LLVM::LLVMFuncOp>
 lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
                              SymbolTableCollection *symbolTables = nullptr);
 
-#define APFLOAT_BIN_OPS(X)                                                     
\
-  X(add)                                                                       
\
-  X(subtract)                                                                  
\
-  X(multiply)                                                                  
\
-  X(divide)                                                                    
\
-  X(remainder)                                                                 
\
-  X(mod)
-
-#define LOOKUP_OR_CREATE_APFLOAT_FN_DECL(OP)                                   
\
-  FailureOr<LLVM::LLVMFuncOp> lookupOrCreateApFloat##OP##Fn(                   
\
-      OpBuilder &b, Operation *moduleOp,                                       
\
-      SymbolTableCollection *symbolTables = nullptr);
-
-APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DECL)
-
-#undef LOOKUP_OR_CREATE_APFLOAT_FN_DECL
-
 /// Declares a function to print a C-string.
 /// If a custom runtime function is defined via `runtimeFunctionName`, it must
 /// have the signature void(char const*). The default function is 
`printString`.
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp 
b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
new file mode 100644
index 0000000000000..bc451e88eb3bd
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -0,0 +1,136 @@
+//===- ArithToAPFloat.cpp - Arithmetic to APFloat impl conversion 
---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::func;
+
+#define APFLOAT_BIN_OPS(X)                                                     
\
+  X(add)                                                                       
\
+  X(subtract)                                                                  
\
+  X(multiply)                                                                  
\
+  X(divide)                                                                    
\
+  X(remainder)                                                                 
\
+  X(mod)
+
+#define APFLOAT_EXTERN_K(OP) kApFloat_##OP
+
+#define APFLOAT_EXTERN_NAME(OP)                                                
\
+  static constexpr llvm::StringRef APFLOAT_EXTERN_K(OP) = "_mlir_"             
\
+                                                          "apfloat_" #OP;
+
+namespace mlir::func {
+#define LOOKUP_OR_CREATE_APFLOAT_FN_DECL(OP)                                   
\
+  FailureOr<FuncOp> lookupOrCreateApFloat##OP##Fn(                             
\
+      OpBuilder &b, Operation *moduleOp,                                       
\
+      SymbolTableCollection *symbolTables = nullptr);
+
+APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DECL)
+
+#undef LOOKUP_OR_CREATE_APFLOAT_FN_DECL
+
+APFLOAT_BIN_OPS(APFLOAT_EXTERN_NAME)
+
+#define LOOKUP_OR_CREATE_APFLOAT_FN_DEFN(OP)                                   
\
+  FailureOr<FuncOp> lookupOrCreateApFloat##OP##Fn(                             
\
+      OpBuilder &b, Operation *moduleOp,                                       
\
+      SymbolTableCollection *symbolTables) {                                   
\
+    return lookupOrCreateFn(b, moduleOp, APFLOAT_EXTERN_K(OP),                 
\
+                            {IntegerType::get(moduleOp->getContext(), 32),     
\
+                             IntegerType::get(moduleOp->getContext(), 64),     
\
+                             IntegerType::get(moduleOp->getContext(), 64)},    
\
+                            {IntegerType::get(moduleOp->getContext(), 64)},    
\
+                            /*setPrivate*/ true, symbolTables);                
\
+  }
+
+APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DEFN)
+#undef LOOKUP_OR_CREATE_APFLOAT_FN_DEFN
+} // namespace mlir::func
+
+struct FancyAddFLowering : OpRewritePattern<arith::AddFOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arith::AddFOp op,
+                                PatternRewriter &rewriter) const override {
+    // Get APFloat adder function from runtime library.
+    auto parent = op->getParentOfType<ModuleOp>();
+    if (!parent)
+      return failure();
+    if (!llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
+                   Float8E5M2FNUZType, Float8E4M3FNUZType,
+                   Float8E4M3B11FNUZType, Float8E3M4Type, Float4E2M1FNType,
+                   Float6E2M3FNType, Float6E3M2FNType, Float8E8M0FNUType>(
+            op.getType()))
+      return failure();
+    FailureOr<Operation *> adder = lookupOrCreateApFloataddFn(rewriter, 
parent);
+
+    // Cast operands to 64-bit integers.
+    Location loc = op.getLoc();
+    auto floatTy = cast<FloatType>(op.getType());
+    auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+    auto int64Type = rewriter.getI64Type();
+    Value lhsBits = arith::ExtUIOp::create(
+        rewriter, loc, int64Type,
+        arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
+    Value rhsBits = arith::ExtUIOp::create(
+        rewriter, loc, int64Type,
+        arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
+
+    // Call software implementation of floating point addition.
+    int32_t sem =
+        llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+    Value semValue = arith::ConstantOp::create(
+        rewriter, loc, rewriter.getI32Type(),
+        rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+    SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+    auto resultOp =
+        func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+                             SymbolRefAttr::get(*adder), params);
+
+    // Truncate result to the original width.
+    Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
+                                                  resultOp->getResult(0));
+    rewriter.replaceAllUsesWith(
+        op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+    return success();
+  }
+};
+
+void arith::populateArithToAPFloatConversionPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<FancyAddFLowering>(patterns.getContext());
+}
+
+namespace {
+struct ArithToAPFloatConversionPass final
+    : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
+  using impl::ArithToAPFloatConversionPassBase<
+      ArithToAPFloatConversionPass>::ArithToAPFloatConversionPassBase;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    RewritePatternSet patterns(op->getContext());
+    arith::populateArithToAPFloatConversionPatterns(patterns);
+    if (failed(applyPatternsGreedily(op, std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
diff --git a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt 
b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
new file mode 100644
index 0000000000000..b0d1e46b3655f
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(MLIRArithToAPFloat
+  ArithToAPFloat.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRArithDialect
+  MLIRArithTransforms
+  MLIRFuncDialect
+  )
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp 
b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 370e048bb3af5..5972c8cc9987e 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -573,53 +573,6 @@ void mlir::arith::registerConvertArithToLLVMInterface(
   });
 }
 
-struct FancyAddFLowering : public ConvertOpToLLVMPattern<arith::AddFOp> {
-  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(arith::AddFOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    // Get APFloat adder function from runtime library.
-    auto parent = op->getParentOfType<ModuleOp>();
-    if (!parent)
-      return failure();
-    if (!llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
-                   Float8E5M2FNUZType, Float8E4M3FNUZType,
-                   Float8E4M3B11FNUZType, Float8E3M4Type, Float4E2M1FNType,
-                   Float6E2M3FNType, Float6E3M2FNType, Float8E8M0FNUType>(
-            op.getType()))
-      return failure();
-    auto floatTy = cast<FloatType>(op.getType());
-    FailureOr<Operation *> adder =
-        LLVM::lookupOrCreateApFloatAddFFn(rewriter, parent);
-
-    // Cast operands to 64-bit integers.
-    Location loc = op.getLoc();
-    Value lhsBits = LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(),
-                                         adaptor.getLhs());
-    Value rhsBits = LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(),
-                                         adaptor.getRhs());
-
-    // Call software implementation of floating point addition.
-    int32_t sem =
-        llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
-    Value semValue = LLVM::ConstantOp::create(
-        rewriter, loc, rewriter.getI32Type(),
-        rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
-    SmallVector<Value> params = {semValue, lhsBits, rhsBits};
-    auto resultOp =
-        LLVM::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
-                             SymbolRefAttr::get(*adder), params);
-
-    // Truncate result to the original width.
-    Value truncatedBits = LLVM::TruncOp::create(
-        rewriter, loc, rewriter.getIntegerType(floatTy.getWidth()),
-        resultOp->getResult(0));
-    rewriter.replaceOp(op, truncatedBits);
-    return success();
-  }
-};
-
 
//===----------------------------------------------------------------------===//
 // Pattern Population
 
//===----------------------------------------------------------------------===//
@@ -635,7 +588,6 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
   // clang-format off
   patterns.add<
     AddFOpLowering,
-    FancyAddFLowering,
     AddIOpLowering,
     AndIOpLowering,
     AddUIExtendedOpLowering,
diff --git a/mlir/lib/Conversion/CMakeLists.txt 
b/mlir/lib/Conversion/CMakeLists.txt
index bebf1b8fff3f9..613dc6d242ceb 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
 add_subdirectory(AMDGPUToROCDL)
 add_subdirectory(ArithCommon)
 add_subdirectory(ArithToAMDGPU)
+add_subdirectory(ArithToAPFloat)
 add_subdirectory(ArithToArmSME)
 add_subdirectory(ArithToEmitC)
 add_subdirectory(ArithToLLVM)
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp 
b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index b4cb0932ef631..e187e62cf6555 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -254,3 +254,45 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, 
func::FuncOp funcOp,
 
   return std::make_pair(*newFuncOpOrFailure, newCallOp);
 }
+
+FailureOr<func::FuncOp>
+func::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
+                       ArrayRef<Type> paramTypes, ArrayRef<Type> resultTypes,
+                       bool setPrivate, SymbolTableCollection *symbolTables) {
+  assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
+         "expected SymbolTable operation");
+
+  FuncOp func;
+  if (symbolTables) {
+    func = symbolTables->lookupSymbolIn<FuncOp>(
+        moduleOp, StringAttr::get(moduleOp->getContext(), name));
+  } else {
+    func = llvm::dyn_cast_or_null<FuncOp>(
+        SymbolTable::lookupSymbolIn(moduleOp, name));
+  }
+
+  FunctionType funcT =
+      FunctionType::get(b.getContext(), paramTypes, resultTypes);
+  // Assert the signature of the found function is same as expected
+  if (func) {
+    if (funcT != func.getFunctionType()) {
+      func.emitError("redefinition of function '")
+          << name << "' of different type " << funcT << " is prohibited";
+      return failure();
+    }
+    return func;
+  }
+
+  OpBuilder::InsertionGuard g(b);
+  assert(!moduleOp->getRegion(0).empty() && "expected non-empty region");
+  b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
+  FuncOp funcOp = FuncOp::create(b, moduleOp->getLoc(), name, funcT);
+  if (setPrivate)
+    funcOp.setPrivate();
+  if (symbolTables) {
+    SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp);
+    symbolTable.insert(funcOp, moduleOp->getRegion(0).front().begin());
+  }
+
+  return funcOp;
+}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp 
b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index cb6ee76f8cbfb..160b6ae89215c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -31,14 +31,6 @@ static constexpr llvm::StringRef kPrintBF16 = "printBF16";
 static constexpr llvm::StringRef kPrintF32 = "printF32";
 static constexpr llvm::StringRef kPrintF64 = "printF64";
 static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
-
-#define APFLOAT_EXTERN_K(OP) kApFloat_##OP
-
-#define APFLOAT_EXTERN_NAME(OP)                                                
\
-  static constexpr llvm::StringRef APFLOAT_EXTERN_K(OP) = "APFloat_" #OP;
-
-APFLOAT_BIN_OPS(APFLOAT_EXTERN_NAME)
-
 static constexpr llvm::StringRef kPrintString = "printString";
 static constexpr llvm::StringRef kPrintOpen = "printOpen";
 static constexpr llvm::StringRef kPrintClose = "printClose";
@@ -179,21 +171,6 @@ mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, 
Operation *moduleOp,
       LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
 }
 
-#define LOOKUP_OR_CREATE_APFLOAT_FN_DEFN(OP)                                   
\
-  FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateApFloat##OP##Fn(       
\
-      OpBuilder &b, Operation *moduleOp,                                       
\
-      SymbolTableCollection *symbolTables) {                                   
\
-    return lookupOrCreateReservedFn(                                           
\
-        b, moduleOp, APFLOAT_EXTERN_K(OP),                                     
\
-        {IntegerType::get(moduleOp->getContext(), 32),                         
\
-         IntegerType::get(moduleOp->getContext(), 64),                         
\
-         IntegerType::get(moduleOp->getContext(), 64)},                        
\
-        IntegerType::get(moduleOp->getContext(), 64), symbolTables);           
\
-  }
-
-APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DEFN)
-#undef LOOKUP_OR_CREATE_APFLOAT_FN_DEFN
-
 static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
   return LLVM::LLVMPointerType::get(context);
 }
diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp 
b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
index 8d2848bd7cf77..a5049436d03c1 100644
--- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
@@ -7,37 +7,15 @@
 
//===----------------------------------------------------------------------===//
 
 #include "llvm/ADT/APFloat.h"
-#include "llvm/Support/Debug.h"
 
 #include <iostream>
 
-#define DEBUG_TYPE "mlir-apfloat-wrapper"
-
 #if (defined(_WIN32) || defined(__CYGWIN__))
 #define MLIR_APFLOAT_WRAPPERS_EXPORTED __declspec(dllexport)
 #else
 #define MLIR_APFLOAT_WRAPPERS_EXPORTED __attribute__((visibility("default")))
 #endif
 
-static std::string_view
-apFloatOpStatusToStr(llvm::APFloatBase::opStatus status) {
-  switch (status) {
-  case llvm::APFloatBase::opOK:
-    return "OK";
-  case llvm::APFloatBase::opInvalidOp:
-    return "InvalidOp";
-  case llvm::APFloatBase::opDivByZero:
-    return "DivByZero";
-  case llvm::APFloatBase::opOverflow:
-    return "Overflow";
-  case llvm::APFloatBase::opUnderflow:
-    return "Underflow";
-  case llvm::APFloatBase::opInexact:
-    return "Inexact";
-  }
-  llvm::report_fatal_error("unhandled llvm::APFloatBase::opStatus variant");
-}
-
 #define APFLOAT_BINARY_OP(OP)                                                  
\
   int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_##OP(                         
\
       int32_t semantics, uint64_t a, uint64_t b) {                             
\
diff --git 
a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir 
b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
index 5cd83688d1710..d4c2394474b15 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir
@@ -1,7 +1,7 @@
 // Check that the ceildivsi lowering is correct.
 // We do not check any poison or UB values, as it is not possible to catch 
them.
 
-// RUN: mlir-opt %s --convert-to-llvm
+// RUN: mlir-opt %s --convert-arith-to-apfloat
 
 // Put rhs into separate function so that it won't be constant-folded.
 func.func @foo() -> f4E2M1FN {
@@ -17,3 +17,22 @@ func.func @entry() {
   return
 }
 
+// CHECK-LABEL:   func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64
+
+// CHECK-LABEL:   func.func @foo() -> f4E2M1FN {
+// CHECK:           %[[CONSTANT_0:.*]] = arith.constant 4.000000e+00 : f4E2M1FN
+// CHECK:           return %[[CONSTANT_0]] : f4E2M1FN
+// CHECK:         }
+
+// CHECK-LABEL:   func.func @entry() {
+// CHECK:           %[[CONSTANT_0:.*]] = arith.constant 18 : i32
+// CHECK:           %[[CONSTANT_1:.*]] = arith.constant 6 : i64
+// CHECK:           %[[VAL_0:.*]] = call @foo() : () -> f4E2M1FN
+// CHECK:           %[[BITCAST_0:.*]] = arith.bitcast %[[VAL_0]] : f4E2M1FN to 
i4
+// CHECK:           %[[EXTUI_0:.*]] = arith.extui %[[BITCAST_0]] : i4 to i64
+// CHECK:           %[[VAL_1:.*]] = call @_mlir_apfloat_add(%[[CONSTANT_0]], 
%[[EXTUI_0]], %[[CONSTANT_1]]) : (i32, i64, i64) -> i64
+// CHECK:           %[[TRUNCI_0:.*]] = arith.trunci %[[VAL_1]] : i64 to i4
+// CHECK:           %[[BITCAST_1:.*]] = arith.bitcast %[[TRUNCI_0]] : i4 to 
f4E2M1FN
+// CHECK:           vector.print %[[BITCAST_1]] : f4E2M1FN
+// CHECK:           return
+// CHECK:         }
\ No newline at end of file

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to