jakeh-gc created this revision.
Herald added subscribers: Moerafaat, zero9178, bzcheeseman, sdasgup3, 
wenzhicui, wrengr, cota, teijeong, rdzhabarov, tatianashp, msifontes, jurahul, 
Kayjukh, grosul1, Joonsoo, liufengdb, aartbik, mgester, arpith-jacob, 
antiagainst, shauheen, rriddle, mehdi_amini, hiraditya.
Herald added a reviewer: rriddle.
Herald added a reviewer: antiagainst.
Herald added a project: All.
jakeh-gc requested review of this revision.
Herald added subscribers: llvm-commits, cfe-commits, stephenneuendorffer, 
nicolasvasilache.
Herald added projects: clang, MLIR, LLVM.

Graphcore, AMD, and Qualcomm have proposed two new FP8 formats, Float8E4M3FZN 
and Float8E5M2FZN. These formats are presented in this paper: 
https://arxiv.org/abs/2206.02915. They are implemented in commercially 
available hardware and the ISA for this hardware is available here: 
https://docs.graphcore.ai/projects/isa-mk2-with-fp8/en/latest/_static/TileVertexISA-IPU21-1.3.1.pdf.

This patch adds support for these two types in MLIR and APFloat, alongside the 
previously added types Float8E4M3FN and Float8E5M2 (D133823 
<https://reviews.llvm.org/D133823>, D137760 <https://reviews.llvm.org/D137760>, 
RFC 
<https://discourse.llvm.org/t/rfc-add-apfloat-and-mlir-type-support-for-fp8-e5m2/65279>).

Following the naming scheme from those existing types, the suffix "FZN" here 
refers to the fact that these types support finite values, positive-only zero 
(no negative zero), and a NaN encoding. In both types NaN has exactly one 
encoding `0b10000000`.

To support this behaviour I have added another value to the 
`fltNonfiniteBehavior` enum to represent this specific NaN encoding. I have 
also added a new field (`fltSignedZeroSupport`) to the `fltSemantics` struct to 
describe whether signed zero is supported.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D141432

Files:
  clang/lib/AST/MicrosoftMangle.cpp
  llvm/include/llvm/ADT/APFloat.h
  llvm/lib/Support/APFloat.cpp
  llvm/unittests/ADT/APFloatTest.cpp
  mlir/include/mlir-c/BuiltinTypes.h
  mlir/include/mlir/IR/Builders.h
  mlir/include/mlir/IR/BuiltinTypes.h
  mlir/include/mlir/IR/BuiltinTypes.td
  mlir/include/mlir/IR/OpBase.td
  mlir/include/mlir/IR/Types.h
  mlir/lib/AsmParser/TokenKinds.def
  mlir/lib/AsmParser/TypeParser.cpp
  mlir/lib/Bindings/Python/IRTypes.cpp
  mlir/lib/CAPI/IR/BuiltinTypes.cpp
  mlir/lib/IR/AsmPrinter.cpp
  mlir/lib/IR/Builders.cpp
  mlir/lib/IR/BuiltinTypes.cpp
  mlir/lib/IR/MLIRContext.cpp
  mlir/lib/IR/Types.cpp
  mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
  mlir/test/IR/attribute.mlir
  mlir/test/python/ir/builtin_types.py
  mlir/utils/lldb-scripts/mlirDataFormatters.py

Index: mlir/utils/lldb-scripts/mlirDataFormatters.py
===================================================================
--- mlir/utils/lldb-scripts/mlirDataFormatters.py
+++ mlir/utils/lldb-scripts/mlirDataFormatters.py
@@ -52,6 +52,8 @@
     "mlir::UnknownLoc": '"loc(unknown)"',
     "mlir::Float8E5M2Type": '"f8E5M2"',
     "mlir::Float8E4M3FNType": '"f8E4M3FN"',
+    "mlir::Float8E4M3FZNType": '"f8E4M3FZN"',
+    "mlir::Float8E5M2FZNType": '"f8E5M2FZN"',
     "mlir::BFloat16Type": '"bf16"',
     "mlir::Float16Type": '"f16"',
     "mlir::Float32Type": '"f32"',
Index: mlir/test/python/ir/builtin_types.py
===================================================================
--- mlir/test/python/ir/builtin_types.py
+++ mlir/test/python/ir/builtin_types.py
@@ -197,6 +197,10 @@
     print("float:", Float8E4M3FNType.get())
     # CHECK: float: f8E5M2
     print("float:", Float8E5M2Type.get())
+    # CHECK: float: f8E4M3FZN
+    print("float:", Float8E4M3FZNType.get())
+    # CHECK: float: f8E5M2FZN
+    print("float:", Float8E5M2FZNType.get())
     # CHECK: float: bf16
     print("float:", BF16Type.get())
     # CHECK: float: f16
Index: mlir/test/IR/attribute.mlir
===================================================================
--- mlir/test/IR/attribute.mlir
+++ mlir/test/IR/attribute.mlir
@@ -44,6 +44,14 @@
     // CHECK: float_attr = 2.000000e+00 : f8E4M3FN
     float_attr = 2. : f8E4M3FN
   } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f8E4M3FZN
+    float_attr = 2. : f8E4M3FZN
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f8E5M2FZN
+    float_attr = 2. : f8E5M2FZN
+  } : () -> ()
   "test.float_attrs"() {
     // CHECK: float_attr = 2.000000e+00 : f16
     float_attr = 2. : f16
Index: mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
===================================================================
--- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -52,6 +52,8 @@
     "DictAttr",
     "Float8E4M3FNType",
     "Float8E5M2Type",
+    "Float8E4M3FZNType",
+    "Float8E5M2FZNType",
     "F16Type",
     "F32Type",
     "F64Type",
@@ -586,6 +588,20 @@
     @staticmethod
     def isinstance(arg: Any) -> bool: ...
 
+class Float8E5M2FZNType(Type):
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @staticmethod
+    def get(*args, **kwargs) -> Float8E5M2FZNType: ...
+    @staticmethod
+    def isinstance(arg: Any) -> bool: ...
+
+class Float8E4M3FZNType(Type):
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @staticmethod
+    def get(*args, **kwargs) -> Float8E4M3FZNType: ...
+    @staticmethod
+    def isinstance(arg: Any) -> bool: ...
+
 class Float8E5M2Type(Type):
     def __init__(self, cast_from_type: Type) -> None: ...
     @staticmethod
Index: mlir/lib/IR/Types.cpp
===================================================================
--- mlir/lib/IR/Types.cpp
+++ mlir/lib/IR/Types.cpp
@@ -20,6 +20,8 @@
 
 bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
 bool Type::isFloat8E4M3FN() const { return isa<Float8E4M3FNType>(); }
+bool Type::isFloat8E4M3FZN() const { return isa<Float8E4M3FZNType>(); }
+bool Type::isFloat8E5M2FZN() const { return isa<Float8E5M2FZNType>(); }
 bool Type::isBF16() const { return isa<BFloat16Type>(); }
 bool Type::isF16() const { return isa<Float16Type>(); }
 bool Type::isF32() const { return isa<Float32Type>(); }
Index: mlir/lib/IR/MLIRContext.cpp
===================================================================
--- mlir/lib/IR/MLIRContext.cpp
+++ mlir/lib/IR/MLIRContext.cpp
@@ -208,6 +208,8 @@
   /// Cached Type Instances.
   Float8E5M2Type f8E5M2Ty;
   Float8E4M3FNType f8E4M3FNTy;
+  Float8E4M3FZNType f8E4M3FZNTy;
+  Float8E5M2FZNType f8E5M2FZNTy;
   BFloat16Type bf16Ty;
   Float16Type f16Ty;
   Float32Type f32Ty;
@@ -280,6 +282,8 @@
   /// Floating-point Types.
   impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
   impl->f8E4M3FNTy = TypeUniquer::get<Float8E4M3FNType>(this);
+  impl->f8E4M3FZNTy = TypeUniquer::get<Float8E4M3FZNType>(this);
+  impl->f8E5M2FZNTy = TypeUniquer::get<Float8E5M2FZNType>(this);
   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
   impl->f32Ty = TypeUniquer::get<Float32Type>(this);
@@ -866,6 +870,12 @@
 Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) {
   return context->getImpl().f8E4M3FNTy;
 }
+Float8E4M3FZNType Float8E4M3FZNType::get(MLIRContext *context) {
+  return context->getImpl().f8E4M3FZNTy;
+}
+Float8E5M2FZNType Float8E5M2FZNType::get(MLIRContext *context) {
+  return context->getImpl().f8E5M2FZNTy;
+}
 BFloat16Type BFloat16Type::get(MLIRContext *context) {
   return context->getImpl().bf16Ty;
 }
