https://github.com/Jason-Van-Beusekom updated 
https://github.com/llvm/llvm-project/pull/180060

>From 744f2eedb47d7b391ec1130734f351633de1d3fc Mon Sep 17 00:00:00 2001
From: jason-van-beusekom <[email protected]>
Date: Thu, 5 Feb 2026 16:43:33 -0600
Subject: [PATCH] [FLANG][MLIR][OpenMP] add MathToNVVM conversion pass to NVPTX
 MLIR This Commit adds the MLIR MathToNVVM conversion pass to flang's NVPTX
 codegen lowering Math and Arith operations to libdevice library calls. This
 allows support for calls to Fortran math intrinsics for OpenMP offload for
 NVIDIA Targets.

---
 clang/lib/Driver/ToolChains/Flang.cpp      |  30 ++
 clang/lib/Driver/ToolChains/Flang.h        |   3 +
 flang/lib/Optimizer/CodeGen/CMakeLists.txt |   1 +
 flang/lib/Optimizer/CodeGen/CodeGen.cpp    |   8 +-
 flang/test/Lower/OpenMP/math-nvptx.f90     | 338 +++++++++++++++++++++
 5 files changed, 379 insertions(+), 1 deletion(-)
 create mode 100644 flang/test/Lower/OpenMP/math-nvptx.f90

diff --git a/clang/lib/Driver/ToolChains/Flang.cpp 
b/clang/lib/Driver/ToolChains/Flang.cpp
index 8425f8fec62a4..2ee6dbb85e132 100644
--- a/clang/lib/Driver/ToolChains/Flang.cpp
+++ b/clang/lib/Driver/ToolChains/Flang.cpp
@@ -8,6 +8,7 @@
 
 #include "Flang.h"
 #include "Arch/RISCV.h"
+#include "Cuda.h"
 
 #include "clang/Basic/CodeGenOptions.h"
 #include "clang/Driver/CommonArgs.h"
@@ -519,6 +520,31 @@ void Flang::AddAMDGPUTargetArgs(const ArgList &Args,
   TC.addClangTargetOptions(Args, CmdArgs, Action::OffloadKind::OFK_OpenMP);
 }
 
