llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-flang-openmp Author: Jason Van Beusekom (Jason-Van-Beusekom) <details> <summary>Changes</summary> This Commit adds the MLIR MathToNVVM conversion pass to flang's NVPTX codegen lowering Math and Arith operations to libdevice library calls. This allows support for calls to Fortran math intrinsics for OpenMP offload for NVIDIA Targets. This is PR (2/2) to fix https://github.com/llvm/llvm-project/issues/147023 and https://github.com/llvm/llvm-project/issues/179347. Commit: 1f39bcf0643a73c840c39b7ee8136d21f4e5cd7e is not apart of this PR and is reviewed in: https://github.com/llvm/llvm-project/pull/180058 --- Patch is 45.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/180060.diff 14 Files Affected: - (modified) clang/lib/Driver/ToolChains/Flang.cpp (+30) - (modified) clang/lib/Driver/ToolChains/Flang.h (+3) - (modified) flang/lib/Optimizer/CodeGen/CMakeLists.txt (+1) - (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+7-1) - (added) flang/test/Lower/OpenMP/math-nvptx.f90 (+338) - (modified) mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h (-6) - (added) mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h (+28) - (modified) mlir/include/mlir/Conversion/Passes.h (+1) - (modified) mlir/include/mlir/Conversion/Passes.td (+14) - (modified) mlir/lib/Conversion/CMakeLists.txt (+1) - (modified) mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt (+1) - (modified) mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp (+2-218) - (added) mlir/lib/Conversion/MathToNVVM/CMakeLists.txt (+26) - (added) mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp (+279) ``````````diff diff --git a/clang/lib/Driver/ToolChains/Flang.cpp b/clang/lib/Driver/ToolChains/Flang.cpp index 8425f8fec62a4..2ee6dbb85e132 100644 --- a/clang/lib/Driver/ToolChains/Flang.cpp +++ b/clang/lib/Driver/ToolChains/Flang.cpp @@ -8,6 +8,7 @@ #include "Flang.h" #include "Arch/RISCV.h" +#include "Cuda.h" #include "clang/Basic/CodeGenOptions.h" #include "clang/Driver/CommonArgs.h" @@ -519,6 +520,31 @@ void Flang::AddAMDGPUTargetArgs(const ArgList &Args, TC.addClangTargetOptions(Args, CmdArgs, Action::OffloadKind::OFK_OpenMP); } +void Flang::AddNVPTXTargetArgs(const ArgList &Args, + ArgStringList &CmdArgs) const { + + // we cannot use addClangTargetOptions, as it appends unsupported args for + // flang: -fcuda-is-device, -fno-threadsafe-statics, + // -fcuda-allow-variadic-functions and -target-sdk-version Instead we manually + // detect the CUDA installation and link libdevice + const ToolChain &TC = getToolChain(); + const Driver &D = TC.getDriver(); + const llvm::Triple &Triple = TC.getEffectiveTriple(); + + // Detect CUDA installation and link libdevice + CudaInstallationDetector CudaInstallation(D, Triple, Args); + if (CudaInstallation.isValid()) { + StringRef GpuArch = Args.getLastArgValue(options::OPT_march_EQ); + if (!GpuArch.empty()) { + std::string LibDeviceFile = CudaInstallation.getLibDeviceFile(GpuArch); + if (!LibDeviceFile.empty()) { + CmdArgs.push_back("-mlink-builtin-bitcode"); + CmdArgs.push_back(Args.MakeArgString(LibDeviceFile)); + } + } + } +} + void Flang::addTargetOptions(const ArgList &Args, ArgStringList &CmdArgs) const { const ToolChain &TC = getToolChain(); @@ -548,6 +574,10 @@ void Flang::addTargetOptions(const ArgList &Args, getTargetFeatures(D, Triple, Args, CmdArgs, /*ForAs*/ false); AddAMDGPUTargetArgs(Args, CmdArgs); break; + case llvm::Triple::nvptx: + case llvm::Triple::nvptx64: + AddNVPTXTargetArgs(Args, CmdArgs); + break; case llvm::Triple::riscv64: getTargetFeatures(D, Triple, Args, CmdArgs, /*ForAs*/ false); AddRISCVTargetArgs(Args, CmdArgs); diff --git a/clang/lib/Driver/ToolChains/Flang.h b/clang/lib/Driver/ToolChains/Flang.h index c0837b80c032e..62d2c6bb2a093 100644 --- a/clang/lib/Driver/ToolChains/Flang.h +++ b/clang/lib/Driver/ToolChains/Flang.h @@ -78,6 +78,9 @@ class LLVM_LIBRARY_VISIBILITY Flang : public Tool { void AddAMDGPUTargetArgs(const llvm::opt::ArgList &Args, llvm::opt::ArgStringList &CmdArgs) const; + void AddNVPTXTargetArgs(const llvm::opt::ArgList &Args, + llvm::opt::ArgStringList &CmdArgs) const; + /// Add specific options for LoongArch64 target. /// /// \param [in] Args The list of input driver arguments diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt index d5ea3c7a8e282..6977c737199e2 100644 --- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt +++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt @@ -41,6 +41,7 @@ add_flang_library(FIRCodeGen MLIRMathToFuncs MLIRMathToLLVM MLIRMathToLibm + MLIRMathToNVVM MLIRMathToROCDL MLIROpenMPToLLVM MLIROpenACCDialect diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 625701725003f..a01f1bf5034ab 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -44,6 +44,7 @@ #include "mlir/Conversion/MathToFuncs/MathToFuncs.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MathToLibm/MathToLibm.h" +#include "mlir/Conversion/MathToNVVM/MathToNVVM.h" #include "mlir/Conversion/MathToROCDL/MathToROCDL.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" @@ -4336,6 +4337,7 @@ class FIRToLLVMLowering mlir::OpPassManager mathConversionPM("builtin.module"); bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN(); + bool isNVPTX = fir::getTargetTriple(mod).isNVPTX(); // If compiling for AMD target some math operations must be lowered to AMD // GPU library calls, the rest can be converted to LLVM intrinsics, which // is handled in the mathToLLVM conversion. The lowering to libm calls is @@ -4344,6 +4346,10 @@ class FIRToLLVMLowering mathConversionPM.addPass(mlir::createConvertMathToROCDL()); mathConversionPM.addPass(mlir::createConvertComplexToROCDLLibraryCalls()); } + // If compiling for NVIDIA target some math operations must be lowered to + // NVVM libdevice calls. + if (isNVPTX) + mathConversionPM.addPass(mlir::createConvertMathToNVVM()); // Convert math::FPowI operations to inline implementation // only if the exponent's width is greater than 32, otherwise, @@ -4398,7 +4404,7 @@ class FIRToLLVMLowering mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, pattern); // Math operations that have not been converted yet must be converted // to Libm. - if (!isAMDGCN) + if (!isAMDGCN && !isNVPTX) mlir::populateMathToLibmConversionPatterns(pattern); mlir::populateComplexToLLVMConversionPatterns(typeConverter, pattern); mlir::index::populateIndexToLLVMConversionPatterns(typeConverter, pattern); diff --git a/flang/test/Lower/OpenMP/math-nvptx.f90 b/flang/test/Lower/OpenMP/math-nvptx.f90 new file mode 100644 index 0000000000000..8569e26b759e8 --- /dev/null +++ b/flang/test/Lower/OpenMP/math-nvptx.f90 @@ -0,0 +1,338 @@ +!REQUIRES: nvptx-registered-target +!RUN: %flang_fc1 -triple nvptx64-nvidia-cuda -emit-llvm -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s + +subroutine omp_pow_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_powf(float {{.*}}, float {{.*}}) + y = x ** x +end subroutine omp_pow_f32 + +subroutine omp_pow_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_pow(double {{.*}}, double {{.*}}) + y = x ** x +end subroutine omp_pow_f64 + +subroutine omp_sin_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_sinf(float {{.*}}) + y = sin(x) +end subroutine omp_sin_f32 + +subroutine omp_sin_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_sin(double {{.*}}) + y = sin(x) +end subroutine omp_sin_f64 + +subroutine omp_abs_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_fabsf(float {{.*}}) + y = abs(x) +end subroutine omp_abs_f32 + +subroutine omp_abs_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_fabs(double {{.*}}) + y = abs(x) +end subroutine omp_abs_f64 + +subroutine omp_atan_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_atanf(float {{.*}}) + y = atan(x) +end subroutine omp_atan_f32 + +subroutine omp_atan_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_atan(double {{.*}}) + y = atan(x) +end subroutine omp_atan_f64 + +subroutine omp_atanh_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_atanhf(float {{.*}}) + y = atanh(x) +end subroutine omp_atanh_f32 + +subroutine omp_atanh_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_atanh(double {{.*}}) + y = atanh(x) +end subroutine omp_atanh_f64 + +subroutine omp_atan2_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_atan2f(float {{.*}}, float {{.*}}) + y = atan2(x, x) +end subroutine omp_atan2_f32 + +subroutine omp_atan2_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_atan2(double {{.*}}, double {{.*}}) + y = atan2(x, x) +end subroutine omp_atan2_f64 + +subroutine omp_cos_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_cosf(float {{.*}}) + y = cos(x) +end subroutine omp_cos_f32 + +subroutine omp_cos_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_cos(double {{.*}}) + y = cos(x) +end subroutine omp_cos_f64 + +subroutine omp_erf_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_erff(float {{.*}}) + y = erf(x) +end subroutine omp_erf_f32 + +subroutine omp_erf_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_erf(double {{.*}}) + y = erf(x) +end subroutine omp_erf_f64 + +subroutine omp_erfc_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_erfcf(float {{.*}}) + y = erfc(x) +end subroutine omp_erfc_f32 + +subroutine omp_erfc_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_erfc(double {{.*}}) + y = erfc(x) +end subroutine omp_erfc_f64 + +subroutine omp_exp_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_expf(float {{.*}}) + y = exp(x) +end subroutine omp_exp_f32 + +subroutine omp_exp_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_exp(double {{.*}}) + y = exp(x) +end subroutine omp_exp_f64 + +subroutine omp_log_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_logf(float {{.*}}) + y = log(x) +end subroutine omp_log_f32 + +subroutine omp_log_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_log(double {{.*}}) + y = log(x) +end subroutine omp_log_f64 + +subroutine omp_log10_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_log10f(float {{.*}}) + y = log10(x) +end subroutine omp_log10_f32 + +subroutine omp_log10_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_log10(double {{.*}}) + y = log10(x) +end subroutine omp_log10_f64 + +subroutine omp_sqrt_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_sqrtf(float {{.*}}) + y = sqrt(x) +end subroutine omp_sqrt_f32 + +subroutine omp_sqrt_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_sqrt(double {{.*}}) + y = sqrt(x) +end subroutine omp_sqrt_f64 + +subroutine omp_tan_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_tanf(float {{.*}}) + y = tan(x) +end subroutine omp_tan_f32 + +subroutine omp_tan_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_tan(double {{.*}}) + y = tan(x) +end subroutine omp_tan_f64 + +subroutine omp_tanh_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_tanhf(float {{.*}}) + y = tanh(x) +end subroutine omp_tanh_f32 + +subroutine omp_tanh_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_tanh(double {{.*}}) + y = tanh(x) +end subroutine omp_tanh_f64 + +subroutine omp_acos_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_acosf(float {{.*}}) + y = acos(x) +end subroutine omp_acos_f32 + +subroutine omp_acos_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_acos(double {{.*}}) + y = acos(x) +end subroutine omp_acos_f64 + +subroutine omp_acosh_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_acoshf(float {{.*}}) + y = acosh(x) +end subroutine omp_acosh_f32 + +subroutine omp_acosh_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_acosh(double {{.*}}) + y = acosh(x) +end subroutine omp_acosh_f64 + +subroutine omp_asin_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_asinf(float {{.*}}) + y = asin(x) +end subroutine omp_asin_f32 + +subroutine omp_asin_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_asin(double {{.*}}) + y = asin(x) +end subroutine omp_asin_f64 + +subroutine omp_asinh_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_asinhf(float {{.*}}) + y = asinh(x) +end subroutine omp_asinh_f32 + +subroutine omp_asinh_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_asinh(double {{.*}}) + y = asinh(x) +end subroutine omp_asinh_f64 + +subroutine omp_cosh_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_coshf(float {{.*}}) + y = cosh(x) +end subroutine omp_cosh_f32 + +subroutine omp_cosh_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_cosh(double {{.*}}) + y = cosh(x) +end subroutine omp_cosh_f64 + +subroutine omp_sinh_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_sinhf(float {{.*}}) + y = sinh(x) +end subroutine omp_sinh_f32 + +subroutine omp_sinh_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_sinh(double {{.*}}) + y = sinh(x) +end subroutine omp_sinh_f64 + +subroutine omp_ceiling_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_ceilf(float {{.*}}) + y = ceiling(x) +end subroutine omp_ceiling_f32 + +subroutine omp_ceiling_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_ceil(double {{.*}}) + y = ceiling(x) +end subroutine omp_ceiling_f64 + +subroutine omp_floor_f32(x, y) +!$omp declare target + real :: x, y +!CHECK: call float @__nv_floorf(float {{.*}}) + y = floor(x) +end subroutine omp_floor_f32 + +subroutine omp_floor_f64(x, y) +!$omp declare target + real(8) :: x, y +!CHECK: call double @__nv_floor(double {{.*}}) + y = floor(x) +end subroutine omp_floor_f64 + +subroutine omp_sign_f32(x, y) +!$omp declare target + real :: x, y, z +!CHECK: call float @__nv_copysignf(float {{.*}}, float {{.*}}) + y = sign(x, z) +end subroutine omp_sign_f32 + +subroutine omp_sign_f64(x, y) +!$omp declare target + real(8) :: x, y, z +!CHECK: call double @__nv_copysign(double {{.*}}, double {{.*.}}) + y = sign(x, z) +end subroutine omp_sign_f64 diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h index 48982ac6efe7c..9d85f04b40c72 100644 --- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -36,12 +36,6 @@ void configureGpuToNVVMConversionLegality(ConversionTarget &target); /// GPU dialect to NVVM. void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter); -/// Populate patterns that lower certain arith and math dialect ops to -/// libdevice calls. -void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter, - RewritePatternSet &patterns, - PatternBenefit benefit = 1); - /// Collect a set of patterns to convert from the GPU dialect to NVVM. void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, diff --git a/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h b/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h new file mode 100644 index 0000000000000..e0e2b2c2e08c3 --- /dev/null +++ b/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h @@ -0,0 +1,28 @@ +//===- MathToNVVM.h - Utils to convert from the Math dialect to NVVM -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_MATHTONVVM_MATHTONVVM_H_ +#define MLIR_CONVERSION_MATHTONVVM_MATHTONVVM_H_ + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include <memory> + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTMATHTONVVM +#include "mlir/Conversion/Passes.h.inc" + +/// Populate the given list with patterns that convert from Math to NVVM +/// libdevice calls. +void populateMathToNVVMConversionPatterns(const LLVMTypeConverter &converter, + RewritePatternSet &patterns, + PatternBenefit benefit = 1); +} // namespace mlir + +#endif // MLIR_CONVERSION_MATHTONVVM_MATHTONVVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 7c2b450ca6710..a54b98004c3b6 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -49,6 +49,7 @@ #include "mlir/Conversion/MathToFuncs/MathToFuncs.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MathToLibm/MathToLibm.h" +#include "mlir/Conversion/MathToNVVM/MathToNVVM.h" #include "mlir/Conversion/MathToROCDL/MathToROCDL.h" #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h" #include "mlir/Conversion/MathToXeVM/MathToXeVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 1096338534416..fd9cbddbd7ab0 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -827,6 +827,20 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { "Chipset that these operations will run on">]; } +//===----------------------------------------------------------------------===// +// MathToNVVM +//===----------------------------------------------------------------------===// + +def ConvertMathToNVVM : Pass<"convert-math-to-nvvm", "ModuleOp"> { + let summary = "Convert Math dialect to NVVM libdevice calls"; + let description = [{ + This pass converts supported Math ops to NVVM libdevice calls. + }]; + let dependentDialects = ["arith::ArithDialect", "func::FuncDialect", + "NVVM::NVVMDialect", "vector::VectorDialect", + ]; +} + //===----------------------------------------------------------------------===// // MathToSPIRV //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 2ed10effb53da..e17988b12cade 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -39,6 +39,7 @@ add_subdirectory(MathToEmitC) add_subdirectory(MathToFuncs) add_subdirectory(MathToLibm) add_subdirectory(MathToLLVM) +add_subdirectory(MathToNVVM) add_subdirectory(MathToROCDL) add_subdirectory(MathToSPIRV) add_subdirectory(MathToXeVM) diff --git a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt index 983aadf2c1517..681d788aa54dd 100644 --- a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt @@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRGPUToNVVMTransforms MLIRGPUToGPURuntimeTransforms MLIRLLVMCommonConversion MLIRLLVMDialect + MLIRMathToNVVM MLIRMemRefToLLVM MLIRNVGPUDialect MLIRNVGPUToNVVM diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 5fdfc9fa8cdb6..4d963c1681511 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -19,6 +19,7 @@ #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/MathToNVVM/MathToNVVM.h" #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -456,229 +457,12 @@ void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) { }); } -struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> { - using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value input = adaptor.getOperand(); - Type inputType = input.getType(); - auto convertedInput = maybeExt(input, rewriter); - auto computeType = convertedInput.getType(); - - StringRef sincosFunc; - if (isa<Float32Type>(computeType)) { - const arith::FastMathFlags flag = op.getFastmath(); - const bool useApprox = - mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn); - sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf"; - } else if (isa<Float64Type>(computeType)) { - sincosFunc = "__nv_sinco... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/180060 _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