Index: mlir/lib/IR/BuiltinTypes.cpp
===================================================================
--- mlir/lib/IR/BuiltinTypes.cpp
+++ mlir/lib/IR/BuiltinTypes.cpp
@@ -88,7 +88,8 @@
 //===----------------------------------------------------------------------===//
 
 unsigned FloatType::getWidth() {
-  if (isa<Float8E5M2Type, Float8E4M3FNType>())
+  if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E4M3FZNType,
+          Float8E5M2FZNType>())
     return 8;
   if (isa<Float16Type, BFloat16Type>())
     return 16;
@@ -109,6 +110,10 @@
     return APFloat::Float8E5M2();
   if (isa<Float8E4M3FNType>())
     return APFloat::Float8E4M3FN();
+  if (isa<Float8E4M3FZNType>())
+    return APFloat::Float8E4M3FZN();
+  if (isa<Float8E5M2FZNType>())
+    return APFloat::Float8E5M2FZN();
   if (isa<BFloat16Type>())
     return APFloat::BFloat();
   if (isa<Float16Type>())
Index: mlir/lib/IR/Builders.cpp
===================================================================
--- mlir/lib/IR/Builders.cpp
+++ mlir/lib/IR/Builders.cpp
@@ -41,6 +41,14 @@
   return FloatType::getFloat8E4M3FN(context);
 }
 
+FloatType Builder::getFloat8E4M3FZNType() {
+  return FloatType::getFloat8E4M3FZN(context);
+}
+
+FloatType Builder::getFloat8E5M2FZNType() {
+  return FloatType::getFloat8E5M2FZN(context);
+}
+
 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
 
 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
Index: mlir/lib/IR/AsmPrinter.cpp
===================================================================
--- mlir/lib/IR/AsmPrinter.cpp
+++ mlir/lib/IR/AsmPrinter.cpp
@@ -2412,6 +2412,8 @@
       .Case<IndexType>([&](Type) { os << "index"; })
       .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
       .Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
+      .Case<Float8E4M3FZNType>([&](Type) { os << "f8E4M3FZN"; })
+      .Case<Float8E5M2FZNType>([&](Type) { os << "f8E5M2FZN"; })
       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
       .Case<Float16Type>([&](Type) { os << "f16"; })
       .Case<Float32Type>([&](Type) { os << "f32"; })
Index: mlir/lib/CAPI/IR/BuiltinTypes.cpp
===================================================================
--- mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -84,6 +84,22 @@
   return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx)));
 }
 
+bool mlirTypeIsAFloat8E4M3FZN(MlirType type) {
+  return unwrap(type).isFloat8E4M3FZN();
+}
+
+MlirType mlirFloat8E4M3FZNTypeGet(MlirContext ctx) {
+  return wrap(FloatType::getFloat8E4M3FZN(unwrap(ctx)));
+}
+
+bool mlirTypeIsAFloat8E5M2FZN(MlirType type) {
+  return unwrap(type).isFloat8E5M2FZN();
+}
+
+MlirType mlirFloat8E5M2FZNTypeGet(MlirContext ctx) {
+  return wrap(FloatType::getFloat8E5M2FZN(unwrap(ctx)));
+}
+
 bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
 
 MlirType mlirBF16TypeGet(MlirContext ctx) {
Index: mlir/lib/Bindings/Python/IRTypes.cpp
===================================================================
--- mlir/lib/Bindings/Python/IRTypes.cpp
+++ mlir/lib/Bindings/Python/IRTypes.cpp
@@ -138,6 +138,42 @@
   }
 };
 
+/// Floating Point Type subclass - Float8E4M3FZN.
+class PyFloat8E4M3FZNType : public PyConcreteType<PyFloat8E4M3FZNType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FZN;
+  static constexpr const char *pyClassName = "Float8E4M3FZNType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat8E4M3FZNTypeGet(context->get());
+          return PyFloat8E4M3FZNType(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a f8e4m3fnz type.");
+  }
+};
+
+/// Floating Point Type subclass - Float8E5M2FZN.
+class PyFloat8E5M2FZNType : public PyConcreteType<PyFloat8E5M2FZNType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FZN;
+  static constexpr const char *pyClassName = "Float8E5M2FZNType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](DefaultingPyMlirContext context) {
+          MlirType t = mlirFloat8E5M2FZNTypeGet(context->get());
+          return PyFloat8E5M2FZNType(context->getRef(), t);
+        },
+        py::arg("context") = py::none(), "Create a f8e5m2fnz type.");
+  }
+};
+
 /// Floating Point Type subclass - BF16Type.
 class PyBF16Type : public PyConcreteType<PyBF16Type> {
 public:
@@ -701,6 +737,8 @@
   PyIndexType::bind(m);
   PyFloat8E4M3FNType::bind(m);
   PyFloat8E5M2Type::bind(m);
+  PyFloat8E4M3FZNType::bind(m);
+  PyFloat8E5M2FZNType::bind(m);
   PyBF16Type::bind(m);
   PyF16Type::bind(m);
   PyF32Type::bind(m);
Index: mlir/lib/AsmParser/TypeParser.cpp
===================================================================
--- mlir/lib/AsmParser/TypeParser.cpp
+++ mlir/lib/AsmParser/TypeParser.cpp
@@ -32,6 +32,8 @@
   case Token::inttype:
   case Token::kw_f8E5M2:
   case Token::kw_f8E4M3FN:
+  case Token::kw_f8E4M3FZN:
+  case Token::kw_f8E5M2FZN:
   case Token::kw_bf16:
   case Token::kw_f16:
   case Token::kw_f32:
@@ -294,6 +296,12 @@
   case Token::kw_f8E4M3FN:
     consumeToken(Token::kw_f8E4M3FN);
     return builder.getFloat8E4M3FNType();
+  case Token::kw_f8E4M3FZN:
+    consumeToken(Token::kw_f8E4M3FZN);
+    return builder.getFloat8E4M3FZNType();
+  case Token::kw_f8E5M2FZN:
+    consumeToken(Token::kw_f8E5M2FZN);
+    return builder.getFloat8E5M2FZNType();
   case Token::kw_bf16:
     consumeToken(Token::kw_bf16);
     return builder.getBF16Type();
Index: mlir/lib/AsmParser/TokenKinds.def
===================================================================
--- mlir/lib/AsmParser/TokenKinds.def
+++ mlir/lib/AsmParser/TokenKinds.def
@@ -95,6 +95,8 @@
 TOK_KEYWORD(f80)
 TOK_KEYWORD(f8E5M2)
 TOK_KEYWORD(f8E4M3FN)
+TOK_KEYWORD(f8E4M3FZN)
+TOK_KEYWORD(f8E5M2FZN)
 TOK_KEYWORD(f128)
 TOK_KEYWORD(false)
 TOK_KEYWORD(floordiv)
Index: mlir/include/mlir/IR/Types.h
===================================================================
--- mlir/include/mlir/IR/Types.h
+++ mlir/include/mlir/IR/Types.h
@@ -125,6 +125,8 @@
   bool isIndex() const;
   bool isFloat8E5M2() const;
   bool isFloat8E4M3FN() const;
+  bool isFloat8E4M3FZN() const;
+  bool isFloat8E5M2FZN() const;
   bool isBF16() const;
   bool isF16() const;
   bool isF32() const;
Index: mlir/include/mlir/IR/OpBase.td
===================================================================
--- mlir/include/mlir/IR/OpBase.td
+++ mlir/include/mlir/IR/OpBase.td
@@ -494,6 +494,10 @@
                BuildableType<"$_builder.getFloat8E4M3FNType()">;
 def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
              BuildableType<"$_builder.getFloat8E5M2Type()">;
+def F8E4M3FZN : Type<CPred<"$_self.isFloat8E4M3FZN()">, "f8E4M3FZN type">,
+                BuildableType<"$_builder.getFloat8E4M3FZNType()">;
+def F8E5M2FZN : Type<CPred<"$_self.isFloat8E5M2FZN()">, "f8E5M2FZN type">,
+                BuildableType<"$_builder.getFloat8E5M2FZNType()">;
 
 def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
                       "complex-type", "::mlir::ComplexType">;
Index: mlir/include/mlir/IR/BuiltinTypes.td
===================================================================
--- mlir/include/mlir/IR/BuiltinTypes.td
+++ mlir/include/mlir/IR/BuiltinTypes.td
@@ -119,6 +119,50 @@
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Float8E4M3FZNType
+
+def Builtin_Float8E4M3FZN : Builtin_FloatType<"Float8E4M3FZN"> {
+  let summary = "8-bit floating point with 3 bit mantissa";
+  let description = [{
+    An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
+    mantissa. This is not a standard type as defined by IEEE-754, but it follows
+    similar conventions, with the exception that there are no infinity values,
+    no negative zero, and only one NaN representation. This type has the
+    following characteristics:
+
+      * bit encoding: S1E4M3
+      * exponent bias: 8
+      * infinities: Not supported
+      * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s
+      * denormals when exponent is 0
+
+    Described in: https://arxiv.org/abs/2209.05433
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Float8E5M2FZNType
+
+def Builtin_Float8E5M2FZN : Builtin_FloatType<"Float8E5M2FZN"> {
+  let summary = "8-bit floating point with 2 bit mantissa";
+  let description = [{
+    An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
+    mantissa. This is not a standard type as defined by IEEE-754, but it follows
+    similar conventions, with the exception that there are no infinity values,
+    no negative zero, and only one NaN representation. This type has the
+    following characteristics:
+
+      * bit encoding: S1E5M2
+      * exponent bias: 16
+      * infinities: Not supported
+      * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s
+      * denormals when exponent is 0
+
+    Described in: https://arxiv.org/abs/2206.02915
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // BFloat16Type
 
Index: mlir/include/mlir/IR/BuiltinTypes.h
===================================================================
--- mlir/include/mlir/IR/BuiltinTypes.h
+++ mlir/include/mlir/IR/BuiltinTypes.h
@@ -48,6 +48,8 @@
   static FloatType getF128(MLIRContext *ctx);
   static FloatType getFloat8E5M2(MLIRContext *ctx);
   static FloatType getFloat8E4M3FN(MLIRContext *ctx);
+  static FloatType getFloat8E4M3FZN(MLIRContext *ctx);
+  static FloatType getFloat8E5M2FZN(MLIRContext *ctx);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(Type type);
@@ -375,8 +377,9 @@
 }
 
 inline bool FloatType::classof(Type type) {
-  return type.isa<Float8E5M2Type, Float8E4M3FNType, BFloat16Type, Float16Type,
-                  Float32Type, Float64Type, Float80Type, Float128Type>();
+  return type.isa<Float8E5M2Type, Float8E4M3FNType, Float8E4M3FZNType,
+                  Float8E5M2FZNType, BFloat16Type, Float16Type, Float32Type,
+                  Float64Type, Float80Type, Float128Type>();
 }
 
 inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
@@ -387,6 +390,14 @@
   return Float8E4M3FNType::get(ctx);
 }
 
+inline FloatType FloatType::getFloat8E4M3FZN(MLIRContext *ctx) {
+  return Float8E4M3FZNType::get(ctx);
+}
+
+inline FloatType FloatType::getFloat8E5M2FZN(MLIRContext *ctx) {
+  return Float8E5M2FZNType::get(ctx);
+}
+
 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
   return BFloat16Type::get(ctx);
 }
Index: mlir/include/mlir/IR/Builders.h
===================================================================
--- mlir/include/mlir/IR/Builders.h
+++ mlir/include/mlir/IR/Builders.h
@@ -61,6 +61,8 @@
   // Types.
   FloatType getFloat8E5M2Type();
   FloatType getFloat8E4M3FNType();
+  FloatType getFloat8E4M3FZNType();
+  FloatType getFloat8E5M2FZNType();
   FloatType getBF16Type();
   FloatType getF16Type();
   FloatType getF32Type();
Index: mlir/include/mlir-c/BuiltinTypes.h
===================================================================
--- mlir/include/mlir-c/BuiltinTypes.h
+++ mlir/include/mlir-c/BuiltinTypes.h
@@ -81,6 +81,20 @@
 /// context.
 MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx);
 
+/// Checks whether the given type is an f8E4M3FZN type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FZN(MlirType type);
+
+/// Creates an f8E4M3FZN type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FZNTypeGet(MlirContext ctx);
+
+/// Checks whether the given type is an f8E5M2FZN type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FZN(MlirType type);
+
+/// Creates an f8E5M2FZN type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FZNTypeGet(MlirContext ctx);
+
 /// Checks whether the given type is a bf16 type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
 
Index: llvm/unittests/ADT/APFloatTest.cpp
===================================================================
--- llvm/unittests/ADT/APFloatTest.cpp
+++ llvm/unittests/ADT/APFloatTest.cpp
@@ -1735,6 +1735,10 @@
   EXPECT_EQ(3.402823466e+38f, APFloat::getLargest(APFloat::IEEEsingle()).convertToFloat());
   EXPECT_EQ(1.7976931348623158e+308, APFloat::getLargest(APFloat::IEEEdouble()).convertToDouble());
   EXPECT_EQ(448, APFloat::getLargest(APFloat::Float8E4M3FN()).convertToDouble());
+  EXPECT_EQ(57344,
+            APFloat::getLargest(APFloat::Float8E5M2FZN()).convertToDouble());
+  EXPECT_EQ(240,
+            APFloat::getLargest(APFloat::Float8E4M3FZN()).convertToDouble());
 }
 
 TEST(APFloatTest, getSmallest) {
@@ -1840,9 +1844,12 @@
       {&APFloat::Float8E5M2(), true, {0x80ULL, 0}, 1},
       {&APFloat::Float8E4M3FN(), false, {0, 0}, 1},
       {&APFloat::Float8E4M3FN(), true, {0x80ULL, 0}, 1},
+      // Not testing sign = true cases because negative zero isn't supported.
+      {&APFloat::Float8E5M2FZN(), false, {0, 0}, 1},
+      {&APFloat::Float8E4M3FZN(), false, {0, 0}, 1},
   };
-  const unsigned NumGetZeroTests = 12;
-  for (unsigned i = 0; i < NumGetZeroTests; ++i) {
+
+  for (unsigned i = 0; i < std::size(GetZeroTest); ++i) {
     APFloat test = APFloat::getZero(*GetZeroTest[i].semantics,
                                     GetZeroTest[i].sign);
     const char *pattern = GetZeroTest[i].sign? "-0x0p+0" : "0x0p+0";
@@ -5233,20 +5240,501 @@
   }
 }
 
+TEST(APFloatTest, Float8E4M3FZNFromString) {
+  // Exactly representable
+  EXPECT_EQ(240, APFloat(APFloat::Float8E4M3FZN(), "240").convertToDouble());
+  // Round down to maximum value
+  EXPECT_EQ(240, APFloat(APFloat::Float8E4M3FZN(), "244").convertToDouble());
+  // Round up, causing overflow to NaN
+  EXPECT_TRUE(APFloat(APFloat::Float8E4M3FZN(), "256").isNaN());
+  // Overflow without rounding
+  EXPECT_TRUE(APFloat(APFloat::Float8E4M3FZN(), "480").isNaN());
+  // Inf converted to NaN
+  EXPECT_TRUE(APFloat(APFloat::Float8E4M3FZN(), "inf").isNaN());
+  // NaN converted to NaN
+  EXPECT_TRUE(APFloat(APFloat::Float8E4M3FZN(), "nan").isNaN());
+}
+
+TEST(APFloatTest, Float8E4M3FZNAdd) {
+  APFloat QNaN = APFloat::getNaN(APFloat::Float8E4M3FZN(), false);
+
+  auto FromStr = [](StringRef S) {
+    return APFloat(APFloat::Float8E4M3FZN(), S);
+  };
+
+  struct {
+    APFloat x;
+    APFloat y;
+    const char *result;
+    int status;
+    int category;
+    APFloat::roundingMode roundingMode = APFloat::rmNearestTiesToEven;
+  } AdditionTests[] = {
+      // Test addition operations involving NaN, overflow, and the max E4M3FZN
+      // value (240) because E4M3FZN differs from IEEE-754 types in these
+      // regards
+      {FromStr("240"), FromStr("4"), "240", APFloat::opInexact,
+       APFloat::fcNormal},
+      {FromStr("240"), FromStr("8"), "NaN",
+       APFloat::opOverflow | APFloat::opInexact, APFloat::fcNaN},
+      {FromStr("240"), FromStr("16"), "NaN",
+       APFloat::opOverflow | APFloat::opInexact, APFloat::fcNaN},
+      {FromStr("-240"), FromStr("-16"), "NaN",
+       APFloat::opOverflow | APFloat::opInexact, APFloat::fcNaN},
+      {QNaN, FromStr("-240"), "NaN", APFloat::opOK, APFloat::fcNaN},
+      {FromStr("240"), FromStr("-16"), "224", APFloat::opOK, APFloat::fcNormal},
+      {FromStr("240"), FromStr("0"), "240", APFloat::opOK, APFloat::fcNormal},
+      {FromStr("240"), FromStr("32"), "240", APFloat::opInexact,
+       APFloat::fcNormal, APFloat::rmTowardZero},
+      {FromStr("240"), FromStr("240"), "240", APFloat::opInexact,
+       APFloat::fcNormal, APFloat::rmTowardZero},
+  };
+
+  for (size_t i = 0; i < std::size(AdditionTests); ++i) {
+    APFloat x(AdditionTests[i].x);
+    APFloat y(AdditionTests[i].y);
+    APFloat::opStatus status = x.add(y, AdditionTests[i].roundingMode);
+
+    APFloat result(APFloat::Float8E4M3FZN(), AdditionTests[i].result);
+
+    EXPECT_TRUE(result.bitwiseIsEqual(x));
+    EXPECT_EQ(AdditionTests[i].status, (int)status);
+    EXPECT_EQ(AdditionTests[i].category, (int)x.getCategory());
+  }
+}
+
+TEST(APFloatTest, Float8E4M3FZNDivideByZero) {
+  APFloat x(APFloat::Float8E4M3FZN(), "1");
+  APFloat zero(APFloat::Float8E4M3FZN(), "0");
+  EXPECT_EQ(x.divide(zero, APFloat::rmNearestTiesToEven), APFloat::opDivByZero);
+  EXPECT_TRUE(x.isNaN());
+}
+
+TEST(APFloatTest, Float8E4M3FZNNext) {
+  APFloat test(APFloat::Float8E4M3FZN(), APFloat::uninitialized);
+  APFloat expected(APFloat::Float8E4M3FZN(), APFloat::uninitialized);
+
+  // nextUp on positive numbers
+  for (int i = 0; i < 127; i++) {
+    test = APFloat(APFloat::Float8E4M3FZN(), APInt(8, i));
+    expected = APFloat(APFloat::Float8E4M3FZN(), APInt(8, i + 1));
+    EXPECT_EQ(test.next(false), APFloat::opOK);
+    EXPECT_TRUE(test.bitwiseIsEqual(expected));
+  }
+
+  // nextUp on negative nonzero numbers
+  for (int i = 130; i < 255; i++) {
+    test = APFloat(APFloat::Float8E4M3FZN(), APInt(8, i));
+    expected = APFloat(APFloat::Float8E4M3FZN(), APInt(8, i - 1));
+    EXPECT_EQ(test.next(false), APFloat::opOK);
+    EXPECT_TRUE(test.bitwiseIsEqual(expected)) << i;
+  }
+
+  // nextUp on NaN
+  test = APFloat::getQNaN(APFloat::Float8E4M3FZN(), false);
+  expected = APFloat::getQNaN(APFloat::Float8E4M3FZN(), false);
+  EXPECT_EQ(test.next(false), APFloat::opOK);
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+
+  // nextDown on positive nonzero finite numbers
+  for (int i = 1; i < 127; i++) {
+    test = APFloat(APFloat::Float8E4M3FZN(), APInt(8, i));
+    expected = APFloat(APFloat::Float8E4M3FZN(), APInt(8, i - 1));
+    EXPECT_EQ(test.next(true), APFloat::opOK);
+    EXPECT_TRUE(test.bitwiseIsEqual(expected));
+  }
+
+  // nextDown on positive zero
+  test = APFloat::getZero(APFloat::Float8E4M3FZN(), true);
+  expected = APFloat::getSmallest(APFloat::Float8E4M3FZN(), true);
+  EXPECT_EQ(test.next(true), APFloat::opOK);
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+
+  // nextDown on negative finite numbers
+  for (int i = 129; i < 255; i++) {
+    test = APFloat(APFloat::Float8E4M3FZN(), APInt(8, i));
+    expected = APFloat(APFloat::Float8E4M3FZN(), APInt(8, i + 1));
+    EXPECT_EQ(test.next(true), APFloat::opOK);
+    EXPECT_TRUE(test.bitwiseIsEqual(expected)) << i;
+  }
+
+  // nextDown on NaN
+  test = APFloat::getQNaN(APFloat::Float8E4M3FZN(), false);
+  expected = APFloat::getQNaN(APFloat::Float8E4M3FZN(), false);
+  EXPECT_EQ(test.next(true), APFloat::opOK);
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+}
+
+TEST(APFloatTest, Float8E4M3FZNExhaustive) {
+  // Test each of the 256 Float8E4M3FZN values.
+  for (int i = 0; i < 256; i++) {
+    APFloat test(APFloat::Float8E4M3FZN(), APInt(8, i));
+    SCOPED_TRACE("i=" + std::to_string(i));
+
+    // isLargest
+    if (i == 127 || i == 255) {
+      EXPECT_TRUE(test.isLargest());
+      EXPECT_EQ(abs(test).convertToDouble(), 240.);
+    } else {
+      EXPECT_FALSE(test.isLargest());
+    }
+
+    // isSmallest
+    if (i == 1 || i == 129) {
+      EXPECT_TRUE(test.isSmallest());
+      EXPECT_EQ(abs(test).convertToDouble(), 0x1p-10);
+    } else {
+      EXPECT_FALSE(test.isSmallest());
+    }
+
+    // convert to BFloat
+    APFloat test2 = test;
+    bool loses_info;
+    APFloat::opStatus status = test2.convert(
+        APFloat::BFloat(), APFloat::rmNearestTiesToEven, &loses_info);
+    EXPECT_EQ(status, APFloat::opOK);
+    EXPECT_FALSE(loses_info);
+    if (i == 128)
+      EXPECT_TRUE(test2.isNaN());
+    else
+      EXPECT_EQ(test.convertToFloat(), test2.convertToFloat());
+
+    // bitcastToAPInt
+    EXPECT_EQ(i, test.bitcastToAPInt());
+  }
+}
+
+TEST(APFloatTest, Float8E4M3FZNExhaustivePair) {
+  // Test each pair of Float8E4M3FZN values.
+  for (int i = 0; i < 256; i++) {
+    for (int j = 0; j < 256; j++) {
+      SCOPED_TRACE("i=" + std::to_string(i) + ",j=" + std::to_string(j));
+      APFloat x(APFloat::Float8E4M3FZN(), APInt(8, i));
+      APFloat y(APFloat::Float8E4M3FZN(), APInt(8, j));
+
+      bool losesInfo;
+      APFloat x16 = x;
+      x16.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_FALSE(losesInfo);
+      APFloat y16 = y;
+      y16.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_FALSE(losesInfo);
+
+      // Add
+      APFloat z = x;
+      z.add(y, APFloat::rmNearestTiesToEven);
+      APFloat z16 = x16;
+      z16.add(y16, APFloat::rmNearestTiesToEven);
+      z16.convert(APFloat::Float8E4M3FZN(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_TRUE(z.bitwiseIsEqual(z16));
+
+      // Subtract
+      z = x;
+      z.subtract(y, APFloat::rmNearestTiesToEven);
+      z16 = x16;
+      z16.subtract(y16, APFloat::rmNearestTiesToEven);
+      z16.convert(APFloat::Float8E4M3FZN(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_TRUE(z.bitwiseIsEqual(z16));
+
+      // Multiply
+      z = x;
+      z.multiply(y, APFloat::rmNearestTiesToEven);
+      z16 = x16;
+      z16.multiply(y16, APFloat::rmNearestTiesToEven);
+      z16.convert(APFloat::Float8E4M3FZN(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j;
+
+      // Divide
+      z = x;
+      z.divide(y, APFloat::rmNearestTiesToEven);
+      z16 = x16;
+      z16.divide(y16, APFloat::rmNearestTiesToEven);
+      z16.convert(APFloat::Float8E4M3FZN(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j;
+
+      // Mod
+      z = x;
+      z.mod(y);
+      z16 = x16;
+      z16.mod(y16);
+      z16.convert(APFloat::Float8E4M3FZN(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j;
+
+      // Remainder
+      z = x;
+      z.remainder(y);
+      z16 = x16;
+      z16.remainder(y16);
+      z16.convert(APFloat::Float8E4M3FZN(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j;
+    }
+  }
+}
+
+TEST(APFloatTest, Float8E5M2FZNFromString) {
+  // Exactly representable
+  EXPECT_EQ(57344,
+            APFloat(APFloat::Float8E5M2FZN(), "57344").convertToDouble());
+  // Round down to maximum value
+  EXPECT_EQ(57344,
+            APFloat(APFloat::Float8E5M2FZN(), "59392").convertToDouble());
+  // Round up, causing overflow to NaN
+  EXPECT_TRUE(APFloat(APFloat::Float8E5M2FZN(), "61440").isNaN());
+  // Overflow without rounding
+  EXPECT_TRUE(APFloat(APFloat::Float8E5M2FZN(), "131072").isNaN());
+  // Inf converted to NaN
+  EXPECT_TRUE(APFloat(APFloat::Float8E5M2FZN(), "inf").isNaN());
+  // NaN converted to NaN
+  EXPECT_TRUE(APFloat(APFloat::Float8E5M2FZN(), "nan").isNaN());
+}
+
+TEST(APFloatTest, Float8E5M2FZNAdd) {
+  APFloat QNaN = APFloat::getNaN(APFloat::Float8E5M2FZN(), false);
+
+  auto FromStr = [](StringRef S) {
+    return APFloat(APFloat::Float8E5M2FZN(), S);
+  };
+
+  struct {
+    APFloat x;
+    APFloat y;
+    const char *result;
+    int status;
+    int category;
+    APFloat::roundingMode roundingMode = APFloat::rmNearestTiesToEven;
+  } AdditionTests[] = {
+      // Test addition operations involving NaN, overflow, and the max E5M2FZN
+      // value (57344) because E5M2FZN differs from IEEE-754 types in these
+      // regards
+      {FromStr("57344"), FromStr("2048"), "57344", APFloat::opInexact,
+       APFloat::fcNormal},
+      {FromStr("57344"), FromStr("4096"), "NaN",
+       APFloat::opOverflow | APFloat::opInexact, APFloat::fcNaN},
+      {FromStr("-57344"), FromStr("-4096"), "NaN",
+       APFloat::opOverflow | APFloat::opInexact, APFloat::fcNaN},
+      {QNaN, FromStr("-57344"), "NaN", APFloat::opOK, APFloat::fcNaN},
+      {FromStr("57344"), FromStr("-8192"), "49152", APFloat::opOK,
+       APFloat::fcNormal},
+      {FromStr("57344"), FromStr("0"), "57344", APFloat::opOK,
+       APFloat::fcNormal},
+      {FromStr("57344"), FromStr("4096"), "57344", APFloat::opInexact,
+       APFloat::fcNormal, APFloat::rmTowardZero},
+      {FromStr("57344"), FromStr("57344"), "57344", APFloat::opInexact,
+       APFloat::fcNormal, APFloat::rmTowardZero},
+  };
+
+  for (size_t i = 0; i < std::size(AdditionTests); ++i) {
+    APFloat x(AdditionTests[i].x);
+    APFloat y(AdditionTests[i].y);
+    APFloat::opStatus status = x.add(y, AdditionTests[i].roundingMode);
+
+    APFloat result(APFloat::Float8E5M2FZN(), AdditionTests[i].result);
+
+    EXPECT_TRUE(result.bitwiseIsEqual(x));
+    EXPECT_EQ(AdditionTests[i].status, (int)status);
+    EXPECT_EQ(AdditionTests[i].category, (int)x.getCategory());
+  }
+}
+
+TEST(APFloatTest, Float8E5M2FZNDivideByZero) {
+  APFloat x(APFloat::Float8E5M2FZN(), "1");
+  APFloat zero(APFloat::Float8E5M2FZN(), "0");
+  EXPECT_EQ(x.divide(zero, APFloat::rmNearestTiesToEven), APFloat::opDivByZero);
+  EXPECT_TRUE(x.isNaN());
+}
+
+TEST(APFloatTest, Float8E5M2FZNNext) {
+  APFloat test(APFloat::Float8E5M2FZN(), APFloat::uninitialized);
+  APFloat expected(APFloat::Float8E5M2FZN(), APFloat::uninitialized);
+
+  // nextUp on positive numbers
+  for (int i = 0; i < 127; i++) {
+    test = APFloat(APFloat::Float8E5M2FZN(), APInt(8, i));
+    expected = APFloat(APFloat::Float8E5M2FZN(), APInt(8, i + 1));
+    EXPECT_EQ(test.next(false), APFloat::opOK);
+    EXPECT_TRUE(test.bitwiseIsEqual(expected));
+  }
+
+  // nextUp on negative nonzero numbers
+  for (int i = 130; i < 255; i++) {
+    test = APFloat(APFloat::Float8E5M2FZN(), APInt(8, i));
+    expected = APFloat(APFloat::Float8E5M2FZN(), APInt(8, i - 1));
+    EXPECT_EQ(test.next(false), APFloat::opOK);
+    EXPECT_TRUE(test.bitwiseIsEqual(expected)) << i;
+  }
+
+  // nextUp on NaN
+  test = APFloat::getQNaN(APFloat::Float8E5M2FZN(), false);
+  expected = APFloat::getQNaN(APFloat::Float8E5M2FZN(), false);
+  EXPECT_EQ(test.next(false), APFloat::opOK);
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+
+  // nextDown on positive nonzero finite numbers
+  for (int i = 1; i < 127; i++) {
+    test = APFloat(APFloat::Float8E5M2FZN(), APInt(8, i));
+    expected = APFloat(APFloat::Float8E5M2FZN(), APInt(8, i - 1));
+    EXPECT_EQ(test.next(true), APFloat::opOK);
+    EXPECT_TRUE(test.bitwiseIsEqual(expected));
+  }
+
+  // nextDown on positive zero
+  test = APFloat::getZero(APFloat::Float8E5M2FZN(), true);
+  expected = APFloat::getSmallest(APFloat::Float8E5M2FZN(), true);
+  EXPECT_EQ(test.next(true), APFloat::opOK);
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+
+  // nextDown on negative finite numbers
+  for (int i = 129; i < 255; i++) {
+    test = APFloat(APFloat::Float8E5M2FZN(), APInt(8, i));
+    expected = APFloat(APFloat::Float8E5M2FZN(), APInt(8, i + 1));
+    EXPECT_EQ(test.next(true), APFloat::opOK);
+    EXPECT_TRUE(test.bitwiseIsEqual(expected)) << i;
+  }
+
+  // nextDown on NaN
+  test = APFloat::getQNaN(APFloat::Float8E5M2FZN(), false);
+  expected = APFloat::getQNaN(APFloat::Float8E5M2FZN(), false);
+  EXPECT_EQ(test.next(true), APFloat::opOK);
+  EXPECT_TRUE(test.bitwiseIsEqual(expected));
+}
+
+TEST(APFloatTest, Float8E5M2FZNExhaustive) {
+  // Test each of the 256 Float8E5M2FZN values.
+  for (int i = 0; i < 256; i++) {
+    APFloat test(APFloat::Float8E5M2FZN(), APInt(8, i));
+    SCOPED_TRACE("i=" + std::to_string(i));
+
+    // isLargest
+    if (i == 127 || i == 255) {
+      EXPECT_TRUE(test.isLargest());
+      EXPECT_EQ(abs(test).convertToDouble(), 57344);
+    } else {
+      EXPECT_FALSE(test.isLargest());
+    }
+
+    // isSmallest
+    if (i == 1 || i == 129) {
+      EXPECT_TRUE(test.isSmallest());
+      EXPECT_EQ(abs(test).convertToDouble(), 0x1p-17);
+    } else {
+      EXPECT_FALSE(test.isSmallest());
+    }
+
+    // convert to BFloat
+    APFloat test2 = test;
+    bool loses_info;
+    APFloat::opStatus status = test2.convert(
+        APFloat::BFloat(), APFloat::rmNearestTiesToEven, &loses_info);
+    EXPECT_EQ(status, APFloat::opOK);
+    EXPECT_FALSE(loses_info);
+    if (i == 128)
+      EXPECT_TRUE(test2.isNaN());
+    else
+      EXPECT_EQ(test.convertToFloat(), test2.convertToFloat());
+
+    // bitcastToAPInt
+    EXPECT_EQ(i, test.bitcastToAPInt());
+  }
+}
+
+TEST(APFloatTest, Float8E5M2FZNExhaustivePair) {
+  // Test each pair of Float8E5M2FZN values.
+  for (int i = 0; i < 256; i++) {
+    for (int j = 0; j < 256; j++) {
+      SCOPED_TRACE("i=" + std::to_string(i) + ",j=" + std::to_string(j));
+      APFloat x(APFloat::Float8E5M2FZN(), APInt(8, i));
+      APFloat y(APFloat::Float8E5M2FZN(), APInt(8, j));
+
+      bool losesInfo;
+      APFloat x16 = x;
+      x16.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_FALSE(losesInfo);
+      APFloat y16 = y;
+      y16.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_FALSE(losesInfo);
+
+      // Add
+      APFloat z = x;
+      z.add(y, APFloat::rmNearestTiesToEven);
+      APFloat z16 = x16;
+      z16.add(y16, APFloat::rmNearestTiesToEven);
+      z16.convert(APFloat::Float8E5M2FZN(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_TRUE(z.bitwiseIsEqual(z16));
+
+      // Subtract
+      z = x;
+      z.subtract(y, APFloat::rmNearestTiesToEven);
+      z16 = x16;
+      z16.subtract(y16, APFloat::rmNearestTiesToEven);
+      z16.convert(APFloat::Float8E5M2FZN(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_TRUE(z.bitwiseIsEqual(z16));
+
+      // Multiply
+      z = x;
+      z.multiply(y, APFloat::rmNearestTiesToEven);
+      z16 = x16;
+      z16.multiply(y16, APFloat::rmNearestTiesToEven);
+      z16.convert(APFloat::Float8E5M2FZN(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j;
+
+      // Divide
+      z = x;
+      z.divide(y, APFloat::rmNearestTiesToEven);
+      z16 = x16;
+      z16.divide(y16, APFloat::rmNearestTiesToEven);
+      z16.convert(APFloat::Float8E5M2FZN(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j;
+
+      // Mod
+      z = x;
+      z.mod(y);
+      z16 = x16;
+      z16.mod(y16);
+      z16.convert(APFloat::Float8E5M2FZN(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j;
+
+      // Remainder
+      z = x;
+      z.remainder(y);
+      z16 = x16;
+      z16.remainder(y16);
+      z16.convert(APFloat::Float8E5M2FZN(), APFloat::rmNearestTiesToEven,
+                  &losesInfo);
+      EXPECT_TRUE(z.bitwiseIsEqual(z16)) << "i=" << i << ", j=" << j;
+    }
+  }
+}
+
 TEST(APFloatTest, F8ToString) {
   for (APFloat::Semantics S :
-       {APFloat::S_Float8E5M2, APFloat::S_Float8E4M3FN}) {
+       {APFloat::S_Float8E5M2, APFloat::S_Float8E4M3FN,
+        APFloat::S_Float8E5M2FZN, APFloat::S_Float8E4M3FZN}) {
     SCOPED_TRACE("Semantics=" + std::to_string(S));
     for (int i = 0; i < 256; i++) {
       SCOPED_TRACE("i=" + std::to_string(i));
-      APFloat test(APFloat::Float8E5M2(), APInt(8, i));
+      APFloat test(APFloat::EnumToSemantics(S), APInt(8, i));
       llvm::SmallString<128> str;
       test.toString(str);
 
       if (test.isNaN()) {
         EXPECT_EQ(str, "NaN");
       } else {
-        APFloat test2(APFloat::Float8E5M2(), str);
+        APFloat test2(APFloat::EnumToSemantics(S), str);
         EXPECT_TRUE(test.bitwiseIsEqual(test2));
       }
     }
@@ -5458,6 +5946,56 @@
   EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
 }
 
+TEST(APFloatTest, Float8E5M2FZNToDouble) {
+  APFloat One(APFloat::Float8E5M2FZN(), "1.0");
+  EXPECT_EQ(1.0, One.convertToDouble());
+  APFloat Two(APFloat::Float8E5M2FZN(), "2.0");
+  EXPECT_EQ(2.0, Two.convertToDouble());
+  APFloat PosLargest = APFloat::getLargest(APFloat::Float8E5M2FZN(), false);
+  EXPECT_EQ(57344., PosLargest.convertToDouble());
+  APFloat NegLargest = APFloat::getLargest(APFloat::Float8E5M2FZN(), true);
+  EXPECT_EQ(-57344., NegLargest.convertToDouble());
+  APFloat PosSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E5M2FZN(), false);
+  EXPECT_EQ(0x1.p-15, PosSmallest.convertToDouble());
+  APFloat NegSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E5M2FZN(), true);
+  EXPECT_EQ(-0x1.p-15, NegSmallest.convertToDouble());
+
+  APFloat SmallestDenorm =
+      APFloat::getSmallest(APFloat::Float8E5M2FZN(), false);
+  EXPECT_TRUE(SmallestDenorm.isDenormal());
+  EXPECT_EQ(0x1p-17, SmallestDenorm.convertToDouble());
+
+  APFloat QNaN = APFloat::getQNaN(APFloat::Float8E5M2FZN());
+  EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
+}
+
+TEST(APFloatTest, Float8E4M3FZNToDouble) {
+  APFloat One(APFloat::Float8E4M3FZN(), "1.0");
+  EXPECT_EQ(1.0, One.convertToDouble());
+  APFloat Two(APFloat::Float8E4M3FZN(), "2.0");
+  EXPECT_EQ(2.0, Two.convertToDouble());
+  APFloat PosLargest = APFloat::getLargest(APFloat::Float8E4M3FZN(), false);
+  EXPECT_EQ(240., PosLargest.convertToDouble());
+  APFloat NegLargest = APFloat::getLargest(APFloat::Float8E4M3FZN(), true);
+  EXPECT_EQ(-240., NegLargest.convertToDouble());
+  APFloat PosSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E4M3FZN(), false);
+  EXPECT_EQ(0x1.p-7, PosSmallest.convertToDouble());
+  APFloat NegSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E4M3FZN(), true);
+  EXPECT_EQ(-0x1.p-7, NegSmallest.convertToDouble());
+
+  APFloat SmallestDenorm =
+      APFloat::getSmallest(APFloat::Float8E4M3FZN(), false);
+  EXPECT_TRUE(SmallestDenorm.isDenormal());
+  EXPECT_EQ(0x1p-10, SmallestDenorm.convertToDouble());
+
+  APFloat QNaN = APFloat::getQNaN(APFloat::Float8E4M3FZN());
+  EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
+}
+
 TEST(APFloatTest, IEEEsingleToFloat) {
   APFloat FPosZero(0.0F);
   APFloat FPosZeroToFloat(FPosZero.convertToFloat());
@@ -5638,4 +6176,74 @@
   EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
 }
 
+TEST(APFloatTest, Float8E5M2FZNToFloat) {
+  APFloat PosZero = APFloat::getZero(APFloat::Float8E5M2FZN());
+  APFloat PosZeroToFloat(PosZero.convertToFloat());
+  EXPECT_TRUE(PosZeroToFloat.isPosZero());
+
+  // Negative zero is not supported
+  APFloat NegZero = APFloat::getZero(APFloat::Float8E5M2FZN(), true);
+  APFloat NegZeroToFloat(NegZero.convertToFloat());
+  EXPECT_TRUE(NegZeroToFloat.isPosZero());
+
+  APFloat One(APFloat::Float8E5M2FZN(), "1.0");
+  EXPECT_EQ(1.0F, One.convertToFloat());
+  APFloat Two(APFloat::Float8E5M2FZN(), "2.0");
+  EXPECT_EQ(2.0F, Two.convertToFloat());
+
+  APFloat PosLargest = APFloat::getLargest(APFloat::Float8E5M2FZN(), false);
+  EXPECT_EQ(57344., PosLargest.convertToFloat());
+  APFloat NegLargest = APFloat::getLargest(APFloat::Float8E5M2FZN(), true);
+  EXPECT_EQ(-57344., NegLargest.convertToFloat());
+  APFloat PosSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E5M2FZN(), false);
+  EXPECT_EQ(0x1.p-15, PosSmallest.convertToFloat());
+  APFloat NegSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E5M2FZN(), true);
+  EXPECT_EQ(-0x1.p-15, NegSmallest.convertToFloat());
+
+  APFloat SmallestDenorm =
+      APFloat::getSmallest(APFloat::Float8E5M2FZN(), false);
+  EXPECT_TRUE(SmallestDenorm.isDenormal());
+  EXPECT_EQ(0x1.p-17, SmallestDenorm.convertToFloat());
+
+  APFloat QNaN = APFloat::getQNaN(APFloat::Float8E5M2FZN());
+  EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
+}
+
+TEST(APFloatTest, Float8E4M3FZNToFloat) {
+  APFloat PosZero = APFloat::getZero(APFloat::Float8E4M3FZN());
+  APFloat PosZeroToFloat(PosZero.convertToFloat());
+  EXPECT_TRUE(PosZeroToFloat.isPosZero());
+
+  // No negative zero
+  APFloat NegZero = APFloat::getZero(APFloat::Float8E4M3FZN(), true);
+  APFloat NegZeroToFloat(NegZero.convertToFloat());
+  EXPECT_TRUE(NegZeroToFloat.isPosZero());
+
+  APFloat One(APFloat::Float8E4M3FZN(), "1.0");
+  EXPECT_EQ(1.0F, One.convertToFloat());
+  APFloat Two(APFloat::Float8E4M3FZN(), "2.0");
+  EXPECT_EQ(2.0F, Two.convertToFloat());
+
+  APFloat PosLargest = APFloat::getLargest(APFloat::Float8E4M3FZN(), false);
+  EXPECT_EQ(240., PosLargest.convertToFloat());
+  APFloat NegLargest = APFloat::getLargest(APFloat::Float8E4M3FZN(), true);
+  EXPECT_EQ(-240, NegLargest.convertToFloat());
+  APFloat PosSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E4M3FZN(), false);
+  EXPECT_EQ(0x1.p-7, PosSmallest.convertToFloat());
+  APFloat NegSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E4M3FZN(), true);
+  EXPECT_EQ(-0x1.p-7, NegSmallest.convertToFloat());
+
+  APFloat SmallestDenorm =
+      APFloat::getSmallest(APFloat::Float8E4M3FZN(), false);
+  EXPECT_TRUE(SmallestDenorm.isDenormal());
+  EXPECT_EQ(0x1.p-10, SmallestDenorm.convertToFloat());
+
+  APFloat QNaN = APFloat::getQNaN(APFloat::Float8E4M3FZN());
+  EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
+}
+
 } // namespace
Index: llvm/lib/Support/APFloat.cpp
===================================================================
--- llvm/lib/Support/APFloat.cpp
+++ llvm/lib/Support/APFloat.cpp
@@ -65,6 +65,25 @@
     // as non-signalling, although the paper does not state whether the NaN
     // values are signalling or not.
     NanOnly,
+
+    // Float8E5M2FZN and Float8E4M3FZN have this behavior. There is no Inf
+    // representation. A value is NaN if the exponent field and the mantissa
+    // field are all 0s, and sign bit is 1. This behavior matches the FP8 types
+    // described in https://arxiv.org/abs/2206.02915.
+    NanOnlyS1E0M0,
+  };
+
+  // Whether both positive and negative zero or only positive zero are
+  // represented.
+  enum class fltSignedZeroSupport {
+    // Represents standard IEEE 754 behavior. Zero can be both positive and
+    // negative.
+    IEEE754,
+
+    // Float8E5M2FZN and Float8E4M3FZN have this behavior. They only support
+    // positive zero. This behavior matches the FP8 types described in
+    // https://arxiv.org/abs/2206.02915.
+    PositiveOnly,
   };
 
   /* Represents floating point arithmetic semantics.  */
@@ -86,6 +105,9 @@
 
     fltNonfiniteBehavior nonFiniteBehavior = fltNonfiniteBehavior::IEEE754;
 
+    /* The supported zero encodings. */
+    fltSignedZeroSupport zeroSupport = fltSignedZeroSupport::IEEE754;
+
     // Returns true if any number described by this semantics can be precisely
     // represented by the specified semantics. Does not take into account
     // the value of fltNonfiniteBehavior.
@@ -103,6 +125,14 @@
   static const fltSemantics semFloat8E5M2 = {15, -14, 3, 8};
   static const fltSemantics semFloat8E4M3FN = {8, -6, 4, 8,
                                                fltNonfiniteBehavior::NanOnly};
+  static const fltSemantics semFloat8E5M2FZN = {
+      15, -15, 3, 8,
+      fltNonfiniteBehavior::NanOnlyS1E0M0,
+      fltSignedZeroSupport::PositiveOnly};
+  static const fltSemantics semFloat8E4M3FZN = {
+      7, -7, 4, 8,
+      fltNonfiniteBehavior::NanOnlyS1E0M0,
+      fltSignedZeroSupport::PositiveOnly};
   static const fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80};
   static const fltSemantics semBogus = {0, 0, 0, 0};
 
@@ -162,6 +192,10 @@
       return Float8E5M2();
     case S_Float8E4M3FN:
       return Float8E4M3FN();
+    case S_Float8E5M2FZN:
+      return Float8E5M2FZN();
+    case S_Float8E4M3FZN:
+      return Float8E4M3FZN();
     case S_x87DoubleExtended:
       return x87DoubleExtended();
     }
@@ -186,6 +220,10 @@
       return S_Float8E5M2;
     else if (&Sem == &llvm::APFloat::Float8E4M3FN())
       return S_Float8E4M3FN;
+    else if (&Sem == &llvm::APFloat::Float8E5M2FZN())
+      return S_Float8E5M2FZN;
+    else if (&Sem == &llvm::APFloat::Float8E4M3FZN())
+      return S_Float8E4M3FZN;
     else if (&Sem == &llvm::APFloat::x87DoubleExtended())
       return S_x87DoubleExtended;
     else
@@ -210,6 +248,8 @@
   }
   const fltSemantics &APFloatBase::Float8E5M2() { return semFloat8E5M2; }
   const fltSemantics &APFloatBase::Float8E4M3FN() { return semFloat8E4M3FN; }
+  const fltSemantics &APFloatBase::Float8E5M2FZN() { return semFloat8E5M2FZN; }
+  const fltSemantics &APFloatBase::Float8E4M3FZN() { return semFloat8E4M3FZN; }
   const fltSemantics &APFloatBase::x87DoubleExtended() {
     return semX87DoubleExtended;
   }
@@ -805,6 +845,13 @@
     fill = &fill_storage;
   }
 
+  if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0) {
+    // There is only one valid NaN encoding.
+    sign = 1;
+    *significand = 0;
+    return;
+  }
+
   // Set the significand bits to the fill.
   if (!fill || fill->getNumWords() < numParts)
     APInt::tcSet(significand, 0, numParts);
@@ -1406,7 +1453,8 @@
       rounding_mode == rmNearestTiesToAway ||
       (rounding_mode == rmTowardPositive && !sign) ||
       (rounding_mode == rmTowardNegative && sign)) {
-    if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly)
+    if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly ||
+        semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0)
       makeNaN(false, sign);
     else
       category = fcInfinity;
@@ -1546,6 +1594,12 @@
 
     /* Did the significand increment overflow?  */
     if (omsb == (unsigned) semantics->precision + 1) {
+      // NanOnlyS1E0M0 types can't overflow to infinity.
+      if (exponent == semantics->maxExponent &&
+          semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0) {
+        return handleOverflow(rounding_mode);
+      }
+
       /* Renormalize by incrementing the exponent and shifting our
          significand right one.  However if we already have the
          maximum exponent we overflow to infinity.  */
@@ -1783,7 +1837,8 @@
     return opOK;
 
   case PackCategoriesIntoKey(fcNormal, fcZero):
-    if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly)
+    if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly ||
+        semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0)
       makeNaN(false, sign);
     else
       category = fcInfinity;
@@ -2347,8 +2402,8 @@
 
   // If this is a truncation, perform the shift before we narrow the storage.
   if (shift < 0 && (isFiniteNonZero() ||
-                    (category == fcNaN && semantics->nonFiniteBehavior !=
-                                              fltNonfiniteBehavior::NanOnly)))
+                    (category == fcNaN && semantics->nonFiniteBehavior ==
+                                              fltNonfiniteBehavior::IEEE754)))
     lostFraction = shiftRight(significandParts(), oldPartCount, -shift);
 
   // Fix the storage so it can hold to new value.
@@ -2370,6 +2425,14 @@
     significand.part = newPart;
   }
 
+  // A bit needs to be set in the significand when converting from a
+  // NanOnlyS1E0M0 encoding, otherwise this will become a -Inf for types that
+  // follow the IEEE-754 convention.
+  if (category == fcNaN &&
+      (fromSemantics.nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0 &&
+       toSemantics.nonFiniteBehavior != fltNonfiniteBehavior::NanOnlyS1E0M0))
+    APInt::tcSetBit(significandParts(), 1);
+
   // Now that we have the right storage, switch the semantics.
   semantics = &toSemantics;
 
@@ -2382,9 +2445,10 @@
     fs = normalize(rounding_mode, lostFraction);
     *losesInfo = (fs != opOK);
   } else if (category == fcNaN) {
-    if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) {
+    if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly ||
+        semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0) {
       *losesInfo =
-          fromSemantics.nonFiniteBehavior != fltNonfiniteBehavior::NanOnly;
+          fromSemantics.nonFiniteBehavior != semantics->nonFiniteBehavior;
       makeNaN(false, sign);
       return is_signaling ? opInvalidOp : opOK;
     }
@@ -2406,7 +2470,9 @@
       fs = opOK;
     }
   } else if (category == fcInfinity &&
-             semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) {
+             (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly ||
+              semantics->nonFiniteBehavior ==
+                  fltNonfiniteBehavior::NanOnlyS1E0M0)) {
     makeNaN(false, sign);
     *losesInfo = true;
     fs = opInexact;
@@ -3534,6 +3600,60 @@
                    (mysignificand & 0x7)));
 }
 