+void Flang::AddNVPTXTargetArgs(const ArgList &Args,
+                               ArgStringList &CmdArgs) const {
+
+  // we cannot use addClangTargetOptions, as it appends unsupported args for
+  // flang: -fcuda-is-device, -fno-threadsafe-statics,
+  // -fcuda-allow-variadic-functions and -target-sdk-version Instead we 
manually
+  // detect the CUDA installation and link libdevice
+  const ToolChain &TC = getToolChain();
+  const Driver &D = TC.getDriver();
+  const llvm::Triple &Triple = TC.getEffectiveTriple();
+
+  // Detect CUDA installation and link libdevice
+  CudaInstallationDetector CudaInstallation(D, Triple, Args);
+  if (CudaInstallation.isValid()) {
+    StringRef GpuArch = Args.getLastArgValue(options::OPT_march_EQ);
+    if (!GpuArch.empty()) {
+      std::string LibDeviceFile = CudaInstallation.getLibDeviceFile(GpuArch);
+      if (!LibDeviceFile.empty()) {
+        CmdArgs.push_back("-mlink-builtin-bitcode");
+        CmdArgs.push_back(Args.MakeArgString(LibDeviceFile));
+      }
+    }
+  }
+}
+
 void Flang::addTargetOptions(const ArgList &Args,
                              ArgStringList &CmdArgs) const {
   const ToolChain &TC = getToolChain();
@@ -548,6 +574,10 @@ void Flang::addTargetOptions(const ArgList &Args,
     getTargetFeatures(D, Triple, Args, CmdArgs, /*ForAs*/ false);
     AddAMDGPUTargetArgs(Args, CmdArgs);
     break;
+  case llvm::Triple::nvptx:
+  case llvm::Triple::nvptx64:
+    AddNVPTXTargetArgs(Args, CmdArgs);
+    break;
   case llvm::Triple::riscv64:
     getTargetFeatures(D, Triple, Args, CmdArgs, /*ForAs*/ false);
     AddRISCVTargetArgs(Args, CmdArgs);
diff --git a/clang/lib/Driver/ToolChains/Flang.h 
b/clang/lib/Driver/ToolChains/Flang.h
index c0837b80c032e..62d2c6bb2a093 100644
--- a/clang/lib/Driver/ToolChains/Flang.h
+++ b/clang/lib/Driver/ToolChains/Flang.h
@@ -78,6 +78,9 @@ class LLVM_LIBRARY_VISIBILITY Flang : public Tool {
   void AddAMDGPUTargetArgs(const llvm::opt::ArgList &Args,
                            llvm::opt::ArgStringList &CmdArgs) const;
 
+  void AddNVPTXTargetArgs(const llvm::opt::ArgList &Args,
+                          llvm::opt::ArgStringList &CmdArgs) const;
+
   /// Add specific options for LoongArch64 target.
   ///
   /// \param [in] Args The list of input driver arguments
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt 
b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index d5ea3c7a8e282..6977c737199e2 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -41,6 +41,7 @@ add_flang_library(FIRCodeGen
   MLIRMathToFuncs
   MLIRMathToLLVM
   MLIRMathToLibm
+  MLIRMathToNVVM
   MLIRMathToROCDL
   MLIROpenMPToLLVM
   MLIROpenACCDialect
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp 
b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 93f3806b18648..51db152f68850 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -44,6 +44,7 @@
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
+#include "mlir/Conversion/MathToNVVM/MathToNVVM.h"
 #include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
@@ -4365,6 +4366,7 @@ class FIRToLLVMLowering
     mlir::OpPassManager mathConversionPM("builtin.module");
 
     bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN();
+    bool isNVPTX = fir::getTargetTriple(mod).isNVPTX();
     // If compiling for AMD target some math operations must be lowered to AMD
     // GPU library calls, the rest can be converted to LLVM intrinsics, which
     // is handled in the mathToLLVM conversion. The lowering to libm calls is
@@ -4373,6 +4375,10 @@ class FIRToLLVMLowering
       mathConversionPM.addPass(mlir::createConvertMathToROCDL());
       
mathConversionPM.addPass(mlir::createConvertComplexToROCDLLibraryCalls());
     }
+    // If compiling for NVIDIA target some math operations must be lowered to
+    // NVVM libdevice calls.
+    if (isNVPTX)
+      mathConversionPM.addPass(mlir::createConvertMathToNVVM());
 
     // Convert math::FPowI operations to inline implementation
     // only if the exponent's width is greater than 32, otherwise,
@@ -4427,7 +4433,7 @@ class FIRToLLVMLowering
     mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, pattern);
     // Math operations that have not been converted yet must be converted
     // to Libm.
-    if (!isAMDGCN)
+    if (!isAMDGCN && !isNVPTX)
       mlir::populateMathToLibmConversionPatterns(pattern);
     mlir::populateComplexToLLVMConversionPatterns(typeConverter, pattern);
     mlir::index::populateIndexToLLVMConversionPatterns(typeConverter, pattern);
diff --git a/flang/test/Lower/OpenMP/math-nvptx.f90 
b/flang/test/Lower/OpenMP/math-nvptx.f90
new file mode 100644
index 0000000000000..8569e26b759e8
--- /dev/null
+++ b/flang/test/Lower/OpenMP/math-nvptx.f90
@@ -0,0 +1,338 @@
+!REQUIRES: nvptx-registered-target
+!RUN: %flang_fc1 -triple nvptx64-nvidia-cuda -emit-llvm -fopenmp 
-fopenmp-is-target-device %s -o - | FileCheck %s
+
+subroutine omp_pow_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_powf(float {{.*}}, float {{.*}})
+  y = x ** x
+end subroutine omp_pow_f32
+
+subroutine omp_pow_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_pow(double {{.*}}, double {{.*}})
+  y = x ** x
+end subroutine omp_pow_f64
+
+subroutine omp_sin_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_sinf(float {{.*}})
+  y = sin(x)
+end subroutine omp_sin_f32
+
+subroutine omp_sin_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_sin(double {{.*}})
+  y = sin(x)
+end subroutine omp_sin_f64
+
+subroutine omp_abs_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_fabsf(float {{.*}})
+  y = abs(x)
+end subroutine omp_abs_f32
+
+subroutine omp_abs_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_fabs(double {{.*}})
+  y = abs(x)
+end subroutine omp_abs_f64
+
+subroutine omp_atan_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_atanf(float {{.*}})
+  y = atan(x)
+end subroutine omp_atan_f32
+
+subroutine omp_atan_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_atan(double {{.*}})
+  y = atan(x)
+end subroutine omp_atan_f64
+
+subroutine omp_atanh_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_atanhf(float {{.*}})
+  y = atanh(x)
+end subroutine omp_atanh_f32
+
+subroutine omp_atanh_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_atanh(double {{.*}})
+  y = atanh(x)
+end subroutine omp_atanh_f64
+
+subroutine omp_atan2_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_atan2f(float {{.*}}, float {{.*}})
+  y = atan2(x, x)
+end subroutine omp_atan2_f32
+
+subroutine omp_atan2_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_atan2(double {{.*}}, double {{.*}})
+  y = atan2(x, x)
+end subroutine omp_atan2_f64
+
+subroutine omp_cos_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_cosf(float {{.*}})
+  y = cos(x)
+end subroutine omp_cos_f32
+
+subroutine omp_cos_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_cos(double {{.*}})
+  y = cos(x)
+end subroutine omp_cos_f64
+
+subroutine omp_erf_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_erff(float {{.*}})
+  y = erf(x)
+end subroutine omp_erf_f32
+
+subroutine omp_erf_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_erf(double {{.*}})
+  y = erf(x)
+end subroutine omp_erf_f64
+
+subroutine omp_erfc_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_erfcf(float {{.*}})
+  y = erfc(x)
+end subroutine omp_erfc_f32
+
+subroutine omp_erfc_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_erfc(double {{.*}})
+  y = erfc(x)
+end subroutine omp_erfc_f64
+
+subroutine omp_exp_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_expf(float {{.*}})
+  y = exp(x)
+end subroutine omp_exp_f32
+
+subroutine omp_exp_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_exp(double {{.*}})
+  y = exp(x)
+end subroutine omp_exp_f64
+
+subroutine omp_log_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_logf(float {{.*}})
+  y = log(x)
+end subroutine omp_log_f32
+
+subroutine omp_log_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_log(double {{.*}})
+  y = log(x)
+end subroutine omp_log_f64
+
+subroutine omp_log10_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_log10f(float {{.*}})
+  y = log10(x)
+end subroutine omp_log10_f32
+
+subroutine omp_log10_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_log10(double {{.*}})
+  y = log10(x)
+end subroutine omp_log10_f64
+
+subroutine omp_sqrt_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_sqrtf(float {{.*}})
+  y = sqrt(x)
+end subroutine omp_sqrt_f32
+
+subroutine omp_sqrt_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_sqrt(double {{.*}})
+  y = sqrt(x)
+end subroutine omp_sqrt_f64
+
+subroutine omp_tan_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_tanf(float {{.*}})
+  y = tan(x)
+end subroutine omp_tan_f32
+
+subroutine omp_tan_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_tan(double {{.*}})
+  y = tan(x)
+end subroutine omp_tan_f64
+
+subroutine omp_tanh_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_tanhf(float {{.*}})
+  y = tanh(x)
+end subroutine omp_tanh_f32
+
+subroutine omp_tanh_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_tanh(double {{.*}})
+  y = tanh(x)
+end subroutine omp_tanh_f64
+
+subroutine omp_acos_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_acosf(float {{.*}})
+  y = acos(x)
+end subroutine omp_acos_f32
+
+subroutine omp_acos_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_acos(double {{.*}})
+  y = acos(x)
+end subroutine omp_acos_f64
+
+subroutine omp_acosh_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_acoshf(float {{.*}})
+  y = acosh(x)
+end subroutine omp_acosh_f32
+
+subroutine omp_acosh_f64(x, y)
+!$omp declare target
+    real(8) :: x, y
+!CHECK: call double @__nv_acosh(double {{.*}})
+    y = acosh(x)
+end subroutine omp_acosh_f64
+
+subroutine omp_asin_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_asinf(float {{.*}})
+  y = asin(x)
+end subroutine omp_asin_f32
+
+subroutine omp_asin_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_asin(double {{.*}})
+  y = asin(x)
+end subroutine omp_asin_f64
+
+subroutine omp_asinh_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_asinhf(float {{.*}})
+  y = asinh(x)
+end subroutine omp_asinh_f32
+
+subroutine omp_asinh_f64(x, y)
+!$omp declare target
+    real(8) :: x, y
+!CHECK: call double @__nv_asinh(double {{.*}})
+    y = asinh(x)
+end subroutine omp_asinh_f64
+
+subroutine omp_cosh_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_coshf(float {{.*}})
+  y = cosh(x)
+end subroutine omp_cosh_f32
+
+subroutine omp_cosh_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_cosh(double {{.*}})
+  y = cosh(x)
+end subroutine omp_cosh_f64
+
+subroutine omp_sinh_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_sinhf(float {{.*}})
+  y = sinh(x)
+end subroutine omp_sinh_f32
+
+subroutine omp_sinh_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_sinh(double {{.*}})
+  y = sinh(x)
+end subroutine omp_sinh_f64
+
+subroutine omp_ceiling_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_ceilf(float {{.*}})
+  y = ceiling(x)
+end subroutine omp_ceiling_f32
+
+subroutine omp_ceiling_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_ceil(double {{.*}})
+  y = ceiling(x)
+end subroutine omp_ceiling_f64
+
+subroutine omp_floor_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__nv_floorf(float {{.*}})
+  y = floor(x)
+end subroutine omp_floor_f32    
+
+subroutine omp_floor_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__nv_floor(double {{.*}})
+  y = floor(x)
+end subroutine omp_floor_f64
+
+subroutine omp_sign_f32(x, y)
+!$omp declare target
+    real :: x, y, z
+!CHECK: call float @__nv_copysignf(float {{.*}}, float {{.*}})
+    y = sign(x, z)
+end subroutine omp_sign_f32
+
+subroutine omp_sign_f64(x, y)
+!$omp declare target
+    real(8) :: x, y, z
+!CHECK: call double @__nv_copysign(double {{.*}}, double {{.*.}})
+    y = sign(x, z)
+end subroutine omp_sign_f64

_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to