https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/116176
>From 8f237ae7e1195cf6c906d4e9075f081d9c7e65eb Mon Sep 17 00:00:00 2001 From: Matthias Springer <msprin...@nvidia.com> Date: Thu, 14 Nov 2024 08:47:06 +0100 Subject: [PATCH] [mlir][Parser] Add `nan` and `inf` keywords --- mlir/lib/AsmParser/AttributeParser.cpp | 30 +++++++---- mlir/lib/AsmParser/Parser.cpp | 22 ++++++++ mlir/lib/AsmParser/TokenKinds.def | 2 + mlir/test/Dialect/Arith/canonicalize.mlir | 10 ++-- mlir/test/IR/attribute.mlir | 54 +++++++++++++++++++ .../math-polynomial-approx.mlir | 36 ++++++------- 6 files changed, 122 insertions(+), 32 deletions(-) diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index ff616dac9625b4..cc038b6b50cb3f 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -21,8 +21,10 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/IntegerSet.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Endian.h" +#include <cmath> #include <optional> using namespace mlir; @@ -121,6 +123,8 @@ Attribute Parser::parseAttribute(Type type) { // Parse floating point and integer attributes. case Token::floatliteral: + case Token::kw_inf: + case Token::kw_nan: return parseFloatAttr(type, /*isNegative=*/false); case Token::integer: return parseDecOrHexAttr(type, /*isNegative=*/false); @@ -128,7 +132,8 @@ Attribute Parser::parseAttribute(Type type) { consumeToken(Token::minus); if (getToken().is(Token::integer)) return parseDecOrHexAttr(type, /*isNegative=*/true); - if (getToken().is(Token::floatliteral)) + if (getToken().is(Token::floatliteral) || getToken().is(Token::kw_inf) || + getToken().is(Token::kw_nan)) return parseFloatAttr(type, /*isNegative=*/true); return (emitWrongTokenError( @@ -342,10 +347,8 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) { /// Parse a float attribute. Attribute Parser::parseFloatAttr(Type type, bool isNegative) { - auto val = getToken().getFloatingPointValue(); - if (!val) - return (emitError("floating point value too large for attribute"), nullptr); - consumeToken(Token::floatliteral); + const Token tok = getToken(); + consumeToken(); if (!type) { // Default to F64 when no type is specified. if (!consumeIf(Token::colon)) @@ -353,10 +356,16 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) { else if (!(type = parseType())) return nullptr; } - if (!isa<FloatType>(type)) + auto floatType = dyn_cast<FloatType>(type); + if (!floatType) return (emitError("floating point value not valid for specified type"), nullptr); - return FloatAttr::get(type, isNegative ? -*val : *val); + auto emitErrorAtTok = [&]() { return emitError(tok.getLoc()); }; + FailureOr<APFloat> result = parseFloatFromLiteral( + emitErrorAtTok, tok, isNegative, floatType.getFloatSemantics()); + if (failed(result)) + return Attribute(); + return FloatAttr::get(floatType, *result); } /// Construct an APint from a parsed value, a known attribute type and @@ -622,7 +631,7 @@ TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy, } // Check to see if floating point values were parsed. - if (token.is(Token::floatliteral)) { + if (token.isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan)) { return p.emitError(tokenLoc) << "expected integer elements, but parsed floating-point"; } @@ -729,6 +738,8 @@ ParseResult TensorLiteralParser::parseElement() { // Parse a boolean element. case Token::kw_true: case Token::kw_false: + case Token::kw_inf: + case Token::kw_nan: case Token::floatliteral: case Token::integer: storage.emplace_back(/*isNegative=*/false, p.getToken()); @@ -738,7 +749,8 @@ ParseResult TensorLiteralParser::parseElement() { // Parse a signed integer or a negative floating-point element. case Token::minus: p.consumeToken(Token::minus); - if (!p.getToken().isAny(Token::floatliteral, Token::integer)) + if (!p.getToken().isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan, + Token::integer)) return p.emitError("expected integer or floating point literal"); storage.emplace_back(/*isNegative=*/true, p.getToken()); p.consumeToken(); diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index e3db248164672c..b347622502ac77 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -350,11 +350,33 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) { ParseResult Parser::parseFloatFromLiteral(std::optional<APFloat> &result, const Token &tok, bool isNegative, const llvm::fltSemantics &semantics) { + // Check for inf keyword. + if (tok.is(Token::kw_inf)) { + if (!APFloat::semanticsHasInf(semantics)) + return emitError(tok.getLoc()) + << "floating point type does not support infinity"; + result = APFloat::getInf(semantics, isNegative); + return success(); + } + + // Check for NaN keyword. + if (tok.is(Token::kw_nan)) { + if (!APFloat::semanticsHasNan(semantics)) + return emitError(tok.getLoc()) + << "floating point type does not support NaN"; + result = APFloat::getNaN(semantics, isNegative); + return success(); + } + // Check for a floating point value. if (tok.is(Token::floatliteral)) { auto val = tok.getFloatingPointValue(); if (!val) return emitError(tok.getLoc()) << "floating point value too large"; + if (std::fpclassify(*val) == FP_ZERO && + !APFloat::semanticsHasZero(semantics)) + return emitError(tok.getLoc()) + << "floating point type does not support zero"; result.emplace(isNegative ? -*val : *val); bool unused; diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def index 49da8c3dea5fa5..9208c8adddcfce 100644 --- a/mlir/lib/AsmParser/TokenKinds.def +++ b/mlir/lib/AsmParser/TokenKinds.def @@ -111,11 +111,13 @@ TOK_KEYWORD(floordiv) TOK_KEYWORD(for) TOK_KEYWORD(func) TOK_KEYWORD(index) +TOK_KEYWORD(inf) TOK_KEYWORD(loc) TOK_KEYWORD(max) TOK_KEYWORD(memref) TOK_KEYWORD(min) TOK_KEYWORD(mod) +TOK_KEYWORD(nan) TOK_KEYWORD(none) TOK_KEYWORD(offset) TOK_KEYWORD(size) diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index a386a178b78995..c86b2b5f63f016 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -1880,7 +1880,7 @@ func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) { // CHECK-NEXT: %[[X:.+]] = arith.minimumf %arg0, %[[C0]] // CHECK-NEXT: return %[[X]], %arg0, %arg0 %c0 = arith.constant 0.0 : f32 - %inf = arith.constant 0x7F800000 : f32 + %inf = arith.constant inf : f32 %0 = arith.minimumf %c0, %arg0 : f32 %1 = arith.minimumf %arg0, %arg0 : f32 %2 = arith.minimumf %inf, %arg0 : f32 @@ -1895,7 +1895,7 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) { // CHECK-NEXT: %[[X:.+]] = arith.maximumf %arg0, %[[C0]] // CHECK-NEXT: return %[[X]], %arg0, %arg0 %c0 = arith.constant 0.0 : f32 - %-inf = arith.constant 0xFF800000 : f32 + %-inf = arith.constant -inf : f32 %0 = arith.maximumf %c0, %arg0 : f32 %1 = arith.maximumf %arg0, %arg0 : f32 %2 = arith.maximumf %-inf, %arg0 : f32 @@ -1910,7 +1910,7 @@ func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) { // CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]] // CHECK-NEXT: return %[[X]], %arg0, %arg0 %c0 = arith.constant 0.0 : f32 - %inf = arith.constant 0x7F800000 : f32 + %inf = arith.constant inf : f32 %0 = arith.minnumf %c0, %arg0 : f32 %1 = arith.minnumf %arg0, %arg0 : f32 %2 = arith.minnumf %inf, %arg0 : f32 @@ -1925,7 +1925,7 @@ func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) { // CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]] // CHECK-NEXT: return %[[X]], %arg0, %arg0 %c0 = arith.constant 0.0 : f32 - %-inf = arith.constant 0xFF800000 : f32 + %-inf = arith.constant -inf : f32 %0 = arith.maxnumf %c0, %arg0 : f32 %1 = arith.maxnumf %arg0, %arg0 : f32 %2 = arith.maxnumf %-inf, %arg0 : f32 @@ -2024,7 +2024,7 @@ func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) { // CHECK-DAG: %[[T:.*]] = arith.constant true // CHECK-DAG: %[[F:.*]] = arith.constant false // CHECK: return %[[F]], %[[F]], %[[T]], %[[T]] - %nan = arith.constant 0x7fffffff : f32 + %nan = arith.constant nan : f32 %0 = arith.cmpf olt, %nan, %arg0 : f32 %1 = arith.cmpf olt, %arg0, %nan : f32 %2 = arith.cmpf ugt, %nan, %arg0 : f32 diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir index a62de3f5004d73..1fbb4986ab2ea7 100644 --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -108,9 +108,63 @@ func.func @float_attrs_pass() { // CHECK: float_attr = 2.000000e+00 : f128 float_attr = 2. : f128 } : () -> () + "test.float_attrs"() { + // Note: nan/inf are printed in binary format because there may be multiple + // nan/inf representations. + // CHECK: float_attr = 0x7FC00000 : f32 + float_attr = nan : f32 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 0x7C : f8E4M3 + float_attr = nan : f8E4M3 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 0xFFC00000 : f32 + float_attr = -nan : f32 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 0xFC : f8E4M3 + float_attr = -nan : f8E4M3 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 0x7F800000 : f32 + float_attr = inf : f32 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 0x78 : f8E4M3 + float_attr = inf : f8E4M3 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 0xFF800000 : f32 + float_attr = -inf : f32 + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 0xF8 : f8E4M3 + float_attr = -inf : f8E4M3 + } : () -> () return } +// ----- + +func.func @float_nan_unsupported() { + "test.float_attrs"() { + // expected-error @below{{floating point type does not support NaN}} + float_attr = nan : f4E2M1FN + } : () -> () +} + +// ----- + +func.func @float_inf_unsupported() { + "test.float_attrs"() { + // expected-error @below{{floating point type does not support infinity}} + float_attr = inf : f4E2M1FN + } : () -> () +} + +// ----- + //===----------------------------------------------------------------------===// // Test integer attributes //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir index b8861198d596b0..28b656b0da5f1a 100644 --- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir +++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir @@ -41,7 +41,7 @@ func.func @tanh() { call @tanh_8xf32(%v2) : (vector<8xf32>) -> () // CHECK: nan - %nan = arith.constant 0x7fc00000 : f32 + %nan = arith.constant nan : f32 call @tanh_f32(%nan) : (f32) -> () return @@ -87,15 +87,15 @@ func.func @log() { call @log_f32(%zero) : (f32) -> () // CHECK: nan - %nan = arith.constant 0x7fc00000 : f32 + %nan = arith.constant nan : f32 call @log_f32(%nan) : (f32) -> () // CHECK: inf - %inf = arith.constant 0x7f800000 : f32 + %inf = arith.constant inf : f32 call @log_f32(%inf) : (f32) -> () // CHECK: -inf, nan, inf, 0.693147 - %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 2.0]> : vector<4xf32> + %special_vec = arith.constant dense<[0.0, -1.0, inf, 2.0]> : vector<4xf32> call @log_4xf32(%special_vec) : (vector<4xf32>) -> () return @@ -141,11 +141,11 @@ func.func @log2() { call @log2_f32(%neg_one) : (f32) -> () // CHECK: inf - %inf = arith.constant 0x7f800000 : f32 + %inf = arith.constant inf : f32 call @log2_f32(%inf) : (f32) -> () // CHECK: -inf, nan, inf, 1.58496 - %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 3.0]> : vector<4xf32> + %special_vec = arith.constant dense<[0.0, -1.0, inf, 3.0]> : vector<4xf32> call @log2_4xf32(%special_vec) : (vector<4xf32>) -> () return @@ -192,11 +192,11 @@ func.func @log1p() { call @log1p_f32(%neg_two) : (f32) -> () // CHECK: inf - %inf = arith.constant 0x7f800000 : f32 + %inf = arith.constant inf : f32 call @log1p_f32(%inf) : (f32) -> () // CHECK: -inf, nan, inf, 9.99995e-06 - %special_vec = arith.constant dense<[-1.0, -1.1, 0x7f800000, 0.00001]> : vector<4xf32> + %special_vec = arith.constant dense<[-1.0, -1.1, inf, 0.00001]> : vector<4xf32> call @log1p_4xf32(%special_vec) : (vector<4xf32>) -> () return @@ -247,7 +247,7 @@ func.func @erf() { call @erf_f32(%val7) : (f32) -> () // CHECK: -1 - %negativeInf = arith.constant 0xff800000 : f32 + %negativeInf = arith.constant -inf : f32 call @erf_f32(%negativeInf) : (f32) -> () // CHECK: -1, -1, -0.913759, -0.731446 @@ -263,11 +263,11 @@ func.func @erf() { call @erf_4xf32(%vecVals3) : (vector<4xf32>) -> () // CHECK: 1 - %inf = arith.constant 0x7f800000 : f32 + %inf = arith.constant inf : f32 call @erf_f32(%inf) : (f32) -> () // CHECK: nan - %nan = arith.constant 0x7fc00000 : f32 + %nan = arith.constant nan : f32 call @erf_f32(%nan) : (f32) -> () return @@ -306,15 +306,15 @@ func.func @exp() { call @exp_4xf32(%special_vec) : (vector<4xf32>) -> () // CHECK: inf - %inf = arith.constant 0x7f800000 : f32 + %inf = arith.constant inf : f32 call @exp_f32(%inf) : (f32) -> () // CHECK: 0 - %negative_inf = arith.constant 0xff800000 : f32 + %negative_inf = arith.constant -inf : f32 call @exp_f32(%negative_inf) : (f32) -> () // CHECK: nan - %nan = arith.constant 0x7fc00000 : f32 + %nan = arith.constant nan : f32 call @exp_f32(%nan) : (f32) -> () return @@ -358,19 +358,19 @@ func.func @expm1() { call @expm1_8xf32(%v2) : (vector<8xf32>) -> () // CHECK: -1 - %neg_inf = arith.constant 0xff800000 : f32 + %neg_inf = arith.constant -inf : f32 call @expm1_f32(%neg_inf) : (f32) -> () // CHECK: inf - %inf = arith.constant 0x7f800000 : f32 + %inf = arith.constant inf : f32 call @expm1_f32(%inf) : (f32) -> () // CHECK: -1, inf, 1e-10 - %special_vec = arith.constant dense<[0xff800000, 0x7f800000, 1.0e-10]> : vector<3xf32> + %special_vec = arith.constant dense<[-inf, inf, 1.0e-10]> : vector<3xf32> call @expm1_3xf32(%special_vec) : (vector<3xf32>) -> () // CHECK: nan - %nan = arith.constant 0x7fc00000 : f32 + %nan = arith.constant nan : f32 call @expm1_f32(%nan) : (f32) -> () return _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits