https://github.com/farzonl updated 
https://github.com/llvm/llvm-project/pull/173044

>From 90dc37f7225eedb10356fadcac0e27784188b523 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <[email protected]>
Date: Fri, 19 Dec 2025 11:17:20 -0500
Subject: [PATCH] [Matrix][HLSL] Implement Matrix single constant index swizzle

fixes #172805

For the constant case if we know the row index then we can compute the
offsets via `E->getEncodedElementAccess(Elts)`. We had to also add
column and row sizes to LValue so that we could compute the right index.
the emitter for `MatrixSingleSubscriptExpr` collects the sizes off the
type and passes it to `MakeMatrixRow`.
---
 clang/lib/CodeGen/CGExpr.cpp                  |  27 +++-
 clang/lib/CodeGen/CGValue.h                   |  18 ++-
 .../MatrixSingleSubscriptConstSwizzle.hlsl    | 118 +++++++++++++++++-
 3 files changed, 156 insertions(+), 7 deletions(-)

diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 27ee96cb6dc82..15f947dd397ca 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -4961,10 +4961,11 @@ LValue CodeGenFunction::EmitMatrixSingleSubscriptExpr(
     const MatrixSingleSubscriptExpr *E) {
   LValue Base = EmitLValue(E->getBase());
   llvm::Value *RowIdx = EmitMatrixIndexExpr(E->getRowIdx());
-
+  const auto *MatTy = E->getBase()->getType()->castAs<ConstantMatrixType>();
   return LValue::MakeMatrixRow(
       MaybeConvertMatrixAddress(Base.getAddress(), *this), RowIdx,
-      E->getBase()->getType(), Base.getBaseInfo(), TBAAAccessInfo());
+      MatTy->getNumColumns(), MatTy->getNumRows(), E->getBase()->getType(),
+      Base.getBaseInfo(), TBAAAccessInfo());
 }
 
 LValue CodeGenFunction::EmitMatrixSubscriptExpr(const MatrixSubscriptExpr *E) {
@@ -5239,8 +5240,28 @@ EmitExtVectorElementExpr(const ExtVectorElementExpr *E) {
     return LValue::MakeExtVectorElt(Base.getAddress(), CV, type,
                                     Base.getBaseInfo(), TBAAAccessInfo());
   }
-  if (Base.isMatrixRow())
+  if (Base.isMatrixRow()) {
+    if (auto *RowIdx =
+            llvm::dyn_cast<llvm::ConstantInt>(Base.getMatrixRowIdx())) {
+      llvm::SmallVector<llvm::Constant *, 8> MatIndices;
+      unsigned NumCols = Base.getMatrixNumCols();
+      unsigned NumRows = Base.getMatrixNumRows();
+      MatIndices.reserve(NumCols);
+
+      unsigned Row = RowIdx->getZExtValue();
+      for (unsigned C = 0; C < NumCols; ++C) {
+        unsigned Col = Indices[C];
+        unsigned Linear = Col * NumRows + Row;
+        MatIndices.push_back(llvm::ConstantInt::get(Int32Ty, Linear));
+      }
+
+      llvm::Constant *ConstIdxs = llvm::ConstantVector::get(MatIndices);
+      return LValue::MakeExtVectorElt(Base.getMatrixAddress(), ConstIdxs,
+                                      E->getBase()->getType(),
+                                      Base.getBaseInfo(), TBAAAccessInfo());
+    }
     return EmitUnsupportedLValue(E, "Matrix single index swizzle");
+  }
 
   assert(Base.isExtVectorElt() && "Can only subscript lvalue vec elts here!");
 
diff --git a/clang/lib/CodeGen/CGValue.h b/clang/lib/CodeGen/CGValue.h
index e1b6ec37be3d4..c3ae130014192 100644
--- a/clang/lib/CodeGen/CGValue.h
+++ b/clang/lib/CodeGen/CGValue.h
@@ -210,6 +210,9 @@ class LValue {
     const CGBitFieldInfo *BitFieldInfo;
   };
 
+  // Note: Only meaningful when isMatrixRow() and the row is swizzled.
+  unsigned NumCols, NumRows;
+
   QualType Type;
 
   // 'const' is unused here
@@ -391,7 +394,7 @@ class LValue {
   }
 
   Address getMatrixAddress() const {
-    assert(isMatrixElt());
+    assert(isMatrixElt() || isMatrixRow());
     return Addr;
   }
   llvm::Value *getMatrixPointer() const {
@@ -408,6 +411,16 @@ class LValue {
     return MatrixRowIdx;
   }
 
+  unsigned getMatrixNumRows() const {
+    assert(isMatrixRow());
+    return NumRows;
+  }
+
+  unsigned getMatrixNumCols() const {
+    assert(isMatrixRow());
+    return NumCols;
+  }
+
   // extended vector elements.
   Address getExtVectorAddress() const {
     assert(isExtVectorElt());
@@ -497,11 +510,14 @@ class LValue {
   }
 
   static LValue MakeMatrixRow(Address Addr, llvm::Value *RowIdx,
+                              unsigned NumCols, unsigned NumRows,
                               QualType MatrixTy, LValueBaseInfo BaseInfo,
                               TBAAAccessInfo TBAAInfo) {
     LValue LV;
     LV.LVType = MatrixRow;
     LV.MatrixRowIdx = RowIdx; // store the row index here
+    LV.NumCols = NumCols;
+    LV.NumRows = NumRows;
     LV.Initialize(MatrixTy, MatrixTy.getQualifiers(), Addr, BaseInfo, 
TBAAInfo);
     return LV;
   }
diff --git 
a/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptConstSwizzle.hlsl 
b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptConstSwizzle.hlsl
index edf831e3833d5..896b4d287ecba 100644
--- 
a/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptConstSwizzle.hlsl
+++ 
b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptConstSwizzle.hlsl
@@ -1,9 +1,6 @@
 // NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py 
UTC_ARGS: --version 6
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.7-library -disable-llvm-passes 
-emit-llvm -finclude-default-header -o - %s | FileCheck %s
 
-// BUG: https://github.com/llvm/llvm-project/issues/172805
-// XFAIL: *
-
 // CHECK-LABEL: define hidden void 
@_Z10setMatrix1Ru11matrix_typeILm4ELm4EfEDv4_f(
 // CHECK-SAME: ptr noalias noundef nonnull align 4 dereferenceable(64) 
[[M:%.*]], <4 x float> noundef nofpclass(nan inf) [[V:%.*]]) #[[ATTR0:[0-9]+]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
@@ -57,6 +54,121 @@ void setMatrix1(out float4x4 M, float4 V) {
 void setMatrix2(out int4x4 M, int4 V) {
     M[2].rgba = V;
 }
+
+// CHECK-LABEL: define hidden void 
@_Z22setMatrixVectorSwizzleRu11matrix_typeILm2ELm3EiEDv3_i(
+// CHECK-SAME: ptr noalias noundef nonnull align 4 dereferenceable(24) 
[[M:%.*]], <3 x i32> noundef [[V:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[M_ADDR:%.*]] = alloca ptr, align 4
+// CHECK-NEXT:    [[V_ADDR:%.*]] = alloca <3 x i32>, align 16
+// CHECK-NEXT:    store ptr [[M]], ptr [[M_ADDR]], align 4
+// CHECK-NEXT:    store <3 x i32> [[V]], ptr [[V_ADDR]], align 16
+// CHECK-NEXT:    [[TMP0:%.*]] = load <3 x i32>, ptr [[V_ADDR]], align 16
+// CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <3 x i32> [[TMP0]], <3 x i32> 
poison, <3 x i32> <i32 2, i32 1, i32 0>
+// CHECK-NEXT:    [[TMP2:%.*]] = load ptr, ptr [[M_ADDR]], align 4, !nonnull 
[[META3]], !align [[META4]]
+// CHECK-NEXT:    [[MATRIX_LOAD:%.*]] = load <6 x i32>, ptr [[TMP2]], align 4
+// CHECK-NEXT:    [[TMP3:%.*]] = extractelement <3 x i32> [[TMP1]], i32 0
+// CHECK-NEXT:    [[TMP4:%.*]] = insertelement <6 x i32> [[MATRIX_LOAD]], i32 
[[TMP3]], i32 0
+// CHECK-NEXT:    [[TMP5:%.*]] = extractelement <3 x i32> [[TMP1]], i32 1
+// CHECK-NEXT:    [[TMP6:%.*]] = insertelement <6 x i32> [[TMP4]], i32 
[[TMP5]], i32 2
+// CHECK-NEXT:    [[TMP7:%.*]] = extractelement <3 x i32> [[TMP1]], i32 2
+// CHECK-NEXT:    [[TMP8:%.*]] = insertelement <6 x i32> [[TMP6]], i32 
[[TMP7]], i32 4
+// CHECK-NEXT:    store <6 x i32> [[TMP8]], ptr [[TMP2]], align 4
+// CHECK-NEXT:    ret void
+//
+void setMatrixVectorSwizzle(out int2x3 M, int3 V) {
+    M[0] = V.bgr;
+}
+
+// CHECK-LABEL: define hidden void 
@_Z24setVectorOnMatrixSwizzleRu11matrix_typeILm2ELm3EiEDv3_i(
+// CHECK-SAME: ptr noalias noundef nonnull align 4 dereferenceable(24) 
[[M:%.*]], <3 x i32> noundef [[V:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[M_ADDR:%.*]] = alloca ptr, align 4
+// CHECK-NEXT:    [[V_ADDR:%.*]] = alloca <3 x i32>, align 16
+// CHECK-NEXT:    store ptr [[M]], ptr [[M_ADDR]], align 4
+// CHECK-NEXT:    store <3 x i32> [[V]], ptr [[V_ADDR]], align 16
+// CHECK-NEXT:    [[TMP0:%.*]] = load <3 x i32>, ptr [[V_ADDR]], align 16
+// CHECK-NEXT:    [[TMP1:%.*]] = load ptr, ptr [[M_ADDR]], align 4, !nonnull 
[[META3]], !align [[META4]]
+// CHECK-NEXT:    [[TMP2:%.*]] = extractelement <3 x i32> [[TMP0]], i32 0
+// CHECK-NEXT:    [[TMP3:%.*]] = getelementptr <6 x i32>, ptr [[TMP1]], i32 0, 
i32 1
+// CHECK-NEXT:    store i32 [[TMP2]], ptr [[TMP3]], align 4
+// CHECK-NEXT:    [[TMP4:%.*]] = extractelement <3 x i32> [[TMP0]], i32 1
+// CHECK-NEXT:    [[TMP5:%.*]] = getelementptr <6 x i32>, ptr [[TMP1]], i32 0, 
i32 5
+// CHECK-NEXT:    store i32 [[TMP4]], ptr [[TMP5]], align 4
+// CHECK-NEXT:    [[TMP6:%.*]] = extractelement <3 x i32> [[TMP0]], i32 2
+// CHECK-NEXT:    [[TMP7:%.*]] = getelementptr <6 x i32>, ptr [[TMP1]], i32 0, 
i32 3
+// CHECK-NEXT:    store i32 [[TMP6]], ptr [[TMP7]], align 4
+// CHECK-NEXT:    ret void
+//
+void setVectorOnMatrixSwizzle(out int2x3 M, int3 V) {
+    M[1].rbg = V;
+}
+
+// CHECK-LABEL: define hidden void 
@_Z19setMatrixFromMatrixRu11matrix_typeILm2ELm3EiES_i(
+// CHECK-SAME: ptr noalias noundef nonnull align 4 dereferenceable(24) 
[[M:%.*]], <6 x i32> noundef [[N:%.*]], i32 noundef [[MINDEX:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[M_ADDR:%.*]] = alloca ptr, align 4
+// CHECK-NEXT:    [[N_ADDR:%.*]] = alloca [6 x i32], align 4
+// CHECK-NEXT:    [[MINDEX_ADDR:%.*]] = alloca i32, align 4
+// CHECK-NEXT:    store ptr [[M]], ptr [[M_ADDR]], align 4
+// CHECK-NEXT:    store <6 x i32> [[N]], ptr [[N_ADDR]], align 4
+// CHECK-NEXT:    store i32 [[MINDEX]], ptr [[MINDEX_ADDR]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = load <6 x i32>, ptr [[N_ADDR]], align 4
+// CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <6 x i32> [[TMP0]], <6 x i32> 
poison, <3 x i32> <i32 3, i32 5, i32 1>
+// CHECK-NEXT:    [[TMP2:%.*]] = load ptr, ptr [[M_ADDR]], align 4, !nonnull 
[[META3]], !align [[META4]]
+// CHECK-NEXT:    [[TMP3:%.*]] = load i32, ptr [[MINDEX_ADDR]], align 4
+// CHECK-NEXT:    [[MATRIX_LOAD:%.*]] = load <6 x i32>, ptr [[TMP2]], align 4
+// CHECK-NEXT:    [[TMP4:%.*]] = add i32 0, [[TMP3]]
+// CHECK-NEXT:    [[TMP5:%.*]] = extractelement <3 x i32> [[TMP1]], i32 0
+// CHECK-NEXT:    [[TMP6:%.*]] = insertelement <6 x i32> [[MATRIX_LOAD]], i32 
[[TMP5]], i32 [[TMP4]]
+// CHECK-NEXT:    [[TMP7:%.*]] = add i32 2, [[TMP3]]
+// CHECK-NEXT:    [[TMP8:%.*]] = extractelement <3 x i32> [[TMP1]], i32 1
+// CHECK-NEXT:    [[TMP9:%.*]] = insertelement <6 x i32> [[TMP6]], i32 
[[TMP8]], i32 [[TMP7]]
+// CHECK-NEXT:    [[TMP10:%.*]] = add i32 4, [[TMP3]]
+// CHECK-NEXT:    [[TMP11:%.*]] = extractelement <3 x i32> [[TMP1]], i32 2
+// CHECK-NEXT:    [[TMP12:%.*]] = insertelement <6 x i32> [[TMP9]], i32 
[[TMP11]], i32 [[TMP10]]
+// CHECK-NEXT:    store <6 x i32> [[TMP12]], ptr [[TMP2]], align 4
+// CHECK-NEXT:    ret void
+//
+void setMatrixFromMatrix(out int2x3 M, int2x3 N, int MIndex) {
+    M[MIndex] = N[1].gbr;
+}
+
+// CHECK-LABEL: define hidden void 
@_Z26setMatrixSwizzleFromMatrixRu11matrix_typeILm2ELm3EiES_i(
+// CHECK-SAME: ptr noalias noundef nonnull align 4 dereferenceable(24) 
[[M:%.*]], <6 x i32> noundef [[N:%.*]], i32 noundef [[NINDEX:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[M_ADDR:%.*]] = alloca ptr, align 4
+// CHECK-NEXT:    [[N_ADDR:%.*]] = alloca [6 x i32], align 4
+// CHECK-NEXT:    [[NINDEX_ADDR:%.*]] = alloca i32, align 4
+// CHECK-NEXT:    store ptr [[M]], ptr [[M_ADDR]], align 4
+// CHECK-NEXT:    store <6 x i32> [[N]], ptr [[N_ADDR]], align 4
+// CHECK-NEXT:    store i32 [[NINDEX]], ptr [[NINDEX_ADDR]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr [[NINDEX_ADDR]], align 4
+// CHECK-NEXT:    [[TMP1:%.*]] = load <6 x i32>, ptr [[N_ADDR]], align 4
+// CHECK-NEXT:    [[TMP2:%.*]] = add i32 0, [[TMP0]]
+// CHECK-NEXT:    [[MATRIX_ELEM:%.*]] = extractelement <6 x i32> [[TMP1]], i32 
[[TMP2]]
+// CHECK-NEXT:    [[MATRIX_ROW_INS:%.*]] = insertelement <3 x i32> poison, i32 
[[MATRIX_ELEM]], i32 0
+// CHECK-NEXT:    [[TMP3:%.*]] = add i32 2, [[TMP0]]
+// CHECK-NEXT:    [[MATRIX_ELEM1:%.*]] = extractelement <6 x i32> [[TMP1]], 
i32 [[TMP3]]
+// CHECK-NEXT:    [[MATRIX_ROW_INS2:%.*]] = insertelement <3 x i32> 
[[MATRIX_ROW_INS]], i32 [[MATRIX_ELEM1]], i32 1
+// CHECK-NEXT:    [[TMP4:%.*]] = add i32 4, [[TMP0]]
+// CHECK-NEXT:    [[MATRIX_ELEM3:%.*]] = extractelement <6 x i32> [[TMP1]], 
i32 [[TMP4]]
+// CHECK-NEXT:    [[MATRIX_ROW_INS4:%.*]] = insertelement <3 x i32> 
[[MATRIX_ROW_INS2]], i32 [[MATRIX_ELEM3]], i32 2
+// CHECK-NEXT:    [[TMP5:%.*]] = load ptr, ptr [[M_ADDR]], align 4, !nonnull 
[[META3]], !align [[META4]]
+// CHECK-NEXT:    [[TMP6:%.*]] = extractelement <3 x i32> [[MATRIX_ROW_INS4]], 
i32 0
+// CHECK-NEXT:    [[TMP7:%.*]] = getelementptr <6 x i32>, ptr [[TMP5]], i32 0, 
i32 5
+// CHECK-NEXT:    store i32 [[TMP6]], ptr [[TMP7]], align 4
+// CHECK-NEXT:    [[TMP8:%.*]] = extractelement <3 x i32> [[MATRIX_ROW_INS4]], 
i32 1
+// CHECK-NEXT:    [[TMP9:%.*]] = getelementptr <6 x i32>, ptr [[TMP5]], i32 0, 
i32 1
+// CHECK-NEXT:    store i32 [[TMP8]], ptr [[TMP9]], align 4
+// CHECK-NEXT:    [[TMP10:%.*]] = extractelement <3 x i32> 
[[MATRIX_ROW_INS4]], i32 2
+// CHECK-NEXT:    [[TMP11:%.*]] = getelementptr <6 x i32>, ptr [[TMP5]], i32 
0, i32 3
+// CHECK-NEXT:    store i32 [[TMP10]], ptr [[TMP11]], align 4
+// CHECK-NEXT:    ret void
+//
+void setMatrixSwizzleFromMatrix(out int2x3 M, int2x3 N, int NIndex) {
+    M[1].brg = N[NIndex];
+}
+
 //.
 // CHECK: [[META3]] = !{}
 // CHECK: [[META4]] = !{i64 4}

_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to