https://github.com/NeKon69 updated https://github.com/llvm/llvm-project/pull/185304
>From 3ff62c7e5f75e2fc6160119a52a74e617c6d90a9 Mon Sep 17 00:00:00 2001 From: NeKon69 <[email protected]> Date: Sun, 8 Mar 2026 20:15:48 +0300 Subject: [PATCH 1/5] [hlsl][dxil][spirv] Add support for `fma` intrinsic --- clang/include/clang/Basic/Builtins.td | 6 + .../clang/Basic/DiagnosticSemaKinds.td | 6 + clang/lib/CodeGen/CGHLSLBuiltins.cpp | 16 ++ .../lib/Headers/hlsl/hlsl_alias_intrinsics.h | 16 ++ clang/lib/Sema/SemaHLSL.cpp | 61 +++++++ .../test/CodeGen/SPIRV/hlsl-intrinsics/fma.ll | 0 clang/test/CodeGenHLSL/builtins/fma.hlsl | 151 ++++++++++++++++++ .../Sema/incompatible-function-to-ptr-decay.c | 18 +++ clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl | 145 +++++++++++++++++ llvm/include/llvm/IR/IntrinsicsDirectX.td | 2 + llvm/include/llvm/IR/IntrinsicsSPIRV.td | 2 +- llvm/lib/Target/DirectX/DXIL.td | 10 ++ llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 11 +- .../Target/SPIRV/SPIRVInstructionSelector.cpp | 2 + .../DirectX/ShaderFlags/double-extensions.ll | 8 + .../test/CodeGen/SPIRV/hlsl-intrinsics/fma.ll | 53 ++++++ 16 files changed, 505 insertions(+), 2 deletions(-) create mode 100644 clang/test/CodeGen/SPIRV/hlsl-intrinsics/fma.ll create mode 100644 clang/test/CodeGenHLSL/builtins/fma.hlsl create mode 100644 clang/test/Sema/incompatible-function-to-ptr-decay.c create mode 100644 clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fma.ll diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index 531c3702161f2..542249f829424 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -5318,6 +5318,12 @@ def HLSLNormalize : LangBuiltin<"HLSL_LANG"> { let Prototype = "void(...)"; } +def HLSLFma : LangBuiltin<"HLSL_LANG"> { + let Spellings = ["__builtin_hlsl_elementwise_fma"]; + let Attributes = [NoThrow, Const, CustomTypeChecking]; + let Prototype = "void(...)"; +} + def HLSLRcp : LangBuiltin<"HLSL_LANG"> { let Spellings = ["__builtin_hlsl_elementwise_rcp"]; let Attributes = [NoThrow, Const, CustomTypeChecking]; diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index 8882ac9b8c0a8..787cd7bcc61bb 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -13229,6 +13229,12 @@ def err_builtin_invalid_arg_type: Error< "%plural{0:|: }3" "%plural{[0,3]:type|:types}1 (was %4)">; +def err_builtin_requires_double_type: Error< + "%ordinal0 argument must be a scalar, vector, or matrix of double type (was %1)">; + +def err_builtin_requires_fp_scalar_or_vector_type: Error< + "%ordinal0 argument must be a scalar or vector of floating-point type (was %1)">; + def err_bswapg_invalid_bit_width : Error< "_BitInt type %0 (%1 bits) must be a multiple of 16 bits for byte swapping">; diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index 70891eac39425..bb5eaf12c93cc 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -979,6 +979,22 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, retType, CGM.getHLSLRuntime().getIsNaNIntrinsic(), ArrayRef<Value *>{Op0}, nullptr, "hlsl.isnan"); } + case Builtin::BI__builtin_hlsl_elementwise_fma: { + Value *M = EmitScalarExpr(E->getArg(0)); + Value *A = EmitScalarExpr(E->getArg(1)); + Value *B = EmitScalarExpr(E->getArg(2)); + if (CGM.getTarget().getTriple().isDXIL()) + return Builder.CreateIntrinsic(M->getType(), Intrinsic::dx_fma, + ArrayRef<Value *>{M, A, B}, nullptr, + "dx.fma"); + + if (CGM.getTarget().getTriple().isSPIRV()) + return Builder.CreateIntrinsic(M->getType(), Intrinsic::spv_fma, + ArrayRef<Value *>{M, A, B}, nullptr, + "spv.fma"); + + break; + } case Builtin::BI__builtin_hlsl_mad: { Value *M = EmitScalarExpr(E->getArg(0)); Value *A = EmitScalarExpr(E->getArg(1)); diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h index 2543401bdfbf9..ab5c6edd6d555 100644 --- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h @@ -1891,6 +1891,22 @@ float3 pow(float3, float3); _HLSL_BUILTIN_ALIAS(__builtin_elementwise_pow) float4 pow(float4, float4); +//===----------------------------------------------------------------------===// +// fused multiply-add builtins +//===----------------------------------------------------------------------===// + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_fma) +double fma(double, double, double); + +template <int s> +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_fma) +vector<double, s> fma(vector<double, s>, vector<double, s>, vector<double, s>); + +template <int w, int h> +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_fma) +matrix<double, w, h> fma(matrix<double, w, h>, matrix<double, w, h>, + matrix<double, w, h>); + //===----------------------------------------------------------------------===// // reversebits builtins //===----------------------------------------------------------------------===// diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 804ea70aaddce..624f621b532a1 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -31,6 +31,7 @@ #include "clang/Basic/TargetInfo.h" #include "clang/Sema/Initialization.h" #include "clang/Sema/Lookup.h" +#include "clang/Sema/Ownership.h" #include "clang/Sema/ParsedAttr.h" #include "clang/Sema/Sema.h" #include "clang/Sema/Template.h" @@ -3040,6 +3041,36 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc, return false; } +static bool CheckFloatOrHalfOrDoubleRepresentation(Sema *S, SourceLocation Loc, + int ArgOrdinal, + clang::QualType PassedType) { + clang::QualType BaseType = + PassedType->isVectorType() + ? PassedType->castAs<clang::VectorType>()->getElementType() + : PassedType; + if (!BaseType->isFloatingType()) + return S->Diag(Loc, diag::err_builtin_requires_fp_scalar_or_vector_type) + << ArgOrdinal << PassedType; + return false; +} + +static bool CheckAnyDoubleRepresentation(Sema *S, SourceLocation Loc, + int ArgOrdinal, + clang::QualType PassedType) { + clang::QualType BaseType = + PassedType->isVectorType() + ? PassedType->castAs<clang::VectorType>()->getElementType() + : PassedType->isMatrixType() + ? PassedType->castAs<clang::MatrixType>()->getElementType() + : PassedType; + if (!BaseType->isDoubleType()) { + return S->Diag(Loc, diag::err_builtin_requires_double_type) + << ArgOrdinal << PassedType; + } + + return false; +} + static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall, unsigned ArgIndex) { auto *Arg = TheCall->getArg(ArgIndex); @@ -3787,6 +3818,35 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { TheCall->setType(ArgTyA); break; } + case Builtin::BI__builtin_hlsl_elementwise_fma: { + if (SemaRef.checkArgCount(TheCall, 3)) { + return true; + } + const llvm::Triple &TT = getASTContext().getTargetInfo().getTriple(); + // This check is here because emitting a general error for both backends + // here (like for exmaple "Accepts only floating points") won't end really + // good. after that we still need to check if the types satisfy + // backends constrains, so we better check everything now rather than + // confusing user with 2 different error messages + + if (TT.isSPIRV()) { + // SPIR-V accept any float (besides matrices) + if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, + CheckFloatOrHalfOrDoubleRepresentation)) + return true; + } else if (TT.isDXIL()) { + // while DirectX accepts only double + if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, + CheckAnyDoubleRepresentation)) + return true; + } + + ExprResult A = TheCall->getArg(0); + QualType ArgTyA = A.get()->getType(); + // return type is the same as input type + TheCall->setType(ArgTyA); + break; + } case Builtin::BI__builtin_hlsl_elementwise_sign: { if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall)) return true; @@ -3936,6 +3996,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { case Builtin::BI__builtin_elementwise_exp10: case Builtin::BI__builtin_elementwise_floor: case Builtin::BI__builtin_elementwise_fmod: + case Builtin::BI__builtin_elementwise_fma: case Builtin::BI__builtin_elementwise_log: case Builtin::BI__builtin_elementwise_log2: case Builtin::BI__builtin_elementwise_log10: diff --git a/clang/test/CodeGen/SPIRV/hlsl-intrinsics/fma.ll b/clang/test/CodeGen/SPIRV/hlsl-intrinsics/fma.ll new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/clang/test/CodeGenHLSL/builtins/fma.hlsl b/clang/test/CodeGenHLSL/builtins/fma.hlsl new file mode 100644 index 0000000000000..88b8e27c37043 --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/fma.hlsl @@ -0,0 +1,151 @@ +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: dxil-pc-shadermodel6.3-library %s -DTEST_DXIL \ +// RUN: -fmatrix-memory-layout=row-major -emit-llvm -disable-llvm-passes -o - | \ +// RUN: FileCheck %s --check-prefixes=CHECK,DXIL_CHECK -DTARGET=dx +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: spirv-unknown-vulkan-compute %s -DTEST_SPIRV \ +// RUN: -fmatrix-memory-layout=row-major -emit-llvm -disable-llvm-passes -o - | \ +// RUN: FileCheck %s --check-prefixes=CHECK,SPIRV_CHECK -DTARGET=spv +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: spirv-unknown-vulkan-compute %s -DTEST_SPIRV_HALF -fnative-half-type \ +// RUN: -fmatrix-memory-layout=row-major -emit-llvm -disable-llvm-passes -o - | \ +// RUN: FileCheck %s --check-prefix=SPIRV_HALF_CHECK + +// CHECK-LABEL: define {{.*}} double @{{.*}}fma_double{{.*}}( +// CHECK: %[[P0:.*]] = load double, ptr %{{.*}}, align 8 +// CHECK: %[[P1:.*]] = load double, ptr %{{.*}}, align 8 +// CHECK: %[[P2:.*]] = load double, ptr %{{.*}}, align 8 +// CHECK: %{{dx|spv}}.fma = call reassoc nnan ninf nsz arcp afn double @llvm.[[TARGET]].fma.f64(double %[[P0]], double %[[P1]], double %[[P2]]) +// CHECK: ret double %{{dx|spv}}.fma +double dxil_fma_double(double a, double b, double c) { return fma(a, b, c); } + +// CHECK-LABEL: define {{.*}} <2 x double> @{{.*}}fma_double2{{.*}}( +// CHECK: %[[P0:.*]] = load <2 x double>, ptr %{{.*}}, align 16 +// CHECK: %[[P1:.*]] = load <2 x double>, ptr %{{.*}}, align 16 +// CHECK: %[[P2:.*]] = load <2 x double>, ptr %{{.*}}, align 16 +// CHECK: %{{dx|spv}}.fma = call reassoc nnan ninf nsz arcp afn <2 x double> @llvm.[[TARGET]].fma.v2f64(<2 x double> %[[P0]], <2 x double> %[[P1]], <2 x double> %[[P2]]) +// CHECK: ret <2 x double> %{{dx|spv}}.fma +double2 dxil_fma_double2(double2 a, double2 b, double2 c) { return fma(a, b, c); } + +// CHECK-LABEL: define {{.*}} <3 x double> @{{.*}}fma_double3{{.*}}( +// CHECK: %[[P0:.*]] = load <3 x double>, ptr %{{.*}}, align 32 +// CHECK: %[[P1:.*]] = load <3 x double>, ptr %{{.*}}, align 32 +// CHECK: %[[P2:.*]] = load <3 x double>, ptr %{{.*}}, align 32 +// CHECK: %{{dx|spv}}.fma = call reassoc nnan ninf nsz arcp afn <3 x double> @llvm.[[TARGET]].fma.v3f64(<3 x double> %[[P0]], <3 x double> %[[P1]], <3 x double> %[[P2]]) +// CHECK: ret <3 x double> %{{dx|spv}}.fma +double3 dxil_fma_double3(double3 a, double3 b, double3 c) { return fma(a, b, c); } + +// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double4{{.*}}( +// CHECK: %[[P0:.*]] = load <4 x double>, ptr %{{.*}}, align 32 +// CHECK: %[[P1:.*]] = load <4 x double>, ptr %{{.*}}, align 32 +// CHECK: %[[P2:.*]] = load <4 x double>, ptr %{{.*}}, align 32 +// CHECK: %{{dx|spv}}.fma = call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.[[TARGET]].fma.v4f64(<4 x double> %[[P0]], <4 x double> %[[P1]], <4 x double> %[[P2]]) +// CHECK: ret <4 x double> %{{dx|spv}}.fma +double4 dxil_fma_double4(double4 a, double4 b, double4 c) { return fma(a, b, c); } + +#ifdef TEST_DXIL + +// DXIL_CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}dxil_fma_double1x4{{.*}}( +// DXIL_CHECK: %dx.fma = call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.dx.fma.v4f64( +// DXIL_CHECK: ret <4 x double> %dx.fma +double1x4 dxil_fma_double1x4(double1x4 a, double1x4 b, double1x4 c) { return fma(a, b, c); } + +// DXIL_CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}dxil_fma_double4x1{{.*}}( +// DXIL_CHECK: %dx.fma = call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.dx.fma.v4f64( +// DXIL_CHECK: ret <4 x double> %dx.fma +double4x1 dxil_fma_double4x1(double4x1 a, double4x1 b, double4x1 c) { return fma(a, b, c); } + +// DXIL_CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}dxil_fma_double2x2{{.*}}( +// DXIL_CHECK: %dx.fma = call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.dx.fma.v4f64( +// DXIL_CHECK: ret <4 x double> %dx.fma +double2x2 dxil_fma_double2x2(double2x2 a, double2x2 b, double2x2 c) { return fma(a, b, c); } + +// DXIL_CHECK-LABEL: define {{.*}} <6 x double> @{{.*}}dxil_fma_double2x3{{.*}}( +// DXIL_CHECK: %dx.fma = call reassoc nnan ninf nsz arcp afn <6 x double> @llvm.dx.fma.v6f64( +// DXIL_CHECK: ret <6 x double> %dx.fma +double2x3 dxil_fma_double2x3(double2x3 a, double2x3 b, double2x3 c) { return fma(a, b, c); } + +// DXIL_CHECK-LABEL: define {{.*}} <6 x double> @{{.*}}dxil_fma_double3x2{{.*}}( +// DXIL_CHECK: %dx.fma = call reassoc nnan ninf nsz arcp afn <6 x double> @llvm.dx.fma.v6f64( +// DXIL_CHECK: ret <6 x double> %dx.fma +double3x2 dxil_fma_double3x2(double3x2 a, double3x2 b, double3x2 c) { return fma(a, b, c); } + +// DXIL_CHECK-LABEL: define {{.*}} <9 x double> @{{.*}}dxil_fma_double3x3{{.*}}( +// DXIL_CHECK: %dx.fma = call reassoc nnan ninf nsz arcp afn <9 x double> @llvm.dx.fma.v9f64( +// DXIL_CHECK: ret <9 x double> %dx.fma +double3x3 dxil_fma_double3x3(double3x3 a, double3x3 b, double3x3 c) { return fma(a, b, c); } + +// DXIL_CHECK-LABEL: define {{.*}} <16 x double> @{{.*}}dxil_fma_double4x4{{.*}}( +// DXIL_CHECK: %dx.fma = call reassoc nnan ninf nsz arcp afn <16 x double> @llvm.dx.fma.v16f64( +// DXIL_CHECK: ret <16 x double> %dx.fma +double4x4 dxil_fma_double4x4(double4x4 a, double4x4 b, double4x4 c) { return fma(a, b, c); } +#endif + +#ifdef TEST_SPIRV +// SPIRV_CHECK-LABEL: define {{.*}} float @{{.*}}spv_fma_float{{.*}}( +// SPIRV_CHECK: %[[P0:.*]] = load float, ptr %{{.*}}, align 4 +// SPIRV_CHECK: %[[P1:.*]] = load float, ptr %{{.*}}, align 4 +// SPIRV_CHECK: %[[P2:.*]] = load float, ptr %{{.*}}, align 4 +// SPIRV_CHECK: %spv.fma = call reassoc nnan ninf nsz arcp afn float @llvm.spv.fma.f32(float %[[P0]], float %[[P1]], float %[[P2]]) +// SPIRV_CHECK: ret float %spv.fma +float spv_fma_float(float a, float b, float c) { return fma(a, b, c); } + +// SPIRV_CHECK-LABEL: define {{.*}} <2 x float> @{{.*}}spv_fma_float2{{.*}}( +// SPIRV_CHECK: %[[P0:.*]] = load <2 x float>, ptr %{{.*}}, align 8 +// SPIRV_CHECK: %[[P1:.*]] = load <2 x float>, ptr %{{.*}}, align 8 +// SPIRV_CHECK: %[[P2:.*]] = load <2 x float>, ptr %{{.*}}, align 8 +// SPIRV_CHECK: %spv.fma = call reassoc nnan ninf nsz arcp afn <2 x float> @llvm.spv.fma.v2f32(<2 x float> %[[P0]], <2 x float> %[[P1]], <2 x float> %[[P2]]) +// SPIRV_CHECK: ret <2 x float> %spv.fma +float2 spv_fma_float2(float2 a, float2 b, float2 c) { return fma(a, b, c); } + +// SPIRV_CHECK-LABEL: define {{.*}} <3 x float> @{{.*}}spv_fma_float3{{.*}}( +// SPIRV_CHECK: %[[P0:.*]] = load <3 x float>, ptr %{{.*}}, align 16 +// SPIRV_CHECK: %[[P1:.*]] = load <3 x float>, ptr %{{.*}}, align 16 +// SPIRV_CHECK: %[[P2:.*]] = load <3 x float>, ptr %{{.*}}, align 16 +// SPIRV_CHECK: %spv.fma = call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.spv.fma.v3f32(<3 x float> %[[P0]], <3 x float> %[[P1]], <3 x float> %[[P2]]) +// SPIRV_CHECK: ret <3 x float> %spv.fma +float3 spv_fma_float3(float3 a, float3 b, float3 c) { return fma(a, b, c); } + +// SPIRV_CHECK-LABEL: define {{.*}} <4 x float> @{{.*}}spv_fma_float4{{.*}}( +// SPIRV_CHECK: %[[P0:.*]] = load <4 x float>, ptr %{{.*}}, align 16 +// SPIRV_CHECK: %[[P1:.*]] = load <4 x float>, ptr %{{.*}}, align 16 +// SPIRV_CHECK: %[[P2:.*]] = load <4 x float>, ptr %{{.*}}, align 16 +// SPIRV_CHECK: %spv.fma = call reassoc nnan ninf nsz arcp afn <4 x float> @llvm.spv.fma.v4f32(<4 x float> %[[P0]], <4 x float> %[[P1]], <4 x float> %[[P2]]) +// SPIRV_CHECK: ret <4 x float> %spv.fma +float4 spv_fma_float4(float4 a, float4 b, float4 c) { return fma(a, b, c); } + +#endif + +#ifdef TEST_SPIRV_HALF +// SPIRV_HALF_CHECK-LABEL: define {{.*}} half @{{.*}}spv_fma_half{{.*}}( +// SPIRV_HALF_CHECK: %[[P0:.*]] = load half, ptr %{{.*}}, align 2 +// SPIRV_HALF_CHECK: %[[P1:.*]] = load half, ptr %{{.*}}, align 2 +// SPIRV_HALF_CHECK: %[[P2:.*]] = load half, ptr %{{.*}}, align 2 +// SPIRV_HALF_CHECK: %spv.fma = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fma.f16(half %[[P0]], half %[[P1]], half %[[P2]]) +// SPIRV_HALF_CHECK: ret half %spv.fma +half spv_fma_half(half a, half b, half c) { return fma(a, b, c); } + +// SPIRV_HALF_CHECK-LABEL: define {{.*}} <2 x half> @{{.*}}spv_fma_half2{{.*}}( +// SPIRV_HALF_CHECK: %[[P0:.*]] = load <2 x half>, ptr %{{.*}}, align 4 +// SPIRV_HALF_CHECK: %[[P1:.*]] = load <2 x half>, ptr %{{.*}}, align 4 +// SPIRV_HALF_CHECK: %[[P2:.*]] = load <2 x half>, ptr %{{.*}}, align 4 +// SPIRV_HALF_CHECK: %spv.fma = call reassoc nnan ninf nsz arcp afn <2 x half> @llvm.spv.fma.v2f16(<2 x half> %[[P0]], <2 x half> %[[P1]], <2 x half> %[[P2]]) +// SPIRV_HALF_CHECK: ret <2 x half> %spv.fma +half2 spv_fma_half2(half2 a, half2 b, half2 c) { return fma(a, b, c); } + +// SPIRV_HALF_CHECK-LABEL: define {{.*}} <3 x half> @{{.*}}spv_fma_half3{{.*}}( +// SPIRV_HALF_CHECK: %[[P0:.*]] = load <3 x half>, ptr %{{.*}}, align 8 +// SPIRV_HALF_CHECK: %[[P1:.*]] = load <3 x half>, ptr %{{.*}}, align 8 +// SPIRV_HALF_CHECK: %[[P2:.*]] = load <3 x half>, ptr %{{.*}}, align 8 +// SPIRV_HALF_CHECK: %spv.fma = call reassoc nnan ninf nsz arcp afn <3 x half> @llvm.spv.fma.v3f16(<3 x half> %[[P0]], <3 x half> %[[P1]], <3 x half> %[[P2]]) +// SPIRV_HALF_CHECK: ret <3 x half> %spv.fma +half3 spv_fma_half3(half3 a, half3 b, half3 c) { return fma(a, b, c); } + +// SPIRV_HALF_CHECK-LABEL: define {{.*}} <4 x half> @{{.*}}spv_fma_half4{{.*}}( +// SPIRV_HALF_CHECK: %[[P0:.*]] = load <4 x half>, ptr %{{.*}}, align 8 +// SPIRV_HALF_CHECK: %[[P1:.*]] = load <4 x half>, ptr %{{.*}}, align 8 +// SPIRV_HALF_CHECK: %[[P2:.*]] = load <4 x half>, ptr %{{.*}}, align 8 +// SPIRV_HALF_CHECK: %spv.fma = call reassoc nnan ninf nsz arcp afn <4 x half> @llvm.spv.fma.v4f16(<4 x half> %[[P0]], <4 x half> %[[P1]], <4 x half> %[[P2]]) +// SPIRV_HALF_CHECK: ret <4 x half> %spv.fma +half4 spv_fma_half4(half4 a, half4 b, half4 c) { return fma(a, b, c); } +#endif diff --git a/clang/test/Sema/incompatible-function-to-ptr-decay.c b/clang/test/Sema/incompatible-function-to-ptr-decay.c new file mode 100644 index 0000000000000..240b5b8763a23 --- /dev/null +++ b/clang/test/Sema/incompatible-function-to-ptr-decay.c @@ -0,0 +1,18 @@ +// RUN: %clang_cc1 -fsyntax-only -fexperimental-overflow-behavior-types -verify %s + +// Issue 182534 +int foo(); + +void bar(__attribute__((opencl_global)) int*); // #cldecl +void baz(__ob_wrap int*); // #ofdecl + +void a() { + bar(foo); + // expected-error@-1 {{passing 'int (*)()' to parameter of type '__global int *' changes address space of pointer}} + // expected-note@#cldecl {{passing argument to parameter here}} + __ob_trap int val[10]; + baz(val); + // expected-error@-1 {{assigning to '__ob_wrap int *' from '__ob_trap int *' with incompatible overflow behavior types ('__ob_wrap' and '__ob_trap')}} + // expected-note@#ofdecl {{passing argument to parameter here}} +} + diff --git a/clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl new file mode 100644 index 0000000000000..1ed7b34b4396f --- /dev/null +++ b/clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl @@ -0,0 +1,145 @@ +// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -x hlsl \ +// RUN: -triple dxil-pc-shadermodel6.6-library %s -DTEST_DXIL \ +// RUN: -emit-llvm-only -disable-llvm-passes -verify=dxil +// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -x hlsl \ +// RUN: -triple spirv-unknown-vulkan-compute %s -DTEST_SPIRV \ +// RUN: -emit-llvm-only -disable-llvm-passes -verify=spv + +#ifdef TEST_DXIL +float dxil_fma_float(float a, float b, float c) { + return fma(a, b, c); + // dxil-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float')}} +} + +float2 dxil_fma_float2(float2 a, float2 b, float2 c) { + return fma(a, b, c); + // dxil-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float2' (aka 'vector<float, 2>'))}} +} + +float4 dxil_fma_float4(float4 a, float4 b, float4 c) { + return fma(a, b, c); + // dxil-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float4' (aka 'vector<float, 4>'))}} +} + +float2x2 dxil_fma_float2x2(float2x2 a, float2x2 b, float2x2 c) { + return fma(a, b, c); + // dxil-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float2x2' (aka 'matrix<float, 2, 2>'))}} +} + +double dxil_fma_bad_second(double a, float b, double c) { + return fma(a, b, c); + // dxil-error@-1 {{2nd argument must be a scalar, vector, or matrix of double type (was 'float')}} +} + +double dxil_fma_bad_third(double a, double b, half c) { + return fma(a, b, c); + // dxil-error@-1 {{3rd argument must be a scalar, vector, or matrix of double type (was 'half')}} +} + +double2 dxil_fma_bad_second_vec(double2 a, float2 b, double2 c) { + return fma(a, b, c); + // dxil-error@-1 {{2nd argument must be a scalar, vector, or matrix of double type (was 'float2' (aka 'vector<float, 2>'))}} +} + +double2x2 dxil_fma_bad_third_mat(double2x2 a, double2x2 b, float2x2 c) { + return fma(a, b, c); + // dxil-error@-1 {{3rd argument must be a scalar, vector, or matrix of double type (was 'float2x2' (aka 'matrix<float, 2, 2>'))}} +} + +half dxil_fma_half(half a, half b, half c) { + return fma(a, b, c); + // dxil-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'half')}} +} + +half2 dxil_fma_half2(half2 a, half2 b, half2 c) { + return fma(a, b, c); + // dxil-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'half2' (aka 'vector<half, 2>'))}} +} + +int dxil_fma_int(int a, int b, int c) { + return fma(a, b, c); + // dxil-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'int')}} +} + +bool dxil_fma_bool(bool a, bool b, bool c) { + return fma(a, b, c); + // dxil-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'bool')}} +} +#endif + +#ifdef TEST_SPIRV +int spv_fma_int(int a, int b, int c) { + return fma(a, b, c); + // spv-error@-1 {{1st argument must be a scalar or vector of floating-point type (was 'int')}} +} + +int2 spv_fma_int2(int2 a, int2 b, int2 c) { + return fma(a, b, c); + // spv-error@-1 {{1st argument must be a scalar or vector of floating-point type (was 'int2' (aka 'vector<int, 2>'))}} +} + +bool spv_fma_bool(bool a, bool b, bool c) { + return fma(a, b, c); + // spv-error@-1 {{1st argument must be a scalar or vector of floating-point type (was 'bool')}} +} + +float spv_fma_bad_second(float a, int b, float c) { + return fma(a, b, c); + // spv-error@-1 {{2nd argument must be a scalar or vector of floating-point type (was 'int')}} +} + +float spv_fma_bad_third(float a, float b, bool c) { + return fma(a, b, c); + // spv-error@-1 {{3rd argument must be a scalar or vector of floating-point type (was 'bool')}} +} + +float2 spv_fma_bad_second_vec(float2 a, int2 b, float2 c) { + return fma(a, b, c); + // spv-error@-1 {{2nd argument must be a scalar or vector of floating-point type (was 'int2' (aka 'vector<int, 2>'))}} +} + +double2 spv_fma_bad_third_vec(double2 a, double2 b, int2 c) { + return fma(a, b, c); + // spv-error@-1 {{3rd argument must be a scalar or vector of floating-point type (was 'int2' (aka 'vector<int, 2>'))}} +} + +float2x2 spv_fma_float2x2(float2x2 a, float2x2 b, float2x2 c) { + return fma(a, b, c); + // spv-error@-1 {{1st argument must be a scalar or vector of floating-point type (was 'float2x2' (aka 'matrix<float, 2, 2>'))}} +} + +float2 spv_fma_bad_second_mat(float2 a, float2x2 b, float2 c) { + return fma(a, b, c); + // spv-error@-1 {{2nd argument must be a scalar or vector of floating-point type (was 'float2x2' (aka 'matrix<float, 2, 2>'))}} +} + +double2 spv_fma_bad_third_mat(double2 a, double2 b, double2x2 c) { + return fma(a, b, c); + // spv-error@-1 {{3rd argument must be a scalar or vector of floating-point type (was 'double2x2' (aka 'matrix<double, 2, 2>'))}} +} + +float2x3 spv_fma_float2x3(float2x3 a, float2x3 b, float2x3 c) { + return fma(a, b, c); + // spv-error@-1 {{1st argument must be a scalar or vector of floating-point type (was 'float2x3' (aka 'matrix<float, 2, 3>'))}} +} + +float3x2 spv_fma_float3x2(float3x2 a, float3x2 b, float3x2 c) { + return fma(a, b, c); + // spv-error@-1 {{1st argument must be a scalar or vector of floating-point type (was 'float3x2' (aka 'matrix<float, 3, 2>'))}} +} + +float4x4 spv_fma_float4x4(float4x4 a, float4x4 b, float4x4 c) { + return fma(a, b, c); + // spv-error@-1 {{1st argument must be a scalar or vector of floating-point type (was 'float4x4' (aka 'matrix<float, 4, 4>'))}} +} + +double2x2 spv_fma_double2x2(double2x2 a, double2x2 b, double2x2 c) { + return fma(a, b, c); + // spv-error@-1 {{1st argument must be a scalar or vector of floating-point type (was 'double2x2' (aka 'matrix<double, 2, 2>'))}} +} + +half2x2 spv_fma_half2x2(half2x2 a, half2x2 b, half2x2 c) { + return fma(a, b, c); + // spv-error@-1 {{1st argument must be a scalar or vector of floating-point type (was 'half2x2' (aka 'matrix<half, 2, 2>'))}} +} +#endif diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index 909482d72aa88..1d2b8faa90f8a 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -259,4 +259,6 @@ def int_dx_store_output [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i8_ty, llvm_i32_ty, llvm_any_ty], [IntrConvergent]>; +// We reject any non-double types in SemaHLSL.cpp so hopefully they won't fall through here. as we don't have `llvm_anydouble_ty` we have to rely on Sema to do its job and filter out all non-double types. +def int_dx_fma : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>; } diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index 9819f881b5c30..d4b2736e0577c 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -292,5 +292,5 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty] def int_spv_unpackhalf2x16 : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [llvm_i32_ty], [IntrNoMem]>; def int_spv_packhalf2x16 : DefaultAttrsIntrinsic<[llvm_anyint_ty], [llvm_anyfloat_ty], [IntrNoMem]>; - + def int_spv_fma : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>; } diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 59a5b7fe4d508..bf9e881041f85 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -782,6 +782,16 @@ def FMad : DXILOp<46, tertiary> { let attributes = [Attributes<DXIL1_0, [ReadNone]>]; } +def Fma : DXILOp<47, tertiary> { + let Doc = "Double-precision fused multiply-add. fma(a,b,c) = a * b + c."; + let intrinsics = [IntrinSelect<int_dx_fma>]; + let arguments = [OverloadTy, OverloadTy, OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads<DXIL1_0, [DoubleTy]>]; + let stages = [Stages<DXIL1_0, [all_stages]>]; + let attributes = [Attributes<DXIL1_0, [ReadNone]>]; +} + def IMad : DXILOp<48, tertiary> { let Doc = "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m " "* a + b."; diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp index 7e16dcda87a57..b8a9f03c92844 100644 --- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp +++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp @@ -106,6 +106,15 @@ static bool checkWaveOps(Intrinsic::ID IID) { } } +static bool checkFmaOps(Intrinsic::ID IID) { + switch (IID) { + default: + return false; + case Intrinsic::dx_fma: + return true; + } +} + static bool isOptimizationDisabled(const Module &M) { const StringRef Key = "dx.disable_optimizations"; if (auto *Flag = mdconst::extract_or_null<ConstantInt>(M.getModuleFlag(Key))) @@ -245,7 +254,7 @@ void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF, // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554 - + CSF.DX11_1_DoubleExtensions |= checkFmaOps(CI->getIntrinsicID()); CSF.WaveOps |= checkWaveOps(CI->getIntrinsicID()); } } diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index dd0830bb879f5..73d2ad23c673d 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -4033,6 +4033,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return selectAll(ResVReg, ResType, I); case Intrinsic::spv_any: return selectAny(ResVReg, ResType, I); + case Intrinsic::spv_fma: + return selectExtInst(ResVReg, ResType, I, CL::fma, GL::Fma); case Intrinsic::spv_cross: return selectExtInst(ResVReg, ResType, I, CL::cross, GL::Cross); case Intrinsic::spv_distance: diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll index dd8ea5f5b1aec..f71ae7bb4a299 100644 --- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll +++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll @@ -26,6 +26,12 @@ define double @test_fdiv_double(double %a, double %b) #0 { ret double %res } +; CHECK: ; Function test_fma_double : 0x00000044 +define double @test_fma_double(double %a, double %b, double %c) #0 { + %r = call double @llvm.dx.fma.f64(double %a, double %b, double %c) + ret double %r +} + ; CHECK: ; Function test_uitofp_i64 : 0x00100044 define double @test_uitofp_i64(i64 %a) #0 { %r = uitofp i64 %a to double @@ -50,4 +56,6 @@ define i64 @test_fptosi_i64(double %a) #0 { ret i64 %r } +declare double @llvm.dx.fma.f64(double, double, double) + attributes #0 = { convergent norecurse nounwind "hlsl.export"} diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fma.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fma.ll new file mode 100644 index 0000000000000..28e7bfa36f591 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fma.ll @@ -0,0 +1,53 @@ +; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %} + +; CHECK: OpExtInstImport "GLSL.std.450" + +define noundef half @fma_half(half noundef %a, half noundef %b, half noundef %c) { +entry: +; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] Fma %[[#]] %[[#]] %[[#]] + %r = call half @llvm.spv.fma.f16(half %a, half %b, half %c) + ret half %r +} + +define noundef float @fma_float(float noundef %a, float noundef %b, float noundef %c) { +entry: +; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] Fma %[[#]] %[[#]] %[[#]] + %r = call float @llvm.spv.fma.f32(float %a, float %b, float %c) + ret float %r +} + +define noundef double @fma_double(double noundef %a, double noundef %b, double noundef %c) { +entry: +; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] Fma %[[#]] %[[#]] %[[#]] + %r = call double @llvm.spv.fma.f64(double %a, double %b, double %c) + ret double %r +} + +define noundef <4 x half> @fma_half4(<4 x half> noundef %a, <4 x half> noundef %b, <4 x half> noundef %c) { +entry: +; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] Fma %[[#]] %[[#]] %[[#]] + %r = call <4 x half> @llvm.spv.fma.v4f16(<4 x half> %a, <4 x half> %b, <4 x half> %c) + ret <4 x half> %r +} + +define noundef <4 x float> @fma_float4(<4 x float> noundef %a, <4 x float> noundef %b, <4 x float> noundef %c) { +entry: +; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] Fma %[[#]] %[[#]] %[[#]] + %r = call <4 x float> @llvm.spv.fma.v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c) + ret <4 x float> %r +} + +define noundef <4 x double> @fma_double4(<4 x double> noundef %a, <4 x double> noundef %b, <4 x double> noundef %c) { +entry: +; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] Fma %[[#]] %[[#]] %[[#]] + %r = call <4 x double> @llvm.spv.fma.v4f64(<4 x double> %a, <4 x double> %b, <4 x double> %c) + ret <4 x double> %r +} + +declare half @llvm.spv.fma.f16(half, half, half) +declare float @llvm.spv.fma.f32(float, float, float) +declare double @llvm.spv.fma.f64(double, double, double) +declare <4 x half> @llvm.spv.fma.v4f16(<4 x half>, <4 x half>, <4 x half>) +declare <4 x float> @llvm.spv.fma.v4f32(<4 x float>, <4 x float>, <4 x float>) +declare <4 x double> @llvm.spv.fma.v4f64(<4 x double>, <4 x double>, <4 x double>) >From c9e8f690eb3b585f5e479851b58ca7d94c7f769d Mon Sep 17 00:00:00 2001 From: NeKon69 <[email protected]> Date: Sun, 8 Mar 2026 21:14:16 +0300 Subject: [PATCH 2/5] cleanup local stuff --- clang/lib/Sema/SemaHLSL.cpp | 1 - clang/test/CodeGen/SPIRV/hlsl-intrinsics/fma.ll | 0 2 files changed, 1 deletion(-) delete mode 100644 clang/test/CodeGen/SPIRV/hlsl-intrinsics/fma.ll diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 624f621b532a1..ed60b65a9fd1f 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -3996,7 +3996,6 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { case Builtin::BI__builtin_elementwise_exp10: case Builtin::BI__builtin_elementwise_floor: case Builtin::BI__builtin_elementwise_fmod: - case Builtin::BI__builtin_elementwise_fma: case Builtin::BI__builtin_elementwise_log: case Builtin::BI__builtin_elementwise_log2: case Builtin::BI__builtin_elementwise_log10: diff --git a/clang/test/CodeGen/SPIRV/hlsl-intrinsics/fma.ll b/clang/test/CodeGen/SPIRV/hlsl-intrinsics/fma.ll deleted file mode 100644 index e69de29bb2d1d..0000000000000 >From f808aff6aa913cb0c73d7c708b05e0f18065c8af Mon Sep 17 00:00:00 2001 From: NeKon69 <[email protected]> Date: Sun, 8 Mar 2026 21:15:37 +0300 Subject: [PATCH 3/5] final cleanup --- .../Sema/incompatible-function-to-ptr-decay.c | 18 ------------------ llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 2 -- 2 files changed, 20 deletions(-) delete mode 100644 clang/test/Sema/incompatible-function-to-ptr-decay.c diff --git a/clang/test/Sema/incompatible-function-to-ptr-decay.c b/clang/test/Sema/incompatible-function-to-ptr-decay.c deleted file mode 100644 index 240b5b8763a23..0000000000000 --- a/clang/test/Sema/incompatible-function-to-ptr-decay.c +++ /dev/null @@ -1,18 +0,0 @@ -// RUN: %clang_cc1 -fsyntax-only -fexperimental-overflow-behavior-types -verify %s - -// Issue 182534 -int foo(); - -void bar(__attribute__((opencl_global)) int*); // #cldecl -void baz(__ob_wrap int*); // #ofdecl - -void a() { - bar(foo); - // expected-error@-1 {{passing 'int (*)()' to parameter of type '__global int *' changes address space of pointer}} - // expected-note@#cldecl {{passing argument to parameter here}} - __ob_trap int val[10]; - baz(val); - // expected-error@-1 {{assigning to '__ob_wrap int *' from '__ob_trap int *' with incompatible overflow behavior types ('__ob_wrap' and '__ob_trap')}} - // expected-note@#ofdecl {{passing argument to parameter here}} -} - diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp index b8a9f03c92844..439b5eaf8756a 100644 --- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp +++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp @@ -252,8 +252,6 @@ void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF, if (FunctionFlags.contains(CF)) CSF.merge(FunctionFlags[CF]); - // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic - // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554 CSF.DX11_1_DoubleExtensions |= checkFmaOps(CI->getIntrinsicID()); CSF.WaveOps |= checkWaveOps(CI->getIntrinsicID()); } >From 87a67d77383a6d29d7aa13f33d36e6009de183ce Mon Sep 17 00:00:00 2001 From: NeKon69 <[email protected]> Date: Sun, 8 Mar 2026 21:50:04 +0300 Subject: [PATCH 4/5] [hlsl] add a check that all arguments are of the same type --- clang/lib/Sema/SemaHLSL.cpp | 3 +- clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl | 60 ++++++++++++++++---- 2 files changed, 52 insertions(+), 11 deletions(-) diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index ed60b65a9fd1f..d9421dac06213 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -3819,7 +3819,8 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { break; } case Builtin::BI__builtin_hlsl_elementwise_fma: { - if (SemaRef.checkArgCount(TheCall, 3)) { + if (SemaRef.checkArgCount(TheCall, 3) || + CheckAllArgsHaveSameType(&SemaRef, TheCall)) { return true; } const llvm::Triple &TT = getASTContext().getTargetInfo().getTriple(); diff --git a/clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl index 1ed7b34b4396f..c2ae8bc00d3c2 100644 --- a/clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl +++ b/clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl @@ -28,22 +28,42 @@ float2x2 dxil_fma_float2x2(float2x2 a, float2x2 b, float2x2 c) { double dxil_fma_bad_second(double a, float b, double c) { return fma(a, b, c); - // dxil-error@-1 {{2nd argument must be a scalar, vector, or matrix of double type (was 'float')}} + // dxil-error@-1 {{all arguments to 'fma' must have the same type}} } double dxil_fma_bad_third(double a, double b, half c) { return fma(a, b, c); - // dxil-error@-1 {{3rd argument must be a scalar, vector, or matrix of double type (was 'half')}} + // dxil-error@-1 {{all arguments to 'fma' must have the same type}} } double2 dxil_fma_bad_second_vec(double2 a, float2 b, double2 c) { return fma(a, b, c); - // dxil-error@-1 {{2nd argument must be a scalar, vector, or matrix of double type (was 'float2' (aka 'vector<float, 2>'))}} + // dxil-error@-1 {{all arguments to 'fma' must have the same type}} } double2x2 dxil_fma_bad_third_mat(double2x2 a, double2x2 b, float2x2 c) { return fma(a, b, c); - // dxil-error@-1 {{3rd argument must be a scalar, vector, or matrix of double type (was 'float2x2' (aka 'matrix<float, 2, 2>'))}} + // dxil-error@-1 {{all arguments to 'fma' must have the same type}} +} + +double2 dxil_fma_mismatch_second(double2 a, double b, double2 c) { + return fma(a, b, c); + // dxil-error@-1 {{all arguments to 'fma' must have the same type}} +} + +double2 dxil_fma_mismatch_third(double2 a, double2 b, double c) { + return fma(a, b, c); + // dxil-error@-1 {{all arguments to 'fma' must have the same type}} +} + +double2x2 dxil_fma_mismatch_second_mat(double2x2 a, double2 b, double2x2 c) { + return fma(a, b, c); + // dxil-error@-1 {{all arguments to 'fma' must have the same type}} +} + +double2x2 dxil_fma_mismatch_third_mat(double2x2 a, double2x2 b, double2 c) { + return fma(a, b, c); + // dxil-error@-1 {{all arguments to 'fma' must have the same type}} } half dxil_fma_half(half a, half b, half c) { @@ -85,22 +105,42 @@ bool spv_fma_bool(bool a, bool b, bool c) { float spv_fma_bad_second(float a, int b, float c) { return fma(a, b, c); - // spv-error@-1 {{2nd argument must be a scalar or vector of floating-point type (was 'int')}} + // spv-error@-1 {{all arguments to 'fma' must have the same type}} } float spv_fma_bad_third(float a, float b, bool c) { return fma(a, b, c); - // spv-error@-1 {{3rd argument must be a scalar or vector of floating-point type (was 'bool')}} + // spv-error@-1 {{all arguments to 'fma' must have the same type}} } float2 spv_fma_bad_second_vec(float2 a, int2 b, float2 c) { return fma(a, b, c); - // spv-error@-1 {{2nd argument must be a scalar or vector of floating-point type (was 'int2' (aka 'vector<int, 2>'))}} + // spv-error@-1 {{all arguments to 'fma' must have the same type}} } double2 spv_fma_bad_third_vec(double2 a, double2 b, int2 c) { return fma(a, b, c); - // spv-error@-1 {{3rd argument must be a scalar or vector of floating-point type (was 'int2' (aka 'vector<int, 2>'))}} + // spv-error@-1 {{all arguments to 'fma' must have the same type}} +} + +float2 spv_fma_mismatch_second(float2 a, float b, float2 c) { + return fma(a, b, c); + // spv-error@-1 {{all arguments to 'fma' must have the same type}} +} + +float2 spv_fma_mismatch_third(float2 a, float2 b, float c) { + return fma(a, b, c); + // spv-error@-1 {{all arguments to 'fma' must have the same type}} +} + +double2 spv_fma_mismatch_second_double(double2 a, double b, double2 c) { + return fma(a, b, c); + // spv-error@-1 {{all arguments to 'fma' must have the same type}} +} + +double2 spv_fma_mismatch_third_double(double2 a, double2 b, double c) { + return fma(a, b, c); + // spv-error@-1 {{all arguments to 'fma' must have the same type}} } float2x2 spv_fma_float2x2(float2x2 a, float2x2 b, float2x2 c) { @@ -110,12 +150,12 @@ float2x2 spv_fma_float2x2(float2x2 a, float2x2 b, float2x2 c) { float2 spv_fma_bad_second_mat(float2 a, float2x2 b, float2 c) { return fma(a, b, c); - // spv-error@-1 {{2nd argument must be a scalar or vector of floating-point type (was 'float2x2' (aka 'matrix<float, 2, 2>'))}} + // spv-error@-1 {{all arguments to 'fma' must have the same type}} } double2 spv_fma_bad_third_mat(double2 a, double2 b, double2x2 c) { return fma(a, b, c); - // spv-error@-1 {{3rd argument must be a scalar or vector of floating-point type (was 'double2x2' (aka 'matrix<double, 2, 2>'))}} + // spv-error@-1 {{all arguments to 'fma' must have the same type}} } float2x3 spv_fma_float2x3(float2x3 a, float2x3 b, float2x3 c) { >From 8bf157d597542a04224e4c34836fc67bcc05637e Mon Sep 17 00:00:00 2001 From: NeKon69 <[email protected]> Date: Tue, 10 Mar 2026 19:48:01 +0300 Subject: [PATCH 5/5] [HLSL][DirectX][SPIRV] refactor to use llvm.fma instead of defining intrinsics for each backend and using them --- clang/include/clang/Basic/Builtins.td | 6 - .../clang/Basic/DiagnosticSemaKinds.td | 4 +- clang/lib/CodeGen/CGHLSLBuiltins.cpp | 16 -- .../lib/Headers/hlsl/hlsl_alias_intrinsics.h | 69 ++++- clang/lib/Sema/SemaChecking.cpp | 14 +- clang/lib/Sema/SemaHLSL.cpp | 38 +-- clang/test/CodeGenHLSL/builtins/fma.hlsl | 269 +++++++++--------- clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl | 168 ++++------- llvm/include/llvm/IR/IntrinsicsDirectX.td | 2 - llvm/include/llvm/IR/IntrinsicsSPIRV.td | 2 +- llvm/lib/Target/DirectX/DXIL.td | 2 +- llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 2 +- .../Target/SPIRV/SPIRVInstructionSelector.cpp | 2 - .../test/CodeGen/SPIRV/hlsl-intrinsics/fma.ll | 53 ---- 14 files changed, 249 insertions(+), 398 deletions(-) delete mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fma.ll diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index 542249f829424..531c3702161f2 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -5318,12 +5318,6 @@ def HLSLNormalize : LangBuiltin<"HLSL_LANG"> { let Prototype = "void(...)"; } -def HLSLFma : LangBuiltin<"HLSL_LANG"> { - let Spellings = ["__builtin_hlsl_elementwise_fma"]; - let Attributes = [NoThrow, Const, CustomTypeChecking]; - let Prototype = "void(...)"; -} - def HLSLRcp : LangBuiltin<"HLSL_LANG"> { let Spellings = ["__builtin_hlsl_elementwise_rcp"]; let Attributes = [NoThrow, Const, CustomTypeChecking]; diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index 787cd7bcc61bb..531871776b51f 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -13232,8 +13232,8 @@ def err_builtin_invalid_arg_type: Error< def err_builtin_requires_double_type: Error< "%ordinal0 argument must be a scalar, vector, or matrix of double type (was %1)">; -def err_builtin_requires_fp_scalar_or_vector_type: Error< - "%ordinal0 argument must be a scalar or vector of floating-point type (was %1)">; +def err_builtin_requires_any_fp_type: Error< + "%ordinal0 argument must be a scalar, vector, or matrix of any floating-point type (was %1)">; def err_bswapg_invalid_bit_width : Error< "_BitInt type %0 (%1 bits) must be a multiple of 16 bits for byte swapping">; diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index bb5eaf12c93cc..70891eac39425 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -979,22 +979,6 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, retType, CGM.getHLSLRuntime().getIsNaNIntrinsic(), ArrayRef<Value *>{Op0}, nullptr, "hlsl.isnan"); } - case Builtin::BI__builtin_hlsl_elementwise_fma: { - Value *M = EmitScalarExpr(E->getArg(0)); - Value *A = EmitScalarExpr(E->getArg(1)); - Value *B = EmitScalarExpr(E->getArg(2)); - if (CGM.getTarget().getTriple().isDXIL()) - return Builder.CreateIntrinsic(M->getType(), Intrinsic::dx_fma, - ArrayRef<Value *>{M, A, B}, nullptr, - "dx.fma"); - - if (CGM.getTarget().getTriple().isSPIRV()) - return Builder.CreateIntrinsic(M->getType(), Intrinsic::spv_fma, - ArrayRef<Value *>{M, A, B}, nullptr, - "spv.fma"); - - break; - } case Builtin::BI__builtin_hlsl_mad: { Value *M = EmitScalarExpr(E->getArg(0)); Value *A = EmitScalarExpr(E->getArg(1)); diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h index ab5c6edd6d555..22d3f955eedcf 100644 --- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h @@ -1229,6 +1229,60 @@ float3 floor(float3); _HLSL_BUILTIN_ALIAS(__builtin_elementwise_floor) float4 floor(float4); +//===----------------------------------------------------------------------===// +// fused multiply-add builtins +//===----------------------------------------------------------------------===// + +/// \fn double fma(double a, double b, double c) +/// \brief Returns the double-precision fused multiply-addition of a * b + c. +/// \param a The first value in the fused multiply-addition. +/// \param b The second value in the fused multiply-addition. +/// \param The third value in the fused multiply-addition. + +// double scalars and vectors +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double fma(double, double, double); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double2 fma(double2, double2, double2); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double3 fma(double3, double3, double3); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double4 fma(double4, double4, double4); + +// double matrices +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double1x1 fma(double1x1, double1x1, double1x1); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double1x2 fma(double1x2, double1x2, double1x2); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double1x3 fma(double1x3, double1x3, double1x3); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double1x4 fma(double1x4, double1x4, double1x4); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double2x1 fma(double2x1, double2x1, double2x1); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double2x2 fma(double2x2, double2x2, double2x2); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double2x3 fma(double2x3, double2x3, double2x3); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double2x4 fma(double2x4, double2x4, double2x4); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double3x1 fma(double3x1, double3x1, double3x1); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double3x2 fma(double3x2, double3x2, double3x2); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double3x3 fma(double3x3, double3x3, double3x3); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double3x4 fma(double3x4, double3x4, double3x4); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double4x1 fma(double4x1, double4x1, double4x1); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double4x2 fma(double4x2, double4x2, double4x2); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double4x3 fma(double4x3, double4x3, double4x3); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double4x4 fma(double4x4, double4x4, double4x4); + //===----------------------------------------------------------------------===// // frac builtins //===----------------------------------------------------------------------===// @@ -1891,21 +1945,6 @@ float3 pow(float3, float3); _HLSL_BUILTIN_ALIAS(__builtin_elementwise_pow) float4 pow(float4, float4); -//===----------------------------------------------------------------------===// -// fused multiply-add builtins -//===----------------------------------------------------------------------===// - -_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_fma) -double fma(double, double, double); - -template <int s> -_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_fma) -vector<double, s> fma(vector<double, s>, vector<double, s>, vector<double, s>); - -template <int w, int h> -_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_fma) -matrix<double, w, h> fma(matrix<double, w, h>, matrix<double, w, h>, - matrix<double, w, h>); //===----------------------------------------------------------------------===// // reversebits builtins diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp index aea5b722738aa..a9710d7053e09 100644 --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -2178,9 +2178,12 @@ static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc, QualType ArgTy, Sema::EltwiseBuiltinArgTyRestriction ArgTyRestr, int ArgOrdinal) { - QualType EltTy = ArgTy; - if (auto *VecTy = EltTy->getAs<VectorType>()) - EltTy = VecTy->getElementType(); + clang::QualType EltTy = + ArgTy->isVectorType() + ? ArgTy->castAs<clang::VectorType>()->getElementType() + : ArgTy->isMatrixType() + ? ArgTy->castAs<clang::MatrixType>()->getElementType() + : ArgTy; switch (ArgTyRestr) { case Sema::EltwiseBuiltinArgTyRestriction::None: @@ -2192,9 +2195,8 @@ checkMathBuiltinElementType(Sema &S, SourceLocation Loc, QualType ArgTy, break; case Sema::EltwiseBuiltinArgTyRestriction::FloatTy: if (!EltTy->isRealFloatingType()) { - return S.Diag(Loc, diag::err_builtin_invalid_arg_type) - << ArgOrdinal << /* scalar or vector */ 5 << /* no int */ 0 - << /* floating-point */ 1 << ArgTy; + return S.Diag(Loc, diag::err_builtin_requires_any_fp_type) + << ArgOrdinal << ArgTy; } break; case Sema::EltwiseBuiltinArgTyRestriction::IntegerTy: diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index d9421dac06213..db9db159c3d99 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -47,6 +47,7 @@ #include "llvm/Support/DXILABI.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/TargetParser/Triple.h" #include <cmath> #include <cstddef> @@ -3041,19 +3042,6 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc, return false; } -static bool CheckFloatOrHalfOrDoubleRepresentation(Sema *S, SourceLocation Loc, - int ArgOrdinal, - clang::QualType PassedType) { - clang::QualType BaseType = - PassedType->isVectorType() - ? PassedType->castAs<clang::VectorType>()->getElementType() - : PassedType; - if (!BaseType->isFloatingType()) - return S->Diag(Loc, diag::err_builtin_requires_fp_scalar_or_vector_type) - << ArgOrdinal << PassedType; - return false; -} - static bool CheckAnyDoubleRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType) { @@ -3818,29 +3806,15 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { TheCall->setType(ArgTyA); break; } - case Builtin::BI__builtin_hlsl_elementwise_fma: { + case Builtin::BI__builtin_elementwise_fma: { if (SemaRef.checkArgCount(TheCall, 3) || CheckAllArgsHaveSameType(&SemaRef, TheCall)) { return true; } - const llvm::Triple &TT = getASTContext().getTargetInfo().getTriple(); - // This check is here because emitting a general error for both backends - // here (like for exmaple "Accepts only floating points") won't end really - // good. after that we still need to check if the types satisfy - // backends constrains, so we better check everything now rather than - // confusing user with 2 different error messages - - if (TT.isSPIRV()) { - // SPIR-V accept any float (besides matrices) - if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, - CheckFloatOrHalfOrDoubleRepresentation)) - return true; - } else if (TT.isDXIL()) { - // while DirectX accepts only double - if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, - CheckAnyDoubleRepresentation)) - return true; - } + + if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, + CheckAnyDoubleRepresentation)) + return true; ExprResult A = TheCall->getArg(0); QualType ArgTyA = A.get()->getType(); diff --git a/clang/test/CodeGenHLSL/builtins/fma.hlsl b/clang/test/CodeGenHLSL/builtins/fma.hlsl index 88b8e27c37043..3d9549197035d 100644 --- a/clang/test/CodeGenHLSL/builtins/fma.hlsl +++ b/clang/test/CodeGenHLSL/builtins/fma.hlsl @@ -1,151 +1,138 @@ // RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ -// RUN: dxil-pc-shadermodel6.3-library %s -DTEST_DXIL \ -// RUN: -fmatrix-memory-layout=row-major -emit-llvm -disable-llvm-passes -o - | \ -// RUN: FileCheck %s --check-prefixes=CHECK,DXIL_CHECK -DTARGET=dx +// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm \ +// RUN: -disable-llvm-passes -o - | FileCheck %s // RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ -// RUN: spirv-unknown-vulkan-compute %s -DTEST_SPIRV \ -// RUN: -fmatrix-memory-layout=row-major -emit-llvm -disable-llvm-passes -o - | \ -// RUN: FileCheck %s --check-prefixes=CHECK,SPIRV_CHECK -DTARGET=spv -// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ -// RUN: spirv-unknown-vulkan-compute %s -DTEST_SPIRV_HALF -fnative-half-type \ -// RUN: -fmatrix-memory-layout=row-major -emit-llvm -disable-llvm-passes -o - | \ -// RUN: FileCheck %s --check-prefix=SPIRV_HALF_CHECK +// RUN: spirv-unknown-vulkan-compute %s -emit-llvm \ +// RUN: -disable-llvm-passes -o - | FileCheck %s // CHECK-LABEL: define {{.*}} double @{{.*}}fma_double{{.*}}( -// CHECK: %[[P0:.*]] = load double, ptr %{{.*}}, align 8 -// CHECK: %[[P1:.*]] = load double, ptr %{{.*}}, align 8 -// CHECK: %[[P2:.*]] = load double, ptr %{{.*}}, align 8 -// CHECK: %{{dx|spv}}.fma = call reassoc nnan ninf nsz arcp afn double @llvm.[[TARGET]].fma.f64(double %[[P0]], double %[[P1]], double %[[P2]]) -// CHECK: ret double %{{dx|spv}}.fma -double dxil_fma_double(double a, double b, double c) { return fma(a, b, c); } +// CHECK: call reassoc nnan ninf nsz arcp afn double @llvm.fma.f64(double +// CHECK: ret double +double fma_double(double a, double b, double c) { return fma(a, b, c); } // CHECK-LABEL: define {{.*}} <2 x double> @{{.*}}fma_double2{{.*}}( -// CHECK: %[[P0:.*]] = load <2 x double>, ptr %{{.*}}, align 16 -// CHECK: %[[P1:.*]] = load <2 x double>, ptr %{{.*}}, align 16 -// CHECK: %[[P2:.*]] = load <2 x double>, ptr %{{.*}}, align 16 -// CHECK: %{{dx|spv}}.fma = call reassoc nnan ninf nsz arcp afn <2 x double> @llvm.[[TARGET]].fma.v2f64(<2 x double> %[[P0]], <2 x double> %[[P1]], <2 x double> %[[P2]]) -// CHECK: ret <2 x double> %{{dx|spv}}.fma -double2 dxil_fma_double2(double2 a, double2 b, double2 c) { return fma(a, b, c); } +// CHECK: call reassoc nnan ninf nsz arcp afn <2 x double> @llvm.fma.v2f64(<2 x double> +// CHECK: ret <2 x double> +double2 fma_double2(double2 a, double2 b, double2 c) { return fma(a, b, c); } // CHECK-LABEL: define {{.*}} <3 x double> @{{.*}}fma_double3{{.*}}( -// CHECK: %[[P0:.*]] = load <3 x double>, ptr %{{.*}}, align 32 -// CHECK: %[[P1:.*]] = load <3 x double>, ptr %{{.*}}, align 32 -// CHECK: %[[P2:.*]] = load <3 x double>, ptr %{{.*}}, align 32 -// CHECK: %{{dx|spv}}.fma = call reassoc nnan ninf nsz arcp afn <3 x double> @llvm.[[TARGET]].fma.v3f64(<3 x double> %[[P0]], <3 x double> %[[P1]], <3 x double> %[[P2]]) -// CHECK: ret <3 x double> %{{dx|spv}}.fma -double3 dxil_fma_double3(double3 a, double3 b, double3 c) { return fma(a, b, c); } +// CHECK: call reassoc nnan ninf nsz arcp afn <3 x double> @llvm.fma.v3f64(<3 x double> +// CHECK: ret <3 x double> +double3 fma_double3(double3 a, double3 b, double3 c) { return fma(a, b, c); } // CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double4{{.*}}( -// CHECK: %[[P0:.*]] = load <4 x double>, ptr %{{.*}}, align 32 -// CHECK: %[[P1:.*]] = load <4 x double>, ptr %{{.*}}, align 32 -// CHECK: %[[P2:.*]] = load <4 x double>, ptr %{{.*}}, align 32 -// CHECK: %{{dx|spv}}.fma = call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.[[TARGET]].fma.v4f64(<4 x double> %[[P0]], <4 x double> %[[P1]], <4 x double> %[[P2]]) -// CHECK: ret <4 x double> %{{dx|spv}}.fma -double4 dxil_fma_double4(double4 a, double4 b, double4 c) { return fma(a, b, c); } - -#ifdef TEST_DXIL - -// DXIL_CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}dxil_fma_double1x4{{.*}}( -// DXIL_CHECK: %dx.fma = call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.dx.fma.v4f64( -// DXIL_CHECK: ret <4 x double> %dx.fma -double1x4 dxil_fma_double1x4(double1x4 a, double1x4 b, double1x4 c) { return fma(a, b, c); } - -// DXIL_CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}dxil_fma_double4x1{{.*}}( -// DXIL_CHECK: %dx.fma = call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.dx.fma.v4f64( -// DXIL_CHECK: ret <4 x double> %dx.fma -double4x1 dxil_fma_double4x1(double4x1 a, double4x1 b, double4x1 c) { return fma(a, b, c); } - -// DXIL_CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}dxil_fma_double2x2{{.*}}( -// DXIL_CHECK: %dx.fma = call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.dx.fma.v4f64( -// DXIL_CHECK: ret <4 x double> %dx.fma -double2x2 dxil_fma_double2x2(double2x2 a, double2x2 b, double2x2 c) { return fma(a, b, c); } - -// DXIL_CHECK-LABEL: define {{.*}} <6 x double> @{{.*}}dxil_fma_double2x3{{.*}}( -// DXIL_CHECK: %dx.fma = call reassoc nnan ninf nsz arcp afn <6 x double> @llvm.dx.fma.v6f64( -// DXIL_CHECK: ret <6 x double> %dx.fma -double2x3 dxil_fma_double2x3(double2x3 a, double2x3 b, double2x3 c) { return fma(a, b, c); } - -// DXIL_CHECK-LABEL: define {{.*}} <6 x double> @{{.*}}dxil_fma_double3x2{{.*}}( -// DXIL_CHECK: %dx.fma = call reassoc nnan ninf nsz arcp afn <6 x double> @llvm.dx.fma.v6f64( -// DXIL_CHECK: ret <6 x double> %dx.fma -double3x2 dxil_fma_double3x2(double3x2 a, double3x2 b, double3x2 c) { return fma(a, b, c); } - -// DXIL_CHECK-LABEL: define {{.*}} <9 x double> @{{.*}}dxil_fma_double3x3{{.*}}( -// DXIL_CHECK: %dx.fma = call reassoc nnan ninf nsz arcp afn <9 x double> @llvm.dx.fma.v9f64( -// DXIL_CHECK: ret <9 x double> %dx.fma -double3x3 dxil_fma_double3x3(double3x3 a, double3x3 b, double3x3 c) { return fma(a, b, c); } - -// DXIL_CHECK-LABEL: define {{.*}} <16 x double> @{{.*}}dxil_fma_double4x4{{.*}}( -// DXIL_CHECK: %dx.fma = call reassoc nnan ninf nsz arcp afn <16 x double> @llvm.dx.fma.v16f64( -// DXIL_CHECK: ret <16 x double> %dx.fma -double4x4 dxil_fma_double4x4(double4x4 a, double4x4 b, double4x4 c) { return fma(a, b, c); } -#endif - -#ifdef TEST_SPIRV -// SPIRV_CHECK-LABEL: define {{.*}} float @{{.*}}spv_fma_float{{.*}}( -// SPIRV_CHECK: %[[P0:.*]] = load float, ptr %{{.*}}, align 4 -// SPIRV_CHECK: %[[P1:.*]] = load float, ptr %{{.*}}, align 4 -// SPIRV_CHECK: %[[P2:.*]] = load float, ptr %{{.*}}, align 4 -// SPIRV_CHECK: %spv.fma = call reassoc nnan ninf nsz arcp afn float @llvm.spv.fma.f32(float %[[P0]], float %[[P1]], float %[[P2]]) -// SPIRV_CHECK: ret float %spv.fma -float spv_fma_float(float a, float b, float c) { return fma(a, b, c); } - -// SPIRV_CHECK-LABEL: define {{.*}} <2 x float> @{{.*}}spv_fma_float2{{.*}}( -// SPIRV_CHECK: %[[P0:.*]] = load <2 x float>, ptr %{{.*}}, align 8 -// SPIRV_CHECK: %[[P1:.*]] = load <2 x float>, ptr %{{.*}}, align 8 -// SPIRV_CHECK: %[[P2:.*]] = load <2 x float>, ptr %{{.*}}, align 8 -// SPIRV_CHECK: %spv.fma = call reassoc nnan ninf nsz arcp afn <2 x float> @llvm.spv.fma.v2f32(<2 x float> %[[P0]], <2 x float> %[[P1]], <2 x float> %[[P2]]) -// SPIRV_CHECK: ret <2 x float> %spv.fma -float2 spv_fma_float2(float2 a, float2 b, float2 c) { return fma(a, b, c); } - -// SPIRV_CHECK-LABEL: define {{.*}} <3 x float> @{{.*}}spv_fma_float3{{.*}}( -// SPIRV_CHECK: %[[P0:.*]] = load <3 x float>, ptr %{{.*}}, align 16 -// SPIRV_CHECK: %[[P1:.*]] = load <3 x float>, ptr %{{.*}}, align 16 -// SPIRV_CHECK: %[[P2:.*]] = load <3 x float>, ptr %{{.*}}, align 16 -// SPIRV_CHECK: %spv.fma = call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.spv.fma.v3f32(<3 x float> %[[P0]], <3 x float> %[[P1]], <3 x float> %[[P2]]) -// SPIRV_CHECK: ret <3 x float> %spv.fma -float3 spv_fma_float3(float3 a, float3 b, float3 c) { return fma(a, b, c); } - -// SPIRV_CHECK-LABEL: define {{.*}} <4 x float> @{{.*}}spv_fma_float4{{.*}}( -// SPIRV_CHECK: %[[P0:.*]] = load <4 x float>, ptr %{{.*}}, align 16 -// SPIRV_CHECK: %[[P1:.*]] = load <4 x float>, ptr %{{.*}}, align 16 -// SPIRV_CHECK: %[[P2:.*]] = load <4 x float>, ptr %{{.*}}, align 16 -// SPIRV_CHECK: %spv.fma = call reassoc nnan ninf nsz arcp afn <4 x float> @llvm.spv.fma.v4f32(<4 x float> %[[P0]], <4 x float> %[[P1]], <4 x float> %[[P2]]) -// SPIRV_CHECK: ret <4 x float> %spv.fma -float4 spv_fma_float4(float4 a, float4 b, float4 c) { return fma(a, b, c); } - -#endif - -#ifdef TEST_SPIRV_HALF -// SPIRV_HALF_CHECK-LABEL: define {{.*}} half @{{.*}}spv_fma_half{{.*}}( -// SPIRV_HALF_CHECK: %[[P0:.*]] = load half, ptr %{{.*}}, align 2 -// SPIRV_HALF_CHECK: %[[P1:.*]] = load half, ptr %{{.*}}, align 2 -// SPIRV_HALF_CHECK: %[[P2:.*]] = load half, ptr %{{.*}}, align 2 -// SPIRV_HALF_CHECK: %spv.fma = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fma.f16(half %[[P0]], half %[[P1]], half %[[P2]]) -// SPIRV_HALF_CHECK: ret half %spv.fma -half spv_fma_half(half a, half b, half c) { return fma(a, b, c); } - -// SPIRV_HALF_CHECK-LABEL: define {{.*}} <2 x half> @{{.*}}spv_fma_half2{{.*}}( -// SPIRV_HALF_CHECK: %[[P0:.*]] = load <2 x half>, ptr %{{.*}}, align 4 -// SPIRV_HALF_CHECK: %[[P1:.*]] = load <2 x half>, ptr %{{.*}}, align 4 -// SPIRV_HALF_CHECK: %[[P2:.*]] = load <2 x half>, ptr %{{.*}}, align 4 -// SPIRV_HALF_CHECK: %spv.fma = call reassoc nnan ninf nsz arcp afn <2 x half> @llvm.spv.fma.v2f16(<2 x half> %[[P0]], <2 x half> %[[P1]], <2 x half> %[[P2]]) -// SPIRV_HALF_CHECK: ret <2 x half> %spv.fma -half2 spv_fma_half2(half2 a, half2 b, half2 c) { return fma(a, b, c); } - -// SPIRV_HALF_CHECK-LABEL: define {{.*}} <3 x half> @{{.*}}spv_fma_half3{{.*}}( -// SPIRV_HALF_CHECK: %[[P0:.*]] = load <3 x half>, ptr %{{.*}}, align 8 -// SPIRV_HALF_CHECK: %[[P1:.*]] = load <3 x half>, ptr %{{.*}}, align 8 -// SPIRV_HALF_CHECK: %[[P2:.*]] = load <3 x half>, ptr %{{.*}}, align 8 -// SPIRV_HALF_CHECK: %spv.fma = call reassoc nnan ninf nsz arcp afn <3 x half> @llvm.spv.fma.v3f16(<3 x half> %[[P0]], <3 x half> %[[P1]], <3 x half> %[[P2]]) -// SPIRV_HALF_CHECK: ret <3 x half> %spv.fma -half3 spv_fma_half3(half3 a, half3 b, half3 c) { return fma(a, b, c); } - -// SPIRV_HALF_CHECK-LABEL: define {{.*}} <4 x half> @{{.*}}spv_fma_half4{{.*}}( -// SPIRV_HALF_CHECK: %[[P0:.*]] = load <4 x half>, ptr %{{.*}}, align 8 -// SPIRV_HALF_CHECK: %[[P1:.*]] = load <4 x half>, ptr %{{.*}}, align 8 -// SPIRV_HALF_CHECK: %[[P2:.*]] = load <4 x half>, ptr %{{.*}}, align 8 -// SPIRV_HALF_CHECK: %spv.fma = call reassoc nnan ninf nsz arcp afn <4 x half> @llvm.spv.fma.v4f16(<4 x half> %[[P0]], <4 x half> %[[P1]], <4 x half> %[[P2]]) -// SPIRV_HALF_CHECK: ret <4 x half> %spv.fma -half4 spv_fma_half4(half4 a, half4 b, half4 c) { return fma(a, b, c); } -#endif +// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double> +// CHECK: ret <4 x double> +double4 fma_double4(double4 a, double4 b, double4 c) { return fma(a, b, c); } + +// CHECK-LABEL: define {{.*}} <1 x double> @{{.*}}fma_double1x1{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <1 x double> @llvm.fma.v1f64(<1 x double> +// CHECK: ret <1 x double> +double1x1 fma_double1x1(double1x1 a, double1x1 b, double1x1 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <2 x double> @{{.*}}fma_double1x2{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <2 x double> @llvm.fma.v2f64(<2 x double> +// CHECK: ret <2 x double> +double1x2 fma_double1x2(double1x2 a, double1x2 b, double1x2 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <3 x double> @{{.*}}fma_double1x3{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <3 x double> @llvm.fma.v3f64(<3 x double> +// CHECK: ret <3 x double> +double1x3 fma_double1x3(double1x3 a, double1x3 b, double1x3 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double1x4{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double> +// CHECK: ret <4 x double> +double1x4 fma_double1x4(double1x4 a, double1x4 b, double1x4 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <2 x double> @{{.*}}fma_double2x1{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <2 x double> @llvm.fma.v2f64(<2 x double> +// CHECK: ret <2 x double> +double2x1 fma_double2x1(double2x1 a, double2x1 b, double2x1 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double2x2{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double> +// CHECK: ret <4 x double> +double2x2 fma_double2x2(double2x2 a, double2x2 b, double2x2 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <6 x double> @{{.*}}fma_double2x3{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <6 x double> @llvm.fma.v6f64(<6 x double> +// CHECK: ret <6 x double> +double2x3 fma_double2x3(double2x3 a, double2x3 b, double2x3 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <8 x double> @{{.*}}fma_double2x4{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <8 x double> @llvm.fma.v8f64(<8 x double> +// CHECK: ret <8 x double> +double2x4 fma_double2x4(double2x4 a, double2x4 b, double2x4 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <3 x double> @{{.*}}fma_double3x1{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <3 x double> @llvm.fma.v3f64(<3 x double> +// CHECK: ret <3 x double> +double3x1 fma_double3x1(double3x1 a, double3x1 b, double3x1 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <6 x double> @{{.*}}fma_double3x2{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <6 x double> @llvm.fma.v6f64(<6 x double> +// CHECK: ret <6 x double> +double3x2 fma_double3x2(double3x2 a, double3x2 b, double3x2 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <9 x double> @{{.*}}fma_double3x3{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <9 x double> @llvm.fma.v9f64(<9 x double> +// CHECK: ret <9 x double> +double3x3 fma_double3x3(double3x3 a, double3x3 b, double3x3 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <12 x double> @{{.*}}fma_double3x4{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <12 x double> @llvm.fma.v12f64(<12 x double> +// CHECK: ret <12 x double> +double3x4 fma_double3x4(double3x4 a, double3x4 b, double3x4 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double4x1{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double> +// CHECK: ret <4 x double> +double4x1 fma_double4x1(double4x1 a, double4x1 b, double4x1 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <8 x double> @{{.*}}fma_double4x2{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <8 x double> @llvm.fma.v8f64(<8 x double> +// CHECK: ret <8 x double> +double4x2 fma_double4x2(double4x2 a, double4x2 b, double4x2 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <12 x double> @{{.*}}fma_double4x3{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <12 x double> @llvm.fma.v12f64(<12 x double> +// CHECK: ret <12 x double> +double4x3 fma_double4x3(double4x3 a, double4x3 b, double4x3 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <16 x double> @{{.*}}fma_double4x4{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <16 x double> @llvm.fma.v16f64(<16 x double> +// CHECK: ret <16 x double> +double4x4 fma_double4x4(double4x4 a, double4x4 b, double4x4 c) { + return fma(a, b, c); +} diff --git a/clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl index c2ae8bc00d3c2..e7d4323ae8adc 100644 --- a/clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl +++ b/clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl @@ -1,185 +1,113 @@ // RUN: %clang_cc1 -finclude-default-header -fnative-half-type -x hlsl \ -// RUN: -triple dxil-pc-shadermodel6.6-library %s -DTEST_DXIL \ -// RUN: -emit-llvm-only -disable-llvm-passes -verify=dxil +// RUN: -triple dxil-pc-shadermodel6.6-library %s \ +// RUN: -emit-llvm-only -disable-llvm-passes -verify \ +// RUN: -verify-ignore-unexpected=note // RUN: %clang_cc1 -finclude-default-header -fnative-half-type -x hlsl \ -// RUN: -triple spirv-unknown-vulkan-compute %s -DTEST_SPIRV \ -// RUN: -emit-llvm-only -disable-llvm-passes -verify=spv +// RUN: -triple spirv-unknown-vulkan-compute %s \ +// RUN: -emit-llvm-only -disable-llvm-passes -verify \ +// RUN: -verify-ignore-unexpected=note -#ifdef TEST_DXIL -float dxil_fma_float(float a, float b, float c) { +float bad_float(float a, float b, float c) { return fma(a, b, c); - // dxil-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float')}} + // expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float')}} } -float2 dxil_fma_float2(float2 a, float2 b, float2 c) { +float2 bad_float2(float2 a, float2 b, float2 c) { return fma(a, b, c); - // dxil-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float2' (aka 'vector<float, 2>'))}} + // expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float2' (aka 'vector<float, 2>'))}} } -float4 dxil_fma_float4(float4 a, float4 b, float4 c) { +float2x2 bad_float2x2(float2x2 a, float2x2 b, float2x2 c) { return fma(a, b, c); - // dxil-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float4' (aka 'vector<float, 4>'))}} + // expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float2x2' (aka 'matrix<float, 2, 2>'))}} } -float2x2 dxil_fma_float2x2(float2x2 a, float2x2 b, float2x2 c) { +half bad_half(half a, half b, half c) { return fma(a, b, c); - // dxil-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float2x2' (aka 'matrix<float, 2, 2>'))}} + // expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'half')}} } -double dxil_fma_bad_second(double a, float b, double c) { +half2 bad_half2(half2 a, half2 b, half2 c) { return fma(a, b, c); - // dxil-error@-1 {{all arguments to 'fma' must have the same type}} + // expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'half2' (aka 'vector<half, 2>'))}} } -double dxil_fma_bad_third(double a, double b, half c) { +half2x2 bad_half2x2(half2x2 a, half2x2 b, half2x2 c) { return fma(a, b, c); - // dxil-error@-1 {{all arguments to 'fma' must have the same type}} + // expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'half2x2' (aka 'matrix<half, 2, 2>'))}} } -double2 dxil_fma_bad_second_vec(double2 a, float2 b, double2 c) { +double mixed_bad_second(double a, float b, double c) { return fma(a, b, c); - // dxil-error@-1 {{all arguments to 'fma' must have the same type}} + // expected-error@-1 {{arguments are of different types ('double' vs 'float')}} } -double2x2 dxil_fma_bad_third_mat(double2x2 a, double2x2 b, float2x2 c) { +double mixed_bad_third(double a, double b, half c) { return fma(a, b, c); - // dxil-error@-1 {{all arguments to 'fma' must have the same type}} + // expected-error@-1 {{arguments are of different types ('double' vs 'half')}} } -double2 dxil_fma_mismatch_second(double2 a, double b, double2 c) { +double2 mixed_bad_second_vec(double2 a, float2 b, double2 c) { return fma(a, b, c); - // dxil-error@-1 {{all arguments to 'fma' must have the same type}} + // expected-error@-1 {{arguments are of different types ('vector<double, [...]>' vs 'vector<float, [...]>')}} } -double2 dxil_fma_mismatch_third(double2 a, double2 b, double c) { +double2 mixed_bad_third_vec(double2 a, double2 b, float2 c) { return fma(a, b, c); - // dxil-error@-1 {{all arguments to 'fma' must have the same type}} + // expected-error@-1 {{arguments are of different types ('vector<double, [...]>' vs 'vector<float, [...]>')}} } -double2x2 dxil_fma_mismatch_second_mat(double2x2 a, double2 b, double2x2 c) { +double2x2 mixed_bad_second_mat(double2x2 a, float2x2 b, double2x2 c) { return fma(a, b, c); - // dxil-error@-1 {{all arguments to 'fma' must have the same type}} + // expected-error@-1 {{arguments are of different types ('matrix<double, [2 * ...]>' vs 'matrix<float, [2 * ...]>')}} } -double2x2 dxil_fma_mismatch_third_mat(double2x2 a, double2x2 b, double2 c) { +double2x2 mixed_bad_third_mat(double2x2 a, double2x2 b, half2x2 c) { return fma(a, b, c); - // dxil-error@-1 {{all arguments to 'fma' must have the same type}} + // expected-error@-1 {{arguments are of different types ('matrix<double, [2 * ...]>' vs 'matrix<half, [2 * ...]>')}} } -half dxil_fma_half(half a, half b, half c) { +double shape_mismatch_second(double a, double2 b, double c) { return fma(a, b, c); - // dxil-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'half')}} + // expected-error@-1 {{call to 'fma' is ambiguous}} } -half2 dxil_fma_half2(half2 a, half2 b, half2 c) { +double2 shape_mismatch_third(double2 a, double2 b, double c) { return fma(a, b, c); - // dxil-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'half2' (aka 'vector<half, 2>'))}} + // expected-error@-1 {{call to 'fma' is ambiguous}} } -int dxil_fma_int(int a, int b, int c) { +double2x2 shape_mismatch_scalar_mat(double2x2 a, double b, double2x2 c) { return fma(a, b, c); - // dxil-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'int')}} + // expected-error@-1 {{call to 'fma' is ambiguous}} } -bool dxil_fma_bool(bool a, bool b, bool c) { +double2x2 shape_mismatch_vec_mat(double2x2 a, double2 b, double2x2 c) { return fma(a, b, c); - // dxil-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'bool')}} + // expected-error@-1 {{arguments are of different types ('double2x2' (aka 'matrix<double, 2, 2>') vs 'double2' (aka 'vector<double, 2>'))}} } -#endif -#ifdef TEST_SPIRV -int spv_fma_int(int a, int b, int c) { +int bad_int(int a, int b, int c) { return fma(a, b, c); - // spv-error@-1 {{1st argument must be a scalar or vector of floating-point type (was 'int')}} + // expected-error@-1 {{1st argument must be a scalar, vector, or matrix of any floating-point type (was 'int')}} } -int2 spv_fma_int2(int2 a, int2 b, int2 c) { +int2 bad_int2(int2 a, int2 b, int2 c) { return fma(a, b, c); - // spv-error@-1 {{1st argument must be a scalar or vector of floating-point type (was 'int2' (aka 'vector<int, 2>'))}} + // expected-error@-1 {{1st argument must be a scalar, vector, or matrix of any floating-point type (was 'int2' (aka 'vector<int, 2>'))}} } -bool spv_fma_bool(bool a, bool b, bool c) { +bool bad_bool(bool a, bool b, bool c) { return fma(a, b, c); - // spv-error@-1 {{1st argument must be a scalar or vector of floating-point type (was 'bool')}} + // expected-error@-1 {{1st argument must be a scalar, vector, or matrix of any floating-point type (was 'bool')}} } -float spv_fma_bad_second(float a, int b, float c) { +bool2 bad_bool2(bool2 a, bool2 b, bool2 c) { return fma(a, b, c); - // spv-error@-1 {{all arguments to 'fma' must have the same type}} + // expected-error@-1 {{1st argument must be a scalar, vector, or matrix of any floating-point type (was 'bool2' (aka 'vector<bool, 2>'))}} } -float spv_fma_bad_third(float a, float b, bool c) { +bool2x2 bad_bool2x2(bool2x2 a, bool2x2 b, bool2x2 c) { return fma(a, b, c); - // spv-error@-1 {{all arguments to 'fma' must have the same type}} + // expected-error@-1 {{1st argument must be a scalar, vector, or matrix of any floating-point type (was 'bool2x2' (aka 'matrix<bool, 2, 2>'))}} } - -float2 spv_fma_bad_second_vec(float2 a, int2 b, float2 c) { - return fma(a, b, c); - // spv-error@-1 {{all arguments to 'fma' must have the same type}} -} - -double2 spv_fma_bad_third_vec(double2 a, double2 b, int2 c) { - return fma(a, b, c); - // spv-error@-1 {{all arguments to 'fma' must have the same type}} -} - -float2 spv_fma_mismatch_second(float2 a, float b, float2 c) { - return fma(a, b, c); - // spv-error@-1 {{all arguments to 'fma' must have the same type}} -} - -float2 spv_fma_mismatch_third(float2 a, float2 b, float c) { - return fma(a, b, c); - // spv-error@-1 {{all arguments to 'fma' must have the same type}} -} - -double2 spv_fma_mismatch_second_double(double2 a, double b, double2 c) { - return fma(a, b, c); - // spv-error@-1 {{all arguments to 'fma' must have the same type}} -} - -double2 spv_fma_mismatch_third_double(double2 a, double2 b, double c) { - return fma(a, b, c); - // spv-error@-1 {{all arguments to 'fma' must have the same type}} -} - -float2x2 spv_fma_float2x2(float2x2 a, float2x2 b, float2x2 c) { - return fma(a, b, c); - // spv-error@-1 {{1st argument must be a scalar or vector of floating-point type (was 'float2x2' (aka 'matrix<float, 2, 2>'))}} -} - -float2 spv_fma_bad_second_mat(float2 a, float2x2 b, float2 c) { - return fma(a, b, c); - // spv-error@-1 {{all arguments to 'fma' must have the same type}} -} - -double2 spv_fma_bad_third_mat(double2 a, double2 b, double2x2 c) { - return fma(a, b, c); - // spv-error@-1 {{all arguments to 'fma' must have the same type}} -} - -float2x3 spv_fma_float2x3(float2x3 a, float2x3 b, float2x3 c) { - return fma(a, b, c); - // spv-error@-1 {{1st argument must be a scalar or vector of floating-point type (was 'float2x3' (aka 'matrix<float, 2, 3>'))}} -} - -float3x2 spv_fma_float3x2(float3x2 a, float3x2 b, float3x2 c) { - return fma(a, b, c); - // spv-error@-1 {{1st argument must be a scalar or vector of floating-point type (was 'float3x2' (aka 'matrix<float, 3, 2>'))}} -} - -float4x4 spv_fma_float4x4(float4x4 a, float4x4 b, float4x4 c) { - return fma(a, b, c); - // spv-error@-1 {{1st argument must be a scalar or vector of floating-point type (was 'float4x4' (aka 'matrix<float, 4, 4>'))}} -} - -double2x2 spv_fma_double2x2(double2x2 a, double2x2 b, double2x2 c) { - return fma(a, b, c); - // spv-error@-1 {{1st argument must be a scalar or vector of floating-point type (was 'double2x2' (aka 'matrix<double, 2, 2>'))}} -} - -half2x2 spv_fma_half2x2(half2x2 a, half2x2 b, half2x2 c) { - return fma(a, b, c); - // spv-error@-1 {{1st argument must be a scalar or vector of floating-point type (was 'half2x2' (aka 'matrix<half, 2, 2>'))}} -} -#endif diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index 1d2b8faa90f8a..909482d72aa88 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -259,6 +259,4 @@ def int_dx_store_output [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i8_ty, llvm_i32_ty, llvm_any_ty], [IntrConvergent]>; -// We reject any non-double types in SemaHLSL.cpp so hopefully they won't fall through here. as we don't have `llvm_anydouble_ty` we have to rely on Sema to do its job and filter out all non-double types. -def int_dx_fma : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>; } diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index d4b2736e0577c..9819f881b5c30 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -292,5 +292,5 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty] def int_spv_unpackhalf2x16 : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [llvm_i32_ty], [IntrNoMem]>; def int_spv_packhalf2x16 : DefaultAttrsIntrinsic<[llvm_anyint_ty], [llvm_anyfloat_ty], [IntrNoMem]>; - def int_spv_fma : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>; + } diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index bf9e881041f85..297a61921949e 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -784,7 +784,7 @@ def FMad : DXILOp<46, tertiary> { def Fma : DXILOp<47, tertiary> { let Doc = "Double-precision fused multiply-add. fma(a,b,c) = a * b + c."; - let intrinsics = [IntrinSelect<int_dx_fma>]; + let intrinsics = [IntrinSelect<int_fma>]; let arguments = [OverloadTy, OverloadTy, OverloadTy]; let result = OverloadTy; let overloads = [Overloads<DXIL1_0, [DoubleTy]>]; diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp index 439b5eaf8756a..fb246b594c5e6 100644 --- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp +++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp @@ -110,7 +110,7 @@ static bool checkFmaOps(Intrinsic::ID IID) { switch (IID) { default: return false; - case Intrinsic::dx_fma: + case Intrinsic::fma: return true; } } diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 73d2ad23c673d..dd0830bb879f5 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -4033,8 +4033,6 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return selectAll(ResVReg, ResType, I); case Intrinsic::spv_any: return selectAny(ResVReg, ResType, I); - case Intrinsic::spv_fma: - return selectExtInst(ResVReg, ResType, I, CL::fma, GL::Fma); case Intrinsic::spv_cross: return selectExtInst(ResVReg, ResType, I, CL::cross, GL::Cross); case Intrinsic::spv_distance: diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fma.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fma.ll deleted file mode 100644 index 28e7bfa36f591..0000000000000 --- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fma.ll +++ /dev/null @@ -1,53 +0,0 @@ -; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s -; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %} - -; CHECK: OpExtInstImport "GLSL.std.450" - -define noundef half @fma_half(half noundef %a, half noundef %b, half noundef %c) { -entry: -; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] Fma %[[#]] %[[#]] %[[#]] - %r = call half @llvm.spv.fma.f16(half %a, half %b, half %c) - ret half %r -} - -define noundef float @fma_float(float noundef %a, float noundef %b, float noundef %c) { -entry: -; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] Fma %[[#]] %[[#]] %[[#]] - %r = call float @llvm.spv.fma.f32(float %a, float %b, float %c) - ret float %r -} - -define noundef double @fma_double(double noundef %a, double noundef %b, double noundef %c) { -entry: -; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] Fma %[[#]] %[[#]] %[[#]] - %r = call double @llvm.spv.fma.f64(double %a, double %b, double %c) - ret double %r -} - -define noundef <4 x half> @fma_half4(<4 x half> noundef %a, <4 x half> noundef %b, <4 x half> noundef %c) { -entry: -; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] Fma %[[#]] %[[#]] %[[#]] - %r = call <4 x half> @llvm.spv.fma.v4f16(<4 x half> %a, <4 x half> %b, <4 x half> %c) - ret <4 x half> %r -} - -define noundef <4 x float> @fma_float4(<4 x float> noundef %a, <4 x float> noundef %b, <4 x float> noundef %c) { -entry: -; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] Fma %[[#]] %[[#]] %[[#]] - %r = call <4 x float> @llvm.spv.fma.v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c) - ret <4 x float> %r -} - -define noundef <4 x double> @fma_double4(<4 x double> noundef %a, <4 x double> noundef %b, <4 x double> noundef %c) { -entry: -; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] Fma %[[#]] %[[#]] %[[#]] - %r = call <4 x double> @llvm.spv.fma.v4f64(<4 x double> %a, <4 x double> %b, <4 x double> %c) - ret <4 x double> %r -} - -declare half @llvm.spv.fma.f16(half, half, half) -declare float @llvm.spv.fma.f32(float, float, float) -declare double @llvm.spv.fma.f64(double, double, double) -declare <4 x half> @llvm.spv.fma.v4f16(<4 x half>, <4 x half>, <4 x half>) -declare <4 x float> @llvm.spv.fma.v4f32(<4 x float>, <4 x float>, <4 x float>) -declare <4 x double> @llvm.spv.fma.v4f64(<4 x double>, <4 x double>, <4 x double>) _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
