Author: Srinivasa Ravi Date: 2026-01-20T17:56:55+05:30 New Revision: 13205c51fc11ac44e36e684cd99c769cf31e27b5
URL: https://github.com/llvm/llvm-project/commit/13205c51fc11ac44e36e684cd99c769cf31e27b5 DIFF: https://github.com/llvm/llvm-project/commit/13205c51fc11ac44e36e684cd99c769cf31e27b5.diff LOG: [clang][NVPTX] Add missing half-precision add/mul/fma intrinsics (#170079) This change adds the following missing half-precision add/sub/fma intrinsics for the NVPTX target: - `llvm.nvvm.add.rn{.ftz}.sat.f16` - `llvm.nvvm.add.rn{.ftz}.sat.v2f16` - `llvm.nvvm.mul.rn{.ftz}.sat.f16` - `llvm.nvvm.mul.rn{.ftz}.sat.v2f16` - `llvm.nvvm.fma.rn.oob.*` We lower `fneg` followed by one of the above addition intrinsics to the corresponding `sub` instruction. This also removes some incorrect `bf16` fma intrinsics with no valid lowering. PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions Added: llvm/test/CodeGen/NVPTX/f16-add-sat.ll llvm/test/CodeGen/NVPTX/f16-mul-sat.ll llvm/test/CodeGen/NVPTX/f16-sub-sat.ll llvm/test/CodeGen/NVPTX/fma-oob.ll Modified: clang/include/clang/Basic/BuiltinsNVPTX.td clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp clang/test/CodeGen/builtins-nvptx.c llvm/docs/NVPTXUsage.rst llvm/include/llvm/IR/IntrinsicsNVVM.td llvm/lib/IR/AutoUpgrade.cpp llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp llvm/lib/Target/NVPTX/NVPTXIntrinsics.td llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp Removed: ################################################################################ diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td index 7ec3dfa4b059f..821c362d100c5 100644 --- a/clang/include/clang/Basic/BuiltinsNVPTX.td +++ b/clang/include/clang/Basic/BuiltinsNVPTX.td @@ -382,16 +382,24 @@ def __nvvm_fma_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16) def __nvvm_fma_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_53, PTX42>; def __nvvm_fma_rn_relu_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_80, PTX70>; def __nvvm_fma_rn_ftz_relu_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_80, PTX70>; +def __nvvm_fma_rn_oob_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_90, PTX81>; +def __nvvm_fma_rn_oob_relu_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16, __fp16)", SM_90, PTX81>; def __nvvm_fma_rn_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; def __nvvm_fma_rn_ftz_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; def __nvvm_fma_rn_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; def __nvvm_fma_rn_ftz_sat_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; def __nvvm_fma_rn_relu_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_80, PTX70>; def __nvvm_fma_rn_ftz_relu_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_80, PTX70>; +def __nvvm_fma_rn_oob_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_90, PTX81>; +def __nvvm_fma_rn_oob_relu_f16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>, _Vector<2, __fp16>)", SM_90, PTX81>; def __nvvm_fma_rn_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_80, PTX70>; def __nvvm_fma_rn_relu_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_80, PTX70>; +def __nvvm_fma_rn_oob_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_90, PTX81>; +def __nvvm_fma_rn_oob_relu_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf16)", SM_90, PTX81>; def __nvvm_fma_rn_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>; def __nvvm_fma_rn_relu_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>; +def __nvvm_fma_rn_oob_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_90, PTX81>; +def __nvvm_fma_rn_oob_relu_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_90, PTX81>; def __nvvm_fma_rn_ftz_f : NVPTXBuiltin<"float(float, float, float)">; def __nvvm_fma_rn_ftz_sat_f : NVPTXBuiltin<"float(float, float, float)">; def __nvvm_fma_rn_f : NVPTXBuiltin<"float(float, float, float)">; @@ -458,6 +466,11 @@ def __nvvm_rsqrt_approx_d : NVPTXBuiltin<"double(double)">; // Add +def __nvvm_add_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; +def __nvvm_add_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; +def __nvvm_add_rn_sat_v2f16 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; +def __nvvm_add_rn_ftz_sat_v2f16 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; + def __nvvm_add_rn_ftz_f : NVPTXBuiltin<"float(float, float)">; def __nvvm_add_rn_ftz_sat_f : NVPTXBuiltin<"float(float, float)">; def __nvvm_add_rn_f : NVPTXBuiltin<"float(float, float)">; @@ -480,6 +493,13 @@ def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">; def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">; def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">; +// Mul + +def __nvvm_mul_rn_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; +def __nvvm_mul_rn_ftz_sat_f16 : NVPTXBuiltinSMAndPTX<"__fp16(__fp16, __fp16)", SM_53, PTX42>; +def __nvvm_mul_rn_sat_v2f16 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; +def __nvvm_mul_rn_ftz_sat_v2f16 : NVPTXBuiltinSMAndPTX<"_Vector<2, __fp16>(_Vector<2, __fp16>, _Vector<2, __fp16>)", SM_53, PTX42>; + // Convert def __nvvm_d2f_rn_ftz : NVPTXBuiltin<"float(double)">; diff --git a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp index a4486965a851a..b4f7342e23473 100644 --- a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp +++ b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp @@ -415,6 +415,14 @@ static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID, return MakeHalfType(CGF.CGM.getIntrinsic(IntrinsicID), BuiltinID, E, CGF); } +static Value *MakeFMAOOB(unsigned IntrinsicID, llvm::Type *Ty, + const CallExpr *E, CodeGenFunction &CGF) { + return CGF.Builder.CreateCall(CGF.CGM.getIntrinsic(IntrinsicID, {Ty}), + {CGF.EmitScalarExpr(E->getArg(0)), + CGF.EmitScalarExpr(E->getArg(1)), + CGF.EmitScalarExpr(E->getArg(2))}); +} + } // namespace Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, @@ -963,6 +971,34 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, return MakeHalfType(Intrinsic::nvvm_fma_rn_sat_f16, BuiltinID, E, *this); case NVPTX::BI__nvvm_fma_rn_sat_f16x2: return MakeHalfType(Intrinsic::nvvm_fma_rn_sat_f16x2, BuiltinID, E, *this); + case NVPTX::BI__nvvm_fma_rn_oob_f16: + return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob, Builder.getHalfTy(), E, + *this); + case NVPTX::BI__nvvm_fma_rn_oob_f16x2: + return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob, + llvm::FixedVectorType::get(Builder.getHalfTy(), 2), E, + *this); + case NVPTX::BI__nvvm_fma_rn_oob_bf16: + return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob, Builder.getBFloatTy(), E, + *this); + case NVPTX::BI__nvvm_fma_rn_oob_bf16x2: + return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob, + llvm::FixedVectorType::get(Builder.getBFloatTy(), 2), E, + *this); + case NVPTX::BI__nvvm_fma_rn_oob_relu_f16: + return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, Builder.getHalfTy(), E, + *this); + case NVPTX::BI__nvvm_fma_rn_oob_relu_f16x2: + return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, + llvm::FixedVectorType::get(Builder.getHalfTy(), 2), E, + *this); + case NVPTX::BI__nvvm_fma_rn_oob_relu_bf16: + return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, Builder.getBFloatTy(), E, + *this); + case NVPTX::BI__nvvm_fma_rn_oob_relu_bf16x2: + return MakeFMAOOB(Intrinsic::nvvm_fma_rn_oob_relu, + llvm::FixedVectorType::get(Builder.getBFloatTy(), 2), E, + *this); case NVPTX::BI__nvvm_fmax_f16: return MakeHalfType(Intrinsic::nvvm_fmax_f16, BuiltinID, E, *this); case NVPTX::BI__nvvm_fmax_f16x2: diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c index cd1447374d000..a739b66042f19 100644 --- a/clang/test/CodeGen/builtins-nvptx.c +++ b/clang/test/CodeGen/builtins-nvptx.c @@ -31,6 +31,9 @@ // RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_80 -target-feature +ptx81 -DPTX=81 \ // RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \ // RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM80 %s +// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_90 -target-feature +ptx81 -DPTX=81\ +// RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \ +// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM90 %s // RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_90 -target-feature +ptx78 -DPTX=78 \ // RUN: -disable-llvm-optzns -fcuda-is-device -emit-llvm -o - -x cuda %s \ // RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX78_SM90 %s @@ -1579,3 +1582,55 @@ __device__ void nvvm_add_fma_f32_sat() { // CHECK: ret void } + +#define F16 (__fp16)0.1f +#define F16_2 (__fp16)0.2f +#define F16X2 {(__fp16)0.1f, (__fp16)0.1f} +#define F16X2_2 {(__fp16)0.2f, (__fp16)0.2f} + +// CHECK-LABEL: nvvm_add_mul_f16_sat +__device__ void nvvm_add_mul_f16_sat() { + // CHECK: call half @llvm.nvvm.add.rn.sat.f16 + __nvvm_add_rn_sat_f16(F16, F16_2); + // CHECK: call half @llvm.nvvm.add.rn.ftz.sat.f16 + __nvvm_add_rn_ftz_sat_f16(F16, F16_2); + // CHECK: call <2 x half> @llvm.nvvm.add.rn.sat.v2f16 + __nvvm_add_rn_sat_v2f16(F16X2, F16X2_2); + // CHECK: call <2 x half> @llvm.nvvm.add.rn.ftz.sat.v2f16 + __nvvm_add_rn_ftz_sat_v2f16(F16X2, F16X2_2); + + // CHECK: call half @llvm.nvvm.mul.rn.sat.f16 + __nvvm_mul_rn_sat_f16(F16, F16_2); + // CHECK: call half @llvm.nvvm.mul.rn.ftz.sat.f16 + __nvvm_mul_rn_ftz_sat_f16(F16, F16_2); + // CHECK: call <2 x half> @llvm.nvvm.mul.rn.sat.v2f16 + __nvvm_mul_rn_sat_v2f16(F16X2, F16X2_2); + // CHECK: call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.v2f16 + __nvvm_mul_rn_ftz_sat_v2f16(F16X2, F16X2_2); + + // CHECK: ret void +} + +// CHECK-LABEL: nvvm_fma_oob +__device__ void nvvm_fma_oob() { +#if __CUDA_ARCH__ >= 900 && (PTX >= 81) + // CHECK_PTX81_SM90: call half @llvm.nvvm.fma.rn.oob.f16 + __nvvm_fma_rn_oob_f16(F16, F16_2, F16_2); + // CHECK_PTX81_SM90: call half @llvm.nvvm.fma.rn.oob.relu.f16 + __nvvm_fma_rn_oob_relu_f16(F16, F16_2, F16_2); + // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.v2f16 + __nvvm_fma_rn_oob_f16x2(F16X2, F16X2_2, F16X2_2); + // CHECK_PTX81_SM90: call <2 x half> @llvm.nvvm.fma.rn.oob.relu.v2f16 + __nvvm_fma_rn_oob_relu_f16x2(F16X2, F16X2_2, F16X2_2); + + // CHECK_PTX81_SM90: call bfloat @llvm.nvvm.fma.rn.oob.bf16 + __nvvm_fma_rn_oob_bf16(BF16, BF16_2, BF16_2); + // CHECK_PTX81_SM90: call bfloat @llvm.nvvm.fma.rn.oob.relu.bf16 + __nvvm_fma_rn_oob_relu_bf16(BF16, BF16_2, BF16_2); + // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.v2bf16 + __nvvm_fma_rn_oob_bf16x2(BF16X2, BF16X2_2, BF16X2_2); + // CHECK_PTX81_SM90: call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.v2bf16 + __nvvm_fma_rn_oob_relu_bf16x2(BF16X2, BF16X2_2, BF16X2_2); +#endif + // CHECK: ret void +} diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst index c8400115c3159..0920d1726c407 100644 --- a/llvm/docs/NVPTXUsage.rst +++ b/llvm/docs/NVPTXUsage.rst @@ -1188,6 +1188,106 @@ used in the '``llvm.nvvm.idp4a.[us].u``' variants, while sign-extension is used with '``llvm.nvvm.idp4a.[us].s``' variants. The dot product of these 4-element vectors is added to ``%c`` to produce the return. +'``llvm.nvvm.add.*``' Half-precision Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" + +.. code-block:: llvm + + declare half @llvm.nvvm.add.rn.sat.f16(half %a, half %b) + declare <2 x half> @llvm.nvvm.add.rn.sat.v2f16(<2 x half> %a, <2 x half> %b) + + declare half @llvm.nvvm.add.rn.ftz.sat.f16(half %a, half %b) + declare <2 x half> @llvm.nvvm.add.rn.ftz.sat.v2f16(<2 x half> %a, <2 x half> %b) + +Overview: +""""""""" + +The '``llvm.nvvm.add.*``' intrinsics perform an addition operation with the +specified rounding mode and modifiers. + +Semantics: +"""""""""" + +The '``.sat``' modifier performs a saturating addition where the result is +clamped to ``[0.0, 1.0]`` and ``NaN`` results are flushed to ``+0.0f``. +The '``.ftz``' modifier flushes subnormal inputs and results to sign-preserving +zero. + +'``llvm.nvvm.mul.*``' Half-precision Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" + +.. code-block:: llvm + + declare half @llvm.nvvm.mul.rn.sat.f16(half %a, half %b) + declare <2 x half> @llvm.nvvm.mul.rn.sat.v2f16(<2 x half> %a, <2 x half> %b) + + declare half @llvm.nvvm.mul.rn.ftz.sat.f16(half %a, half %b) + declare <2 x half> @llvm.nvvm.mul.rn.ftz.sat.v2f16(<2 x half> %a, <2 x half> %b) + +Overview: +""""""""" + +The '``llvm.nvvm.mul.*``' intrinsics perform a multiplication operation with +the specified rounding mode and modifiers. + +Semantics: +"""""""""" + +The '``.sat``' modifier performs a saturating multiplication where the result is +clamped to ``[0.0, 1.0]`` and ``NaN`` results are flushed to ``+0.0f``. +The '``.ftz``' modifier flushes subnormal inputs and results to sign-preserving +zero. + +'``llvm.nvvm.fma.*``' Half-precision Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" + +.. code-block:: llvm + + declare half @llvm.nvvm.fma.rn{.ftz}.f16(half %a, half %b, half %c) + declare <2 x half> @llvm.nvvm.fma.rn{.ftz}.f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %c) + declare bfloat @llvm.nvvm.fma.rn.bf16(bfloat %a, bfloat %b, bfloat %c) + declare <2 x bfloat> @llvm.nvvm.fma.rn.bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) + + declare half @llvm.nvvm.fma.rn{.ftz}.sat.f16(half %a, half %b, half %c) + declare <2 x half> @llvm.nvvm.fma.rn{.ftz}.sat.f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %c) + + declare half @llvm.nvvm.fma.rn{.ftz}.relu.f16(half %a, half %b, half %c) + declare <2 x half> @llvm.nvvm.fma.rn{.ftz}.relu.f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %c) + declare bfloat @llvm.nvvm.fma.rn.relu.bf16(bfloat %a, bfloat %b, bfloat %c) + declare <2 x bfloat> @llvm.nvvm.fma.rn.relu.bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) + + declare half @llvm.nvvm.fma.rn.oob{.relu}.f16(half %a, half %b, half %c) + declare <2 x half> @llvm.nvvm.fma.rn.oob{.relu}.v2f16(<2 x half> %a, <2 x half> %b, <2 x half> %c) + declare bfloat @llvm.nvvm.fma.rn.oob{.relu}.bf16(bfloat %a, bfloat %b, bfloat %c) + declare <2 x bfloat> @llvm.nvvm.fma.rn.oob{.relu}.v2bf16(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) + +Overview: +""""""""" + +The '``llvm.nvvm.fma.*``' intrinsics perform a fused multiply-add with no loss +of precision in the intermediate product and addition. + +Semantics: +"""""""""" + +The '``.sat``' modifier performs a saturating operation where the result is +clamped to ``[0.0, 1.0]`` and ``NaN`` results are flushed to ``+0.0f``. +The '``.ftz``' modifier flushes subnormal inputs and results to sign-preserving +zero. +The '``.relu``' modifier clamps the result to ``0`` if negative and ``NaN`` +results are flushed to canonical ``NaN``. +The '``.oob``' modifier clamps the result to ``0`` if either of the operands is +an ``OOB NaN`` (defined under `Tensors <https://docs.nvidia.com/cuda/parallel-thread-execution/#tensors>`__) value. + Bit Manipulation Intrinsics --------------------------- diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 98a568461381d..3918290230a7c 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -1365,6 +1365,14 @@ let TargetPrefix = "nvvm" in { def int_nvvm_mul_ # rnd # _d : NVVMBuiltin, DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>; } + + foreach ftz = ["", "_ftz"] in { + def int_nvvm_mul_rn # ftz # _sat_f16 : NVVMBuiltin, + DefaultAttrsIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>; + + def int_nvvm_mul_rn # ftz # _sat_v2f16 : NVVMBuiltin, + DefaultAttrsIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>; + } // ftz } // @@ -1496,16 +1504,23 @@ let TargetPrefix = "nvvm" in { def int_nvvm_fma_rn # ftz # variant # _f16x2 : PureIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty]>; - - def int_nvvm_fma_rn # ftz # variant # _bf16 : NVVMBuiltin, - PureIntrinsic<[llvm_bfloat_ty], - [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>; - - def int_nvvm_fma_rn # ftz # variant # _bf16x2 : NVVMBuiltin, - PureIntrinsic<[llvm_v2bf16_ty], - [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>; } // ftz } // variant + + foreach relu = ["", "_relu"] in { + def int_nvvm_fma_rn # relu # _bf16 : NVVMBuiltin, + PureIntrinsic<[llvm_bfloat_ty], + [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty]>; + + def int_nvvm_fma_rn # relu # _bf16x2 : NVVMBuiltin, + PureIntrinsic<[llvm_v2bf16_ty], + [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty]>; + + // oob (out-of-bounds) - clamps the result to 0 if either of the operand is + // an OOB NaN value. + def int_nvvm_fma_rn_oob # relu : PureIntrinsic<[llvm_anyfloat_ty], + [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>; + } // relu foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in { foreach ftz = ["", "_ftz"] in { @@ -1575,6 +1590,7 @@ let TargetPrefix = "nvvm" in { // // Add // + let IntrProperties = [IntrNoMem, IntrSpeculatable, Commutative, IntrNoCreateUndefOrPoison] in { foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in { @@ -1586,7 +1602,16 @@ let TargetPrefix = "nvvm" in { } // ftz def int_nvvm_add # rnd # _d : NVVMBuiltin, DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>; - } // rnd + } + + foreach ftz = ["", "_ftz"] in { + def int_nvvm_add_rn # ftz # _sat_f16 : NVVMBuiltin, + DefaultAttrsIntrinsic<[llvm_half_ty], [llvm_half_ty, llvm_half_ty]>; + + def int_nvvm_add_rn # ftz # _sat_v2f16 : NVVMBuiltin, + DefaultAttrsIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty, llvm_v2f16_ty]>; + + } // ftz } // diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp index 1c51a6fa980c3..b886a589d3ff1 100644 --- a/llvm/lib/IR/AutoUpgrade.cpp +++ b/llvm/lib/IR/AutoUpgrade.cpp @@ -1171,16 +1171,8 @@ static Intrinsic::ID shouldUpgradeNVPTXBF16Intrinsic(StringRef Name) { return StringSwitch<Intrinsic::ID>(Name) .Case("bf16", Intrinsic::nvvm_fma_rn_bf16) .Case("bf16x2", Intrinsic::nvvm_fma_rn_bf16x2) - .Case("ftz.bf16", Intrinsic::nvvm_fma_rn_ftz_bf16) - .Case("ftz.bf16x2", Intrinsic::nvvm_fma_rn_ftz_bf16x2) - .Case("ftz.relu.bf16", Intrinsic::nvvm_fma_rn_ftz_relu_bf16) - .Case("ftz.relu.bf16x2", Intrinsic::nvvm_fma_rn_ftz_relu_bf16x2) - .Case("ftz.sat.bf16", Intrinsic::nvvm_fma_rn_ftz_sat_bf16) - .Case("ftz.sat.bf16x2", Intrinsic::nvvm_fma_rn_ftz_sat_bf16x2) .Case("relu.bf16", Intrinsic::nvvm_fma_rn_relu_bf16) .Case("relu.bf16x2", Intrinsic::nvvm_fma_rn_relu_bf16x2) - .Case("sat.bf16", Intrinsic::nvvm_fma_rn_sat_bf16) - .Case("sat.bf16x2", Intrinsic::nvvm_fma_rn_sat_bf16x2) .Default(Intrinsic::not_intrinsic); if (Name.consume_front("fmax.")) diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 729e3139ca2ca..1be35a1c67457 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -867,15 +867,29 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand); // We have some custom DAG combine patterns for these nodes - setTargetDAGCombine( - {ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, - ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM, - ISD::FMAXIMUM, ISD::FMINIMUM, ISD::FMAXIMUMNUM, - ISD::FMINIMUMNUM, ISD::MUL, ISD::SELECT, - ISD::SHL, ISD::SREM, ISD::UREM, - ISD::VSELECT, ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, - ISD::LOAD, ISD::STORE, ISD::ZERO_EXTEND, - ISD::SIGN_EXTEND}); + setTargetDAGCombine({ISD::ADD, + ISD::AND, + ISD::EXTRACT_VECTOR_ELT, + ISD::FADD, + ISD::FMAXNUM, + ISD::FMINNUM, + ISD::FMAXIMUM, + ISD::FMINIMUM, + ISD::FMAXIMUMNUM, + ISD::FMINIMUMNUM, + ISD::MUL, + ISD::SELECT, + ISD::SHL, + ISD::SREM, + ISD::UREM, + ISD::VSELECT, + ISD::BUILD_VECTOR, + ISD::ADDRSPACECAST, + ISD::LOAD, + ISD::STORE, + ISD::ZERO_EXTEND, + ISD::SIGN_EXTEND, + ISD::INTRINSIC_WO_CHAIN}); // setcc for f16x2 and bf16x2 needs special handling to prevent // legalizer's attempt to scalarize it due to v2i1 not being legal. @@ -6836,6 +6850,59 @@ static SDValue sinkProxyReg(SDValue R, SDValue Chain, } } +static unsigned getF16SubOpc(Intrinsic::ID AddIntrinsicID) { + switch (AddIntrinsicID) { + default: + break; + case Intrinsic::nvvm_add_rn_sat_f16: + case Intrinsic::nvvm_add_rn_sat_v2f16: + return NVPTXISD::SUB_RN_SAT; + case Intrinsic::nvvm_add_rn_ftz_sat_f16: + case Intrinsic::nvvm_add_rn_ftz_sat_v2f16: + return NVPTXISD::SUB_RN_FTZ_SAT; + } + llvm_unreachable("Invalid F16 add intrinsic"); +} + +static SDValue combineF16AddWithNeg(SDNode *N, SelectionDAG &DAG, + Intrinsic::ID AddIntrinsicID) { + SDValue Op1 = N->getOperand(1); + SDValue Op2 = N->getOperand(2); + + SDValue SubOp1, SubOp2; + + if (Op1.getOpcode() == ISD::FNEG) { + SubOp1 = Op2; + SubOp2 = Op1.getOperand(0); + } else if (Op2.getOpcode() == ISD::FNEG) { + SubOp1 = Op1; + SubOp2 = Op2.getOperand(0); + } else { + return SDValue(); + } + + SDLoc DL(N); + return DAG.getNode(getF16SubOpc(AddIntrinsicID), DL, N->getValueType(0), + SubOp1, SubOp2); +} + +static SDValue combineIntrinsicWOChain(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + const NVPTXSubtarget &STI) { + unsigned IID = N->getConstantOperandVal(0); + + switch (IID) { + default: + break; + case Intrinsic::nvvm_add_rn_sat_f16: + case Intrinsic::nvvm_add_rn_ftz_sat_f16: + case Intrinsic::nvvm_add_rn_sat_v2f16: + case Intrinsic::nvvm_add_rn_ftz_sat_v2f16: + return combineF16AddWithNeg(N, DCI.DAG, IID); + } + return SDValue(); +} + static SDValue combineProxyReg(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { @@ -6904,6 +6971,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, return PerformSELECTShiftCombine(N, DCI); case ISD::VSELECT: return PerformVSELECTCombine(N, DCI); + case ISD::INTRINSIC_WO_CHAIN: + return combineIntrinsicWOChain(N, DCI, STI); } return SDValue(); } diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index ad7031d088c79..ac9ce96da6cf0 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1503,6 +1503,11 @@ def INT_NVVM_MUL_RP_D : F_MATH_2<"mul.rp.f64", B64, B64, B64, int_nvvm_mul_rp_d> def INT_NVVM_MUL24_I : F_MATH_2<"mul24.lo.s32", B32, B32, B32, int_nvvm_mul24_i>; def INT_NVVM_MUL24_UI : F_MATH_2<"mul24.lo.u32", B32, B32, B32, int_nvvm_mul24_ui>; +def INT_NVVM_MUL_RN_SAT_F16 : F_MATH_2<"mul.rn.sat.f16", B16, B16, B16, int_nvvm_mul_rn_sat_f16>; +def INT_NVVM_MUL_RN_FTZ_SAT_F16 : F_MATH_2<"mul.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_mul_rn_ftz_sat_f16>; +def INT_NVVM_MUL_RN_SAT_F16X2 : F_MATH_2<"mul.rn.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_sat_v2f16>; +def INT_NVVM_MUL_RN_FTZ_SAT_F16X2 : F_MATH_2<"mul.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_mul_rn_ftz_sat_v2f16>; + // // Div // @@ -1705,16 +1710,8 @@ multiclass FMA_INST { [hasPTX<70>, hasSM<80>]>, FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, B16, [hasPTX<70>, hasSM<80>]>, - FMA_TUPLE<"_rn_ftz_bf16", int_nvvm_fma_rn_ftz_bf16, B16, - [hasPTX<70>, hasSM<80>]>, - FMA_TUPLE<"_rn_sat_bf16", int_nvvm_fma_rn_sat_bf16, B16, - [hasPTX<70>, hasSM<80>]>, - FMA_TUPLE<"_rn_ftz_sat_bf16", int_nvvm_fma_rn_ftz_sat_bf16, B16, - [hasPTX<70>, hasSM<80>]>, FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, B16, [hasPTX<70>, hasSM<80>]>, - FMA_TUPLE<"_rn_ftz_relu_bf16", int_nvvm_fma_rn_ftz_relu_bf16, B16, - [hasPTX<70>, hasSM<80>]>, FMA_TUPLE<"_rn_f16x2", int_nvvm_fma_rn_f16x2, B32, [hasPTX<42>, hasSM<53>]>, @@ -1728,10 +1725,11 @@ multiclass FMA_INST { [hasPTX<70>, hasSM<80>]>, FMA_TUPLE<"_rn_ftz_relu_f16x2", int_nvvm_fma_rn_ftz_relu_f16x2, B32, [hasPTX<70>, hasSM<80>]>, + FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, B32, [hasPTX<70>, hasSM<80>]>, FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, B32, - [hasPTX<70>, hasSM<80>]> + [hasPTX<70>, hasSM<80>]>, ] in { def P.Variant : F_MATH_3<!strconcat("fma", !subst("_", ".", P.Variant)), @@ -1767,6 +1765,18 @@ let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in { (INT_NVVM_MIXED_FMA_rn_f32_bf16 B16:$a, B16:$b, B32:$c)>; } +foreach ty = [F16RT, F16X2RT, BF16RT, BF16X2RT] in { + foreach relu = ["", "_relu"] in { + defvar Intr = !cast<Intrinsic>("int_nvvm_fma_rn_oob" # relu); + defvar suffix = !subst("_", ".", relu # "_" # ty.PtxType); + def INT_NVVM_FMA_OOB # relu # ty.PtxType : + BasicNVPTXInst<(outs ty.RC:$dst), (ins ty.RC:$a, ty.RC:$b, ty.RC:$c), + "fma.rn.oob" # suffix, + [(set ty.Ty:$dst, (Intr ty.Ty:$a, ty.Ty:$b, ty.Ty:$c))]>, + Requires<[hasPTX<81>, hasSM<90>]>; + } +} + // // Rcp // @@ -1865,6 +1875,11 @@ let Predicates = [doRsqrtOpt] in { // Add // +def INT_NVVM_ADD_RN_SAT_F16 : F_MATH_2<"add.rn.sat.f16", B16, B16, B16, int_nvvm_add_rn_sat_f16>; +def INT_NVVM_ADD_RN_FTZ_SAT_F16 : F_MATH_2<"add.rn.ftz.sat.f16", B16, B16, B16, int_nvvm_add_rn_ftz_sat_f16>; +def INT_NVVM_ADD_RN_SAT_F16X2 : F_MATH_2<"add.rn.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_sat_v2f16>; +def INT_NVVM_ADD_RN_FTZ_SAT_F16X2 : F_MATH_2<"add.rn.ftz.sat.f16x2", B32, B32, B32, int_nvvm_add_rn_ftz_sat_v2f16>; + def INT_NVVM_ADD_RN_FTZ_F : F_MATH_2<"add.rn.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_f>; def INT_NVVM_ADD_RN_SAT_FTZ_F : F_MATH_2<"add.rn.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_sat_f>; def INT_NVVM_ADD_RN_F : F_MATH_2<"add.rn.f32", B32, B32, B32, int_nvvm_add_rn_f>; @@ -1914,6 +1929,21 @@ let Predicates = [hasSM<100>, hasPTX<86>, doNoF32FTZ] in { // Sub // +def sub_rn_sat : SDNode<"NVPTXISD::SUB_RN_SAT", SDTFPBinOp>; +def sub_rn_ftz_sat : + SDNode<"NVPTXISD::SUB_RN_FTZ_SAT", SDTFPBinOp>; + +class INT_NVVM_SUB_RN<RegTyInfo TyInfo, string variant> : + BasicNVPTXInst<(outs TyInfo.RC:$dst), (ins TyInfo.RC:$a, TyInfo.RC:$b), + !subst("_", ".", "sub.rn" # variant # "." # TyInfo.PtxType), + [(set TyInfo.Ty:$dst, + (!cast<SDNode>("sub_rn" # variant) TyInfo.Ty:$a, TyInfo.Ty:$b))]>; + +def INT_NVVM_SUB_RN_SAT_F16 : INT_NVVM_SUB_RN<F16RT, "_sat">; +def INT_NVVM_SUB_RN_FTZ_SAT_F16 : INT_NVVM_SUB_RN<F16RT, "_ftz_sat">; +def INT_NVVM_SUB_RN_SAT_F16X2 : INT_NVVM_SUB_RN<F16X2RT, "_sat">; +def INT_NVVM_SUB_RN_FTZ_SAT_F16X2 : INT_NVVM_SUB_RN<F16X2RT, "_ftz_sat">; + foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in { foreach ftz = ["", "_ftz"] in { foreach sat = ["", "_sat"] in { diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp index 334c2775007c7..c1fe9300785a3 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp @@ -207,12 +207,8 @@ static Instruction *convertNvvmIntrinsicToLlvm(InstCombiner &IC, return {Intrinsic::fma, FTZ_MustBeOn, true}; case Intrinsic::nvvm_fma_rn_bf16: return {Intrinsic::fma, FTZ_MustBeOff, true}; - case Intrinsic::nvvm_fma_rn_ftz_bf16: - return {Intrinsic::fma, FTZ_MustBeOn, true}; case Intrinsic::nvvm_fma_rn_bf16x2: return {Intrinsic::fma, FTZ_MustBeOff, true}; - case Intrinsic::nvvm_fma_rn_ftz_bf16x2: - return {Intrinsic::fma, FTZ_MustBeOn, true}; case Intrinsic::nvvm_fmax_d: return {Intrinsic::maxnum, FTZ_Any}; case Intrinsic::nvvm_fmax_f: diff --git a/llvm/test/CodeGen/NVPTX/f16-add-sat.ll b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll new file mode 100644 index 0000000000000..c2ffc126694c4 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/f16-add-sat.ll @@ -0,0 +1,63 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s +; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%} + +define half @add_rn_sat_f16(half %a, half %b) { +; CHECK-LABEL: add_rn_sat_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [add_rn_sat_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [add_rn_sat_f16_param_1]; +; CHECK-NEXT: add.rn.sat.f16 %rs3, %rs1, %rs2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.add.rn.sat.f16(half %a, half %b) + ret half %1 +} + +define <2 x half> @add_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) { +; CHECK-LABEL: add_rn_sat_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [add_rn_sat_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [add_rn_sat_f16x2_param_1]; +; CHECK-NEXT: add.rn.sat.f16x2 %r3, %r1, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r3; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.add.rn.sat.v2f16(<2 x half> %a, <2 x half> %b) + ret <2 x half> %1 +} + +define half @add_rn_ftz_sat_f16(half %a, half %b) { +; CHECK-LABEL: add_rn_ftz_sat_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [add_rn_ftz_sat_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [add_rn_ftz_sat_f16_param_1]; +; CHECK-NEXT: add.rn.ftz.sat.f16 %rs3, %rs1, %rs2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.add.rn.ftz.sat.f16(half %a, half %b) + ret half %1 +} + +define <2 x half> @add_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) { +; CHECK-LABEL: add_rn_ftz_sat_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [add_rn_ftz_sat_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [add_rn_ftz_sat_f16x2_param_1]; +; CHECK-NEXT: add.rn.ftz.sat.f16x2 %r3, %r1, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r3; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.v2f16(<2 x half> %a, <2 x half> %b) + ret <2 x half> %1 +} diff --git a/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll new file mode 100644 index 0000000000000..4bcc018f290d7 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/f16-mul-sat.ll @@ -0,0 +1,63 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s +; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%} + +define half @mul_rn_sat_f16(half %a, half %b) { +; CHECK-LABEL: mul_rn_sat_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [mul_rn_sat_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [mul_rn_sat_f16_param_1]; +; CHECK-NEXT: mul.rn.sat.f16 %rs3, %rs1, %rs2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.mul.rn.sat.f16(half %a, half %b) + ret half %1 +} + +define <2 x half> @mul_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) { +; CHECK-LABEL: mul_rn_sat_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [mul_rn_sat_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [mul_rn_sat_f16x2_param_1]; +; CHECK-NEXT: mul.rn.sat.f16x2 %r3, %r1, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r3; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.mul.rn.sat.v2f16(<2 x half> %a, <2 x half> %b) + ret <2 x half> %1 +} + +define half @mul_rn_ftz_sat_f16(half %a, half %b) { +; CHECK-LABEL: mul_rn_ftz_sat_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [mul_rn_ftz_sat_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [mul_rn_ftz_sat_f16_param_1]; +; CHECK-NEXT: mul.rn.ftz.sat.f16 %rs3, %rs1, %rs2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.mul.rn.ftz.sat.f16(half %a, half %b) + ret half %1 +} + +define <2 x half> @mul_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) { +; CHECK-LABEL: mul_rn_ftz_sat_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [mul_rn_ftz_sat_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [mul_rn_ftz_sat_f16x2_param_1]; +; CHECK-NEXT: mul.rn.ftz.sat.f16x2 %r3, %r1, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r3; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.mul.rn.ftz.sat.v2f16(<2 x half> %a, <2 x half> %b) + ret <2 x half> %1 +} diff --git a/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll new file mode 100644 index 0000000000000..774ce7ccb2f95 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/f16-sub-sat.ll @@ -0,0 +1,69 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | FileCheck %s +; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx60 | FileCheck %s +; RUN: %if ptxas-isa-4.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx42 | %ptxas-verify%} +; RUN: %if ptxas-isa-6.0 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_53 -mattr=+ptx60 | %ptxas-verify%} + +define half @sub_rn_sat_f16(half %a, half %b) { +; CHECK-LABEL: sub_rn_sat_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [sub_rn_sat_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [sub_rn_sat_f16_param_1]; +; CHECK-NEXT: sub.rn.sat.f16 %rs3, %rs1, %rs2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; +; CHECK-NEXT: ret; + %1 = fneg half %b + %res = call half @llvm.nvvm.add.rn.sat.f16(half %a, half %1) + ret half %res +} + +define <2 x half> @sub_rn_sat_f16x2(<2 x half> %a, <2 x half> %b) { +; CHECK-LABEL: sub_rn_sat_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [sub_rn_sat_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [sub_rn_sat_f16x2_param_1]; +; CHECK-NEXT: sub.rn.sat.f16x2 %r3, %r1, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r3; +; CHECK-NEXT: ret; + %1 = fneg <2 x half> %b + %res = call <2 x half> @llvm.nvvm.add.rn.sat.v2f16(<2 x half> %a, <2 x half> %1) + ret <2 x half> %res +} + +define half @sub_rn_ftz_sat_f16(half %a, half %b) { +; CHECK-LABEL: sub_rn_ftz_sat_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [sub_rn_ftz_sat_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [sub_rn_ftz_sat_f16_param_1]; +; CHECK-NEXT: sub.rn.ftz.sat.f16 %rs3, %rs1, %rs2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs3; +; CHECK-NEXT: ret; + %1 = fneg half %b + %res = call half @llvm.nvvm.add.rn.ftz.sat.f16(half %a, half %1) + ret half %res +} + +define <2 x half> @sub_rn_ftz_sat_f16x2(<2 x half> %a, <2 x half> %b) { +; CHECK-LABEL: sub_rn_ftz_sat_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [sub_rn_ftz_sat_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [sub_rn_ftz_sat_f16x2_param_1]; +; CHECK-NEXT: sub.rn.ftz.sat.f16x2 %r3, %r1, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0], %r3; +; CHECK-NEXT: ret; + %1 = fneg <2 x half> %b + %res = call <2 x half> @llvm.nvvm.add.rn.ftz.sat.v2f16(<2 x half> %a, <2 x half> %1) + ret <2 x half> %res +} diff --git a/llvm/test/CodeGen/NVPTX/fma-oob.ll b/llvm/test/CodeGen/NVPTX/fma-oob.ll new file mode 100644 index 0000000000000..7fd9ae13d1998 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/fma-oob.ll @@ -0,0 +1,131 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx81 | FileCheck %s +; RUN: %if ptxas-isa-8.1 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx81 | %ptxas-verify -arch=sm_90 %} + +define half @fma_oob_f16(half %a, half %b, half %c) { +; CHECK-LABEL: fma_oob_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_f16_param_1]; +; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_f16_param_2]; +; CHECK-NEXT: fma.rn.oob.f16 %rs4, %rs1, %rs2, %rs3; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs4; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.fma.rn.oob.f16(half %a, half %b, half %c) + ret half %1 +} + +define half @fma_oob_relu_f16(half %a, half %b, half %c) { +; CHECK-LABEL: fma_oob_relu_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_relu_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_relu_f16_param_1]; +; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_relu_f16_param_2]; +; CHECK-NEXT: fma.rn.oob.relu.f16 %rs4, %rs1, %rs2, %rs3; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs4; +; CHECK-NEXT: ret; + %1 = call half @llvm.nvvm.fma.rn.oob.relu.f16(half %a, half %b, half %c) + ret half %1 +} + +define <2 x half> @fma_oob_f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %c) { +; CHECK-LABEL: fma_oob_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_f16x2_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_f16x2_param_2]; +; CHECK-NEXT: fma.rn.oob.f16x2 %r4, %r1, %r2, %r3; +; CHECK-NEXT: st.param.b32 [func_retval0], %r4; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.v2f16( <2 x half> %a, <2 x half> %b, <2 x half> %c) + ret <2 x half> %1 +} + +define <2 x half> @fma_oob_relu_f16x2(<2 x half> %a, <2 x half> %b, <2 x half> %c) { +; CHECK-LABEL: fma_oob_relu_f16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_relu_f16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_relu_f16x2_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_relu_f16x2_param_2]; +; CHECK-NEXT: fma.rn.oob.relu.f16x2 %r4, %r1, %r2, %r3; +; CHECK-NEXT: st.param.b32 [func_retval0], %r4; +; CHECK-NEXT: ret; + %1 = call <2 x half> @llvm.nvvm.fma.rn.oob.relu.v2f16( <2 x half> %a, <2 x half> %b, <2 x half> %c) + ret <2 x half> %1 +} + +define bfloat @fma_oob_bf16(bfloat %a, bfloat %b, bfloat %c) { +; CHECK-LABEL: fma_oob_bf16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_bf16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_bf16_param_1]; +; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_bf16_param_2]; +; CHECK-NEXT: fma.rn.oob.bf16 %rs4, %rs1, %rs2, %rs3; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs4; +; CHECK-NEXT: ret; + %1 = call bfloat @llvm.nvvm.fma.rn.oob.bf16(bfloat %a, bfloat %b, bfloat %c) + ret bfloat %1 +} + +define bfloat @fma_oob_relu_bf16(bfloat %a, bfloat %b, bfloat %c) { +; CHECK-LABEL: fma_oob_relu_bf16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [fma_oob_relu_bf16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [fma_oob_relu_bf16_param_1]; +; CHECK-NEXT: ld.param.b16 %rs3, [fma_oob_relu_bf16_param_2]; +; CHECK-NEXT: fma.rn.oob.relu.bf16 %rs4, %rs1, %rs2, %rs3; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs4; +; CHECK-NEXT: ret; + %1 = call bfloat @llvm.nvvm.fma.rn.oob.relu.bf16(bfloat %a, bfloat %b, bfloat %c) + ret bfloat %1 +} + +define <2 x bfloat> @fma_oob_bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) { +; CHECK-LABEL: fma_oob_bf16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_bf16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_bf16x2_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_bf16x2_param_2]; +; CHECK-NEXT: fma.rn.oob.bf16x2 %r4, %r1, %r2, %r3; +; CHECK-NEXT: st.param.b32 [func_retval0], %r4; +; CHECK-NEXT: ret; + %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.v2bf16( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) + ret <2 x bfloat> %1 +} + +define <2 x bfloat> @fma_oob_relu_bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) { +; CHECK-LABEL: fma_oob_relu_bf16x2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [fma_oob_relu_bf16x2_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [fma_oob_relu_bf16x2_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [fma_oob_relu_bf16x2_param_2]; +; CHECK-NEXT: fma.rn.oob.relu.bf16x2 %r4, %r1, %r2, %r3; +; CHECK-NEXT: st.param.b32 [func_retval0], %r4; +; CHECK-NEXT: ret; + %1 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.v2bf16( <2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) + ret <2 x bfloat> %1 +} _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
