https://github.com/TIFitis updated 
https://github.com/llvm/llvm-project/pull/158722

>From 8c9a156aa4f682ca836403bd71608c5aa2352d46 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <akash.baner...@amd.com>
Date: Mon, 15 Sep 2025 20:35:29 +0100
Subject: [PATCH 1/3] Add complex.powi op.

---
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 20 ++--
 .../Transforms/ConvertComplexPow.cpp          | 92 +++++++++----------
 flang/test/Lower/HLFIR/binary-ops.f90         |  2 +-
 .../test/Lower/Intrinsics/pow_complex16i.f90  |  2 +-
 .../test/Lower/Intrinsics/pow_complex16k.f90  |  2 +-
 flang/test/Lower/amdgcn-complex.f90           |  9 ++
 flang/test/Lower/power-operator.f90           |  9 +-
 flang/test/Transforms/convert-complex-pow.fir | 42 +++------
 .../mlir/Dialect/Complex/IR/ComplexOps.td     | 26 ++++++
 .../ComplexToROCDLLibraryCalls.cpp            | 41 ++++++++-
 .../Transforms/AlgebraicSimplification.cpp    | 24 +++--
 .../Dialect/Math/Transforms/CMakeLists.txt    |  1 +
 .../complex-to-rocdl-library-calls.mlir       | 14 +++
 mlir/test/Dialect/Complex/powi-simplify.mlir  | 20 ++++
 14 files changed, 198 insertions(+), 106 deletions(-)
 create mode 100644 mlir/test/Dialect/Complex/powi-simplify.mlir

diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp 
b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 466458c05dba7..74a4e8f85c8ff 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1331,14 +1331,20 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, 
mlir::Location loc,
     return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
   auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
   mlir::Value exp = args[1];
-  if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
-    auto realTy = complexTy.getElementType();
-    mlir::Value realExp = builder.createConvert(loc, realTy, exp);
-    mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
-    exp =
-        builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
+  mlir::Value result;
+  if (mlir::isa<mlir::IntegerType>(exp.getType()) ||
+      mlir::isa<mlir::IndexType>(exp.getType())) {
+    result = builder.create<mlir::complex::PowiOp>(loc, args[0], exp);
+  } else {
+    if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
+      auto realTy = complexTy.getElementType();
+      mlir::Value realExp = builder.createConvert(loc, realTy, exp);
+      mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
+      exp = builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp,
+                                                    zero);
+    }
+    result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
   }
-  mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
   result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
   return result;
 }
diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp 
b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
index dced5f90d6924..42f5df160798c 100644
--- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
+++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
@@ -61,63 +61,55 @@ void ConvertComplexPowPass::runOnOperation() {
 
   fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
 
-  mod.walk([&](complex::PowOp op) {
+  mod.walk([&](complex::PowiOp op) {
     builder.setInsertionPoint(op);
     Location loc = op.getLoc();
     auto complexTy = cast<ComplexType>(op.getType());
     auto elemTy = complexTy.getElementType();
-
     Value base = op.getLhs();
-    Value rhs = op.getRhs();
-
-    Value intExp;
-    if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
-      if (isZero(create.getImaginary())) {
-        if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
-          if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
-            intExp = conv.getValue();
-        }
-      }
-    }
-
+    Value intExp = op.getRhs();
     func::FuncOp callee;
-    SmallVector<Value> args;
-    if (intExp) {
-      unsigned realBits = cast<FloatType>(elemTy).getWidth();
-      unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
-      auto funcTy = builder.getFunctionType(
-          {complexTy, builder.getIntegerType(intBits)}, {complexTy});
-      if (realBits == 32 && intBits == 32)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
-      else if (realBits == 32 && intBits == 64)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
-      else if (realBits == 64 && intBits == 32)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
-      else if (realBits == 64 && intBits == 64)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
-      else if (realBits == 128 && intBits == 32)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
-      else if (realBits == 128 && intBits == 64)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
-      else
-        return;
-      args = {base, intExp};
-    } else {
-      unsigned realBits = cast<FloatType>(elemTy).getWidth();
-      auto funcTy =
-          builder.getFunctionType({complexTy, complexTy}, {complexTy});
-      if (realBits == 32)
-        callee = getOrDeclare(builder, loc, "cpowf", funcTy);
-      else if (realBits == 64)
-        callee = getOrDeclare(builder, loc, "cpow", funcTy);
-      else if (realBits == 128)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
-      else
-        return;
-      args = {base, rhs};
-    }
+    unsigned realBits = cast<FloatType>(elemTy).getWidth();
+    unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
+    auto funcTy = builder.getFunctionType(
+        {complexTy, builder.getIntegerType(intBits)}, {complexTy});
+    if (realBits == 32 && intBits == 32)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
+    else if (realBits == 32 && intBits == 64)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
+    else if (realBits == 64 && intBits == 32)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
+    else if (realBits == 64 && intBits == 64)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
+    else if (realBits == 128 && intBits == 32)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
+    else if (realBits == 128 && intBits == 64)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
+    else
+      return;
+    auto call = fir::CallOp::create(builder, loc, callee, {base, intExp});
+    op.replaceAllUsesWith(call.getResult(0));
+    op.erase();
+  });
 
-    auto call = fir::CallOp::create(builder, loc, callee, args);
+  mod.walk([&](complex::PowOp op) {
+    builder.setInsertionPoint(op);
+    Location loc = op.getLoc();
+    auto complexTy = cast<ComplexType>(op.getType());
+    auto elemTy = complexTy.getElementType();
+    unsigned realBits = cast<FloatType>(elemTy).getWidth();
+    func::FuncOp callee;
+    auto funcTy = builder.getFunctionType({complexTy, complexTy}, {complexTy});
+    if (realBits == 32)
+      callee = getOrDeclare(builder, loc, "cpowf", funcTy);
+    else if (realBits == 64)
+      callee = getOrDeclare(builder, loc, "cpow", funcTy);
+    else if (realBits == 128)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
+    else
+      return;
+    auto call =
+        fir::CallOp::create(builder, loc, callee, {op.getLhs(), op.getRhs()});
     op.replaceAllUsesWith(call.getResult(0));
     op.erase();
   });
diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 
b/flang/test/Lower/HLFIR/binary-ops.f90
index 1fbd333db37c3..7e1691dd1587a 100644
--- a/flang/test/Lower/HLFIR/binary-ops.f90
+++ b/flang/test/Lower/HLFIR/binary-ops.f90
@@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z)
 ! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, 
!fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
 ! CHECK:  %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
 ! CHECK:  %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
-! CHECK:  %[[VAL_8:.*]] = complex.pow
+! CHECK:  %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] : complex<f32>, 
i32
 
 subroutine extremum(c, n, l)
   integer(8), intent(in) :: l
diff --git a/flang/test/Lower/Intrinsics/pow_complex16i.f90 
b/flang/test/Lower/Intrinsics/pow_complex16i.f90
index 1827863a57f43..0b26024b02021 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16i.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16i.f90
@@ -4,7 +4,7 @@
 ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
 
 ! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> 
complex<f128>
-! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
+! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
   complex(16) :: a
   integer(4) :: b
   b = a ** b
diff --git a/flang/test/Lower/Intrinsics/pow_complex16k.f90 
b/flang/test/Lower/Intrinsics/pow_complex16k.f90
index 039dfd5152a06..90a9f5e03628d 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16k.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16k.f90
@@ -4,7 +4,7 @@
 ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
 
 ! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> 
complex<f128>
-! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
+! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
   complex(16) :: a
   integer(8) :: b
   b = a ** b
diff --git a/flang/test/Lower/amdgcn-complex.f90 
b/flang/test/Lower/amdgcn-complex.f90
index 4ee5de4d2842e..a28eaea82379b 100644
--- a/flang/test/Lower/amdgcn-complex.f90
+++ b/flang/test/Lower/amdgcn-complex.f90
@@ -25,3 +25,12 @@ subroutine pow_test(a, b, c)
    complex :: a, b, c
    a = b**c
 end subroutine pow_test
+
+! CHECK-LABEL: func @_QPpowi_test(
+! CHECK: complex.powi
+! CHECK-NOT: fir.call @_FortranAcpowi
+subroutine powi_test(a, b, c)
+   complex :: a, b
+   integer :: i
+   b = a ** i
+end subroutine powi_test
diff --git a/flang/test/Lower/power-operator.f90 
b/flang/test/Lower/power-operator.f90
index 3058927144248..9f74d172a6bb2 100644
--- a/flang/test/Lower/power-operator.f90
+++ b/flang/test/Lower/power-operator.f90
@@ -96,7 +96,7 @@ subroutine pow_c4_i4(x, y, z)
   complex :: x, z
   integer :: y
   z = x ** y
-  ! CHECK: complex.pow
+  ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32
   ! PRECISE: fir.call @_FortranAcpowi
 end subroutine
 
@@ -105,7 +105,7 @@ subroutine pow_c4_i8(x, y, z)
   complex :: x, z
   integer(8) :: y
   z = x ** y
-  ! CHECK: complex.pow
+  ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i64
   ! PRECISE: fir.call @_FortranAcpowk
 end subroutine
 
@@ -114,7 +114,7 @@ subroutine pow_c8_i4(x, y, z)
   complex(8) :: x, z
   integer :: y
   z = x ** y
-  ! CHECK: complex.pow
+  ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i32
   ! PRECISE: fir.call @_FortranAzpowi
 end subroutine
 
@@ -123,7 +123,7 @@ subroutine pow_c8_i8(x, y, z)
   complex(8) :: x, z
   integer(8) :: y
   z = x ** y
-  ! CHECK: complex.pow
+  ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i64
   ! PRECISE: fir.call @_FortranAzpowk
 end subroutine
 
@@ -142,4 +142,3 @@ subroutine pow_c8_c8(x, y, z)
   ! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64>
   ! PRECISE: fir.call @cpow
 end subroutine
-
diff --git a/flang/test/Transforms/convert-complex-pow.fir 
b/flang/test/Transforms/convert-complex-pow.fir
index d980817aba9b9..4555fea61e496 100644
--- a/flang/test/Transforms/convert-complex-pow.fir
+++ b/flang/test/Transforms/convert-complex-pow.fir
@@ -2,18 +2,12 @@
 
 module {
   func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
-    %c0 = arith.constant 0.000000e+00 : f32
-    %c1 = fir.convert %arg1 : (i32) -> f32
-    %c2 = complex.create %c1, %c0 : complex<f32>
-    %0 = complex.pow %arg0, %c2 : complex<f32>
+    %0 = complex.powi %arg0, %arg1 : complex<f32>, i32
     return %0 : complex<f32>
   }
 
   func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
-    %c0 = arith.constant 0.000000e+00 : f32
-    %c1 = fir.convert %arg1 : (i64) -> f32
-    %c2 = complex.create %c1, %c0 : complex<f32>
-    %0 = complex.pow %arg0, %c2 : complex<f32>
+    %0 = complex.powi %arg0, %arg1 : complex<f32>, i64
     return %0 : complex<f32>
   }
 
@@ -23,18 +17,12 @@ module {
   }
 
   func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> {
-    %c0 = arith.constant 0.000000e+00 : f64
-    %c1 = fir.convert %arg1 : (i32) -> f64
-    %c2 = complex.create %c1, %c0 : complex<f64>
-    %0 = complex.pow %arg0, %c2 : complex<f64>
+    %0 = complex.powi %arg0, %arg1 : complex<f64>, i32
     return %0 : complex<f64>
   }
 
   func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> {
-    %c0 = arith.constant 0.000000e+00 : f64
-    %c1 = fir.convert %arg1 : (i64) -> f64
-    %c2 = complex.create %c1, %c0 : complex<f64>
-    %0 = complex.pow %arg0, %c2 : complex<f64>
+    %0 = complex.powi %arg0, %arg1 : complex<f64>, i64
     return %0 : complex<f64>
   }
 
@@ -44,18 +32,12 @@ module {
   }
 
   func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> {
-    %c0 = arith.constant 0.000000e+00 : f128
-    %c1 = fir.convert %arg1 : (i32) -> f128
-    %c2 = complex.create %c1, %c0 : complex<f128>
-    %0 = complex.pow %arg0, %c2 : complex<f128>
+    %0 = complex.powi %arg0, %arg1 : complex<f128>, i32
     return %0 : complex<f128>
   }
 
   func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> {
-    %c0 = arith.constant 0.000000e+00 : f128
-    %c1 = fir.convert %arg1 : (i64) -> f128
-    %c2 = complex.create %c1, %c0 : complex<f128>
-    %0 = complex.pow %arg0, %c2 : complex<f128>
+    %0 = complex.powi %arg0, %arg1 : complex<f128>, i64
     return %0 : complex<f128>
   }
 
@@ -67,11 +49,11 @@ module {
 
 // CHECK-LABEL: func.func @pow_c4_i4(
 // CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> 
complex<f32>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c4_i8(
 // CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> 
complex<f32>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c4_c4(
 // CHECK: fir.call @cpowf(%{{.*}}, %{{.*}}) : (complex<f32>, complex<f32>) -> 
complex<f32>
@@ -79,11 +61,11 @@ module {
 
 // CHECK-LABEL: func.func @pow_c8_i4(
 // CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> 
complex<f64>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c8_i8(
 // CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> 
complex<f64>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c8_c8(
 // CHECK: fir.call @cpow(%{{.*}}, %{{.*}}) : (complex<f64>, complex<f64>) -> 
complex<f64>
@@ -91,11 +73,11 @@ module {
 
 // CHECK-LABEL: func.func @pow_c16_i4(
 // CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) 
-> complex<f128>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c16_i8(
 // CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) 
-> complex<f128>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c16_c16(
 // CHECK: fir.call @_FortranACPowF128(%{{.*}}, %{{.*}}) : (complex<f128>, 
complex<f128>) -> complex<f128>
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td 
b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 44590406301eb..ca5103c16889c 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -443,6 +443,32 @@ def PowOp : ComplexArithmeticOp<"pow"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// PowiOp
+//===----------------------------------------------------------------------===//
+
+def PowiOp : Complex_Op<"powi",
+    [Pure, Elementwise, SameOperandsAndResultShape,
+     AllTypesMatch<["lhs", "result"]>]> {
+  let summary = "complex number raised to integer power";
+  let description = [{
+    The `powi` operation takes a complex number and an integer exponent.
+
+    Example:
+
+    ```mlir
+    %a = complex.powi %b, %c : complex<f32>, i32
+    ```
+  }];
+
+  let arguments = (ins Complex<AnyFloat>:$lhs,
+                       AnySignlessInteger:$rhs);
+  let results = (outs Complex<AnyFloat>:$result);
+
+  let assemblyFormat =
+      "$lhs `,` $rhs attr-dict `:` type($result) `,` type($rhs)";
+}
+
 
//===----------------------------------------------------------------------===//
 // ReOp
 
//===----------------------------------------------------------------------===//
diff --git 
a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp 
b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 0372f32d6b6df..25e5ab49cdb8c 100644
--- 
a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ 
b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -7,9 +7,11 @@
 
//===----------------------------------------------------------------------===//
 
 #include 
"mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
@@ -71,10 +73,40 @@ struct PowOpToROCDLLibraryCalls : public 
OpRewritePattern<complex::PowOp> {
     return success();
   }
 };
+
+// Rewrite complex.powi(z, n) -> complex.pow(z, complex(float(n), 0))
+struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> {
+  using OpRewritePattern<complex::PowiOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(complex::PowiOp op,
+                                PatternRewriter &rewriter) const final {
+    auto complexType = cast<ComplexType>(getElementTypeOrSelf(op.getType()));
+    Type elementType = complexType.getElementType();
+
+    Type exponentType = op.getRhs().getType();
+    Type exponentFloatType = elementType;
+    if (auto shapedType = dyn_cast<ShapedType>(exponentType))
+      exponentFloatType = shapedType.cloneWith(std::nullopt, elementType);
+
+    Location loc = op.getLoc();
+    Value exponentReal =
+        rewriter.create<arith::SIToFPOp>(loc, exponentFloatType, op.getRhs());
+    Value zeroImag = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(exponentFloatType));
+    Value exponent = rewriter.create<complex::CreateOp>(
+        loc, op.getLhs().getType(), exponentReal, zeroImag);
+
+    rewriter
+        .replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
+                                            exponent);
+    return success();
+  }
+};
 } // namespace
 
 void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
     RewritePatternSet &patterns) {
+  patterns.add<PowiOpToROCDLLibraryCalls>(patterns.getContext());
   patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
   patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
       patterns.getContext(), "__ocml_cabs_f32");
@@ -125,11 +157,12 @@ void 
ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
   populateComplexToROCDLLibraryCallsConversionPatterns(patterns);
 
   ConversionTarget target(getContext());
-  target.addLegalDialect<func::FuncDialect>();
-  target.addLegalOp<complex::MulOp>();
+  target.addLegalDialect<arith::ArithDialect, func::FuncDialect>();
+  target.addLegalOp<complex::CreateOp, complex::MulOp>();
   target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
-                      complex::LogOp, complex::PowOp, complex::SinOp,
-                      complex::SqrtOp, complex::TanOp, complex::TanhOp>();
+                      complex::LogOp, complex::PowOp, complex::PowiOp,
+                      complex::SinOp, complex::SqrtOp, complex::TanOp,
+                      complex::TanhOp>();
   if (failed(applyPartialConversion(op, target, std::move(patterns))))
     signalPassFailure();
 }
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp 
b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 31785eb20a642..3711c112cc631 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -13,6 +13,7 @@
 
//===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -175,12 +176,20 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, 
MulOpTy>::matchAndRewrite(
 
   Value one;
   Type opType = getElementTypeOrSelf(op.getType());
-  if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
+  if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) {
     one = arith::ConstantOp::create(rewriter, loc,
                                     rewriter.getFloatAttr(opType, 1.0));
-  else
+  } else if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>) {
+    auto complexTy = cast<ComplexType>(opType);
+    Type elementType = complexTy.getElementType();
+    auto realPart = rewriter.getFloatAttr(elementType, 1.0);
+    auto imagPart = rewriter.getFloatAttr(elementType, 0.0);
+    one = rewriter.create<complex::ConstantOp>(
+        loc, complexTy, rewriter.getArrayAttr({realPart, imagPart}));
+  } else {
     one = arith::ConstantOp::create(rewriter, loc,
                                     rewriter.getIntegerAttr(opType, 1));
+  }
 
   // Replace `[fi]powi(x, 0)` with `1`.
   if (exponentValue == 0) {
@@ -224,9 +233,10 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, 
MulOpTy>::matchAndRewrite(
 
 void mlir::populateMathAlgebraicSimplificationPatterns(
     RewritePatternSet &patterns) {
-  patterns
-      .add<PowFStrengthReduction,
-           PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
-           PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>(
-          patterns.getContext());
+  patterns.add<
+      PowFStrengthReduction,
+      PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
+      PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>,
+      PowIStrengthReduction<complex::PowiOp, complex::DivOp, complex::MulOp>>(
+      patterns.getContext(), /*exponentThreshold=*/8);
 }
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt 
b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index d37a056e8e158..ff62b515533c3 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRMathTransforms
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
+  MLIRComplexDialect
   MLIRDialectUtils
   MLIRIR
   MLIRMathDialect
diff --git 
a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
 
b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index 080ba4f0ff67b..cf177528e532c 100644
--- 
a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
+++ 
b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -68,6 +68,20 @@ func.func @pow_caller(%z: complex<f32>, %w: complex<f32>) -> 
complex<f32> {
   return %r : complex<f32>
 }
 
+//CHECK-LABEL: @powi_caller
+//CHECK:          (%[[Z:.*]]: complex<f32>, %[[N:.*]]: i32)
+func.func @powi_caller(%z: complex<f32>, %n: i32) -> complex<f32> {
+  // CHECK: %[[N_FP:.*]] = arith.sitofp %[[N]] : i32 to f32
+  // CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[N_COMPLEX:.*]] = complex.create %[[N_FP]], %[[ZERO]] : 
complex<f32>
+  // CHECK: %[[LOG:.*]] = call @__ocml_clog_f32(%[[Z]]) : (complex<f32>) -> 
complex<f32>
+  // CHECK: %[[MUL:.*]] = complex.mul %[[N_COMPLEX]], %[[LOG]] : complex<f32>
+  // CHECK: %[[EXP:.*]] = call @__ocml_cexp_f32(%[[MUL]]) : (complex<f32>) -> 
complex<f32>
+  // CHECK: return %[[EXP]] : complex<f32>
+  %r = complex.powi %z, %n : complex<f32>, i32
+  return %r : complex<f32>
+}
+
 //CHECK-LABEL: @sin_caller
 func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, 
complex<f64>) {
   // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})
diff --git a/mlir/test/Dialect/Complex/powi-simplify.mlir 
b/mlir/test/Dialect/Complex/powi-simplify.mlir
new file mode 100644
index 0000000000000..c7bb6a9d81479
--- /dev/null
+++ b/mlir/test/Dialect/Complex/powi-simplify.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -test-math-algebraic-simplification | FileCheck %s
+
+func.func @pow3(%arg0: complex<f32>) -> complex<f32> {
+  %c3 = arith.constant 3 : i32
+  %0 = complex.powi %arg0, %c3 : complex<f32>, i32
+  return %0 : complex<f32>
+}
+// CHECK-LABEL: func.func @pow3(
+// CHECK-NOT: complex.powi
+// CHECK: %[[M0:.+]] = complex.mul %{{.*}}, %{{.*}} : complex<f32>
+// CHECK: %[[M1:.+]] = complex.mul %[[M0]], %{{.*}} : complex<f32>
+// CHECK: return %[[M1]] : complex<f32>
+
+func.func @pow9(%arg0: complex<f32>) -> complex<f32> {
+  %c9 = arith.constant 9 : i32
+  %0 = complex.powi %arg0, %c9 : complex<f32>, i32
+  return %0 : complex<f32>
+}
+// CHECK-LABEL: func.func @pow9(
+// CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32

>From c8d369af21012d0d930110a34a8778372b313609 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <akash.baner...@amd.com>
Date: Mon, 15 Sep 2025 20:47:37 +0100
Subject: [PATCH 2/3] Fix clang-format.

---
 .../ComplexToROCDLLibraryCalls.cpp                           | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git 
a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp 
b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 25e5ab49cdb8c..f162adfdb64c2 100644
--- 
a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ 
b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -96,9 +96,8 @@ struct PowiOpToROCDLLibraryCalls : public 
OpRewritePattern<complex::PowiOp> {
     Value exponent = rewriter.create<complex::CreateOp>(
         loc, op.getLhs().getType(), exponentReal, zeroImag);
 
-    rewriter
-        .replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
-                                            exponent);
+    rewriter.replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
+                                                exponent);
     return success();
   }
 };

>From 45a13315225621be42672a98e72bd5972c0d6b62 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <akash.baner...@amd.com>
Date: Mon, 15 Sep 2025 20:57:37 +0100
Subject: [PATCH 3/3] Remove unused function.

---
 flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp | 7 -------
 1 file changed, 7 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp 
b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
index 42f5df160798c..ba3edabf26c04 100644
--- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
+++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
@@ -47,13 +47,6 @@ static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, 
Location loc,
   return func;
 }
 
-static bool isZero(Value v) {
-  if (auto cst = v.getDefiningOp<arith::ConstantOp>())
-    if (auto attr = dyn_cast<FloatAttr>(cst.getValue()))
-      return attr.getValue().isZero();
-  return false;
-}
-
 void ConvertComplexPowPass::runOnOperation() {
   ModuleOp mod = getOperation();
   if (fir::getTargetTriple(mod).isAMDGCN())

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to