================
@@ -554,6 +556,77 @@ emitCallMaybeConstrainedBuiltin(CIRGenBuilderTy &builder, 
mlir::Location loc,
   return builder.emitIntrinsicCallOp(loc, intrName, retTy, ops);
 }
 
+static mlir::Value emitVectorFmaLaneSource(CIRGenBuilderTy &builder,
+                                           mlir::Location loc,
+                                           const CallExpr *expr,
+                                           ASTContext &ctx,
+                                           mlir::Value laneSource,
+                                           cir::VectorType ty,
+                                           cir::VectorType sourceTy) {
+  if (laneSource.getType() != sourceTy)
+    laneSource = builder.createBitcast(loc, laneSource, sourceTy);
+
+  auto vecTy = mlir::cast<cir::VectorType>(ty);
+  int64_t lane = expr->getArg(3)->EvaluateKnownConstInt(ctx).getSExtValue();
+  llvm::SmallVector<int64_t> mask(vecTy.getSize(), lane);
+  return builder.createVecShuffle(loc, laneSource, mask);
+}
+
+static mlir::Value emitVectorFmaBuiltin(CIRGenFunction &cgf,
+                                        mlir::Location loc,
+                                        llvm::SmallVectorImpl<mlir::Value> 
&ops,
+                                        const CallExpr *expr) {
+  cir::VectorType ty = 
mlir::cast<cir::VectorType>(cgf.convertType(expr->getType()));
+  if (ops[0].getType() != ty)
+    ops[0] = cgf.getBuilder().createBitcast(loc, ops[0], ty);
+  if (ops[1].getType() != ty)
+    ops[1] = cgf.getBuilder().createBitcast(loc, ops[1], ty);
+  if (ops[2].getType() != ty)
+    ops[2] = cgf.getBuilder().createBitcast(loc, ops[2], ty);
+  std::rotate(ops.begin(), ops.begin() + 1, ops.end());
+  return emitCallMaybeConstrainedBuiltin(cgf.getBuilder(), loc, "fma", ty, 
ops);
+}
+
+static mlir::Value emitVectorFmaLaneBuiltin(CIRGenFunction &cgf,
+                                            unsigned builtinID,
+                                            NeonTypeFlags type,
+                                            mlir::Location loc,
+                                            const CallExpr *expr,
+                                            llvm::SmallVectorImpl<mlir::Value> 
&ops) {
+  cir::VectorType ty = getNeonType(&cgf, type, loc);
+  if (!ty)
+    return nullptr;
+
+  auto vecTy = mlir::cast<cir::VectorType>(ty);
+  cir::VectorType sourceTy = ty;
+  unsigned vectorFmaBuiltin = NEON::BI__builtin_neon_vfma_v;
+
+  switch (builtinID) {
+  case NEON::BI__builtin_neon_vfmaq_lane_v:
+    sourceTy = cir::VectorType::get(vecTy.getElementType(), vecTy.getSize() / 
2);
+    vectorFmaBuiltin = NEON::BI__builtin_neon_vfmaq_v;
+    break;
+  case NEON::BI__builtin_neon_vfma_laneq_v:
+    sourceTy = cir::VectorType::get(vecTy.getElementType(), vecTy.getSize() * 
2);
+    break;
+  case NEON::BI__builtin_neon_vfmaq_laneq_v:
+    vectorFmaBuiltin = NEON::BI__builtin_neon_vfmaq_v;
+    break;
+  case NEON::BI__builtin_neon_vfma_lane_v:
+    break;
+  default:
+    llvm_unreachable("unexpected vfma lane builtin");
+  }
+
+  llvm::SmallVector<mlir::Value> fmaOps(ops.begin(), ops.end() - 1);
+  fmaOps[2] = emitVectorFmaLaneSource(cgf.getBuilder(), loc, expr,
+                                      cgf.getContext(), ops[2], ty, sourceTy);
+  const ARMVectorIntrinsicInfo *info = findARMVectorIntrinsicInMap(
+      AArch64SIMDIntrinsicMap, vectorFmaBuiltin,
+      aarch64SIMDIntrinsicsProvenSorted);
+  return emitCommonNeonBuiltinExpr(cgf, *info, fmaOps, expr);
+}
----------------
banach-space wrote:

We didn't need these extra helpers in neither the incubator nor for the 
original code-gen. 

https://github.com/llvm/llvm-project/pull/188190
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to