================
@@ -47,74 +47,61 @@ static func::FuncOp getOrDeclare(fir::FirOpBuilder 
&builder, Location loc,
   return func;
 }
 
-static bool isZero(Value v) {
-  if (auto cst = v.getDefiningOp<arith::ConstantOp>())
-    if (auto attr = dyn_cast<FloatAttr>(cst.getValue()))
-      return attr.getValue().isZero();
-  return false;
-}
-
 void ConvertComplexPowPass::runOnOperation() {
   ModuleOp mod = getOperation();
   fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
 
-  mod.walk([&](complex::PowOp op) {
+  mod.walk([&](complex::PowiOp op) {
     builder.setInsertionPoint(op);
     Location loc = op.getLoc();
     auto complexTy = cast<ComplexType>(op.getType());
     auto elemTy = complexTy.getElementType();
-
     Value base = op.getLhs();
-    Value rhs = op.getRhs();
-
-    Value intExp;
-    if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
-      if (isZero(create.getImaginary())) {
-        if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
-          if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
-            intExp = conv.getValue();
-        }
-      }
-    }
-
+    Value intExp = op.getRhs();
     func::FuncOp callee;
-    SmallVector<Value> args;
-    if (intExp) {
-      unsigned realBits = cast<FloatType>(elemTy).getWidth();
-      unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
-      auto funcTy = builder.getFunctionType(
-          {complexTy, builder.getIntegerType(intBits)}, {complexTy});
-      if (realBits == 32 && intBits == 32)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
-      else if (realBits == 32 && intBits == 64)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
-      else if (realBits == 64 && intBits == 32)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
-      else if (realBits == 64 && intBits == 64)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
-      else if (realBits == 128 && intBits == 32)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
-      else if (realBits == 128 && intBits == 64)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
-      else
-        return;
-      args = {base, intExp};
-    } else {
-      unsigned realBits = cast<FloatType>(elemTy).getWidth();
-      auto funcTy =
-          builder.getFunctionType({complexTy, complexTy}, {complexTy});
-      if (realBits == 32)
-        callee = getOrDeclare(builder, loc, "cpowf", funcTy);
-      else if (realBits == 64)
-        callee = getOrDeclare(builder, loc, "cpow", funcTy);
-      else if (realBits == 128)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
-      else
-        return;
-      args = {base, rhs};
-    }
+    unsigned realBits = cast<FloatType>(elemTy).getWidth();
+    unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
+    auto funcTy = builder.getFunctionType(
+        {complexTy, builder.getIntegerType(intBits)}, {complexTy});
+    if (realBits == 32 && intBits == 32)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
+    else if (realBits == 32 && intBits == 64)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
+    else if (realBits == 64 && intBits == 32)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
+    else if (realBits == 64 && intBits == 64)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
+    else if (realBits == 128 && intBits == 32)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
+    else if (realBits == 128 && intBits == 64)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
+    else
+      return;
+    auto call = fir::CallOp::create(builder, loc, callee, {base, intExp});
+    if (auto fmf = op.getFastmathAttr())
+      call.setFastmathAttr(fmf);
+    op.replaceAllUsesWith(call.getResult(0));
+    op.erase();
+  });
 
-    auto call = fir::CallOp::create(builder, loc, callee, args);
+  mod.walk([&](complex::PowOp op) {
----------------
TIFitis wrote:

I've updated this.

https://github.com/llvm/llvm-project/pull/158722
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to