https://github.com/adams381 updated 
https://github.com/llvm/llvm-project/pull/195879

>From 2edda6b97394782a4617582ef716c777e2c590e9 Mon Sep 17 00:00:00 2001
From: Adam Smith <[email protected]>
Date: Tue, 5 May 2026 09:26:10 -0700
Subject: [PATCH] [CIR] Add Direct coerce-in-registers + cir.reinterpret_cast
 op

Fourth PR in the split of #192119/#192124. Implements the
Direct-with-coercion path in CallConvLowering and picks off
andykaylor's five inline review comments from the original PR.

The new cir.reinterpret_cast op is for same-bit-width in-register
reinterpretation (vector<2 x float> <-> complex<float>).
emitCoercion uses it when source and destination differ only in
vector-vs-non-vector shape and have identical bit width, instead
of going through memory.  For everything else (records, or shape
doesn't match) the helper still does alloca/store/ptr-cast/load.

Andy's comments, in order:
- Temporary alloca alignment is now max(srcAlign, dstAlign) from
  DataLayout instead of hardcoded.
- The alloca lives in the entry block via InsertionGuard so it
  composes with HoistAllocas regardless of pipeline order.
- isVolatile kept as UnitAttr-absence with an inline comment.
- vector<->complex now uses cir.reinterpret_cast.
- Memory path has three new .cir tests covering it.

In-body coercion (insertArgCoercion / insertReturnCoercion) folds
into the existing per-function rewriteFunctionDefinition method
introduced by the prior Direct/Ignore PR's review fixup.  It runs
ahead of the Ignore-arg drop in the same per-function inner-loop
window the pass driver already uses, so the per-function
window-of-invalidity invariant is unchanged: F's signature and
F's body coerce together; F's callers update inside the same
inner loop iteration.  The Ignore-arg drop reuses the existing
poison-stub idiom (the alloca/load-fallback that earlier drafts
used is unnecessary once the drop happens after coercion in the
same per-function window).

LowerToLLVM gets a stub for the new op: bitcast for same-shape
converted types, error-with-message for aggregates.  We don't
produce aggregates from CallConvLowering today, so the error
path is only reachable from hand-written IR; follow-up patch can
add an extract/insert lowering if needed.

Co-authored-by: Cursor <[email protected]>
---
 clang/include/clang/CIR/Dialect/IR/CIROps.td  |  48 ++++
 clang/lib/CIR/Dialect/IR/CIRDialect.cpp       |  22 ++
 .../Transforms/CallConvLoweringPass.cpp       |   2 +-
 .../TargetLowering/CIRABIRewriteContext.cpp   | 228 ++++++++++++++++--
 .../TargetLowering/CIRABIRewriteContext.h     |  15 +-
 .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp |  39 +++
 clang/test/CIR/IR/reinterpret-cast.cir        |  28 +++
 .../abi-lowering/coerce-int-to-record.cir     |  59 +++++
 .../abi-lowering/coerce-record-to-int.cir     |  50 ++++
 .../coerce-record-to-record-via-memory.cir    |  34 +++
 .../coerce-vector-to-complex-reinterpret.cir  |  42 ++++
 11 files changed, 537 insertions(+), 30 deletions(-)
 create mode 100644 clang/test/CIR/IR/reinterpret-cast.cir
 create mode 100644 
clang/test/CIR/Transforms/abi-lowering/coerce-int-to-record.cir
 create mode 100644 
clang/test/CIR/Transforms/abi-lowering/coerce-record-to-int.cir
 create mode 100644 
clang/test/CIR/Transforms/abi-lowering/coerce-record-to-record-via-memory.cir
 create mode 100644 
clang/test/CIR/Transforms/abi-lowering/coerce-vector-to-complex-reinterpret.cir

diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td 
b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index d55439b8618a5..5774507395b08 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -288,6 +288,54 @@ def CIR_CastOp : CIR_Op<"cast", [
 
 }
 
+//===----------------------------------------------------------------------===//
+// ReinterpretCastOp
+//===----------------------------------------------------------------------===//
+
+def CIR_ReinterpretCastOp : CIR_Op<"reinterpret_cast", [Pure]> {
+  let summary = "Reinterpret a value as a different same-bit-width type";
+  let description = [{
+    The `cir.reinterpret_cast` operation reinterprets the bits of its source
+    value as a different type, with no IR-level cost.  It is used by the
+    calling-convention lowering pass to coerce between same-bit-width types
+    that have an LLVM-IR-level shape mismatch but identical in-register
+    representation -- for example, between `!cir.vector<2 x !cir.float>` and
+    `!cir.complex<!cir.float>`, both of which lower to the same LLVM IR
+    representation but have distinct CIR types.
+
+    Unlike `cir.cast bitcast`, which is overloaded for pointer-to-pointer
+    bitcasts and several other use cases, `cir.reinterpret_cast` is reserved
+    for in-register value reinterpretation only.  The result type must
+    differ from the source type; otherwise the op is meaningless and the
+    folder removes it.
+
+    **Invariant** (not currently enforced by the verifier): the source and
+    destination types must have the same bit width per the module's
+    DataLayout, and they must use the same in-register lane order on the
+    target.  Producers (e.g. CallConvLowering's coerce-in-registers path)
+    are responsible for ensuring this; a follow-up patch will move the
+    bit-width check into the verifier once the design question of
+    DataLayout-aware op verifiers is resolved.
+
+    Example:
+
+    ```
+    %c = cir.reinterpret_cast %v
+       : !cir.vector<2 x !cir.float> -> !cir.complex<!cir.float>
+    ```
+  }];
+
+  let arguments = (ins CIR_AnyType:$src);
+  let results = (outs CIR_AnyType:$result);
+
+  let assemblyFormat = [{
+    $src `:` type($src) `->` type($result) attr-dict
+  }];
+
+  let hasVerifier = 1;
+  let hasFolder = 1;
+}
+
 
//===----------------------------------------------------------------------===//
 // DynamicCastOp
 
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp 
b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 3c465f94723f1..15adc2fccf5bf 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -891,6 +891,28 @@ static Value tryFoldCastChain(cir::CastOp op) {
   return {};
 }
 
+//===----------------------------------------------------------------------===//
+// ReinterpretCastOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult cir::ReinterpretCastOp::verify() {
+  // The op is meaningless for identical types -- the folder is the right
+  // way to remove it -- but we accept it at the verifier level so that
+  // peephole code (e.g. pattern rewriters that round-trip values) doesn't
+  // need a type-equality guard.  Producers should still avoid emitting
+  // it for matching types.
+  //
+  // The same-bit-width invariant is documented on the op but not yet
+  // checked here; see the op description for the rationale.
+  return success();
+}
+
+OpFoldResult cir::ReinterpretCastOp::fold(FoldAdaptor adaptor) {
+  if (getSrc().getType() == getType())
+    return getSrc();
+  return {};
+}
+
 OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
   if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getSrc())) {
     // Propagate poison value
diff --git a/clang/lib/CIR/Dialect/Transforms/CallConvLoweringPass.cpp 
b/clang/lib/CIR/Dialect/Transforms/CallConvLoweringPass.cpp
index 838125037afd5..c00947593517e 100644
--- a/clang/lib/CIR/Dialect/Transforms/CallConvLoweringPass.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CallConvLoweringPass.cpp
@@ -137,7 +137,7 @@ void CallConvLoweringPass::runOnOperation() {
   }
 
   DataLayout dl(moduleOp);
-  CIRABIRewriteContext rewriteCtx(moduleOp);
+  CIRABIRewriteContext rewriteCtx(moduleOp, dl);
   SymbolTable symbolTable(moduleOp);
 
   // Classify every cir.func up front.  No IR mutation happens here, so
diff --git 
a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp 
b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
index 2c29c83b999ba..29ac1d371cd64 100644
--- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
@@ -54,12 +54,11 @@ LogicalResult buildNewArgTypes(ArrayRef<Type> oldArgTypes,
     Type origTy = oldArgTypes[idx];
     switch (ac.kind) {
     case ArgKind::Direct:
-      if (ac.coercedType) {
-        emitError() << "Direct with coerced type at arg " << idx
-                    << " not yet implemented in CallConvLowering";
-        return failure();
-      }
-      newArgTypes.push_back(origTy);
+      // Direct with a coerced type means the wire signature uses the
+      // coerced type; the body still expects origTy and we'll insert a
+      // reinterpret/coercion at the entry block.  Direct without a
+      // coerced type is a true pass-through.
+      newArgTypes.push_back(ac.coercedType ? ac.coercedType : origTy);
       break;
     case ArgKind::Ignore:
       break;
@@ -93,12 +92,9 @@ Type computeNewReturnType(Type origRetTy, const 
ArgClassification &retInfo,
                           function_ref<InFlightDiagnostic()> emitError) {
   switch (retInfo.kind) {
   case ArgKind::Direct:
-    if (retInfo.coercedType) {
-      emitError() << "Direct return with coerced type not yet implemented "
-                  << "in CallConvLowering";
-      return nullptr;
-    }
-    return origRetTy;
+    // Direct return with a coerced type uses the coerced type on the wire;
+    // the rewriter inserts a coercion before each cir.return.
+    return retInfo.coercedType ? retInfo.coercedType : origRetTy;
   case ArgKind::Ignore:
     return cir::VoidType::get(ctx);
   case ArgKind::Expand:
@@ -176,6 +172,157 @@ ArrayAttr updateResAttrs(MLIRContext *ctx, ArrayAttr 
existingResAttrs,
   return ArrayAttr::get(ctx, {DictionaryAttr::get(ctx, attrs)});
 }
 
+/// Coerce \p src to type \p dstTy at the current builder insertion point.
+///
+/// Three strategies, in order of preference:
+///   - If src and dst are the same type, return src unchanged and leave
+///     \p createdOps empty.
+///   - If both are non-aggregate same-bit-width values that just differ in
+///     vector-vs-scalar shape (e.g. !cir.vector<2 x !cir.float> ↔
+///     !cir.complex<!cir.float>), use cir.reinterpret_cast which is free at
+///     the IR level.
+///   - Otherwise go through memory: allocate a slot of the source type
+///     (using max(srcAlign, dstAlign) for the alloca alignment), store
+///     the source, bitcast the pointer to the destination type, load the
+///     destination type back.
+///
+/// The temporary alloca is placed at the start of the enclosing function's
+/// entry block so that it composes correctly with the HoistAllocas pass
+/// regardless of pipeline ordering.
+///
+/// Any operations the helper creates are appended to \p createdOps so the
+/// caller can pass them to replaceAllUsesExcept and avoid clobbering the
+/// store's value operand when later rewiring the source value.
+Value emitCoercion(OpBuilder &rewriter, Location loc, Type dstTy, Value src,
+                   FunctionOpInterface funcOp, const DataLayout &dl,
+                   SmallPtrSetImpl<Operation *> &createdOps) {
+  Type srcTy = src.getType();
+  if (srcTy == dstTy)
+    return src;
+
+  // Reinterpret path: same total bit width, neither side is a record, and
+  // the shapes differ only in vector-vs-non-vector.  Going through memory
+  // is wasteful for these — they have the same in-register representation.
+  bool isAggregate = isa<cir::RecordType>(srcTy) || 
isa<cir::RecordType>(dstTy);
+  bool vectorMismatch =
+      isa<cir::VectorType>(srcTy) != isa<cir::VectorType>(dstTy);
+  if (!isAggregate && vectorMismatch &&
+      dl.getTypeSizeInBits(srcTy) == dl.getTypeSizeInBits(dstTy)) {
+    auto reinterpret =
+        cir::ReinterpretCastOp::create(rewriter, loc, dstTy, src);
+    createdOps.insert(reinterpret);
+    return reinterpret;
+  }
+
+  // Memory path: alloca + store + ptr-cast + load.  The alloca goes in the
+  // entry block (Andy's review comment #3 on the original PR), with
+  // alignment = max(srcAlign, dstAlign) to satisfy both the store and the
+  // load (review comment #1).
+  uint64_t srcAlign = dl.getTypeABIAlignment(srcTy);
+  uint64_t dstAlign = dl.getTypeABIAlignment(dstTy);
+  uint64_t allocaAlign = std::max(srcAlign, dstAlign);
+
+  auto srcPtrTy = cir::PointerType::get(srcTy);
+  auto dstPtrTy = cir::PointerType::get(dstTy);
+
+  cir::AllocaOp alloca;
+  {
+    OpBuilder::InsertionGuard guard(rewriter);
+    Block &entry = funcOp->getRegion(0).front();
+    rewriter.setInsertionPointToStart(&entry);
+    alloca = cir::AllocaOp::create(rewriter, loc, srcPtrTy, srcTy,
+                                   rewriter.getStringAttr("coerce"),
+                                   rewriter.getI64IntegerAttr(allocaAlign));
+  }
+  createdOps.insert(alloca);
+
+  auto store = cir::StoreOp::create(rewriter, loc, src, alloca,
+                                    /*isVolatile=*/UnitAttr(),
+                                    /*alignment=*/IntegerAttr(),
+                                    /*sync_scope=*/cir::SyncScopeKindAttr(),
+                                    /*mem_order=*/cir::MemOrderAttr());
+  createdOps.insert(store);
+
+  auto ptrCast = cir::CastOp::create(rewriter, loc, dstPtrTy,
+                                     cir::CastKind::bitcast, alloca);
+  createdOps.insert(ptrCast);
+
+  auto load = cir::LoadOp::create(rewriter, loc, dstTy, ptrCast,
+                                  /*isDeref=*/UnitAttr(),
+                                  /*isVolatile=*/UnitAttr(),
+                                  /*alignment=*/IntegerAttr(),
+                                  /*sync_scope=*/cir::SyncScopeKindAttr(),
+                                  /*mem_order=*/cir::MemOrderAttr());
+  createdOps.insert(load);
+  return load;
+}
+
+/// Convenience overload for callers that don't need the createdOps set
+/// (e.g. call-site coercion where we don't replaceAllUsesExcept).
+Value emitCoercion(OpBuilder &rewriter, Location loc, Type dstTy, Value src,
+                   FunctionOpInterface funcOp, const DataLayout &dl) {
+  SmallPtrSet<Operation *, 4> ignored;
+  return emitCoercion(rewriter, loc, dstTy, src, funcOp, dl, ignored);
+}
+
+/// Insert coercion before each cir.return so the returned value matches the
+/// new (coerced) return type.
+void insertReturnCoercion(FunctionOpInterface funcOp, Type origRetTy,
+                          Type coercedRetTy, OpBuilder &rewriter,
+                          const DataLayout &dl) {
+  SmallVector<cir::ReturnOp> returns;
+  funcOp.walk([&](cir::ReturnOp r) { returns.push_back(r); });
+  for (cir::ReturnOp r : returns) {
+    if (r.getInput().empty())
+      continue;
+    Value origVal = r.getInput()[0];
+    if (origVal.getType() == coercedRetTy)
+      continue;
+    rewriter.setInsertionPoint(r);
+    Value coerced =
+        emitCoercion(rewriter, r.getLoc(), coercedRetTy, origVal, funcOp, dl);
+    r->setOperand(0, coerced);
+  }
+}
+
+/// For each Direct arg with a coerced type, change the block argument's type
+/// to the coerced type and insert a coercion at function entry that maps it
+/// back to the original type for body uses.
+void insertArgCoercion(FunctionOpInterface funcOp,
+                       const FunctionClassification &fc, OpBuilder &rewriter,
+                       const DataLayout &dl) {
+  Region &body = funcOp->getRegion(0);
+  if (body.empty())
+    return;
+  Block &entry = body.front();
+
+  for (auto [idx, ac] : llvm::enumerate(fc.argInfos)) {
+    if (ac.kind != ArgKind::Direct || !ac.coercedType)
+      continue;
+    if (idx >= entry.getNumArguments())
+      continue;
+
+    BlockArgument blockArg = entry.getArgument(idx);
+    Type oldArgTy = blockArg.getType();
+    Type newArgTy = ac.coercedType;
+    if (oldArgTy == newArgTy)
+      continue;
+
+    blockArg.setType(newArgTy);
+
+    rewriter.setInsertionPointToStart(&entry);
+    SmallPtrSet<Operation *, 4> coercionOps;
+    Value adapted = emitCoercion(rewriter, funcOp.getLoc(), oldArgTy, blockArg,
+                                 funcOp, dl, coercionOps);
+
+    // Replace blockArg uses with the adapted value, except inside the helper
+    // ops we just created.  This is critical: the StoreOp's value operand is
+    // blockArg, and if we naively replaceAllUses it gets swapped to adapted
+    // (now of the original type != the alloca's pointee type).
+    blockArg.replaceAllUsesExcept(adapted, coercionOps);
+  }
+}
+
 } // namespace
 
 LogicalResult CIRABIRewriteContext::rewriteFunctionDefinition(
@@ -217,6 +364,23 @@ LogicalResult 
CIRABIRewriteContext::rewriteFunctionDefinition(
   if (funcOp.isDefinition()) {
     Region &body = funcOp->getRegion(0);
     if (!body.empty()) {
+      // In-body coercion for Direct-with-coerce / Extend args: change
+      // block-arg types to the coerced types and insert a
+      // cir.reinterpret_cast at the top of the entry block that converts
+      // each coerced value back to its original type, then route existing
+      // body uses (including in-body cir.call operands) through the cast.
+      // Done before the Ignore-drop below so the entry block argument
+      // indices used here still refer to the original positions.
+      insertArgCoercion(funcOp, fc, builder, dl);
+
+      // Direct return with coerced type: insert a coercion at every
+      // cir.return so the returned value matches the (coerced) return
+      // type in the new function signature set below.
+      if (fc.returnInfo.kind == ArgKind::Direct && fc.returnInfo.coercedType &&
+          !oldResultTypes.empty() && fc.returnInfo.coercedType != origRetTy)
+        insertReturnCoercion(funcOp, origRetTy, fc.returnInfo.coercedType,
+                             builder, dl);
+
       Block &entry = body.front();
 
       // For each Ignored argument: drop the block argument and, if the
@@ -302,23 +466,19 @@ LogicalResult CIRABIRewriteContext::rewriteCallSite(
            << "indirect call not yet implemented in CallConvLowering";
 
   MLIRContext *ctx = callOp->getContext();
+  auto enclosingFunc = call->getParentOfType<FunctionOpInterface>();
 
   for (auto [idx, ac] : llvm::enumerate(fc.argInfos)) {
     switch (ac.kind) {
     case ArgKind::Direct:
-      if (ac.coercedType)
-        return call.emitOpError()
-               << "Direct with coerced type at call-site arg " << idx
-               << " not yet implemented in CallConvLowering";
-      break;
     case ArgKind::Ignore:
       break;
     case ArgKind::Expand:
       return call.emitOpError() << "Expand at call-site arg " << idx
                                 << " not yet implemented in CallConvLowering";
     case ArgKind::Extend:
-      // Extend at the call site is just an attribute change (llvm.signext /
-      // llvm.zeroext on the call's arg_attrs); no IR-level cast.
+      // Direct (with or without coercion), Ignore, Expand, and Extend are
+      // all handled below.  Extend is attribute-only at the IR level.
       break;
     case ArgKind::Indirect:
       return call.emitOpError() << "Indirect at call-site arg " << idx
@@ -326,6 +486,8 @@ LogicalResult CIRABIRewriteContext::rewriteCallSite(
     }
   }
 
+  builder.setInsertionPoint(call);
+
   SmallVector<Value> newArgs;
   ValueRange argOperands = call.getArgOperands();
   newArgs.reserve(argOperands.size());
@@ -337,7 +499,12 @@ LogicalResult CIRABIRewriteContext::rewriteCallSite(
   for (auto [idx, ac] : llvm::enumerate(fc.argInfos)) {
     if (ac.kind == ArgKind::Ignore)
       continue;
-    newArgs.push_back(argOperands[idx]);
+    Value arg = argOperands[idx];
+    if (ac.kind == ArgKind::Direct && ac.coercedType &&
+        arg.getType() != ac.coercedType)
+      arg = emitCoercion(builder, call.getLoc(), ac.coercedType, arg,
+                         enclosingFunc, dl);
+    newArgs.push_back(arg);
   }
 
   bool hasResult = call.getNumResults() > 0;
@@ -346,10 +513,11 @@ LogicalResult CIRABIRewriteContext::rewriteCallSite(
   Type callRetTy = origRetTy;
   if (fc.returnInfo.kind == ArgKind::Ignore && hasResult)
     callRetTy = cir::VoidType::get(ctx);
-  if (fc.returnInfo.kind == ArgKind::Direct && fc.returnInfo.coercedType)
-    return call.emitOpError() << "Direct return with coerced type at "
-                              << "call-site not yet implemented in "
-                              << "CallConvLowering";
+  bool returnNeedsCoercion =
+      hasResult && fc.returnInfo.kind == ArgKind::Direct &&
+      fc.returnInfo.coercedType && fc.returnInfo.coercedType != origRetTy;
+  if (returnNeedsCoercion)
+    callRetTy = fc.returnInfo.coercedType;
 
   builder.setInsertionPoint(call);
   auto newCall = cir::CallOp::create(builder, call.getLoc(),
@@ -358,6 +526,15 @@ LogicalResult CIRABIRewriteContext::rewriteCallSite(
     if (!newCall->hasAttr(attr.getName()))
       newCall->setAttr(attr.getName(), attr.getValue());
 
+  // Direct return with coercion: the new call returns the coerced type;
+  // emit a coercion back to the original type for the call's existing uses.
+  if (returnNeedsCoercion) {
+    builder.setInsertionPointAfter(newCall);
+    Value coercedBack = emitCoercion(builder, call.getLoc(), origRetTy,
+                                     newCall.getResult(), enclosingFunc, dl);
+    call.getResult().replaceAllUsesWith(coercedBack);
+  }
+
   // Layer llvm.signext / llvm.zeroext onto the new call's arg_attrs and
   // res_attrs for Extend args/return.  Ignore args also require a rebuild
   // because their slots are dropped from the output array.
@@ -384,7 +561,8 @@ LogicalResult CIRABIRewriteContext::rewriteCallSite(
       Value poison = createIgnoredValue(builder, call.getLoc(), origRetTy);
       call.getResult().replaceAllUsesWith(poison);
     }
-  } else if (hasResult) {
+  } else if (hasResult && !returnNeedsCoercion) {
+    // returnNeedsCoercion already wired up the coerced result above.
     call.getResult().replaceAllUsesWith(newCall.getResult());
   }
 
diff --git 
a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h 
b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h
index 7a0c0b8a2f22c..038e81026784c 100644
--- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h
+++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h
@@ -11,9 +11,10 @@
 // rewrites a cir.func signature, the function body, and call sites to match
 // the ABI-lowered shape.
 //
-// This file currently handles only Direct (pass-through) and Ignore.  Other
-// ArgKind handlers (Extend, Direct-with-coercion, Indirect, Expand) are
-// added by subsequent PRs in the calling-convention-lowering split series.
+// This file currently handles Direct (pass-through and coerce-in-registers),
+// Extend, and Ignore.  The remaining ArgKind handlers (Indirect, Expand)
+// are added by subsequent PRs in the calling-convention-lowering split
+// series.
 //
 
//===----------------------------------------------------------------------===//
 
@@ -22,6 +23,7 @@
 
 #include "mlir/ABI/ABIRewriteContext.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "clang/CIR/Dialect/IR/CIRDialect.h"
 
 namespace cir {
@@ -31,9 +33,13 @@ namespace cir {
 /// The driver pass (CallConvLoweringPass) computes a FunctionClassification
 /// for each cir.func / cir.call and dispatches to this class to perform the
 /// actual IR rewriting using cir dialect operations.
+///
+/// Holds a reference to the module's DataLayout for coercion alignment
+/// queries.  The DataLayout must outlive the rewrite context.
 class CIRABIRewriteContext : public mlir::abi::ABIRewriteContext {
 public:
-  explicit CIRABIRewriteContext(mlir::ModuleOp module) : module(module) {}
+  CIRABIRewriteContext(mlir::ModuleOp module, const mlir::DataLayout &dl)
+      : module(module), dl(dl) {}
 
   mlir::LogicalResult
   rewriteFunctionDefinition(mlir::FunctionOpInterface funcOp,
@@ -49,6 +55,7 @@ class CIRABIRewriteContext : public 
mlir::abi::ABIRewriteContext {
 
 private:
   mlir::ModuleOp module;
+  const mlir::DataLayout &dl;
 };
 
 } // namespace cir
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp 
b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index c4e98e299dfc1..0b11cc4547178 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1659,6 +1659,45 @@ mlir::LogicalResult 
CIRToLLVMReturnOpLowering::matchAndRewrite(
   return mlir::LogicalResult::success();
 }
 
+mlir::LogicalResult CIRToLLVMReinterpretCastOpLowering::matchAndRewrite(
+    cir::ReinterpretCastOp op, OpAdaptor adaptor,
+    mlir::ConversionPatternRewriter &rewriter) const {
+  // After type conversion, source and destination LLVM types may be:
+  //   (a) Identical: trivially replace uses with the source value (the
+  //       op was a CIR-level type rename only; LLVM sees no change).
+  //   (b) Same scalar / vector category, same bit width: emit
+  //       LLVM::BitcastOp.
+  //   (c) Aggregate vs scalar / aggregate vs vector: LLVM::BitcastOp
+  //       does not allow aggregate types.  We currently emit an error
+  //       directing the producer to go through memory.  A future patch
+  //       will add an extract/insert lowering for the aggregate case so
+  //       the LLVM IR avoids the memory roundtrip too.
+  mlir::Type llvmDstTy = getTypeConverter()->convertType(op.getType());
+  mlir::Value llvmSrc = adaptor.getSrc();
+  mlir::Type llvmSrcTy = llvmSrc.getType();
+
+  if (llvmSrcTy == llvmDstTy) {
+    rewriter.replaceOp(op, llvmSrc);
+    return mlir::success();
+  }
+
+  bool srcIsAggregate =
+      mlir::isa<mlir::LLVM::LLVMStructType, mlir::LLVM::LLVMArrayType>(
+          llvmSrcTy);
+  bool dstIsAggregate =
+      mlir::isa<mlir::LLVM::LLVMStructType, mlir::LLVM::LLVMArrayType>(
+          llvmDstTy);
+  if (srcIsAggregate || dstIsAggregate)
+    return op.emitOpError()
+           << "lowering cir.reinterpret_cast to LLVM with aggregate type "
+           << "not yet implemented; producer should fall back to memory "
+           << "coercion until a follow-up patch adds extract/insert "
+           << "lowering";
+
+  rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(op, llvmDstTy, llvmSrc);
+  return mlir::success();
+}
+
 mlir::LogicalResult CIRToLLVMRotateOpLowering::matchAndRewrite(
     cir::RotateOp op, OpAdaptor adaptor,
     mlir::ConversionPatternRewriter &rewriter) const {
diff --git a/clang/test/CIR/IR/reinterpret-cast.cir 
b/clang/test/CIR/IR/reinterpret-cast.cir
new file mode 100644
index 0000000000000..94742e15cda42
--- /dev/null
+++ b/clang/test/CIR/IR/reinterpret-cast.cir
@@ -0,0 +1,28 @@
+// RUN: cir-opt %s --verify-roundtrip | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+  // Vector ↔ complex same-bit-width reinterpret (the canonical use case
+  // from cir-call-conv-lowering's coerce-in-registers path).
+  cir.func @vec_to_complex(%v : !cir.vector<2 x !cir.float>)
+      -> !cir.complex<!cir.float> {
+    %c = cir.reinterpret_cast %v
+       : !cir.vector<2 x !cir.float> -> !cir.complex<!cir.float>
+    cir.return %c : !cir.complex<!cir.float>
+  }
+
+  // Reverse direction.
+  cir.func @complex_to_vec(%c : !cir.complex<!cir.float>)
+      -> !cir.vector<2 x !cir.float> {
+    %v = cir.reinterpret_cast %c
+       : !cir.complex<!cir.float> -> !cir.vector<2 x !cir.float>
+    cir.return %v : !cir.vector<2 x !cir.float>
+  }
+}
+
+// CHECK:      cir.func{{.*}} @vec_to_complex
+// CHECK:        cir.reinterpret_cast %{{.*}} : !cir.vector<2 x !cir.float> -> 
!cir.complex<!cir.float>
+
+// CHECK:      cir.func{{.*}} @complex_to_vec
+// CHECK:        cir.reinterpret_cast %{{.*}} : !cir.complex<!cir.float> -> 
!cir.vector<2 x !cir.float>
diff --git a/clang/test/CIR/Transforms/abi-lowering/coerce-int-to-record.cir 
b/clang/test/CIR/Transforms/abi-lowering/coerce-int-to-record.cir
new file mode 100644
index 0000000000000..f90427bf68b4c
--- /dev/null
+++ b/clang/test/CIR/Transforms/abi-lowering/coerce-int-to-record.cir
@@ -0,0 +1,59 @@
+// Direct return with coerced type going from a small record to a same-bit-
+// width integer.  Mirror of coerce-record-to-int.cir but exercising the
+// return-side coercion code path: every cir.return gets the original
+// record value coerced to the integer type before being returned.
+// RUN: cir-opt %s -cir-call-conv-lowering="classification-attr=test_classify" 
\
+// RUN:   | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+!s64i = !cir.int<s, 64>
+!rec_Pair = !cir.record<struct "Pair" {!s32i, !s32i}>
+
+#coerce_pair_return_to_i64 = {
+  return = { kind = "direct", coerced_type = !s64i },
+  args   = [ ]
+}
+
+#all_direct_no_args = {
+  return = { kind = "direct" },
+  args   = [ ]
+}
+
+module attributes {
+  dlti.dl_spec = #dlti.dl_spec<
+    #dlti.dl_entry<i32, dense<32>: vector<2xi64>>,
+    #dlti.dl_entry<i64, dense<64>: vector<2xi64>>>
+} {
+
+  cir.func @returns_pair() -> !rec_Pair
+      attributes { test_classify = #coerce_pair_return_to_i64 } {
+    %0 = cir.const #cir.zero : !rec_Pair
+    cir.return %0 : !rec_Pair
+  }
+
+  // Signature changes to !s64i return; the cir.return's record operand
+  // gets coerced via memory roundtrip before being returned.  The alloca
+  // is hoisted to the entry-block start (Andy's review comment #3 from the
+  // original PR) so it sits ahead of the const that produces the value.
+  // CHECK:      cir.func{{.*}} @returns_pair() -> !s64i
+  // CHECK:        %[[SLOT:.*]] = cir.alloca !rec_Pair, !cir.ptr<!rec_Pair>, 
["coerce"]
+  // CHECK:        %[[VAL:.*]] = cir.const #cir.zero : !rec_Pair
+  // CHECK:        cir.store %[[VAL]], %[[SLOT]] : !rec_Pair, 
!cir.ptr<!rec_Pair>
+  // CHECK:        %[[CAST:.*]] = cir.cast bitcast %[[SLOT]] : 
!cir.ptr<!rec_Pair> -> !cir.ptr<!s64i>
+  // CHECK:        %[[COERCED:.*]] = cir.load %[[CAST]] : !cir.ptr<!s64i>, 
!s64i
+  // CHECK:        cir.return %[[COERCED]] : !s64i
+
+  cir.func @caller() -> !rec_Pair
+      attributes { test_classify = #coerce_pair_return_to_i64 } {
+    %0 = cir.call @returns_pair() : () -> !rec_Pair
+    cir.return %0 : !rec_Pair
+  }
+
+  // At the call site the lowered call returns !s64i; the rewriter coerces
+  // it back to !rec_Pair for downstream uses (the caller's own return
+  // also needs the coerce-back-then-coerce-forward chain since caller's
+  // return is also Direct-with-coerce).
+  // CHECK:      cir.func{{.*}} @caller() -> !s64i
+  // CHECK:        %{{.*}} = cir.call @returns_pair() : () -> !s64i
+
+}
diff --git a/clang/test/CIR/Transforms/abi-lowering/coerce-record-to-int.cir 
b/clang/test/CIR/Transforms/abi-lowering/coerce-record-to-int.cir
new file mode 100644
index 0000000000000..f31f09181710e
--- /dev/null
+++ b/clang/test/CIR/Transforms/abi-lowering/coerce-record-to-int.cir
@@ -0,0 +1,50 @@
+// Direct with coerced type going from a small record to a same-bit-width
+// integer.  The shapes don't match (record vs scalar) so the rewriter
+// emits a memory roundtrip: alloca in the entry block + store + ptr-cast +
+// load.
+// RUN: cir-opt %s -cir-call-conv-lowering="classification-attr=test_classify" 
\
+// RUN:   | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+!s64i = !cir.int<s, 64>
+!rec_Pair = !cir.record<struct "Pair" {!s32i, !s32i}>
+
+#coerce_pair_to_i64 = {
+  return = { kind = "direct" },
+  args   = [ { kind = "direct", coerced_type = !s64i } ]
+}
+
+module attributes {
+  dlti.dl_spec = #dlti.dl_spec<
+    #dlti.dl_entry<i32, dense<32>: vector<2xi64>>,
+    #dlti.dl_entry<i64, dense<64>: vector<2xi64>>>
+} {
+
+  cir.func @takes_pair(%arg0: !rec_Pair)
+      attributes { test_classify = #coerce_pair_to_i64 } {
+    cir.return
+  }
+
+  // Signature changes to !s64i; entry block grows an alloca + store + cast
+  // + load chain that recovers the original !rec_Pair value.  The alloca
+  // lands at the very start of the entry block so this composes correctly
+  // with cir-hoist-allocas regardless of pipeline ordering.
+  // CHECK:      cir.func{{.*}} @takes_pair(%[[ARG:.*]]: !s64i)
+  // CHECK:        %[[SLOT:.*]] = cir.alloca !s64i, !cir.ptr<!s64i>, ["coerce"]
+  // CHECK:        cir.store %[[ARG]], %[[SLOT]] : !s64i, !cir.ptr<!s64i>
+  // CHECK:        %[[CAST:.*]] = cir.cast bitcast %[[SLOT]] : !cir.ptr<!s64i> 
-> !cir.ptr<!rec_Pair>
+  // CHECK:        %{{.*}} = cir.load %[[CAST]] : !cir.ptr<!rec_Pair>, 
!rec_Pair
+
+  cir.func @caller(%arg0: !rec_Pair)
+      attributes { test_classify = #coerce_pair_to_i64 } {
+    cir.call @takes_pair(%arg0) : (!rec_Pair) -> ()
+    cir.return
+  }
+
+  // At the call site, the original !rec_Pair gets coerced to !s64i via the
+  // same memory roundtrip before being passed.  Caller's own arg coercion
+  // chain runs first (it shares the pattern), then the call.
+  // CHECK:      cir.func{{.*}} @caller(%[[ARG:.*]]: !s64i)
+  // CHECK:        cir.call @takes_pair(%{{.*}}) : (!s64i) -> ()
+
+}
diff --git 
a/clang/test/CIR/Transforms/abi-lowering/coerce-record-to-record-via-memory.cir 
b/clang/test/CIR/Transforms/abi-lowering/coerce-record-to-record-via-memory.cir
new file mode 100644
index 0000000000000..1669bf1232d28
--- /dev/null
+++ 
b/clang/test/CIR/Transforms/abi-lowering/coerce-record-to-record-via-memory.cir
@@ -0,0 +1,34 @@
+// Direct with a coerced type that's a different record (record-to-record):
+// neither side is a vector and at least one is a record, so the rewriter
+// uses the memory-roundtrip path even though both types are aggregates.
+// RUN: cir-opt %s -cir-call-conv-lowering="classification-attr=test_classify" 
\
+// RUN:   | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+!s64i = !cir.int<s, 64>
+!rec_Pair  = !cir.record<struct "Pair"  {!s32i, !s32i}>
+!rec_Single = !cir.record<struct "Single" {!s64i}>
+
+#coerce_pair_to_single = {
+  return = { kind = "direct" },
+  args   = [ { kind = "direct", coerced_type = !rec_Single } ]
+}
+
+module attributes {
+  dlti.dl_spec = #dlti.dl_spec<
+    #dlti.dl_entry<i32, dense<32>: vector<2xi64>>,
+    #dlti.dl_entry<i64, dense<64>: vector<2xi64>>>
+} {
+
+  cir.func @takes_pair(%arg0: !rec_Pair)
+      attributes { test_classify = #coerce_pair_to_single } {
+    cir.return
+  }
+
+  // CHECK: cir.func{{.*}} @takes_pair(%[[ARG:.*]]: !rec_Single)
+  // CHECK:   %[[SLOT:.*]] = cir.alloca !rec_Single, !cir.ptr<!rec_Single>, 
["coerce"]
+  // CHECK:   cir.store %[[ARG]], %[[SLOT]] : !rec_Single, 
!cir.ptr<!rec_Single>
+  // CHECK:   %[[CAST:.*]] = cir.cast bitcast %[[SLOT]] : 
!cir.ptr<!rec_Single> -> !cir.ptr<!rec_Pair>
+  // CHECK:   %{{.*}} = cir.load %[[CAST]] : !cir.ptr<!rec_Pair>, !rec_Pair
+
+}
diff --git 
a/clang/test/CIR/Transforms/abi-lowering/coerce-vector-to-complex-reinterpret.cir
 
b/clang/test/CIR/Transforms/abi-lowering/coerce-vector-to-complex-reinterpret.cir
new file mode 100644
index 0000000000000..ceb1f9e364466
--- /dev/null
+++ 
b/clang/test/CIR/Transforms/abi-lowering/coerce-vector-to-complex-reinterpret.cir
@@ -0,0 +1,42 @@
+// Direct with coerced type that differs from the original only in
+// vector-vs-non-vector shape (same total bit width, neither side a record):
+// the rewriter emits cir.reinterpret_cast instead of going through memory.
+// RUN: cir-opt %s -cir-call-conv-lowering="classification-attr=test_classify" 
\
+// RUN:   | FileCheck %s
+
+#coerce_complex_to_vec2 = {
+  return = { kind = "direct" },
+  args   = [ { kind = "direct",
+               coerced_type = !cir.vector<2 x !cir.float> } ]
+}
+
+module attributes {
+  dlti.dl_spec = #dlti.dl_spec<
+    #dlti.dl_entry<f32, dense<32>: vector<2xi64>>>
+} {
+
+  cir.func @takes_complex(%arg0: !cir.complex<!cir.float>)
+      attributes { test_classify = #coerce_complex_to_vec2 } {
+    cir.return
+  }
+
+  // The signature changes to the coerced (vector) type; the body still
+  // expects the complex, so a reinterpret_cast lands at function entry to
+  // adapt the new block argument back to the original type.
+  // CHECK: cir.func{{.*}} @takes_complex(%[[ARG:.*]]: !cir.vector<2 x 
!cir.float>)
+  // CHECK:   %{{.*}} = cir.reinterpret_cast %[[ARG]] : !cir.vector<2 x 
!cir.float> -> !cir.complex<!cir.float>
+
+  cir.func @caller(%arg0: !cir.complex<!cir.float>)
+      attributes { test_classify = #coerce_complex_to_vec2 } {
+    cir.call @takes_complex(%arg0) : (!cir.complex<!cir.float>) -> ()
+    cir.return
+  }
+
+  // At the call site the rewriter coerces the original (complex) value to
+  // the vector type before passing it through.
+  // CHECK: cir.func{{.*}} @caller(%[[ARG:.*]]: !cir.vector<2 x !cir.float>)
+  // CHECK:   %[[COMPLEX:.*]] = cir.reinterpret_cast %[[ARG]] : !cir.vector<2 
x !cir.float> -> !cir.complex<!cir.float>
+  // CHECK:   %[[COERCED:.*]] = cir.reinterpret_cast %[[COMPLEX]] : 
!cir.complex<!cir.float> -> !cir.vector<2 x !cir.float>
+  // CHECK:   cir.call @takes_complex(%[[COERCED]]) : (!cir.vector<2 x 
!cir.float>) -> ()
+
+}

_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to