================ @@ -47,74 +47,61 @@ static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc, return func; } -static bool isZero(Value v) { - if (auto cst = v.getDefiningOp<arith::ConstantOp>()) - if (auto attr = dyn_cast<FloatAttr>(cst.getValue())) - return attr.getValue().isZero(); - return false; -} - void ConvertComplexPowPass::runOnOperation() { ModuleOp mod = getOperation(); fir::FirOpBuilder builder(mod, fir::getKindMapping(mod)); - mod.walk([&](complex::PowOp op) { + mod.walk([&](complex::PowiOp op) { builder.setInsertionPoint(op); Location loc = op.getLoc(); auto complexTy = cast<ComplexType>(op.getType()); auto elemTy = complexTy.getElementType(); - Value base = op.getLhs(); - Value rhs = op.getRhs(); - - Value intExp; - if (auto create = rhs.getDefiningOp<complex::CreateOp>()) { - if (isZero(create.getImaginary())) { - if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) { - if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType())) - intExp = conv.getValue(); - } - } - } - + Value intExp = op.getRhs(); func::FuncOp callee; - SmallVector<Value> args; - if (intExp) { - unsigned realBits = cast<FloatType>(elemTy).getWidth(); - unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth(); - auto funcTy = builder.getFunctionType( - {complexTy, builder.getIntegerType(intBits)}, {complexTy}); - if (realBits == 32 && intBits == 32) - callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy); - else if (realBits == 32 && intBits == 64) - callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy); - else if (realBits == 64 && intBits == 32) - callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy); - else if (realBits == 64 && intBits == 64) - callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy); - else if (realBits == 128 && intBits == 32) - callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy); - else if (realBits == 128 && intBits == 64) - callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy); - else - return; - args = {base, intExp}; - } else { - unsigned realBits = cast<FloatType>(elemTy).getWidth(); - auto funcTy = - builder.getFunctionType({complexTy, complexTy}, {complexTy}); - if (realBits == 32) - callee = getOrDeclare(builder, loc, "cpowf", funcTy); - else if (realBits == 64) - callee = getOrDeclare(builder, loc, "cpow", funcTy); - else if (realBits == 128) - callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy); - else - return; - args = {base, rhs}; - } + unsigned realBits = cast<FloatType>(elemTy).getWidth(); + unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth(); + auto funcTy = builder.getFunctionType( + {complexTy, builder.getIntegerType(intBits)}, {complexTy}); + if (realBits == 32 && intBits == 32) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy); + else if (realBits == 32 && intBits == 64) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy); + else if (realBits == 64 && intBits == 32) + callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy); + else if (realBits == 64 && intBits == 64) + callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy); + else if (realBits == 128 && intBits == 32) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy); + else if (realBits == 128 && intBits == 64) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy); + else + return; + auto call = fir::CallOp::create(builder, loc, callee, {base, intExp}); + if (auto fmf = op.getFastmathAttr()) + call.setFastmathAttr(fmf); + op.replaceAllUsesWith(call.getResult(0)); + op.erase(); + }); - auto call = fir::CallOp::create(builder, loc, callee, args); + mod.walk([&](complex::PowOp op) { ---------------- TIFitis wrote:
I've updated this. https://github.com/llvm/llvm-project/pull/158722 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits