================
@@ -0,0 +1,279 @@
+//===-- MathToNVVM.cpp - conversion from Math to NVVM libdevice calls ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToNVVM/MathToNVVM.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTONVVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-nvvm"
+
+template <typename OpTy>
+static void populateOpPatterns(const LLVMTypeConverter &converter,
+                               RewritePatternSet &patterns,
+                               PatternBenefit benefit, StringRef f32Func,
+                               StringRef f64Func, StringRef f32ApproxFunc = "",
+                               StringRef f16Func = "") {
+  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
+                                           f32ApproxFunc, f16Func,
+                                           /*i32Func=*/"", benefit);
+}
+
+template <typename OpTy>
+static void populateIntOpPatterns(const LLVMTypeConverter &converter,
+                                  RewritePatternSet &patterns,
+                                  PatternBenefit benefit, StringRef i32Func) {
+  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func,
+                                           benefit);
+}
+
+template <typename OpTy>
+static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
+                                       RewritePatternSet &patterns,
+                                       PatternBenefit benefit,
+                                       StringRef f32Func, StringRef f64Func) {
+  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "",
+                                           /*i32Func=*/"", benefit);
+}
+
+// Custom pattern for sincos since it returns two values
+struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
+  using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value input = adaptor.getOperand();
+    Type inputType = input.getType();
+    auto convertedInput = maybeExt(input, rewriter);
+    auto computeType = convertedInput.getType();
+
+    StringRef sincosFunc;
+    if (isa<Float32Type>(computeType)) {
+      const arith::FastMathFlags flag = op.getFastmath();
+      const bool useApprox =
+          mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn);
+      sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
+    } else if (isa<Float64Type>(computeType)) {
+      sincosFunc = "__nv_sincos";
+    } else {
+      return rewriter.notifyMatchFailure(op,
+                                         "unsupported operand type for 
sincos");
+    }
+
+    auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+
+    Value sinPtr, cosPtr;
+    {
+      OpBuilder::InsertionGuard guard(rewriter);
+      auto *scope =
+          op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
+      assert(scope && "Expected op to be inside automatic allocation scope");
+      rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
+      auto one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
+                                          rewriter.getI32IntegerAttr(1));
+      sinPtr =
+          LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
+      cosPtr =
+          LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
+    }
+
+    createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
+                     op);
+
+    auto sinResult = LLVM::LoadOp::create(rewriter, loc, computeType, sinPtr);
+    auto cosResult = LLVM::LoadOp::create(rewriter, loc, computeType, cosPtr);
+
+    rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
+                            maybeTrunc(cosResult, inputType, rewriter)});
+    return success();
+  }
+
+private:
+  Value maybeExt(Value operand, PatternRewriter &rewriter) const {
+    if (isa<Float16Type, BFloat16Type>(operand.getType()))
+      return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
+                                   Float32Type::get(rewriter.getContext()),
+                                   operand);
+    return operand;
+  }
+
+  Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
+    if (operand.getType() != type)
+      return LLVM::FPTruncOp::create(rewriter, operand.getLoc(), type, 
operand);
+    return operand;
+  }
+
+  void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
+                        StringRef funcName, Value input, Value sinPtr,
+                        Value cosPtr, Operation *op) const {
+    auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
+    auto ptrType = sinPtr.getType();
+
+    SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
+    auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
+
+    auto funcAttr = StringAttr::get(op->getContext(), funcName);
+    auto funcOp =
+        SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
+
+    if (!funcOp) {
+      auto parentFunc = op->getParentOfType<FunctionOpInterface>();
+      assert(parentFunc && "expected there to be a parent function");
+      OpBuilder b(parentFunc);
+
+      auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
+      funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
+    }
+
+    SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
+    LLVM::CallOp::create(rewriter, loc, funcOp, callOperands);
+  }
+};
+
+void mlir::populateMathToNVVMConversionPatterns(
----------------
Jason-Van-Beusekom wrote:

This Pr is for specifically the flang changes, the MLIR pass is being reviewed 
in https://github.com/llvm/llvm-project/pull/180058. Apologies for the 
confusion, I thought it wise to split this change into two PR's, but since I do 
not have write access I cannot do a traditional stacked PR. 

The commit 
https://github.com/llvm/llvm-project/pull/180060/changes/c3fb230fa3632ed3d976e7b86d60f81e962d7504
 has the changes I want reviewed for this PR.

https://github.com/llvm/llvm-project/pull/180060
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to