================
@@ -2320,6 +2330,65 @@ static bool interp__builtin_elementwise_sat(InterpState 
&S, CodePtr OpPC,
   return true;
 }
 
+static bool interp__builtin_elementwise_fma(InterpState &S, CodePtr OpPC,
+                                            const CallExpr *Call) {
+  assert(Call->getNumArgs() == 3);
+
+  llvm::RoundingMode RM = getRoundingMode(S, Call);
+
+  const QualType Arg1Type = Call->getArg(0)->getType();
+  const QualType Arg2Type = Call->getArg(1)->getType();
+  const QualType Arg3Type = Call->getArg(2)->getType();
+
+  // Non-vector floating point types.
+  if (!Arg1Type->isVectorType()) {
+    assert(!Arg2Type->isVectorType());
+    assert(!Arg3Type->isVectorType());
+
+    const Floating &Z = S.Stk.pop<Floating>();
+    const Floating &Y = S.Stk.pop<Floating>();
+    const Floating &X = S.Stk.pop<Floating>();
+
+    APFloat F = X.getAPFloat();
+    F.fusedMultiplyAdd(Y.getAPFloat(), Z.getAPFloat(), RM);
+    Floating Result = S.allocFloat(X.getSemantics());
+    Result.copy(F);
+    S.Stk.push<Floating>(Result);
+    return true;
+  }
+
+  // Vector type.
+  assert(Arg1Type->isVectorType() &&
+         Arg2Type->isVectorType() &&
+         Arg3Type->isVectorType());
+
+  const VectorType *VecT = Arg1Type->castAs<VectorType>();
+  const QualType ElemT = VecT->getElementType();
+  unsigned NumElems = VecT->getNumElements();
+
+  assert(ElemT == Arg2Type->castAs<VectorType>()->getElementType() &&
+         ElemT == Arg3Type->castAs<VectorType>()->getElementType());
+  assert(NumElems == Arg2Type->castAs<VectorType>()->getNumElements() &&
+         NumElems == Arg3Type->castAs<VectorType>()->getNumElements());
+  assert(ElemT->isRealFloatingType());
+
+  const Pointer &VZ = S.Stk.pop<Pointer>();
+  const Pointer &VY = S.Stk.pop<Pointer>();
+  const Pointer &VX = S.Stk.pop<Pointer>();
+  const Pointer &Dst = S.Stk.peek<Pointer>();
+
+  for (unsigned I = 0; I != NumElems; ++I) {
+    using T = PrimConv<PT_Float>::T;
+    APFloat X = VX.elem<T>(I).getAPFloat();
+    APFloat Y = VY.elem<T>(I).getAPFloat();
+    APFloat Z = VZ.elem<T>(I).getAPFloat();
+    (void)X.fusedMultiplyAdd(Y, Z, RM);
+    Dst.elem<T>(I) = static_cast<PrimConv<PT_Float>::T>(X);
----------------
ckoparkar wrote:

Done.

https://github.com/llvm/llvm-project/pull/152919
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to