https://github.com/TIFitis updated https://github.com/llvm/llvm-project/pull/158722
>From 6976910364aa2fe18603aefcb27b10bd0120513d Mon Sep 17 00:00:00 2001 From: Akash Banerjee <akash.baner...@amd.com> Date: Mon, 15 Sep 2025 20:35:29 +0100 Subject: [PATCH 1/6] Add complex.powi op. --- flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 20 ++-- .../Transforms/ConvertComplexPow.cpp | 94 +++++++++---------- flang/test/Lower/HLFIR/binary-ops.f90 | 2 +- .../test/Lower/Intrinsics/pow_complex16i.f90 | 2 +- .../test/Lower/Intrinsics/pow_complex16k.f90 | 2 +- flang/test/Lower/amdgcn-complex.f90 | 9 ++ flang/test/Lower/power-operator.f90 | 9 +- .../mlir/Dialect/Complex/IR/ComplexOps.td | 26 +++++ .../ComplexToROCDLLibraryCalls.cpp | 41 +++++++- .../Transforms/AlgebraicSimplification.cpp | 24 +++-- .../Dialect/Math/Transforms/CMakeLists.txt | 1 + .../complex-to-rocdl-library-calls.mlir | 14 +++ mlir/test/Dialect/Complex/powi-simplify.mlir | 20 ++++ 13 files changed, 188 insertions(+), 76 deletions(-) create mode 100644 mlir/test/Dialect/Complex/powi-simplify.mlir diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 466458c05dba7..74a4e8f85c8ff 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -1331,14 +1331,20 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc, return genLibCall(builder, loc, mathOp, mathLibFuncType, args); auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0)); mlir::Value exp = args[1]; - if (!mlir::isa<mlir::ComplexType>(exp.getType())) { - auto realTy = complexTy.getElementType(); - mlir::Value realExp = builder.createConvert(loc, realTy, exp); - mlir::Value zero = builder.createRealConstant(loc, realTy, 0); - exp = - builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero); + mlir::Value result; + if (mlir::isa<mlir::IntegerType>(exp.getType()) || + mlir::isa<mlir::IndexType>(exp.getType())) { + result = builder.create<mlir::complex::PowiOp>(loc, args[0], exp); + } else { + if (!mlir::isa<mlir::ComplexType>(exp.getType())) { + auto realTy = complexTy.getElementType(); + mlir::Value realExp = builder.createConvert(loc, realTy, exp); + mlir::Value zero = builder.createRealConstant(loc, realTy, 0); + exp = builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, + zero); + } + result = builder.create<mlir::complex::PowOp>(loc, args[0], exp); } - mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp); result = builder.createConvert(loc, mathLibFuncType.getResult(0), result); return result; } diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp index 78f9d9e4f639a..d76451459def9 100644 --- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp +++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp @@ -58,63 +58,57 @@ void ConvertComplexPowPass::runOnOperation() { ModuleOp mod = getOperation(); fir::FirOpBuilder builder(mod, fir::getKindMapping(mod)); - mod.walk([&](complex::PowOp op) { + mod.walk([&](complex::PowiOp op) { builder.setInsertionPoint(op); Location loc = op.getLoc(); auto complexTy = cast<ComplexType>(op.getType()); auto elemTy = complexTy.getElementType(); - Value base = op.getLhs(); - Value rhs = op.getRhs(); - - Value intExp; - if (auto create = rhs.getDefiningOp<complex::CreateOp>()) { - if (isZero(create.getImaginary())) { - if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) { - if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType())) - intExp = conv.getValue(); - } - } - } - + Value intExp = op.getRhs(); func::FuncOp callee; - SmallVector<Value> args; - if (intExp) { - unsigned realBits = cast<FloatType>(elemTy).getWidth(); - unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth(); - auto funcTy = builder.getFunctionType( - {complexTy, builder.getIntegerType(intBits)}, {complexTy}); - if (realBits == 32 && intBits == 32) - callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy); - else if (realBits == 32 && intBits == 64) - callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy); - else if (realBits == 64 && intBits == 32) - callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy); - else if (realBits == 64 && intBits == 64) - callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy); - else if (realBits == 128 && intBits == 32) - callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy); - else if (realBits == 128 && intBits == 64) - callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy); - else - return; - args = {base, intExp}; - } else { - unsigned realBits = cast<FloatType>(elemTy).getWidth(); - auto funcTy = - builder.getFunctionType({complexTy, complexTy}, {complexTy}); - if (realBits == 32) - callee = getOrDeclare(builder, loc, "cpowf", funcTy); - else if (realBits == 64) - callee = getOrDeclare(builder, loc, "cpow", funcTy); - else if (realBits == 128) - callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy); - else - return; - args = {base, rhs}; - } + unsigned realBits = cast<FloatType>(elemTy).getWidth(); + unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth(); + auto funcTy = builder.getFunctionType( + {complexTy, builder.getIntegerType(intBits)}, {complexTy}); + if (realBits == 32 && intBits == 32) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy); + else if (realBits == 32 && intBits == 64) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy); + else if (realBits == 64 && intBits == 32) + callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy); + else if (realBits == 64 && intBits == 64) + callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy); + else if (realBits == 128 && intBits == 32) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy); + else if (realBits == 128 && intBits == 64) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy); + else + return; + auto call = fir::CallOp::create(builder, loc, callee, {base, intExp}); + if (auto fmf = op.getFastmathAttr()) + call.setFastmathAttr(fmf); + op.replaceAllUsesWith(call.getResult(0)); + op.erase(); + }); - auto call = fir::CallOp::create(builder, loc, callee, args); + mod.walk([&](complex::PowOp op) { + builder.setInsertionPoint(op); + Location loc = op.getLoc(); + auto complexTy = cast<ComplexType>(op.getType()); + auto elemTy = complexTy.getElementType(); + unsigned realBits = cast<FloatType>(elemTy).getWidth(); + func::FuncOp callee; + auto funcTy = builder.getFunctionType({complexTy, complexTy}, {complexTy}); + if (realBits == 32) + callee = getOrDeclare(builder, loc, "cpowf", funcTy); + else if (realBits == 64) + callee = getOrDeclare(builder, loc, "cpow", funcTy); + else if (realBits == 128) + callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy); + else + return; + auto call = + fir::CallOp::create(builder, loc, callee, {op.getLhs(), op.getRhs()}); if (auto fmf = op.getFastmathAttr()) call.setFastmathAttr(fmf); op.replaceAllUsesWith(call.getResult(0)); diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90 index 1fbd333db37c3..7e1691dd1587a 100644 --- a/flang/test/Lower/HLFIR/binary-ops.f90 +++ b/flang/test/Lower/HLFIR/binary-ops.f90 @@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z) ! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>) ! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>> ! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32> -! CHECK: %[[VAL_8:.*]] = complex.pow +! CHECK: %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] : complex<f32>, i32 subroutine extremum(c, n, l) integer(8), intent(in) :: l diff --git a/flang/test/Lower/Intrinsics/pow_complex16i.f90 b/flang/test/Lower/Intrinsics/pow_complex16i.f90 index 1827863a57f43..0b26024b02021 100644 --- a/flang/test/Lower/Intrinsics/pow_complex16i.f90 +++ b/flang/test/Lower/Intrinsics/pow_complex16i.f90 @@ -4,7 +4,7 @@ ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s ! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128> -! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128> +! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128> complex(16) :: a integer(4) :: b b = a ** b diff --git a/flang/test/Lower/Intrinsics/pow_complex16k.f90 b/flang/test/Lower/Intrinsics/pow_complex16k.f90 index 039dfd5152a06..90a9f5e03628d 100644 --- a/flang/test/Lower/Intrinsics/pow_complex16k.f90 +++ b/flang/test/Lower/Intrinsics/pow_complex16k.f90 @@ -4,7 +4,7 @@ ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s ! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128> -! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128> +! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128> complex(16) :: a integer(8) :: b b = a ** b diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90 index 4ee5de4d2842e..a28eaea82379b 100644 --- a/flang/test/Lower/amdgcn-complex.f90 +++ b/flang/test/Lower/amdgcn-complex.f90 @@ -25,3 +25,12 @@ subroutine pow_test(a, b, c) complex :: a, b, c a = b**c end subroutine pow_test + +! CHECK-LABEL: func @_QPpowi_test( +! CHECK: complex.powi +! CHECK-NOT: fir.call @_FortranAcpowi +subroutine powi_test(a, b, c) + complex :: a, b + integer :: i + b = a ** i +end subroutine powi_test diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90 index 3058927144248..9f74d172a6bb2 100644 --- a/flang/test/Lower/power-operator.f90 +++ b/flang/test/Lower/power-operator.f90 @@ -96,7 +96,7 @@ subroutine pow_c4_i4(x, y, z) complex :: x, z integer :: y z = x ** y - ! CHECK: complex.pow + ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32 ! PRECISE: fir.call @_FortranAcpowi end subroutine @@ -105,7 +105,7 @@ subroutine pow_c4_i8(x, y, z) complex :: x, z integer(8) :: y z = x ** y - ! CHECK: complex.pow + ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i64 ! PRECISE: fir.call @_FortranAcpowk end subroutine @@ -114,7 +114,7 @@ subroutine pow_c8_i4(x, y, z) complex(8) :: x, z integer :: y z = x ** y - ! CHECK: complex.pow + ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i32 ! PRECISE: fir.call @_FortranAzpowi end subroutine @@ -123,7 +123,7 @@ subroutine pow_c8_i8(x, y, z) complex(8) :: x, z integer(8) :: y z = x ** y - ! CHECK: complex.pow + ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i64 ! PRECISE: fir.call @_FortranAzpowk end subroutine @@ -142,4 +142,3 @@ subroutine pow_c8_c8(x, y, z) ! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64> ! PRECISE: fir.call @cpow end subroutine - diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td index 44590406301eb..ca5103c16889c 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -443,6 +443,32 @@ def PowOp : ComplexArithmeticOp<"pow"> { }]; } +//===----------------------------------------------------------------------===// +// PowiOp +//===----------------------------------------------------------------------===// + +def PowiOp : Complex_Op<"powi", + [Pure, Elementwise, SameOperandsAndResultShape, + AllTypesMatch<["lhs", "result"]>]> { + let summary = "complex number raised to integer power"; + let description = [{ + The `powi` operation takes a complex number and an integer exponent. + + Example: + + ```mlir + %a = complex.powi %b, %c : complex<f32>, i32 + ``` + }]; + + let arguments = (ins Complex<AnyFloat>:$lhs, + AnySignlessInteger:$rhs); + let results = (outs Complex<AnyFloat>:$result); + + let assemblyFormat = + "$lhs `,` $rhs attr-dict `:` type($result) `,` type($rhs)"; +} + //===----------------------------------------------------------------------===// // ReOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 72b1fa6e833f9..361e422ce1468 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -7,9 +7,11 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -74,10 +76,40 @@ struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> { return success(); } }; + +// Rewrite complex.powi(z, n) -> complex.pow(z, complex(float(n), 0)) +struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> { + using OpRewritePattern<complex::PowiOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(complex::PowiOp op, + PatternRewriter &rewriter) const final { + auto complexType = cast<ComplexType>(getElementTypeOrSelf(op.getType())); + Type elementType = complexType.getElementType(); + + Type exponentType = op.getRhs().getType(); + Type exponentFloatType = elementType; + if (auto shapedType = dyn_cast<ShapedType>(exponentType)) + exponentFloatType = shapedType.cloneWith(std::nullopt, elementType); + + Location loc = op.getLoc(); + Value exponentReal = + rewriter.create<arith::SIToFPOp>(loc, exponentFloatType, op.getRhs()); + Value zeroImag = rewriter.create<arith::ConstantOp>( + loc, rewriter.getZeroAttr(exponentFloatType)); + Value exponent = rewriter.create<complex::CreateOp>( + loc, op.getLhs().getType(), exponentReal, zeroImag); + + rewriter + .replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(), + exponent); + return success(); + } +}; } // namespace void mlir::populateComplexToROCDLLibraryCallsConversionPatterns( RewritePatternSet &patterns) { + patterns.add<PowiOpToROCDLLibraryCalls>(patterns.getContext()); patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext()); patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>( patterns.getContext(), "__ocml_cabs_f32"); @@ -128,11 +160,12 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() { populateComplexToROCDLLibraryCallsConversionPatterns(patterns); ConversionTarget target(getContext()); - target.addLegalDialect<func::FuncDialect>(); - target.addLegalOp<complex::MulOp>(); + target.addLegalDialect<arith::ArithDialect, func::FuncDialect>(); + target.addLegalOp<complex::CreateOp, complex::MulOp>(); target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp, - complex::LogOp, complex::PowOp, complex::SinOp, - complex::SqrtOp, complex::TanOp, complex::TanhOp>(); + complex::LogOp, complex::PowOp, complex::PowiOp, + complex::SinOp, complex::SqrtOp, complex::TanOp, + complex::TanhOp>(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp index 31785eb20a642..3711c112cc631 100644 --- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp +++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -175,12 +176,20 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite( Value one; Type opType = getElementTypeOrSelf(op.getType()); - if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) + if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) { one = arith::ConstantOp::create(rewriter, loc, rewriter.getFloatAttr(opType, 1.0)); - else + } else if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>) { + auto complexTy = cast<ComplexType>(opType); + Type elementType = complexTy.getElementType(); + auto realPart = rewriter.getFloatAttr(elementType, 1.0); + auto imagPart = rewriter.getFloatAttr(elementType, 0.0); + one = rewriter.create<complex::ConstantOp>( + loc, complexTy, rewriter.getArrayAttr({realPart, imagPart})); + } else { one = arith::ConstantOp::create(rewriter, loc, rewriter.getIntegerAttr(opType, 1)); + } // Replace `[fi]powi(x, 0)` with `1`. if (exponentValue == 0) { @@ -224,9 +233,10 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite( void mlir::populateMathAlgebraicSimplificationPatterns( RewritePatternSet &patterns) { - patterns - .add<PowFStrengthReduction, - PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>, - PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>( - patterns.getContext()); + patterns.add< + PowFStrengthReduction, + PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>, + PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>, + PowIStrengthReduction<complex::PowiOp, complex::DivOp, complex::MulOp>>( + patterns.getContext(), /*exponentThreshold=*/8); } diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt index d37a056e8e158..ff62b515533c3 100644 --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRMathTransforms LINK_LIBS PUBLIC MLIRArithDialect + MLIRComplexDialect MLIRDialectUtils MLIRIR MLIRMathDialect diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir index 080ba4f0ff67b..cf177528e532c 100644 --- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir +++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir @@ -68,6 +68,20 @@ func.func @pow_caller(%z: complex<f32>, %w: complex<f32>) -> complex<f32> { return %r : complex<f32> } +//CHECK-LABEL: @powi_caller +//CHECK: (%[[Z:.*]]: complex<f32>, %[[N:.*]]: i32) +func.func @powi_caller(%z: complex<f32>, %n: i32) -> complex<f32> { + // CHECK: %[[N_FP:.*]] = arith.sitofp %[[N]] : i32 to f32 + // CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[N_COMPLEX:.*]] = complex.create %[[N_FP]], %[[ZERO]] : complex<f32> + // CHECK: %[[LOG:.*]] = call @__ocml_clog_f32(%[[Z]]) : (complex<f32>) -> complex<f32> + // CHECK: %[[MUL:.*]] = complex.mul %[[N_COMPLEX]], %[[LOG]] : complex<f32> + // CHECK: %[[EXP:.*]] = call @__ocml_cexp_f32(%[[MUL]]) : (complex<f32>) -> complex<f32> + // CHECK: return %[[EXP]] : complex<f32> + %r = complex.powi %z, %n : complex<f32>, i32 + return %r : complex<f32> +} + //CHECK-LABEL: @sin_caller func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}}) diff --git a/mlir/test/Dialect/Complex/powi-simplify.mlir b/mlir/test/Dialect/Complex/powi-simplify.mlir new file mode 100644 index 0000000000000..c7bb6a9d81479 --- /dev/null +++ b/mlir/test/Dialect/Complex/powi-simplify.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-opt %s -test-math-algebraic-simplification | FileCheck %s + +func.func @pow3(%arg0: complex<f32>) -> complex<f32> { + %c3 = arith.constant 3 : i32 + %0 = complex.powi %arg0, %c3 : complex<f32>, i32 + return %0 : complex<f32> +} +// CHECK-LABEL: func.func @pow3( +// CHECK-NOT: complex.powi +// CHECK: %[[M0:.+]] = complex.mul %{{.*}}, %{{.*}} : complex<f32> +// CHECK: %[[M1:.+]] = complex.mul %[[M0]], %{{.*}} : complex<f32> +// CHECK: return %[[M1]] : complex<f32> + +func.func @pow9(%arg0: complex<f32>) -> complex<f32> { + %c9 = arith.constant 9 : i32 + %0 = complex.powi %arg0, %c9 : complex<f32>, i32 + return %0 : complex<f32> +} +// CHECK-LABEL: func.func @pow9( +// CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32 >From 8f71488583c15d68c5fd2bf6e86a280698f09624 Mon Sep 17 00:00:00 2001 From: Akash Banerjee <akash.baner...@amd.com> Date: Mon, 15 Sep 2025 20:47:37 +0100 Subject: [PATCH 2/6] Fix clang-format. --- .../ComplexToROCDLLibraryCalls.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 361e422ce1468..dbb26377fc3c4 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -99,9 +99,8 @@ struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> { Value exponent = rewriter.create<complex::CreateOp>( loc, op.getLhs().getType(), exponentReal, zeroImag); - rewriter - .replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(), - exponent); + rewriter.replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(), + exponent); return success(); } }; >From 52182f113bde37682749b5b8723a2bd7802300bf Mon Sep 17 00:00:00 2001 From: Akash Banerjee <akash.baner...@amd.com> Date: Mon, 15 Sep 2025 20:57:37 +0100 Subject: [PATCH 3/6] Remove unused function. --- flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp index d76451459def9..1c251883cf707 100644 --- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp +++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp @@ -47,13 +47,6 @@ static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc, return func; } -static bool isZero(Value v) { - if (auto cst = v.getDefiningOp<arith::ConstantOp>()) - if (auto attr = dyn_cast<FloatAttr>(cst.getValue())) - return attr.getValue().isZero(); - return false; -} - void ConvertComplexPowPass::runOnOperation() { ModuleOp mod = getOperation(); fir::FirOpBuilder builder(mod, fir::getKindMapping(mod)); >From 6faad70f8ef516996ccd4436e7c7cc3ec29310f6 Mon Sep 17 00:00:00 2001 From: Akash Banerjee <akash.baner...@amd.com> Date: Wed, 17 Sep 2025 22:15:26 +0100 Subject: [PATCH 4/6] Add fastmath attribute. Update op description. Update tests. --- flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 10 ++-- flang/test/Lower/HLFIR/binary-ops.f90 | 2 +- .../test/Lower/Intrinsics/pow_complex16i.f90 | 2 +- .../test/Lower/Intrinsics/pow_complex16k.f90 | 2 +- flang/test/Transforms/convert-complex-pow.fir | 60 +++++++++---------- .../mlir/Dialect/Complex/IR/ComplexOps.td | 14 +++-- .../ComplexToROCDLLibraryCalls.cpp | 2 +- .../Transforms/AlgebraicSimplification.cpp | 18 +++++- 8 files changed, 63 insertions(+), 47 deletions(-) diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 74a4e8f85c8ff..c7cbf162db786 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -1332,9 +1332,11 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc, auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0)); mlir::Value exp = args[1]; mlir::Value result; - if (mlir::isa<mlir::IntegerType>(exp.getType()) || - mlir::isa<mlir::IndexType>(exp.getType())) { - result = builder.create<mlir::complex::PowiOp>(loc, args[0], exp); + auto fmfAttr = mlir::arith::FastMathFlagsAttr::get( + builder.getContext(), builder.getFastMathFlags()); + if (mlir::isa<mlir::IntegerType>(exp.getType())) { + result = builder.create<mlir::complex::PowiOp>( + loc, mathLibFuncType.getResult(0), args[0], args[1], fmfAttr); } else { if (!mlir::isa<mlir::ComplexType>(exp.getType())) { auto realTy = complexTy.getElementType(); @@ -1343,7 +1345,7 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc, exp = builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero); } - result = builder.create<mlir::complex::PowOp>(loc, args[0], exp); + result = builder.create<mlir::complex::PowOp>(loc, args[0], exp, fmfAttr); } result = builder.createConvert(loc, mathLibFuncType.getResult(0), result); return result; diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90 index 7e1691dd1587a..b7695a761a0b8 100644 --- a/flang/test/Lower/HLFIR/binary-ops.f90 +++ b/flang/test/Lower/HLFIR/binary-ops.f90 @@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z) ! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>) ! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>> ! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32> -! CHECK: %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] : complex<f32>, i32 +! CHECK: %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] fastmath<contract> : complex<f32>, i32 subroutine extremum(c, n, l) integer(8), intent(in) :: l diff --git a/flang/test/Lower/Intrinsics/pow_complex16i.f90 b/flang/test/Lower/Intrinsics/pow_complex16i.f90 index 0b26024b02021..ea18d67b75460 100644 --- a/flang/test/Lower/Intrinsics/pow_complex16i.f90 +++ b/flang/test/Lower/Intrinsics/pow_complex16i.f90 @@ -4,7 +4,7 @@ ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s ! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128> -! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128> +! CHECK: complex.powi %{{.*}}, %{{.*}} fastmath<contract> : complex<f128> complex(16) :: a integer(4) :: b b = a ** b diff --git a/flang/test/Lower/Intrinsics/pow_complex16k.f90 b/flang/test/Lower/Intrinsics/pow_complex16k.f90 index 90a9f5e03628d..d2b70185bda9f 100644 --- a/flang/test/Lower/Intrinsics/pow_complex16k.f90 +++ b/flang/test/Lower/Intrinsics/pow_complex16k.f90 @@ -4,7 +4,7 @@ ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s ! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128> -! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128> +! CHECK: complex.powi %{{.*}}, %{{.*}} fastmath<contract> : complex<f128> complex(16) :: a integer(8) :: b b = a ** b diff --git a/flang/test/Transforms/convert-complex-pow.fir b/flang/test/Transforms/convert-complex-pow.fir index e09fa7316c4b0..23316ed46d40f 100644 --- a/flang/test/Transforms/convert-complex-pow.fir +++ b/flang/test/Transforms/convert-complex-pow.fir @@ -2,51 +2,38 @@ module { func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> { - %c0 = arith.constant 0.0 : f32 - %0 = fir.convert %arg1 : (i32) -> f32 - %1 = complex.create %0, %c0 : complex<f32> - %2 = complex.pow %arg0, %1 : complex<f32> - return %2 : complex<f32> + %0 = complex.powi %arg0, %arg1 : complex<f32>, i32 + return %0 : complex<f32> + } + + func.func @pow_c4_i4_fast(%arg0: complex<f32>, %arg1: i32) -> complex<f32> { + %0 = complex.powi %arg0, %arg1 fastmath<fast> : complex<f32>, i32 + return %0 : complex<f32> } func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> { - %c0 = arith.constant 0.0 : f32 - %0 = fir.convert %arg1 : (i64) -> f32 - %1 = complex.create %0, %c0 : complex<f32> - %2 = complex.pow %arg0, %1 : complex<f32> - return %2 : complex<f32> + %0 = complex.powi %arg0, %arg1 : complex<f32>, i64 + return %0 : complex<f32> } func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> { - %c0 = arith.constant 0.0 : f64 - %0 = fir.convert %arg1 : (i32) -> f64 - %1 = complex.create %0, %c0 : complex<f64> - %2 = complex.pow %arg0, %1 : complex<f64> - return %2 : complex<f64> + %0 = complex.powi %arg0, %arg1 : complex<f64>, i32 + return %0 : complex<f64> } func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> { - %c0 = arith.constant 0.0 : f64 - %0 = fir.convert %arg1 : (i64) -> f64 - %1 = complex.create %0, %c0 : complex<f64> - %2 = complex.pow %arg0, %1 : complex<f64> - return %2 : complex<f64> + %0 = complex.powi %arg0, %arg1 : complex<f64>, i64 + return %0 : complex<f64> } func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> { - %c0 = arith.constant 0.0 : f128 - %0 = fir.convert %arg1 : (i32) -> f128 - %1 = complex.create %0, %c0 : complex<f128> - %2 = complex.pow %arg0, %1 : complex<f128> - return %2 : complex<f128> + %0 = complex.powi %arg0, %arg1 : complex<f128>, i32 + return %0 : complex<f128> } func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> { - %c0 = arith.constant 0.0 : f128 - %0 = fir.convert %arg1 : (i64) -> f128 - %1 = complex.create %0, %c0 : complex<f128> - %2 = complex.pow %arg0, %1 : complex<f128> - return %2 : complex<f128> + %0 = complex.powi %arg0, %arg1 : complex<f128>, i64 + return %0 : complex<f128> } func.func @pow_c4_fast(%arg0: complex<f32>, %arg1: f32) -> complex<f32> { @@ -74,26 +61,37 @@ module { // CHECK-LABEL: func.func @pow_c4_i4( // CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32> // CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi + +// CHECK-LABEL: func.func @pow_c4_i4_fast( +// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) fastmath<fast> : (complex<f32>, i32) -> complex<f32> +// CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi // CHECK-LABEL: func.func @pow_c4_i8( // CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32> // CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi // CHECK-LABEL: func.func @pow_c8_i4( // CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64> // CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi // CHECK-LABEL: func.func @pow_c8_i8( // CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64> // CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi // CHECK-LABEL: func.func @pow_c16_i4( // CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128> // CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi // CHECK-LABEL: func.func @pow_c16_i8( // CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128> // CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi // CHECK-LABEL: func.func @pow_c4_fast( // CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f32> @@ -108,4 +106,4 @@ module { // CHECK-LABEL: func.func @pow_c16_complex( // CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f128> // CHECK: fir.call @_FortranACPowF128(%{{.*}}, %[[EXP]]) : (complex<f128>, complex<f128>) -> complex<f128> -// CHECK-NOT: complex.pow \ No newline at end of file +// CHECK-NOT: complex.pow diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td index ca5103c16889c..828379ded14b3 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -449,10 +449,13 @@ def PowOp : ComplexArithmeticOp<"pow"> { def PowiOp : Complex_Op<"powi", [Pure, Elementwise, SameOperandsAndResultShape, - AllTypesMatch<["lhs", "result"]>]> { - let summary = "complex number raised to integer power"; + AllTypesMatch<["lhs", "result"]>, + DeclareOpInterfaceMethods<ArithFastMathInterface>]> { + let summary = "complex number raised to signed integer power"; let description = [{ - The `powi` operation takes a complex number and an integer exponent. + The `powi` operation takes a `base` operand of complex type and a `power` + operand of signed integer type and returns one result of the same type + as `base`. The result is `base` raised to the power of `power`. Example: @@ -462,11 +465,12 @@ def PowiOp : Complex_Op<"powi", }]; let arguments = (ins Complex<AnyFloat>:$lhs, - AnySignlessInteger:$rhs); + AnySignlessInteger:$rhs, + OptionalAttr<Arith_FastMathAttr>:$fastmath); let results = (outs Complex<AnyFloat>:$result); let assemblyFormat = - "$lhs `,` $rhs attr-dict `:` type($result) `,` type($rhs)"; + "$lhs `,` $rhs (`fastmath` `` $fastmath^)? attr-dict `:` type($result) `,` type($rhs)"; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index dbb26377fc3c4..42099aaa6b574 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -100,7 +100,7 @@ struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> { loc, op.getLhs().getType(), exponentReal, zeroImag); rewriter.replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(), - exponent); + exponent, op.getFastmathAttr()); return success(); } }; diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp index 3711c112cc631..fffccf130a571 100644 --- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp +++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp @@ -217,13 +217,25 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite( // `[fi]powi(x, negative_exponent)` // with: // (1 / x) * (1 / x) * (1 / x) * ... + auto buildMul = [&](Value lhs, Value rhs) { + if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>) + return rewriter.create<MulOpTy>(loc, op.getType(), lhs, rhs, + op.getFastmathAttr()); + else + return MulOpTy::create(rewriter, loc, lhs, rhs); + }; for (unsigned i = 1; i < exponentValue; ++i) - result = MulOpTy::create(rewriter, loc, result, base); + result = buildMul(result, base); // Inverse the base for negative exponent, i.e. for // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`. - if (exponentIsNegative) - result = DivOpTy::create(rewriter, loc, bcast(one), result); + if (exponentIsNegative) { + if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>) + result = rewriter.create<DivOpTy>(loc, op.getType(), bcast(one), result, + op.getFastmathAttr()); + else + result = DivOpTy::create(rewriter, loc, bcast(one), result); + } rewriter.replaceOp(op, result); return success(); >From f659924576565e89b98ad381a8fd54020515592b Mon Sep 17 00:00:00 2001 From: Akash Banerjee <akash.baner...@amd.com> Date: Wed, 17 Sep 2025 23:04:41 +0100 Subject: [PATCH 5/6] Remove genComplexPow, use genMathOp instead. Add complex.powi->complex.pow conversion in ComplexToStandard pass. --- flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 59 +++++++------------ .../ComplexToStandard/ComplexToStandard.cpp | 25 ++++++++ .../convert-to-standard.mlir | 30 ++++++++++ 3 files changed, 76 insertions(+), 38 deletions(-) diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index c7cbf162db786..9e7ed8f4d3129 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -1272,7 +1272,18 @@ mlir::Value genMathOp(fir::FirOpBuilder &builder, mlir::Location loc, LLVM_DEBUG(llvm::dbgs() << "Generating '" << mathLibFuncName << "' operation with type "; mathLibFuncType.dump(); llvm::dbgs() << "\n"); - result = T::create(builder, loc, args); + if constexpr (std::is_same_v<T, mlir::complex::PowOp>) { + auto resultType = mathLibFuncType.getResult(0); + result = T::create(builder, loc, resultType, args); + } else if constexpr (std::is_same_v<T, mlir::complex::PowiOp>) { + auto resultType = mathLibFuncType.getResult(0); + auto fmfAttr = mlir::arith::FastMathFlagsAttr::get( + builder.getContext(), builder.getFastMathFlags()); + result = builder.create<mlir::complex::PowiOp>(loc, resultType, args[0], + args[1], fmfAttr); + } else { + result = T::create(builder, loc, args); + } } LLVM_DEBUG(result.dump(); llvm::dbgs() << "\n"); return result; @@ -1323,34 +1334,6 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc, return result; } -mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc, - const MathOperation &mathOp, - mlir::FunctionType mathLibFuncType, - llvm::ArrayRef<mlir::Value> args) { - if (mathRuntimeVersion == preciseVersion) - return genLibCall(builder, loc, mathOp, mathLibFuncType, args); - auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0)); - mlir::Value exp = args[1]; - mlir::Value result; - auto fmfAttr = mlir::arith::FastMathFlagsAttr::get( - builder.getContext(), builder.getFastMathFlags()); - if (mlir::isa<mlir::IntegerType>(exp.getType())) { - result = builder.create<mlir::complex::PowiOp>( - loc, mathLibFuncType.getResult(0), args[0], args[1], fmfAttr); - } else { - if (!mlir::isa<mlir::ComplexType>(exp.getType())) { - auto realTy = complexTy.getElementType(); - mlir::Value realExp = builder.createConvert(loc, realTy, exp); - mlir::Value zero = builder.createRealConstant(loc, realTy, 0); - exp = builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, - zero); - } - result = builder.create<mlir::complex::PowOp>(loc, args[0], exp, fmfAttr); - } - result = builder.createConvert(loc, mathLibFuncType.getResult(0), result); - return result; -} - /// Mapping between mathematical intrinsic operations and MLIR operations /// of some appropriate dialect (math, complex, etc.) or libm calls. /// TODO: support remaining Fortran math intrinsics. @@ -1676,11 +1659,11 @@ static constexpr MathOperation mathOperations[] = { {"pow", RTNAME_STRING(PowF128), FuncTypeReal16Real16Real16, genLibF128Call}, {"pow", "cpowf", genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Complex<4>>, - genComplexPow}, + genMathOp<mlir::complex::PowOp>}, {"pow", "cpow", genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Complex<8>>, - genComplexPow}, + genMathOp<mlir::complex::PowOp>}, {"pow", RTNAME_STRING(CPowF128), FuncTypeComplex16Complex16Complex16, - genComplexPow}, + genMathOp<mlir::complex::PowOp>}, {"pow", RTNAME_STRING(FPow4i), genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>, genMathOp<mlir::math::FPowIOp>}, @@ -1701,20 +1684,20 @@ static constexpr MathOperation mathOperations[] = { genMathOp<mlir::math::FPowIOp>}, {"pow", RTNAME_STRING(cpowi), genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>, - genComplexPow}, + genMathOp<mlir::complex::PowiOp>}, {"pow", RTNAME_STRING(zpowi), genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>, - genComplexPow}, + genMathOp<mlir::complex::PowiOp>}, {"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4, - genComplexPow}, + genMathOp<mlir::complex::PowiOp>}, {"pow", RTNAME_STRING(cpowk), genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>, - genComplexPow}, + genMathOp<mlir::complex::PowiOp>}, {"pow", RTNAME_STRING(zpowk), genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>, - genComplexPow}, + genMathOp<mlir::complex::PowiOp>}, {"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8, - genComplexPow}, + genMathOp<mlir::complex::PowiOp>}, {"pow-unsigned", RTNAME_STRING(UPow1), genFuncType<Ty::Integer<1>, Ty::Integer<1>, Ty::Integer<1>>, genLibCall}, {"pow-unsigned", RTNAME_STRING(UPow2), diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 5ad514d0f48e7..5613e021cd709 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -926,6 +926,30 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder, return cutoff4; } +struct PowiOpConversion : public OpConversionPattern<complex::PowiOp> { + using OpConversionPattern<complex::PowiOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::PowiOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder builder(op.getLoc(), rewriter); + auto type = cast<ComplexType>(op.getType()); + auto elementType = cast<FloatType>(type.getElementType()); + + Value floatExponent = + builder.create<arith::SIToFPOp>(elementType, adaptor.getRhs()); + Value zero = arith::ConstantOp::create( + builder, elementType, builder.getFloatAttr(elementType, 0.0)); + Value complexExponent = + complex::CreateOp::create(builder, type, floatExponent, zero); + + auto pow = builder.create<complex::PowOp>( + type, adaptor.getLhs(), complexExponent, op.getFastmathAttr()); + rewriter.replaceOp(op, pow.getResult()); + return success(); + } +}; + struct PowOpConversion : public OpConversionPattern<complex::PowOp> { using OpConversionPattern<complex::PowOp>::OpConversionPattern; @@ -1070,6 +1094,7 @@ void mlir::populateComplexToStandardConversionPatterns( SqrtOpConversion, TanTanhOpConversion<complex::TanOp>, TanTanhOpConversion<complex::TanhOp>, + PowiOpConversion, PowOpConversion, RsqrtOpConversion >(patterns.getContext()); diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index a4ddabbd0821a..dec62f92c7b2e 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -700,6 +700,36 @@ func.func @complex_pow_with_fmf(%lhs: complex<f32>, // ----- +// CHECK-LABEL: func.func @complex_powi +// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[EXP:.*]]: i32 +func.func @complex_powi(%lhs: complex<f32>, %rhs: i32) -> complex<f32> { + %pow = complex.powi %lhs, %rhs : complex<f32>, i32 + return %pow : complex<f32> +} + +// CHECK: %[[FLOAT_EXP:.*]] = arith.sitofp %[[EXP]] : i32 to f32 +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[CPLX_EXP:.*]] = complex.create %[[FLOAT_EXP]], %[[ZERO]] : complex<f32> +// CHECK: math.atan2 +// CHECK-NOT: complex.powi + +// ----- + +// CHECK-LABEL: func.func @complex_powi_with_fmf +// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[EXP:.*]]: i32 +func.func @complex_powi_with_fmf(%lhs: complex<f32>, %rhs: i32) -> complex<f32> { + %pow = complex.powi %lhs, %rhs fastmath<nnan,contract> : complex<f32>, i32 + return %pow : complex<f32> +} + +// CHECK: %[[FLOAT_EXP:.*]] = arith.sitofp %[[EXP]] : i32 to f32 +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[CPLX_EXP:.*]] = complex.create %[[FLOAT_EXP]], %[[ZERO]] : complex<f32> +// CHECK: math.atan2 {{.*}} fastmath<nnan,contract> : f32 +// CHECK-NOT: complex.powi + +// ----- + // CHECK-LABEL: func.func @complex_rsqrt func.func @complex_rsqrt(%arg: complex<f32>) -> complex<f32> { %rsqrt = complex.rsqrt %arg : complex<f32> >From 9902e0850bcd8c81d0715e966cd1e7307538a748 Mon Sep 17 00:00:00 2001 From: Akash Banerjee <akash.baner...@amd.com> Date: Thu, 18 Sep 2025 16:19:44 +0100 Subject: [PATCH 6/6] Convert both ops in single walk. --- .../Transforms/ConvertComplexPow.cpp | 111 +++++++++--------- 1 file changed, 57 insertions(+), 54 deletions(-) diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp index 1c251883cf707..127f8720ae524 100644 --- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp +++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp @@ -51,60 +51,63 @@ void ConvertComplexPowPass::runOnOperation() { ModuleOp mod = getOperation(); fir::FirOpBuilder builder(mod, fir::getKindMapping(mod)); - mod.walk([&](complex::PowiOp op) { - builder.setInsertionPoint(op); - Location loc = op.getLoc(); - auto complexTy = cast<ComplexType>(op.getType()); - auto elemTy = complexTy.getElementType(); - Value base = op.getLhs(); - Value intExp = op.getRhs(); - func::FuncOp callee; - unsigned realBits = cast<FloatType>(elemTy).getWidth(); - unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth(); - auto funcTy = builder.getFunctionType( - {complexTy, builder.getIntegerType(intBits)}, {complexTy}); - if (realBits == 32 && intBits == 32) - callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy); - else if (realBits == 32 && intBits == 64) - callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy); - else if (realBits == 64 && intBits == 32) - callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy); - else if (realBits == 64 && intBits == 64) - callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy); - else if (realBits == 128 && intBits == 32) - callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy); - else if (realBits == 128 && intBits == 64) - callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy); - else - return; - auto call = fir::CallOp::create(builder, loc, callee, {base, intExp}); - if (auto fmf = op.getFastmathAttr()) - call.setFastmathAttr(fmf); - op.replaceAllUsesWith(call.getResult(0)); - op.erase(); - }); + mod.walk([&](Operation *op) { + if (auto powIop = dyn_cast<complex::PowiOp>(op)) { + builder.setInsertionPoint(powIop); + Location loc = powIop.getLoc(); + auto complexTy = cast<ComplexType>(powIop.getType()); + auto elemTy = complexTy.getElementType(); + Value base = powIop.getLhs(); + Value intExp = powIop.getRhs(); + func::FuncOp callee; + unsigned realBits = cast<FloatType>(elemTy).getWidth(); + unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth(); + auto funcTy = builder.getFunctionType( + {complexTy, builder.getIntegerType(intBits)}, {complexTy}); + if (realBits == 32 && intBits == 32) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy); + else if (realBits == 32 && intBits == 64) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy); + else if (realBits == 64 && intBits == 32) + callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy); + else if (realBits == 64 && intBits == 64) + callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy); + else if (realBits == 128 && intBits == 32) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy); + else if (realBits == 128 && intBits == 64) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy); + else + return; + auto call = fir::CallOp::create(builder, loc, callee, {base, intExp}); + if (auto fmf = powIop.getFastmathAttr()) + call.setFastmathAttr(fmf); + powIop.replaceAllUsesWith(call.getResult(0)); + powIop.erase(); + } - mod.walk([&](complex::PowOp op) { - builder.setInsertionPoint(op); - Location loc = op.getLoc(); - auto complexTy = cast<ComplexType>(op.getType()); - auto elemTy = complexTy.getElementType(); - unsigned realBits = cast<FloatType>(elemTy).getWidth(); - func::FuncOp callee; - auto funcTy = builder.getFunctionType({complexTy, complexTy}, {complexTy}); - if (realBits == 32) - callee = getOrDeclare(builder, loc, "cpowf", funcTy); - else if (realBits == 64) - callee = getOrDeclare(builder, loc, "cpow", funcTy); - else if (realBits == 128) - callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy); - else - return; - auto call = - fir::CallOp::create(builder, loc, callee, {op.getLhs(), op.getRhs()}); - if (auto fmf = op.getFastmathAttr()) - call.setFastmathAttr(fmf); - op.replaceAllUsesWith(call.getResult(0)); - op.erase(); + if (auto powOp = dyn_cast<complex::PowOp>(op)) { + builder.setInsertionPoint(powOp); + Location loc = powOp.getLoc(); + auto complexTy = cast<ComplexType>(powOp.getType()); + auto elemTy = complexTy.getElementType(); + unsigned realBits = cast<FloatType>(elemTy).getWidth(); + func::FuncOp callee; + auto funcTy = + builder.getFunctionType({complexTy, complexTy}, {complexTy}); + if (realBits == 32) + callee = getOrDeclare(builder, loc, "cpowf", funcTy); + else if (realBits == 64) + callee = getOrDeclare(builder, loc, "cpow", funcTy); + else if (realBits == 128) + callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy); + else + return; + auto call = fir::CallOp::create(builder, loc, callee, + {powOp.getLhs(), powOp.getRhs()}); + if (auto fmf = powOp.getFastmathAttr()) + call.setFastmathAttr(fmf); + powOp.replaceAllUsesWith(call.getResult(0)); + powOp.erase(); + } }); } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits