https://github.com/emias11 created 
https://github.com/llvm/llvm-project/pull/177884

Fixes #176468

Extract common HLSL-specific code from `isVectorElt()`, `isMatrixElt()`, 
and `isMatrixRow()` cases in `EmitStoreThroughLValue` into two helper functions:
- `emitExtendForHLSLStore`
- `emitHLSLElementStore`

>From 8656beb392758b535cdad3d1d09ff50c4c7199c2 Mon Sep 17 00:00:00 2001
From: Mia Stapleton <[email protected]>
Date: Sun, 25 Jan 2026 22:14:24 +0000
Subject: [PATCH] HLSL Refactor to fix #176468

---
 clang/lib/CodeGen/CGExpr.cpp | 66 ++++++++++++++++++++----------------
 1 file changed, 36 insertions(+), 30 deletions(-)

diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 76a3939cd28eb..126cd7634bc70 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -2637,6 +2637,33 @@ RValue CodeGenFunction::EmitLoadOfGlobalRegLValue(LValue 
LV) {
   return RValue::get(Call);
 }
 
+/// Extends a scalar or vector value if needed for HLSL element stores. 
+static llvm::Value *emitExtendForHLSLStore(CGBuilderTy &Builder,
+                                           llvm::Value *Val,
+                                           llvm::Type *ElemTy) {
+  if (Val->getType()->getScalarSizeInBits() >= ElemTy->getScalarSizeInBits())
+    return Val;
+
+  llvm::Type *TargetTy = ElemTy->getScalarType();
+  if (auto *VecTy = dyn_cast<llvm::FixedVectorType>(Val->getType()))
+    TargetTy = llvm::FixedVectorType::get(TargetTy, VecTy->getNumElements());
+
+  return Builder.CreateZExt(Val, TargetTy);
+}
+
+/// Emit an HLSL element-wise store using GEP.
+static void emitHLSLElementStore(CGBuilderTy &Builder,
+                                 llvm::Value *Val,
+                                 Address DstAddr,
+                                 llvm::Value *Idx,
+                                 llvm::Type *DestAddrTy,
+                                 CharUnits ElemAlign,
+                                 bool IsVolatile) {
+  llvm::Value *Zero = Builder.getInt32(0);
+  Address DstElemAddr = Builder.CreateGEP(DstAddr, {Zero, Idx}, DestAddrTy, 
ElemAlign);
+  Builder.CreateStore(Val, DstElemAddr, IsVolatile);
+}
+
 /// EmitStoreThroughLValue - Store the specified rvalue into the specified
 /// lvalue, where both are guaranteed to the have the same type, and that type
 /// is 'Ty'.
@@ -2657,16 +2684,10 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue 
Src, LValue Dst,
         assert(ElemTy->getScalarSizeInBits() >= 8 &&
                "vector element type must be at least byte-sized");
 
-        llvm::Value *Val = Src.getScalarVal();
-        if (Val->getType()->getPrimitiveSizeInBits() <
-            ElemTy->getScalarSizeInBits())
-          Val = Builder.CreateZExt(Val, ElemTy->getScalarType());
+        llvm::Value *Val = emitExtendForHLSLStore(Builder, Src.getScalarVal(), 
ElemTy);
 
-        llvm::Value *Idx = Dst.getVectorIdx();
-        llvm::Value *Zero = llvm::ConstantInt::get(Int32Ty, 0);
-        Address DstElemAddr =
-            Builder.CreateGEP(DstAddr, {Zero, Idx}, DestAddrTy, ElemAlign);
-        Builder.CreateStore(Val, DstElemAddr, Dst.isVolatileQualified());
+        emitHLSLElementStore(Builder, Val, DstAddr, Dst.getVectorIdx(), 
+                             DestAddrTy, ElemAlign, Dst.isVolatileQualified());
         return;
       }
 
@@ -2730,16 +2751,10 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue 
Src, LValue Dst,
         assert(ElemTy->getScalarSizeInBits() >= 8 &&
                "matrix element type must be at least byte-sized");
 
-        llvm::Value *Val = Src.getScalarVal();
-        if (Val->getType()->getPrimitiveSizeInBits() <
-            ElemTy->getScalarSizeInBits())
-          Val = Builder.CreateZExt(Val, ElemTy->getScalarType());
+        llvm::Value *Val = emitExtendForHLSLStore(Builder, Src.getScalarVal(), 
ElemTy);
 
-        llvm::Value *Idx = Dst.getMatrixIdx();
-        llvm::Value *Zero = llvm::ConstantInt::get(Int32Ty, 0);
-        Address DstElemAddr =
-            Builder.CreateGEP(DstAddr, {Zero, Idx}, DestAddrTy, ElemAlign);
-        Builder.CreateStore(Val, DstElemAddr, Dst.isVolatileQualified());
+        emitHLSLElementStore(Builder, Val, DstAddr, Dst.getMatrixIdx(), 
+                             DestAddrTy, ElemAlign, Dst.isVolatileQualified());
         return;
       }
 
@@ -2780,14 +2795,7 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, 
LValue Dst,
       assert(ElemTy->getScalarSizeInBits() >= 8 &&
              "matrix element type must be at least byte-sized");
 
-      llvm::Value *RowVal = Src.getScalarVal();
-      if (RowVal->getType()->getScalarType()->getPrimitiveSizeInBits() <
-          ElemTy->getScalarSizeInBits()) {
-        auto *RowValVecTy = cast<llvm::FixedVectorType>(RowVal->getType());
-        llvm::Type *StorageElmTy = llvm::FixedVectorType::get(
-            ElemTy->getScalarType(), RowValVecTy->getNumElements());
-        RowVal = Builder.CreateZExt(RowVal, StorageElmTy);
-      }
+      llvm::Value *RowVal = emitExtendForHLSLStore(Builder, 
Src.getScalarVal(), ElemTy);
 
       llvm::MatrixBuilder MB(Builder);
 
@@ -2811,11 +2819,9 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, 
LValue Dst,
         llvm::Value *EltIndex =
             MB.CreateIndex(Row, ColIdx, NumRows, NumCols, IsMatrixRowMajor);
         llvm::Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
-        llvm::Value *Zero = llvm::ConstantInt::get(Int32Ty, 0);
         llvm::Value *NewElt = Builder.CreateExtractElement(RowVal, Lane);
-        Address DstElemAddr =
-            Builder.CreateGEP(DstAddr, {Zero, EltIndex}, DestAddrTy, 
ElemAlign);
-        Builder.CreateStore(NewElt, DstElemAddr, Dst.isVolatileQualified());
+        emitHLSLElementStore(Builder, NewElt, DstAddr, EltIndex, 
+                             DestAddrTy, ElemAlign, Dst.isVolatileQualified());
       }
 
       return;

_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to