+APInt IEEEFloat::convertFloat8E5M2FZNAPFloatToAPInt() const {
+  assert(semantics == (const llvm::fltSemantics *)&semFloat8E5M2FZN);
+  assert(partCount() == 1);
+
+  uint32_t mysign, myexponent, mysignificand;
+
+  if (isFiniteNonZero()) {
+    mysign = sign;
+    myexponent = exponent + 16; // bias
+    mysignificand = (uint32_t)*significandParts();
+    if (myexponent == 1 && !(mysignificand & 0x4))
+      myexponent = 0; // denormal
+  } else if (category == fcZero) {
+    mysign = 0;
+    myexponent = 0;
+    mysignificand = 0;
+  } else {
+    assert(category == fcNaN && "Unknown category!");
+    mysign = 1;
+    myexponent = 0;
+    mysignificand = 0;
+  }
+
+  return APInt(8, (((mysign & 1) << 7) | ((myexponent & 0x1f) << 2) |
+                   (mysignificand & 0x3)));
+}
+
+APInt IEEEFloat::convertFloat8E4M3FZNAPFloatToAPInt() const {
+  assert(semantics == (const llvm::fltSemantics *)&semFloat8E4M3FZN);
+  assert(partCount() == 1);
+
+  uint32_t mysign, myexponent, mysignificand;
+
+  if (isFiniteNonZero()) {
+    mysign = sign;
+    myexponent = exponent + 8; // bias
+    mysignificand = (uint32_t)*significandParts();
+    if (myexponent == 1 && !(mysignificand & 0x8))
+      myexponent = 0; // denormal
+  } else if (category == fcZero) {
+    mysign = 0;
+    myexponent = 0;
+    mysignificand = 0;
+  } else {
+    assert(category == fcNaN && "Unknown category!");
+    mysign = 1;
+    myexponent = 0;
+    mysignificand = 0;
+  }
+
+  return APInt(8, (((mysign & 1) << 7) | ((myexponent & 0xf) << 3) |
+                   (mysignificand & 0x7)));
+}
+
 // This function creates an APInt that is just a bit map of the floating
 // point constant as it would appear in memory.  It is not a conversion,
 // and treating the result as a normal integer is unlikely to be useful.
