https://github.com/Lewuathe updated https://github.com/llvm/llvm-project/pull/79786
>From 8a7243c4c2be5db5e0a95535f36386557e68e18c Mon Sep 17 00:00:00 2001 From: Kai Sasaki <lewua...@gmail.com> Date: Fri, 15 Dec 2023 15:53:54 +0900 Subject: [PATCH] [mlir][complex] Prevent underflow in complex.abs --- .../ComplexToStandard/ComplexToStandard.cpp | 58 ++++++-- .../convert-to-standard.mlir | 125 +++++++++++++++--- .../ComplexToStandard/full-conversion.mlir | 27 +++- .../Dialect/Complex/CPU/correctness.mlir | 48 +++++++ 4 files changed, 218 insertions(+), 40 deletions(-) diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 4c9dad9e2c17312..6e0eddab0e3aae0 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -26,29 +26,59 @@ namespace mlir { using namespace mlir; namespace { +// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780. struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { using OpConversionPattern<complex::AbsOp>::OpConversionPattern; LogicalResult matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto type = op.getType(); + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); - Value real = - rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex()); - Value imag = - rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex()); - Value realSqr = - rewriter.create<arith::MulFOp>(loc, real, real, fmf.getValue()); - Value imagSqr = - rewriter.create<arith::MulFOp>(loc, imag, imag, fmf.getValue()); - Value sqNorm = - rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr, fmf.getValue()); - - rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm); + Type elementType = op.getType(); + Value arg = adaptor.getComplex(); + + Value zero = + b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); + Value one = b.create<arith::ConstantOp>(elementType, + b.getFloatAttr(elementType, 1.0)); + + Value real = b.create<complex::ReOp>(elementType, arg); + Value imag = b.create<complex::ImOp>(elementType, arg); + + Value realIsZero = + b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero); + Value imagIsZero = + b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); + + // Real > Imag + Value imagDivReal = b.create<arith::DivFOp>(imag, real, fmf.getValue()); + Value imagSq = + b.create<arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue()); + Value imagSqPlusOne = b.create<arith::AddFOp>(imagSq, one, fmf.getValue()); + Value imagSqrt = b.create<math::SqrtOp>(imagSqPlusOne, fmf.getValue()); + Value realAbs = b.create<math::AbsFOp>(real, fmf.getValue()); + Value absImag = b.create<arith::MulFOp>(imagSqrt, realAbs, fmf.getValue()); + + // Real <= Imag + Value realDivImag = b.create<arith::DivFOp>(real, imag, fmf.getValue()); + Value realSq = + b.create<arith::MulFOp>(realDivImag, realDivImag, fmf.getValue()); + Value realSqPlusOne = b.create<arith::AddFOp>(realSq, one, fmf.getValue()); + Value realSqrt = b.create<math::SqrtOp>(realSqPlusOne, fmf.getValue()); + Value imagAbs = b.create<math::AbsFOp>(imag, fmf.getValue()); + Value absReal = b.create<arith::MulFOp>(realSqrt, imagAbs, fmf.getValue()); + + rewriter.replaceOpWithNewOp<arith::SelectOp>( + op, realIsZero, imag, + b.create<arith::SelectOp>( + imagIsZero, real, + b.create<arith::SelectOp>( + b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, real, imag), + absImag, absReal))); + return success(); } }; diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index 8fa29ea43854a4b..de519f6a706ab03 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -7,13 +7,30 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 { %abs = complex.abs %arg: complex<f32> return %abs : f32 } + +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32> // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32> -// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32 -// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32 -// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] : f32 -// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 -// CHECK: return %[[NORM]] : f32 +// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 +// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32 +// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32 +// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32 +// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32 +// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] : f32 +// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32 +// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32 +// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32 +// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32 +// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] : f32 +// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32 +// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 +// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32 +// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32 +// CHECK: return %[[ABS3]] : f32 // ----- @@ -241,12 +258,28 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> { %log = complex.log %arg: complex<f32> return %log : complex<f32> } +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32> // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32> -// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32 -// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32 -// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32 -// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 +// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32 +// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32 +// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32 +// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32 +// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] : f32 +// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32 +// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32 +// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32 +// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32 +// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] : f32 +// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32 +// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 +// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32 +// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32 // CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32> // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32> @@ -469,12 +502,28 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> { // CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 // CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 // CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1 +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32> // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32> -// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL2]], %[[REAL2]] : f32 -// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG2]], %[[IMAG2]] : f32 -// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32 -// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL2]], %[[ZERO]] : f32 +// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG2]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG2]], %[[REAL2]] : f32 +// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32 +// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32 +// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32 +// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL2]] : f32 +// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32 +// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL2]], %[[IMAG2]] : f32 +// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32 +// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32 +// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32 +// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG2]] : f32 +// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32 +// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL2]], %[[IMAG2]] : f32 +// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 +// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL2]], %[[ABS1]] : f32 +// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG2]], %[[ABS2]] : f32 // CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] : f32 // CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] : f32 // CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32> @@ -716,13 +765,29 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 { %abs = complex.abs %arg fastmath<nnan,contract> : complex<f32> return %abs : f32 } +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32> // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32> -// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32 -// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32 -// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] fastmath<nnan,contract> : f32 -// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 -// CHECK: return %[[NORM]] : f32 +// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 +// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32 +// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32 +// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32 +// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32 +// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32 +// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32 +// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 +// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32 +// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32 +// CHECK: return %[[ABS3]] : f32 // ----- @@ -807,12 +872,28 @@ func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> { %log = complex.log %arg fastmath<nnan,contract> : complex<f32> return %log : complex<f32> } +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32> // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32> -// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32 -// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32 -// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] fastmath<nnan,contract> : f32 -// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 +// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32 +// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32 +// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32 +// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32 +// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32 +// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32 +// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] fastmath<nnan,contract> : f32 +// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32 +// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32 +// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32 // CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] fastmath<nnan,contract> : f32 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32> // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32> diff --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir index 9983dd46f094334..c01f3698c73868c 100644 --- a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir +++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir @@ -6,12 +6,31 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 { %abs = complex.abs %arg: complex<f32> return %abs : f32 } +// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 +// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 // CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]] // CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]] -// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]] : f32 -// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]] : f32 -// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]] : f32 -// CHECK: %[[NORM:.*]] = llvm.intr.sqrt(%[[SQ_NORM]]) : (f32) -> f32 +// CHECK: %[[REAL_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[REAL]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[IMAG]], %[[ZERO]] : f32 + +// CHECK: %[[IMAG_DIV_REAL:.*]] = llvm.fdiv %[[IMAG]], %[[REAL]] : f32 +// CHECK: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32 +// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = llvm.fadd %[[IMAG_SQ]], %[[ONE]] : f32 +// CHECK: %[[IMAG_SQRT:.*]] = llvm.intr.sqrt(%[[IMAG_SQ_PLUS_ONE]]) : (f32) -> f32 +// CHECK: %[[REAL_ABS:.*]] = llvm.intr.fabs(%[[REAL]]) : (f32) -> f32 +// CHECK: %[[ABS_IMAG:.*]] = llvm.fmul %[[IMAG_SQRT]], %[[REAL_ABS]] : f32 + +// CHECK: %[[REAL_DIV_IMAG:.*]] = llvm.fdiv %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32 +// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = llvm.fadd %[[REAL_SQ]], %[[ONE]] : f32 +// CHECK: %[[REAL_SQRT:.*]] = llvm.intr.sqrt(%[[REAL_SQ_PLUS_ONE]]) : (f32) -> f32 +// CHECK: %[[IMAG_ABS:.*]] = llvm.intr.fabs(%[[IMAG]]) : (f32) -> f32 +// CHECK: %[[ABS_REAL:.*]] = llvm.fmul %[[REAL_SQRT]], %[[IMAG_ABS]] : f32 + +// CHECK: %[[REAL_GT_IMAG:.*]] = llvm.fcmp "ogt" %[[REAL]], %[[IMAG]] : f32 +// CHECK: %[[ABS1:.*]] = llvm.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : i1, f32 +// CHECK: %[[ABS2:.*]] = llvm.select %[[IMAG_IS_ZERO]], %[[REAL]], %[[ABS1]] : i1, f32 +// CHECK: %[[NORM:.*]] = llvm.select %[[REAL_IS_ZERO]], %[[IMAG]], %[[ABS2]] : i1, f32 // CHECK: llvm.return %[[NORM]] : f32 // CHECK-LABEL: llvm.func @complex_eq diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir index 349b92a7aefa2e3..efdf057f08df271 100644 --- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir +++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir @@ -106,6 +106,27 @@ func.func @angle(%arg: complex<f32>) -> f32 { func.return %angle : f32 } +func.func @test_element_f64(%input: tensor<?xcomplex<f64>>, + %func: (complex<f64>) -> f64) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %size = tensor.dim %input, %c0: tensor<?xcomplex<f64>> + + scf.for %i = %c0 to %size step %c1 { + %elem = tensor.extract %input[%i]: tensor<?xcomplex<f64>> + + %val = func.call_indirect %func(%elem) : (complex<f64>) -> f64 + vector.print %val : f64 + scf.yield + } + func.return +} + +func.func @abs(%arg: complex<f64>) -> f64 { + %abs = complex.abs %arg : complex<f64> + func.return %abs : f64 +} + func.func @entry() { // complex.sqrt test %sqrt_test = arith.constant dense<[ @@ -300,5 +321,32 @@ func.func @entry() { call @test_element(%angle_test_cast, %angle_func) : (tensor<?xcomplex<f32>>, (complex<f32>) -> f32) -> () + // complex.abs test + %abs_test = arith.constant dense<[ + (1.0, 1.0), + // CHECK: 1.414 + (1.0e300, 1.0e300), + // CHECK-NEXT: 1.41421e+300 + (1.0e-300, 1.0e-300), + // CHECK-NEXT: 1.41421e-300 + (5.0, 0.0), + // CHECK-NEXT: 5 + (0.0, 6.0), + // CHECK-NEXT: 6 + (7.0, 8.0), + // CHECK-NEXT: 10.6301 + (-1.0, -1.0), + // CHECK-NEXT: 1.414 + (-1.0e300, -1.0e300) + // CHECK-NEXT: 1.41421e+300 + ]> : tensor<8xcomplex<f64>> + %abs_test_cast = tensor.cast %abs_test + : tensor<8xcomplex<f64>> to tensor<?xcomplex<f64>> + + %abs_func = func.constant @abs : (complex<f64>) -> f64 + + call @test_element_f64(%abs_test_cast, %abs_func) + : (tensor<?xcomplex<f64>>, (complex<f64>) -> f64) -> () + func.return } _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits