================ @@ -0,0 +1,279 @@ +//===-- MathToNVVM.cpp - conversion from Math to NVVM libdevice calls ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MathToNVVM/MathToNVVM.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/Pass/Pass.h" + +#include "../GPUCommon/GPUOpsLowering.h" +#include "../GPUCommon/OpToFuncCallLowering.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTMATHTONVVM +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +#define DEBUG_TYPE "math-to-nvvm" + +template <typename OpTy> +static void populateOpPatterns(const LLVMTypeConverter &converter, + RewritePatternSet &patterns, + PatternBenefit benefit, StringRef f32Func, + StringRef f64Func, StringRef f32ApproxFunc = "", + StringRef f16Func = "") { + patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit); + patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, + f32ApproxFunc, f16Func, + /*i32Func=*/"", benefit); +} + +template <typename OpTy> +static void populateIntOpPatterns(const LLVMTypeConverter &converter, + RewritePatternSet &patterns, + PatternBenefit benefit, StringRef i32Func) { + patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit); + patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func, + benefit); +} + +template <typename OpTy> +static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter, + RewritePatternSet &patterns, + PatternBenefit benefit, + StringRef f32Func, StringRef f64Func) { + patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit); + patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "", + /*i32Func=*/"", benefit); +} + +// Custom pattern for sincos since it returns two values +struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> { + using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = adaptor.getOperand(); + Type inputType = input.getType(); + auto convertedInput = maybeExt(input, rewriter); + auto computeType = convertedInput.getType(); + + StringRef sincosFunc; + if (isa<Float32Type>(computeType)) { + const arith::FastMathFlags flag = op.getFastmath(); + const bool useApprox = + mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn); + sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf"; + } else if (isa<Float64Type>(computeType)) { + sincosFunc = "__nv_sincos"; + } else { + return rewriter.notifyMatchFailure(op, + "unsupported operand type for sincos"); + } + + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); + + Value sinPtr, cosPtr; + { + OpBuilder::InsertionGuard guard(rewriter); + auto *scope = + op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>(); + assert(scope && "Expected op to be inside automatic allocation scope"); + rewriter.setInsertionPointToStart(&scope->getRegion(0).front()); + auto one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(1)); + sinPtr = + LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0); + cosPtr = + LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0); + } + + createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr, + op); + + auto sinResult = LLVM::LoadOp::create(rewriter, loc, computeType, sinPtr); + auto cosResult = LLVM::LoadOp::create(rewriter, loc, computeType, cosPtr); + + rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter), + maybeTrunc(cosResult, inputType, rewriter)}); + return success(); + } + +private: + Value maybeExt(Value operand, PatternRewriter &rewriter) const { + if (isa<Float16Type, BFloat16Type>(operand.getType())) + return LLVM::FPExtOp::create(rewriter, operand.getLoc(), + Float32Type::get(rewriter.getContext()), + operand); + return operand; + } + + Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const { + if (operand.getType() != type) + return LLVM::FPTruncOp::create(rewriter, operand.getLoc(), type, operand); + return operand; + } + + void createSincosCall(ConversionPatternRewriter &rewriter, Location loc, + StringRef funcName, Value input, Value sinPtr, + Value cosPtr, Operation *op) const { + auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext()); + auto ptrType = sinPtr.getType(); + + SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType}; + auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes); + + auto funcAttr = StringAttr::get(op->getContext(), funcName); + auto funcOp = + SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr); + + if (!funcOp) { + auto parentFunc = op->getParentOfType<FunctionOpInterface>(); + assert(parentFunc && "expected there to be a parent function"); + OpBuilder b(parentFunc); + + auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>(); + funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType); + } + + SmallVector<Value> callOperands = {input, sinPtr, cosPtr}; + LLVM::CallOp::create(rewriter, loc, funcOp, callOperands); + } +}; + +void mlir::populateMathToNVVMConversionPatterns( ---------------- Jason-Van-Beusekom wrote:
This Pr is for specifically the flang changes, the MLIR pass is being reviewed in https://github.com/llvm/llvm-project/pull/180058. Apologies for the confusion, I thought it wise to split this change into two PR's, but since I do not have write access I cannot do a traditional stacked PR. The commit https://github.com/llvm/llvm-project/pull/180060/changes/c3fb230fa3632ed3d976e7b86d60f81e962d7504 has the changes I want reviewed for this PR. https://github.com/llvm/llvm-project/pull/180060 _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