@@ -3563,6 +3683,12 @@
   if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FN)
     return convertFloat8E4M3FNAPFloatToAPInt();
 
+  if (semantics == (const llvm::fltSemantics *)&semFloat8E5M2FZN)
+    return convertFloat8E5M2FZNAPFloatToAPInt();
+
+  if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FZN)
+    return convertFloat8E4M3FZNAPFloatToAPInt();
+
   assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended &&
          "unknown format!");
   return convertF80LongDoubleAPFloatToAPInt();
@@ -3844,6 +3970,66 @@
   }
 }
 
+void IEEEFloat::initFromFloat8E5M2FZNAPInt(const APInt &api) {
+  uint32_t i = (uint32_t)*api.getRawData();
+  uint32_t mysign = (i >> 7) & 1;
+  uint32_t myexponent = (i >> 2) & 0x1f;
+  uint32_t mysignificand = i & 0x3;
+
+  initialize(&semFloat8E5M2FZN);
+  assert(partCount() == 1);
+
+  if (mysign == 0 && myexponent == 0 && mysignificand == 0) {
+    // Negative zero not supported.
+    makeZero(false);
+  } else if (mysign == 1 && myexponent == 0 && mysignificand == 0) {
+    category = fcNaN;
+    sign = 1;
+    exponent = exponentNaN();
+    *significandParts() = mysignificand;
+  } else {
+    category = fcNormal;
+    sign = mysign;
+    exponent = myexponent - 16; // bias
+    *significandParts() = mysignificand;
+
+    if (myexponent == 0) // denormal
+      exponent = -15;
+    else
+      *significandParts() |= 0x4; // integer bit
+  }
+}
+
+void IEEEFloat::initFromFloat8E4M3FZNAPInt(const APInt &api) {
+  uint32_t i = (uint32_t)*api.getRawData();
+  uint32_t mysign = (i >> 7) & 1;
+  uint32_t myexponent = (i >> 3) & 0xf;
+  uint32_t mysignificand = i & 0x7;
+
+  initialize(&semFloat8E4M3FZN);
+  assert(partCount() == 1);
+
+  if (mysign == 0 && myexponent == 0 && mysignificand == 0) {
+    // Negative zero not supported.
+    makeZero(false);
+  } else if (mysign == 1 && myexponent == 0 && mysignificand == 0) {
+    category = fcNaN;
+    sign = 1;
+    exponent = exponentNaN();
+    *significandParts() = mysignificand;
+  } else {
+    category = fcNormal;
+    sign = mysign;
+    exponent = myexponent - 8; // bias
+    *significandParts() = mysignificand;
+
+    if (myexponent == 0) // denormal
+      exponent = -7;
+    else
+      *significandParts() |= 0x8; // integer bit
+  }
+}
+
 /// Treat api as containing the bits of a floating point number.
 void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
   assert(api.getBitWidth() == Sem->sizeInBits);
@@ -3865,6 +4051,10 @@
     return initFromFloat8E5M2APInt(api);
   if (Sem == &semFloat8E4M3FN)
     return initFromFloat8E4M3FNAPInt(api);
+  if (Sem == &semFloat8E5M2FZN)
+    return initFromFloat8E5M2FZNAPInt(api);
+  if (Sem == &semFloat8E4M3FZN)
+    return initFromFloat8E4M3FZNAPInt(api);
 
   llvm_unreachable(nullptr);
 }
@@ -4274,6 +4464,8 @@
     return false;
   if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly)
     return false;
+  if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0)
+    return false;
 
   // IEEE-754R 2008 6.2.1: A signaling NaN bit string should be encoded with the
   // first bit of the trailing significand being 0.
@@ -4325,7 +4517,8 @@
     }
 
     if (isLargest() && !isNegative()) {
-      if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) {
+      if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly ||
+          semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0) {
         // nextUp(getLargest()) == NAN
         makeNaN();
         break;
@@ -4409,6 +4602,8 @@
 APFloatBase::ExponentType IEEEFloat::exponentNaN() const {
   if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly)
     return semantics->maxExponent;
+  if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0)
+    return 0;
   return semantics->maxExponent + 1;
 }
 
@@ -4421,7 +4616,8 @@
 }
 
 void IEEEFloat::makeInf(bool Negative) {
-  if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) {
+  if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly ||
+      semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnlyS1E0M0) {
     // There is no Inf, so make NaN instead.
     makeNaN(false, Negative);
     return;
@@ -4433,6 +4629,10 @@
 }
 
 void IEEEFloat::makeZero(bool Negative) {
+  if (semantics->zeroSupport == fltSignedZeroSupport::PositiveOnly) {
+    Negative = false;
+  }
+
   category = fcZero;
   sign = Negative;
   exponent = exponentZero();
@@ -4441,7 +4641,8 @@
 
 void IEEEFloat::makeQuiet() {
   assert(isNaN());
-  if (semantics->nonFiniteBehavior != fltNonfiniteBehavior::NanOnly)
+  if (semantics->nonFiniteBehavior != fltNonfiniteBehavior::NanOnly &&
+      semantics->nonFiniteBehavior != fltNonfiniteBehavior::NanOnlyS1E0M0)
     APInt::tcSetBit(significandParts(), semantics->precision - 2);
 }
 
Index: llvm/include/llvm/ADT/APFloat.h
===================================================================
--- llvm/include/llvm/ADT/APFloat.h
+++ llvm/include/llvm/ADT/APFloat.h
@@ -163,6 +163,18 @@
     // Unlike IEEE-754 types, there are no infinity values, and NaN is
     // represented with the exponent and mantissa bits set to all 1s.
     S_Float8E4M3FN,
+    // 8-bit floating point number mostly following IEEE-754 conventions with
+    // bit layout S1E5M2 as described in https://arxiv.org/abs/2206.02915.
+    // Unlike IEEE-754 types, there are no infinity values, there is no
+    // negative zero, and NaN is represented with the exponent and mantissa
+    // bits set to all 0s and the sign bit set to 1.
+    S_Float8E5M2FZN,
+    // 8-bit floating point number mostly following IEEE-754 conventions with
+    // bit layout S1E4M3 as described in https://arxiv.org/abs/2206.02915.
+    // Unlike IEEE-754 types, there are no infinity values, there is no
+    // negative zero, and NaN is represented with the exponent and mantissa
+    // bits set to all 0s and the sign bit set to 1.
+    S_Float8E4M3FZN,
     S_x87DoubleExtended,
     S_MaxSemantics = S_x87DoubleExtended,
   };
@@ -178,6 +190,8 @@
   static const fltSemantics &PPCDoubleDouble() LLVM_READNONE;
   static const fltSemantics &Float8E5M2() LLVM_READNONE;
   static const fltSemantics &Float8E4M3FN() LLVM_READNONE;
+  static const fltSemantics &Float8E5M2FZN() LLVM_READNONE;
+  static const fltSemantics &Float8E4M3FZN() LLVM_READNONE;
   static const fltSemantics &x87DoubleExtended() LLVM_READNONE;
 
   /// A Pseudo fltsemantic used to construct APFloats that cannot conflict with
@@ -570,6 +584,8 @@
   APInt convertPPCDoubleDoubleAPFloatToAPInt() const;
   APInt convertFloat8E5M2APFloatToAPInt() const;
   APInt convertFloat8E4M3FNAPFloatToAPInt() const;
+  APInt convertFloat8E5M2FZNAPFloatToAPInt() const;
+  APInt convertFloat8E4M3FZNAPFloatToAPInt() const;
   void initFromAPInt(const fltSemantics *Sem, const APInt &api);
   void initFromHalfAPInt(const APInt &api);
   void initFromBFloatAPInt(const APInt &api);
@@ -580,6 +596,8 @@
   void initFromPPCDoubleDoubleAPInt(const APInt &api);
   void initFromFloat8E5M2APInt(const APInt &api);
   void initFromFloat8E4M3FNAPInt(const APInt &api);
+  void initFromFloat8E5M2FZNAPInt(const APInt &api);
+  void initFromFloat8E4M3FZNAPInt(const APInt &api);
 
   void assign(const IEEEFloat &);
   void copySignificand(const IEEEFloat &);
Index: clang/lib/AST/MicrosoftMangle.cpp
===================================================================
--- clang/lib/AST/MicrosoftMangle.cpp
+++ clang/lib/AST/MicrosoftMangle.cpp
@@ -840,6 +840,8 @@
   case APFloat::S_PPCDoubleDouble: Out << 'Z'; break;
   case APFloat::S_Float8E5M2:
   case APFloat::S_Float8E4M3FN:
+  case APFloat::S_Float8E5M2FZN:
+  case APFloat::S_Float8E4M3FZN:
     llvm_unreachable("Tried to mangle unexpected APFloat semantics");
   }
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to