https://github.com/TIFitis updated https://github.com/llvm/llvm-project/pull/158722
>From 8c9a156aa4f682ca836403bd71608c5aa2352d46 Mon Sep 17 00:00:00 2001 From: Akash Banerjee <akash.baner...@amd.com> Date: Mon, 15 Sep 2025 20:35:29 +0100 Subject: [PATCH 1/3] Add complex.powi op. --- flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 20 ++-- .../Transforms/ConvertComplexPow.cpp | 92 +++++++++---------- flang/test/Lower/HLFIR/binary-ops.f90 | 2 +- .../test/Lower/Intrinsics/pow_complex16i.f90 | 2 +- .../test/Lower/Intrinsics/pow_complex16k.f90 | 2 +- flang/test/Lower/amdgcn-complex.f90 | 9 ++ flang/test/Lower/power-operator.f90 | 9 +- flang/test/Transforms/convert-complex-pow.fir | 42 +++------ .../mlir/Dialect/Complex/IR/ComplexOps.td | 26 ++++++ .../ComplexToROCDLLibraryCalls.cpp | 41 ++++++++- .../Transforms/AlgebraicSimplification.cpp | 24 +++-- .../Dialect/Math/Transforms/CMakeLists.txt | 1 + .../complex-to-rocdl-library-calls.mlir | 14 +++ mlir/test/Dialect/Complex/powi-simplify.mlir | 20 ++++ 14 files changed, 198 insertions(+), 106 deletions(-) create mode 100644 mlir/test/Dialect/Complex/powi-simplify.mlir diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 466458c05dba7..74a4e8f85c8ff 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -1331,14 +1331,20 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc, return genLibCall(builder, loc, mathOp, mathLibFuncType, args); auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0)); mlir::Value exp = args[1]; - if (!mlir::isa<mlir::ComplexType>(exp.getType())) { - auto realTy = complexTy.getElementType(); - mlir::Value realExp = builder.createConvert(loc, realTy, exp); - mlir::Value zero = builder.createRealConstant(loc, realTy, 0); - exp = - builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero); + mlir::Value result; + if (mlir::isa<mlir::IntegerType>(exp.getType()) || + mlir::isa<mlir::IndexType>(exp.getType())) { + result = builder.create<mlir::complex::PowiOp>(loc, args[0], exp); + } else { + if (!mlir::isa<mlir::ComplexType>(exp.getType())) { + auto realTy = complexTy.getElementType(); + mlir::Value realExp = builder.createConvert(loc, realTy, exp); + mlir::Value zero = builder.createRealConstant(loc, realTy, 0); + exp = builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, + zero); + } + result = builder.create<mlir::complex::PowOp>(loc, args[0], exp); } - mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp); result = builder.createConvert(loc, mathLibFuncType.getResult(0), result); return result; } diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp index dced5f90d6924..42f5df160798c 100644 --- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp +++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp @@ -61,63 +61,55 @@ void ConvertComplexPowPass::runOnOperation() { fir::FirOpBuilder builder(mod, fir::getKindMapping(mod)); - mod.walk([&](complex::PowOp op) { + mod.walk([&](complex::PowiOp op) { builder.setInsertionPoint(op); Location loc = op.getLoc(); auto complexTy = cast<ComplexType>(op.getType()); auto elemTy = complexTy.getElementType(); - Value base = op.getLhs(); - Value rhs = op.getRhs(); - - Value intExp; - if (auto create = rhs.getDefiningOp<complex::CreateOp>()) { - if (isZero(create.getImaginary())) { - if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) { - if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType())) - intExp = conv.getValue(); - } - } - } - + Value intExp = op.getRhs(); func::FuncOp callee; - SmallVector<Value> args; - if (intExp) { - unsigned realBits = cast<FloatType>(elemTy).getWidth(); - unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth(); - auto funcTy = builder.getFunctionType( - {complexTy, builder.getIntegerType(intBits)}, {complexTy}); - if (realBits == 32 && intBits == 32) - callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy); - else if (realBits == 32 && intBits == 64) - callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy); - else if (realBits == 64 && intBits == 32) - callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy); - else if (realBits == 64 && intBits == 64) - callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy); - else if (realBits == 128 && intBits == 32) - callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy); - else if (realBits == 128 && intBits == 64) - callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy); - else - return; - args = {base, intExp}; - } else { - unsigned realBits = cast<FloatType>(elemTy).getWidth(); - auto funcTy = - builder.getFunctionType({complexTy, complexTy}, {complexTy}); - if (realBits == 32) - callee = getOrDeclare(builder, loc, "cpowf", funcTy); - else if (realBits == 64) - callee = getOrDeclare(builder, loc, "cpow", funcTy); - else if (realBits == 128) - callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy); - else - return; - args = {base, rhs}; - } + unsigned realBits = cast<FloatType>(elemTy).getWidth(); + unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth(); + auto funcTy = builder.getFunctionType( + {complexTy, builder.getIntegerType(intBits)}, {complexTy}); + if (realBits == 32 && intBits == 32) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy); + else if (realBits == 32 && intBits == 64) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy); + else if (realBits == 64 && intBits == 32) + callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy); + else if (realBits == 64 && intBits == 64) + callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy); + else if (realBits == 128 && intBits == 32) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy); + else if (realBits == 128 && intBits == 64) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy); + else + return; + auto call = fir::CallOp::create(builder, loc, callee, {base, intExp}); + op.replaceAllUsesWith(call.getResult(0)); + op.erase(); + }); - auto call = fir::CallOp::create(builder, loc, callee, args); + mod.walk([&](complex::PowOp op) { + builder.setInsertionPoint(op); + Location loc = op.getLoc(); + auto complexTy = cast<ComplexType>(op.getType()); + auto elemTy = complexTy.getElementType(); + unsigned realBits = cast<FloatType>(elemTy).getWidth(); + func::FuncOp callee; + auto funcTy = builder.getFunctionType({complexTy, complexTy}, {complexTy}); + if (realBits == 32) + callee = getOrDeclare(builder, loc, "cpowf", funcTy); + else if (realBits == 64) + callee = getOrDeclare(builder, loc, "cpow", funcTy); + else if (realBits == 128) + callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy); + else + return; + auto call = + fir::CallOp::create(builder, loc, callee, {op.getLhs(), op.getRhs()}); op.replaceAllUsesWith(call.getResult(0)); op.erase(); }); diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90 index 1fbd333db37c3..7e1691dd1587a 100644 --- a/flang/test/Lower/HLFIR/binary-ops.f90 +++ b/flang/test/Lower/HLFIR/binary-ops.f90 @@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z) ! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>) ! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>> ! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32> -! CHECK: %[[VAL_8:.*]] = complex.pow +! CHECK: %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] : complex<f32>, i32 subroutine extremum(c, n, l) integer(8), intent(in) :: l diff --git a/flang/test/Lower/Intrinsics/pow_complex16i.f90 b/flang/test/Lower/Intrinsics/pow_complex16i.f90 index 1827863a57f43..0b26024b02021 100644 --- a/flang/test/Lower/Intrinsics/pow_complex16i.f90 +++ b/flang/test/Lower/Intrinsics/pow_complex16i.f90 @@ -4,7 +4,7 @@ ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s ! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128> -! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128> +! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128> complex(16) :: a integer(4) :: b b = a ** b diff --git a/flang/test/Lower/Intrinsics/pow_complex16k.f90 b/flang/test/Lower/Intrinsics/pow_complex16k.f90 index 039dfd5152a06..90a9f5e03628d 100644 --- a/flang/test/Lower/Intrinsics/pow_complex16k.f90 +++ b/flang/test/Lower/Intrinsics/pow_complex16k.f90 @@ -4,7 +4,7 @@ ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s ! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128> -! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128> +! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128> complex(16) :: a integer(8) :: b b = a ** b diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90 index 4ee5de4d2842e..a28eaea82379b 100644 --- a/flang/test/Lower/amdgcn-complex.f90 +++ b/flang/test/Lower/amdgcn-complex.f90 @@ -25,3 +25,12 @@ subroutine pow_test(a, b, c) complex :: a, b, c a = b**c end subroutine pow_test + +! CHECK-LABEL: func @_QPpowi_test( +! CHECK: complex.powi +! CHECK-NOT: fir.call @_FortranAcpowi +subroutine powi_test(a, b, c) + complex :: a, b + integer :: i + b = a ** i +end subroutine powi_test diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90 index 3058927144248..9f74d172a6bb2 100644 --- a/flang/test/Lower/power-operator.f90 +++ b/flang/test/Lower/power-operator.f90 @@ -96,7 +96,7 @@ subroutine pow_c4_i4(x, y, z) complex :: x, z integer :: y z = x ** y - ! CHECK: complex.pow + ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32 ! PRECISE: fir.call @_FortranAcpowi end subroutine @@ -105,7 +105,7 @@ subroutine pow_c4_i8(x, y, z) complex :: x, z integer(8) :: y z = x ** y - ! CHECK: complex.pow + ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i64 ! PRECISE: fir.call @_FortranAcpowk end subroutine @@ -114,7 +114,7 @@ subroutine pow_c8_i4(x, y, z) complex(8) :: x, z integer :: y z = x ** y - ! CHECK: complex.pow + ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i32 ! PRECISE: fir.call @_FortranAzpowi end subroutine @@ -123,7 +123,7 @@ subroutine pow_c8_i8(x, y, z) complex(8) :: x, z integer(8) :: y z = x ** y - ! CHECK: complex.pow + ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i64 ! PRECISE: fir.call @_FortranAzpowk end subroutine @@ -142,4 +142,3 @@ subroutine pow_c8_c8(x, y, z) ! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64> ! PRECISE: fir.call @cpow end subroutine - diff --git a/flang/test/Transforms/convert-complex-pow.fir b/flang/test/Transforms/convert-complex-pow.fir index d980817aba9b9..4555fea61e496 100644 --- a/flang/test/Transforms/convert-complex-pow.fir +++ b/flang/test/Transforms/convert-complex-pow.fir @@ -2,18 +2,12 @@ module { func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> { - %c0 = arith.constant 0.000000e+00 : f32 - %c1 = fir.convert %arg1 : (i32) -> f32 - %c2 = complex.create %c1, %c0 : complex<f32> - %0 = complex.pow %arg0, %c2 : complex<f32> + %0 = complex.powi %arg0, %arg1 : complex<f32>, i32 return %0 : complex<f32> } func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> { - %c0 = arith.constant 0.000000e+00 : f32 - %c1 = fir.convert %arg1 : (i64) -> f32 - %c2 = complex.create %c1, %c0 : complex<f32> - %0 = complex.pow %arg0, %c2 : complex<f32> + %0 = complex.powi %arg0, %arg1 : complex<f32>, i64 return %0 : complex<f32> } @@ -23,18 +17,12 @@ module { } func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> { - %c0 = arith.constant 0.000000e+00 : f64 - %c1 = fir.convert %arg1 : (i32) -> f64 - %c2 = complex.create %c1, %c0 : complex<f64> - %0 = complex.pow %arg0, %c2 : complex<f64> + %0 = complex.powi %arg0, %arg1 : complex<f64>, i32 return %0 : complex<f64> } func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> { - %c0 = arith.constant 0.000000e+00 : f64 - %c1 = fir.convert %arg1 : (i64) -> f64 - %c2 = complex.create %c1, %c0 : complex<f64> - %0 = complex.pow %arg0, %c2 : complex<f64> + %0 = complex.powi %arg0, %arg1 : complex<f64>, i64 return %0 : complex<f64> } @@ -44,18 +32,12 @@ module { } func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> { - %c0 = arith.constant 0.000000e+00 : f128 - %c1 = fir.convert %arg1 : (i32) -> f128 - %c2 = complex.create %c1, %c0 : complex<f128> - %0 = complex.pow %arg0, %c2 : complex<f128> + %0 = complex.powi %arg0, %arg1 : complex<f128>, i32 return %0 : complex<f128> } func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> { - %c0 = arith.constant 0.000000e+00 : f128 - %c1 = fir.convert %arg1 : (i64) -> f128 - %c2 = complex.create %c1, %c0 : complex<f128> - %0 = complex.pow %arg0, %c2 : complex<f128> + %0 = complex.powi %arg0, %arg1 : complex<f128>, i64 return %0 : complex<f128> } @@ -67,11 +49,11 @@ module { // CHECK-LABEL: func.func @pow_c4_i4( // CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32> -// CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi // CHECK-LABEL: func.func @pow_c4_i8( // CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32> -// CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi // CHECK-LABEL: func.func @pow_c4_c4( // CHECK: fir.call @cpowf(%{{.*}}, %{{.*}}) : (complex<f32>, complex<f32>) -> complex<f32> @@ -79,11 +61,11 @@ module { // CHECK-LABEL: func.func @pow_c8_i4( // CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64> -// CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi // CHECK-LABEL: func.func @pow_c8_i8( // CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64> -// CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi // CHECK-LABEL: func.func @pow_c8_c8( // CHECK: fir.call @cpow(%{{.*}}, %{{.*}}) : (complex<f64>, complex<f64>) -> complex<f64> @@ -91,11 +73,11 @@ module { // CHECK-LABEL: func.func @pow_c16_i4( // CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128> -// CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi // CHECK-LABEL: func.func @pow_c16_i8( // CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128> -// CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi // CHECK-LABEL: func.func @pow_c16_c16( // CHECK: fir.call @_FortranACPowF128(%{{.*}}, %{{.*}}) : (complex<f128>, complex<f128>) -> complex<f128> diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td index 44590406301eb..ca5103c16889c 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -443,6 +443,32 @@ def PowOp : ComplexArithmeticOp<"pow"> { }]; } +//===----------------------------------------------------------------------===// +// PowiOp +//===----------------------------------------------------------------------===// + +def PowiOp : Complex_Op<"powi", + [Pure, Elementwise, SameOperandsAndResultShape, + AllTypesMatch<["lhs", "result"]>]> { + let summary = "complex number raised to integer power"; + let description = [{ + The `powi` operation takes a complex number and an integer exponent. + + Example: + + ```mlir + %a = complex.powi %b, %c : complex<f32>, i32 + ``` + }]; + + let arguments = (ins Complex<AnyFloat>:$lhs, + AnySignlessInteger:$rhs); + let results = (outs Complex<AnyFloat>:$result); + + let assemblyFormat = + "$lhs `,` $rhs attr-dict `:` type($result) `,` type($rhs)"; +} + //===----------------------------------------------------------------------===// // ReOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 0372f32d6b6df..25e5ab49cdb8c 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -7,9 +7,11 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -71,10 +73,40 @@ struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> { return success(); } }; + +// Rewrite complex.powi(z, n) -> complex.pow(z, complex(float(n), 0)) +struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> { + using OpRewritePattern<complex::PowiOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(complex::PowiOp op, + PatternRewriter &rewriter) const final { + auto complexType = cast<ComplexType>(getElementTypeOrSelf(op.getType())); + Type elementType = complexType.getElementType(); + + Type exponentType = op.getRhs().getType(); + Type exponentFloatType = elementType; + if (auto shapedType = dyn_cast<ShapedType>(exponentType)) + exponentFloatType = shapedType.cloneWith(std::nullopt, elementType); + + Location loc = op.getLoc(); + Value exponentReal = + rewriter.create<arith::SIToFPOp>(loc, exponentFloatType, op.getRhs()); + Value zeroImag = rewriter.create<arith::ConstantOp>( + loc, rewriter.getZeroAttr(exponentFloatType)); + Value exponent = rewriter.create<complex::CreateOp>( + loc, op.getLhs().getType(), exponentReal, zeroImag); + + rewriter + .replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(), + exponent); + return success(); + } +}; } // namespace void mlir::populateComplexToROCDLLibraryCallsConversionPatterns( RewritePatternSet &patterns) { + patterns.add<PowiOpToROCDLLibraryCalls>(patterns.getContext()); patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext()); patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>( patterns.getContext(), "__ocml_cabs_f32"); @@ -125,11 +157,12 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() { populateComplexToROCDLLibraryCallsConversionPatterns(patterns); ConversionTarget target(getContext()); - target.addLegalDialect<func::FuncDialect>(); - target.addLegalOp<complex::MulOp>(); + target.addLegalDialect<arith::ArithDialect, func::FuncDialect>(); + target.addLegalOp<complex::CreateOp, complex::MulOp>(); target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp, - complex::LogOp, complex::PowOp, complex::SinOp, - complex::SqrtOp, complex::TanOp, complex::TanhOp>(); + complex::LogOp, complex::PowOp, complex::PowiOp, + complex::SinOp, complex::SqrtOp, complex::TanOp, + complex::TanhOp>(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp index 31785eb20a642..3711c112cc631 100644 --- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp +++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -175,12 +176,20 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite( Value one; Type opType = getElementTypeOrSelf(op.getType()); - if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) + if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) { one = arith::ConstantOp::create(rewriter, loc, rewriter.getFloatAttr(opType, 1.0)); - else + } else if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>) { + auto complexTy = cast<ComplexType>(opType); + Type elementType = complexTy.getElementType(); + auto realPart = rewriter.getFloatAttr(elementType, 1.0); + auto imagPart = rewriter.getFloatAttr(elementType, 0.0); + one = rewriter.create<complex::ConstantOp>( + loc, complexTy, rewriter.getArrayAttr({realPart, imagPart})); + } else { one = arith::ConstantOp::create(rewriter, loc, rewriter.getIntegerAttr(opType, 1)); + } // Replace `[fi]powi(x, 0)` with `1`. if (exponentValue == 0) { @@ -224,9 +233,10 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite( void mlir::populateMathAlgebraicSimplificationPatterns( RewritePatternSet &patterns) { - patterns - .add<PowFStrengthReduction, - PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>, - PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>( - patterns.getContext()); + patterns.add< + PowFStrengthReduction, + PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>, + PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>, + PowIStrengthReduction<complex::PowiOp, complex::DivOp, complex::MulOp>>( + patterns.getContext(), /*exponentThreshold=*/8); } diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt index d37a056e8e158..ff62b515533c3 100644 --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRMathTransforms LINK_LIBS PUBLIC MLIRArithDialect + MLIRComplexDialect MLIRDialectUtils MLIRIR MLIRMathDialect diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir index 080ba4f0ff67b..cf177528e532c 100644 --- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir +++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir @@ -68,6 +68,20 @@ func.func @pow_caller(%z: complex<f32>, %w: complex<f32>) -> complex<f32> { return %r : complex<f32> } +//CHECK-LABEL: @powi_caller +//CHECK: (%[[Z:.*]]: complex<f32>, %[[N:.*]]: i32) +func.func @powi_caller(%z: complex<f32>, %n: i32) -> complex<f32> { + // CHECK: %[[N_FP:.*]] = arith.sitofp %[[N]] : i32 to f32 + // CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[N_COMPLEX:.*]] = complex.create %[[N_FP]], %[[ZERO]] : complex<f32> + // CHECK: %[[LOG:.*]] = call @__ocml_clog_f32(%[[Z]]) : (complex<f32>) -> complex<f32> + // CHECK: %[[MUL:.*]] = complex.mul %[[N_COMPLEX]], %[[LOG]] : complex<f32> + // CHECK: %[[EXP:.*]] = call @__ocml_cexp_f32(%[[MUL]]) : (complex<f32>) -> complex<f32> + // CHECK: return %[[EXP]] : complex<f32> + %r = complex.powi %z, %n : complex<f32>, i32 + return %r : complex<f32> +} + //CHECK-LABEL: @sin_caller func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}}) diff --git a/mlir/test/Dialect/Complex/powi-simplify.mlir b/mlir/test/Dialect/Complex/powi-simplify.mlir new file mode 100644 index 0000000000000..c7bb6a9d81479 --- /dev/null +++ b/mlir/test/Dialect/Complex/powi-simplify.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-opt %s -test-math-algebraic-simplification | FileCheck %s + +func.func @pow3(%arg0: complex<f32>) -> complex<f32> { + %c3 = arith.constant 3 : i32 + %0 = complex.powi %arg0, %c3 : complex<f32>, i32 + return %0 : complex<f32> +} +// CHECK-LABEL: func.func @pow3( +// CHECK-NOT: complex.powi +// CHECK: %[[M0:.+]] = complex.mul %{{.*}}, %{{.*}} : complex<f32> +// CHECK: %[[M1:.+]] = complex.mul %[[M0]], %{{.*}} : complex<f32> +// CHECK: return %[[M1]] : complex<f32> + +func.func @pow9(%arg0: complex<f32>) -> complex<f32> { + %c9 = arith.constant 9 : i32 + %0 = complex.powi %arg0, %c9 : complex<f32>, i32 + return %0 : complex<f32> +} +// CHECK-LABEL: func.func @pow9( +// CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32 >From c8d369af21012d0d930110a34a8778372b313609 Mon Sep 17 00:00:00 2001 From: Akash Banerjee <akash.baner...@amd.com> Date: Mon, 15 Sep 2025 20:47:37 +0100 Subject: [PATCH 2/3] Fix clang-format. --- .../ComplexToROCDLLibraryCalls.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 25e5ab49cdb8c..f162adfdb64c2 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -96,9 +96,8 @@ struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> { Value exponent = rewriter.create<complex::CreateOp>( loc, op.getLhs().getType(), exponentReal, zeroImag); - rewriter - .replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(), - exponent); + rewriter.replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(), + exponent); return success(); } }; >From 45a13315225621be42672a98e72bd5972c0d6b62 Mon Sep 17 00:00:00 2001 From: Akash Banerjee <akash.baner...@amd.com> Date: Mon, 15 Sep 2025 20:57:37 +0100 Subject: [PATCH 3/3] Remove unused function. --- flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp index 42f5df160798c..ba3edabf26c04 100644 --- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp +++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp @@ -47,13 +47,6 @@ static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc, return func; } -static bool isZero(Value v) { - if (auto cst = v.getDefiningOp<arith::ConstantOp>()) - if (auto attr = dyn_cast<FloatAttr>(cst.getValue())) - return attr.getValue().isZero(); - return false; -} - void ConvertComplexPowPass::runOnOperation() { ModuleOp mod = getOperation(); if (fir::getTargetTriple(mod).isAMDGCN()) _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits