https://github.com/Icohedron updated https://github.com/llvm/llvm-project/pull/176216
>From 99fff5b751847ed98511346fadaf315b25b27751 Mon Sep 17 00:00:00 2001 From: Deric Cheung <[email protected]> Date: Thu, 15 Jan 2026 10:10:03 -0800 Subject: [PATCH 1/4] Update indexed matrix elements individually for HLSL --- clang/lib/CodeGen/CGExpr.cpp | 71 ++++++--- .../MatrixSingleSubscriptConstSwizzle.hlsl | 22 +-- .../MatrixSingleSubscriptDynamicSwizzle.hlsl | 25 +-- .../MatrixSingleSubscriptSetter.hlsl | 142 ++++++++++-------- .../BasicFeatures/matrix-type-indexing.hlsl | 7 +- clang/test/CodeGenHLSL/BoolMatrix.hlsl | 17 +-- 6 files changed, 165 insertions(+), 119 deletions(-) diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index 896c60b13c160..6f08d0ef1d3b1 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -2716,6 +2716,37 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst, return EmitStoreThroughGlobalRegLValue(Src, Dst); if (Dst.isMatrixElt()) { + if (getLangOpts().HLSL) { + // HLSL allows direct access to matrix elements, so storing to + // individual elements of a matrix through MatrixElt is handled as + // separate store instructions. + Address DstAddr = Dst.getMatrixAddress(); + llvm::Type *DestAddrTy = DstAddr.getElementType(); + llvm::Type *ElemTy = DestAddrTy->getScalarType(); + CharUnits ElemAlign = CharUnits::fromQuantity( + CGM.getDataLayout().getPrefTypeAlign(ElemTy)); + + assert(ElemTy->getScalarSizeInBits() >= 8 && + "matrix element type must be at least byte-sized"); + + llvm::Value *Val = Src.getScalarVal(); + if (Val->getType()->getPrimitiveSizeInBits() < + ElemTy->getScalarSizeInBits()) + Val = Builder.CreateZExt(Val, ElemTy->getScalarType()); + + llvm::Value *Idx = Dst.getMatrixIdx(); + if (CGM.getCodeGenOpts().OptimizationLevel > 0) { + const auto *const MatTy = Dst.getType()->castAs<ConstantMatrixType>(); + llvm::MatrixBuilder MB(Builder); + MB.CreateIndexAssumption(Idx, MatTy->getNumElementsFlattened()); + } + llvm::Value *Zero = llvm::ConstantInt::get(Int32Ty, 0); + Address DstElemAddr = + Builder.CreateGEP(DstAddr, {Zero, Idx}, DestAddrTy, ElemAlign); + Builder.CreateStore(Val, DstElemAddr, Dst.isVolatileQualified()); + return; + } + llvm::Value *Idx = Dst.getMatrixIdx(); if (CGM.getCodeGenOpts().OptimizationLevel > 0) { const auto *const MatTy = Dst.getType()->castAs<ConstantMatrixType>(); @@ -2724,10 +2755,6 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst, } llvm::Instruction *Load = Builder.CreateLoad(Dst.getMatrixAddress()); llvm::Value *InsertVal = Src.getScalarVal(); - if (getLangOpts().HLSL && InsertVal->getType()->isIntegerTy(1)) { - llvm::Type *StorageElmTy = Load->getType()->getScalarType(); - InsertVal = Builder.CreateZExt(InsertVal, StorageElmTy); - } llvm::Value *Vec = Builder.CreateInsertElement(Load, InsertVal, Idx, "matins"); auto *I = Builder.CreateStore(Vec, Dst.getMatrixAddress(), @@ -2736,6 +2763,11 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst, return; } if (Dst.isMatrixRow()) { + // NOTE: Since there are no other languages that implement matrix single + // subscripting, the logic here is specific to HLSL which allows stores to + // indivdual rows of matrices. + assert(getLangOpts().HLSL && + "Store through matrix row LValues is only implemented for HLSL!"); QualType MatTy = Dst.getType(); const ConstantMatrixType *MT = MatTy->castAs<ConstantMatrixType>(); @@ -2743,21 +2775,21 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst, unsigned NumCols = MT->getNumColumns(); unsigned NumLanes = NumCols; - llvm::Value *MatrixVec = - Builder.CreateLoad(Dst.getAddress(), "matrix.load"); + Address DstAddr = Dst.getMatrixAddress(); + llvm::Type *DestAddrTy = DstAddr.getElementType(); + llvm::Type *ElemTy = DestAddrTy->getScalarType(); + CharUnits ElemAlign = CharUnits::fromQuantity( + CGM.getDataLayout().getPrefTypeAlign(ElemTy)); - llvm::Value *Row = Dst.getMatrixRowIdx(); - llvm::Value *RowVal = Src.getScalarVal(); // <NumCols x T> - - if (RowVal->getType()->isIntOrIntVectorTy(1)) { - // NOTE: If matrix single subscripting becomes a feature in languages - // other than HLSL, the following assert should be removed and the - // assert condition should be made part of the enclosing if-statement - // condition as is the case for similar logic for Dst.isMatrixElt() - assert(getLangOpts().HLSL); + assert(ElemTy->getScalarSizeInBits() >= 8 && + "matrix element type must be at least byte-sized"); + + llvm::Value *RowVal = Src.getScalarVal(); + if (RowVal->getType()->getScalarType()->getPrimitiveSizeInBits() < + ElemTy->getScalarSizeInBits()) { auto *RowValVecTy = cast<llvm::FixedVectorType>(RowVal->getType()); llvm::Type *StorageElmTy = - llvm::FixedVectorType::get(MatrixVec->getType()->getScalarType(), + llvm::FixedVectorType::get(ElemTy->getScalarType(), RowValVecTy->getNumElements()); RowVal = Builder.CreateZExt(RowVal, StorageElmTy); } @@ -2772,6 +2804,7 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst, ->getNumElements(); } + llvm::Value *Row = Dst.getMatrixRowIdx(); for (unsigned Col = 0; Col < NumLanes; ++Col) { llvm::Value *ColIdx; if (ColConstsIndices) @@ -2783,11 +2816,13 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst, llvm::Value *EltIndex = MB.CreateIndex(Row, ColIdx, NumRows, NumCols, IsMatrixRowMajor); llvm::Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col); + llvm::Value *Zero = llvm::ConstantInt::get(Int32Ty, 0); llvm::Value *NewElt = Builder.CreateExtractElement(RowVal, Lane); - MatrixVec = Builder.CreateInsertElement(MatrixVec, NewElt, EltIndex); + Address DstElemAddr = + Builder.CreateGEP(DstAddr, {Zero, EltIndex}, DestAddrTy, ElemAlign); + Builder.CreateStore(NewElt, DstElemAddr, Dst.isVolatileQualified()); } - Builder.CreateStore(MatrixVec, Dst.getAddress()); return; } diff --git a/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptConstSwizzle.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptConstSwizzle.hlsl index 896b4d287ecba..02885d153697a 100644 --- a/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptConstSwizzle.hlsl +++ b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptConstSwizzle.hlsl @@ -65,14 +65,15 @@ void setMatrix2(out int4x4 M, int4 V) { // CHECK-NEXT: [[TMP0:%.*]] = load <3 x i32>, ptr [[V_ADDR]], align 16 // CHECK-NEXT: [[TMP1:%.*]] = shufflevector <3 x i32> [[TMP0]], <3 x i32> poison, <3 x i32> <i32 2, i32 1, i32 0> // CHECK-NEXT: [[TMP2:%.*]] = load ptr, ptr [[M_ADDR]], align 4, !nonnull [[META3]], !align [[META4]] -// CHECK-NEXT: [[MATRIX_LOAD:%.*]] = load <6 x i32>, ptr [[TMP2]], align 4 // CHECK-NEXT: [[TMP3:%.*]] = extractelement <3 x i32> [[TMP1]], i32 0 -// CHECK-NEXT: [[TMP4:%.*]] = insertelement <6 x i32> [[MATRIX_LOAD]], i32 [[TMP3]], i32 0 +// CHECK-NEXT: [[TMP4:%.*]] = getelementptr <6 x i32>, ptr [[TMP2]], i32 0, i32 0 +// CHECK-NEXT: store i32 [[TMP3]], ptr [[TMP4]], align 4 // CHECK-NEXT: [[TMP5:%.*]] = extractelement <3 x i32> [[TMP1]], i32 1 -// CHECK-NEXT: [[TMP6:%.*]] = insertelement <6 x i32> [[TMP4]], i32 [[TMP5]], i32 2 +// CHECK-NEXT: [[TMP6:%.*]] = getelementptr <6 x i32>, ptr [[TMP2]], i32 0, i32 2 +// CHECK-NEXT: store i32 [[TMP5]], ptr [[TMP6]], align 4 // CHECK-NEXT: [[TMP7:%.*]] = extractelement <3 x i32> [[TMP1]], i32 2 -// CHECK-NEXT: [[TMP8:%.*]] = insertelement <6 x i32> [[TMP6]], i32 [[TMP7]], i32 4 -// CHECK-NEXT: store <6 x i32> [[TMP8]], ptr [[TMP2]], align 4 +// CHECK-NEXT: [[TMP8:%.*]] = getelementptr <6 x i32>, ptr [[TMP2]], i32 0, i32 4 +// CHECK-NEXT: store i32 [[TMP7]], ptr [[TMP8]], align 4 // CHECK-NEXT: ret void // void setMatrixVectorSwizzle(out int2x3 M, int3 V) { @@ -116,17 +117,18 @@ void setVectorOnMatrixSwizzle(out int2x3 M, int3 V) { // CHECK-NEXT: [[TMP1:%.*]] = shufflevector <6 x i32> [[TMP0]], <6 x i32> poison, <3 x i32> <i32 3, i32 5, i32 1> // CHECK-NEXT: [[TMP2:%.*]] = load ptr, ptr [[M_ADDR]], align 4, !nonnull [[META3]], !align [[META4]] // CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr [[MINDEX_ADDR]], align 4 -// CHECK-NEXT: [[MATRIX_LOAD:%.*]] = load <6 x i32>, ptr [[TMP2]], align 4 // CHECK-NEXT: [[TMP4:%.*]] = add i32 0, [[TMP3]] // CHECK-NEXT: [[TMP5:%.*]] = extractelement <3 x i32> [[TMP1]], i32 0 -// CHECK-NEXT: [[TMP6:%.*]] = insertelement <6 x i32> [[MATRIX_LOAD]], i32 [[TMP5]], i32 [[TMP4]] +// CHECK-NEXT: [[TMP6:%.*]] = getelementptr <6 x i32>, ptr [[TMP2]], i32 0, i32 [[TMP4]] +// CHECK-NEXT: store i32 [[TMP5]], ptr [[TMP6]], align 4 // CHECK-NEXT: [[TMP7:%.*]] = add i32 2, [[TMP3]] // CHECK-NEXT: [[TMP8:%.*]] = extractelement <3 x i32> [[TMP1]], i32 1 -// CHECK-NEXT: [[TMP9:%.*]] = insertelement <6 x i32> [[TMP6]], i32 [[TMP8]], i32 [[TMP7]] +// CHECK-NEXT: [[TMP9:%.*]] = getelementptr <6 x i32>, ptr [[TMP2]], i32 0, i32 [[TMP7]] +// CHECK-NEXT: store i32 [[TMP8]], ptr [[TMP9]], align 4 // CHECK-NEXT: [[TMP10:%.*]] = add i32 4, [[TMP3]] // CHECK-NEXT: [[TMP11:%.*]] = extractelement <3 x i32> [[TMP1]], i32 2 -// CHECK-NEXT: [[TMP12:%.*]] = insertelement <6 x i32> [[TMP9]], i32 [[TMP11]], i32 [[TMP10]] -// CHECK-NEXT: store <6 x i32> [[TMP12]], ptr [[TMP2]], align 4 +// CHECK-NEXT: [[TMP12:%.*]] = getelementptr <6 x i32>, ptr [[TMP2]], i32 0, i32 [[TMP10]] +// CHECK-NEXT: store i32 [[TMP11]], ptr [[TMP12]], align 4 // CHECK-NEXT: ret void // void setMatrixFromMatrix(out int2x3 M, int2x3 N, int MIndex) { diff --git a/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptDynamicSwizzle.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptDynamicSwizzle.hlsl index bfd6e68af8775..97ce63f545cff 100644 --- a/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptDynamicSwizzle.hlsl +++ b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptDynamicSwizzle.hlsl @@ -13,20 +13,22 @@ // CHECK-NEXT: [[TMP0:%.*]] = load <4 x float>, ptr [[V_ADDR]], align 16 // CHECK-NEXT: [[TMP1:%.*]] = load ptr, ptr [[M_ADDR]], align 4, !nonnull [[META3:![0-9]+]], !align [[META4:![0-9]+]] // CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[INDEX_ADDR]], align 4 -// CHECK-NEXT: [[MATRIX_LOAD:%.*]] = load <16 x float>, ptr [[TMP1]], align 4 // CHECK-NEXT: [[TMP3:%.*]] = add i32 12, [[TMP2]] // CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x float> [[TMP0]], i32 0 -// CHECK-NEXT: [[TMP5:%.*]] = insertelement <16 x float> [[MATRIX_LOAD]], float [[TMP4]], i32 [[TMP3]] +// CHECK-NEXT: [[TMP5:%.*]] = getelementptr <16 x float>, ptr [[TMP1]], i32 0, i32 [[TMP3]] +// CHECK-NEXT: store float [[TMP4]], ptr [[TMP5]], align 4 // CHECK-NEXT: [[TMP6:%.*]] = add i32 8, [[TMP2]] // CHECK-NEXT: [[TMP7:%.*]] = extractelement <4 x float> [[TMP0]], i32 1 -// CHECK-NEXT: [[TMP8:%.*]] = insertelement <16 x float> [[TMP5]], float [[TMP7]], i32 [[TMP6]] +// CHECK-NEXT: [[TMP8:%.*]] = getelementptr <16 x float>, ptr [[TMP1]], i32 0, i32 [[TMP6]] +// CHECK-NEXT: store float [[TMP7]], ptr [[TMP8]], align 4 // CHECK-NEXT: [[TMP9:%.*]] = add i32 4, [[TMP2]] // CHECK-NEXT: [[TMP10:%.*]] = extractelement <4 x float> [[TMP0]], i32 2 -// CHECK-NEXT: [[TMP11:%.*]] = insertelement <16 x float> [[TMP8]], float [[TMP10]], i32 [[TMP9]] +// CHECK-NEXT: [[TMP11:%.*]] = getelementptr <16 x float>, ptr [[TMP1]], i32 0, i32 [[TMP9]] +// CHECK-NEXT: store float [[TMP10]], ptr [[TMP11]], align 4 // CHECK-NEXT: [[TMP12:%.*]] = add i32 0, [[TMP2]] // CHECK-NEXT: [[TMP13:%.*]] = extractelement <4 x float> [[TMP0]], i32 3 -// CHECK-NEXT: [[TMP14:%.*]] = insertelement <16 x float> [[TMP11]], float [[TMP13]], i32 [[TMP12]] -// CHECK-NEXT: store <16 x float> [[TMP14]], ptr [[TMP1]], align 4 +// CHECK-NEXT: [[TMP14:%.*]] = getelementptr <16 x float>, ptr [[TMP1]], i32 0, i32 [[TMP12]] +// CHECK-NEXT: store float [[TMP13]], ptr [[TMP14]], align 4 // CHECK-NEXT: ret void // void setMatrix(out float4x4 M, int index, float4 V) { @@ -131,17 +133,18 @@ int3 getMatrixSwizzle2x3(out int2x3 M, int index) { // CHECK-NEXT: [[MATRIX_ROW_INS4:%.*]] = insertelement <3 x i32> [[MATRIX_ROW_INS2]], i32 [[MATRIX_ELEM3]], i32 2 // CHECK-NEXT: [[TMP5:%.*]] = load ptr, ptr [[M_ADDR]], align 4, !nonnull [[META3]], !align [[META4]] // CHECK-NEXT: [[TMP6:%.*]] = load i32, ptr [[INDEX_ADDR]], align 4 -// CHECK-NEXT: [[MATRIX_LOAD:%.*]] = load <6 x i32>, ptr [[TMP5]], align 4 // CHECK-NEXT: [[TMP7:%.*]] = add i32 4, [[TMP6]] // CHECK-NEXT: [[TMP8:%.*]] = extractelement <3 x i32> [[MATRIX_ROW_INS4]], i32 0 -// CHECK-NEXT: [[TMP9:%.*]] = insertelement <6 x i32> [[MATRIX_LOAD]], i32 [[TMP8]], i32 [[TMP7]] +// CHECK-NEXT: [[TMP9:%.*]] = getelementptr <6 x i32>, ptr [[TMP5]], i32 0, i32 [[TMP7]] +// CHECK-NEXT: store i32 [[TMP8]], ptr [[TMP9]], align 4 // CHECK-NEXT: [[TMP10:%.*]] = add i32 0, [[TMP6]] // CHECK-NEXT: [[TMP11:%.*]] = extractelement <3 x i32> [[MATRIX_ROW_INS4]], i32 1 -// CHECK-NEXT: [[TMP12:%.*]] = insertelement <6 x i32> [[TMP9]], i32 [[TMP11]], i32 [[TMP10]] +// CHECK-NEXT: [[TMP12:%.*]] = getelementptr <6 x i32>, ptr [[TMP5]], i32 0, i32 [[TMP10]] +// CHECK-NEXT: store i32 [[TMP11]], ptr [[TMP12]], align 4 // CHECK-NEXT: [[TMP13:%.*]] = add i32 2, [[TMP6]] // CHECK-NEXT: [[TMP14:%.*]] = extractelement <3 x i32> [[MATRIX_ROW_INS4]], i32 2 -// CHECK-NEXT: [[TMP15:%.*]] = insertelement <6 x i32> [[TMP12]], i32 [[TMP14]], i32 [[TMP13]] -// CHECK-NEXT: store <6 x i32> [[TMP15]], ptr [[TMP5]], align 4 +// CHECK-NEXT: [[TMP15:%.*]] = getelementptr <6 x i32>, ptr [[TMP5]], i32 0, i32 [[TMP13]] +// CHECK-NEXT: store i32 [[TMP14]], ptr [[TMP15]], align 4 // CHECK-NEXT: ret void // void setMatrixSwizzleFromMatrix(out int2x3 M, int2x3 N, int index) { diff --git a/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptSetter.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptSetter.hlsl index d314f3a87d619..15861b3211606 100644 --- a/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptSetter.hlsl +++ b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptSetter.hlsl @@ -13,20 +13,22 @@ // CHECK-NEXT: [[TMP0:%.*]] = load <4 x float>, ptr [[V_ADDR]], align 16 // CHECK-NEXT: [[TMP1:%.*]] = load ptr, ptr [[M_ADDR]], align 4, !nonnull [[META3:![0-9]+]], !align [[META4:![0-9]+]] // CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[INDEX_ADDR]], align 4 -// CHECK-NEXT: [[MATRIX_LOAD:%.*]] = load <16 x float>, ptr [[TMP1]], align 4 // CHECK-NEXT: [[TMP3:%.*]] = add i32 0, [[TMP2]] // CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x float> [[TMP0]], i32 0 -// CHECK-NEXT: [[TMP5:%.*]] = insertelement <16 x float> [[MATRIX_LOAD]], float [[TMP4]], i32 [[TMP3]] +// CHECK-NEXT: [[TMP5:%.*]] = getelementptr <16 x float>, ptr [[TMP1]], i32 0, i32 [[TMP3]] +// CHECK-NEXT: store float [[TMP4]], ptr [[TMP5]], align 4 // CHECK-NEXT: [[TMP6:%.*]] = add i32 4, [[TMP2]] // CHECK-NEXT: [[TMP7:%.*]] = extractelement <4 x float> [[TMP0]], i32 1 -// CHECK-NEXT: [[TMP8:%.*]] = insertelement <16 x float> [[TMP5]], float [[TMP7]], i32 [[TMP6]] +// CHECK-NEXT: [[TMP8:%.*]] = getelementptr <16 x float>, ptr [[TMP1]], i32 0, i32 [[TMP6]] +// CHECK-NEXT: store float [[TMP7]], ptr [[TMP8]], align 4 // CHECK-NEXT: [[TMP9:%.*]] = add i32 8, [[TMP2]] // CHECK-NEXT: [[TMP10:%.*]] = extractelement <4 x float> [[TMP0]], i32 2 -// CHECK-NEXT: [[TMP11:%.*]] = insertelement <16 x float> [[TMP8]], float [[TMP10]], i32 [[TMP9]] +// CHECK-NEXT: [[TMP11:%.*]] = getelementptr <16 x float>, ptr [[TMP1]], i32 0, i32 [[TMP9]] +// CHECK-NEXT: store float [[TMP10]], ptr [[TMP11]], align 4 // CHECK-NEXT: [[TMP12:%.*]] = add i32 12, [[TMP2]] // CHECK-NEXT: [[TMP13:%.*]] = extractelement <4 x float> [[TMP0]], i32 3 -// CHECK-NEXT: [[TMP14:%.*]] = insertelement <16 x float> [[TMP11]], float [[TMP13]], i32 [[TMP12]] -// CHECK-NEXT: store <16 x float> [[TMP14]], ptr [[TMP1]], align 4 +// CHECK-NEXT: [[TMP14:%.*]] = getelementptr <16 x float>, ptr [[TMP1]], i32 0, i32 [[TMP12]] +// CHECK-NEXT: store float [[TMP13]], ptr [[TMP14]], align 4 // CHECK-NEXT: ret void // void setMatrix(out float4x4 M, int index, float4 V) { @@ -47,11 +49,10 @@ void setMatrix(out float4x4 M, int index, float4 V) { // CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT]], <1 x float> poison, <1 x i32> zeroinitializer // CHECK-NEXT: [[TMP1:%.*]] = load ptr, ptr [[M_ADDR]], align 4, !nonnull [[META3]], !align [[META4]] // CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[INDEX_ADDR]], align 4 -// CHECK-NEXT: [[MATRIX_LOAD:%.*]] = load <2 x float>, ptr [[TMP1]], align 4 // CHECK-NEXT: [[TMP3:%.*]] = add i32 0, [[TMP2]] // CHECK-NEXT: [[TMP4:%.*]] = extractelement <1 x float> [[SPLAT_SPLAT]], i32 0 -// CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x float> [[MATRIX_LOAD]], float [[TMP4]], i32 [[TMP3]] -// CHECK-NEXT: store <2 x float> [[TMP5]], ptr [[TMP1]], align 4 +// CHECK-NEXT: [[TMP5:%.*]] = getelementptr <2 x float>, ptr [[TMP1]], i32 0, i32 [[TMP3]] +// CHECK-NEXT: store float [[TMP4]], ptr [[TMP5]], align 4 // CHECK-NEXT: ret void // void setMatrixScalar(out float2x1 M, int index, float S) { @@ -72,21 +73,23 @@ void setMatrixScalar(out float2x1 M, int index, float S) { // CHECK-NEXT: [[LOADEDV:%.*]] = trunc <4 x i32> [[TMP1]] to <4 x i1> // CHECK-NEXT: [[TMP2:%.*]] = load ptr, ptr [[M_ADDR]], align 4, !nonnull [[META3]], !align [[META4]] // CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr [[INDEX_ADDR]], align 4 -// CHECK-NEXT: [[MATRIX_LOAD:%.*]] = load <16 x i32>, ptr [[TMP2]], align 4 // CHECK-NEXT: [[TMP4:%.*]] = zext <4 x i1> [[LOADEDV]] to <4 x i32> // CHECK-NEXT: [[TMP5:%.*]] = add i32 0, [[TMP3]] // CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x i32> [[TMP4]], i32 0 -// CHECK-NEXT: [[TMP7:%.*]] = insertelement <16 x i32> [[MATRIX_LOAD]], i32 [[TMP6]], i32 [[TMP5]] +// CHECK-NEXT: [[TMP7:%.*]] = getelementptr <16 x i32>, ptr [[TMP2]], i32 0, i32 [[TMP5]] +// CHECK-NEXT: store i32 [[TMP6]], ptr [[TMP7]], align 4 // CHECK-NEXT: [[TMP8:%.*]] = add i32 4, [[TMP3]] // CHECK-NEXT: [[TMP9:%.*]] = extractelement <4 x i32> [[TMP4]], i32 1 -// CHECK-NEXT: [[TMP10:%.*]] = insertelement <16 x i32> [[TMP7]], i32 [[TMP9]], i32 [[TMP8]] +// CHECK-NEXT: [[TMP10:%.*]] = getelementptr <16 x i32>, ptr [[TMP2]], i32 0, i32 [[TMP8]] +// CHECK-NEXT: store i32 [[TMP9]], ptr [[TMP10]], align 4 // CHECK-NEXT: [[TMP11:%.*]] = add i32 8, [[TMP3]] // CHECK-NEXT: [[TMP12:%.*]] = extractelement <4 x i32> [[TMP4]], i32 2 -// CHECK-NEXT: [[TMP13:%.*]] = insertelement <16 x i32> [[TMP10]], i32 [[TMP12]], i32 [[TMP11]] +// CHECK-NEXT: [[TMP13:%.*]] = getelementptr <16 x i32>, ptr [[TMP2]], i32 0, i32 [[TMP11]] +// CHECK-NEXT: store i32 [[TMP12]], ptr [[TMP13]], align 4 // CHECK-NEXT: [[TMP14:%.*]] = add i32 12, [[TMP3]] // CHECK-NEXT: [[TMP15:%.*]] = extractelement <4 x i32> [[TMP4]], i32 3 -// CHECK-NEXT: [[TMP16:%.*]] = insertelement <16 x i32> [[TMP13]], i32 [[TMP15]], i32 [[TMP14]] -// CHECK-NEXT: store <16 x i32> [[TMP16]], ptr [[TMP2]], align 4 +// CHECK-NEXT: [[TMP16:%.*]] = getelementptr <16 x i32>, ptr [[TMP2]], i32 0, i32 [[TMP14]] +// CHECK-NEXT: store i32 [[TMP15]], ptr [[TMP16]], align 4 // CHECK-NEXT: ret void // void setBoolMatrix(out bool4x4 M, int index, bool4 V) { @@ -109,12 +112,11 @@ void setBoolMatrix(out bool4x4 M, int index, bool4 V) { // CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <1 x i1> [[SPLAT_SPLATINSERT]], <1 x i1> poison, <1 x i32> zeroinitializer // CHECK-NEXT: [[TMP1:%.*]] = load ptr, ptr [[M_ADDR]], align 4, !nonnull [[META3]], !align [[META4]] // CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[INDEX_ADDR]], align 4 -// CHECK-NEXT: [[MATRIX_LOAD:%.*]] = load <2 x i32>, ptr [[TMP1]], align 4 // CHECK-NEXT: [[TMP3:%.*]] = zext <1 x i1> [[SPLAT_SPLAT]] to <1 x i32> // CHECK-NEXT: [[TMP4:%.*]] = add i32 0, [[TMP2]] // CHECK-NEXT: [[TMP5:%.*]] = extractelement <1 x i32> [[TMP3]], i32 0 -// CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x i32> [[MATRIX_LOAD]], i32 [[TMP5]], i32 [[TMP4]] -// CHECK-NEXT: store <2 x i32> [[TMP6]], ptr [[TMP1]], align 4 +// CHECK-NEXT: [[TMP6:%.*]] = getelementptr <2 x i32>, ptr [[TMP1]], i32 0, i32 [[TMP4]] +// CHECK-NEXT: store i32 [[TMP5]], ptr [[TMP6]], align 4 // CHECK-NEXT: ret void // void setBoolMatrixScalar(out bool2x1 M, int index, bool S) { @@ -138,16 +140,18 @@ void setBoolMatrixScalar(out bool2x1 M, int index, bool S) { // CHECK-NEXT: [[MATRIX_ELEM5:%.*]] = extractelement <16 x i32> [[TMP0]], i32 15 // CHECK-NEXT: [[MATRIX_ROW_INS6:%.*]] = insertelement <4 x i32> [[MATRIX_ROW_INS4]], i32 [[MATRIX_ELEM5]], i32 3 // CHECK-NEXT: [[TMP1:%.*]] = load ptr, ptr [[M_ADDR]], align 4, !nonnull [[META3]], !align [[META4]] -// CHECK-NEXT: [[MATRIX_LOAD:%.*]] = load <16 x i32>, ptr [[TMP1]], align 4 // CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS6]], i32 0 -// CHECK-NEXT: [[TMP3:%.*]] = insertelement <16 x i32> [[MATRIX_LOAD]], i32 [[TMP2]], i32 0 +// CHECK-NEXT: [[TMP3:%.*]] = getelementptr <16 x i32>, ptr [[TMP1]], i32 0, i32 0 +// CHECK-NEXT: store i32 [[TMP2]], ptr [[TMP3]], align 4 // CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS6]], i32 1 -// CHECK-NEXT: [[TMP5:%.*]] = insertelement <16 x i32> [[TMP3]], i32 [[TMP4]], i32 4 +// CHECK-NEXT: [[TMP5:%.*]] = getelementptr <16 x i32>, ptr [[TMP1]], i32 0, i32 4 +// CHECK-NEXT: store i32 [[TMP4]], ptr [[TMP5]], align 4 // CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS6]], i32 2 -// CHECK-NEXT: [[TMP7:%.*]] = insertelement <16 x i32> [[TMP5]], i32 [[TMP6]], i32 8 +// CHECK-NEXT: [[TMP7:%.*]] = getelementptr <16 x i32>, ptr [[TMP1]], i32 0, i32 8 +// CHECK-NEXT: store i32 [[TMP6]], ptr [[TMP7]], align 4 // CHECK-NEXT: [[TMP8:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS6]], i32 3 -// CHECK-NEXT: [[TMP9:%.*]] = insertelement <16 x i32> [[TMP7]], i32 [[TMP8]], i32 12 -// CHECK-NEXT: store <16 x i32> [[TMP9]], ptr [[TMP1]], align 4 +// CHECK-NEXT: [[TMP9:%.*]] = getelementptr <16 x i32>, ptr [[TMP1]], i32 0, i32 12 +// CHECK-NEXT: store i32 [[TMP8]], ptr [[TMP9]], align 4 // CHECK-NEXT: [[TMP10:%.*]] = load <16 x i32>, ptr [[N_ADDR]], align 4 // CHECK-NEXT: [[MATRIX_ELEM7:%.*]] = extractelement <16 x i32> [[TMP10]], i32 2 // CHECK-NEXT: [[MATRIX_ROW_INS8:%.*]] = insertelement <4 x i32> poison, i32 [[MATRIX_ELEM7]], i32 0 @@ -158,56 +162,62 @@ void setBoolMatrixScalar(out bool2x1 M, int index, bool S) { // CHECK-NEXT: [[MATRIX_ELEM13:%.*]] = extractelement <16 x i32> [[TMP10]], i32 14 // CHECK-NEXT: [[MATRIX_ROW_INS14:%.*]] = insertelement <4 x i32> [[MATRIX_ROW_INS12]], i32 [[MATRIX_ELEM13]], i32 3 // CHECK-NEXT: [[TMP11:%.*]] = load ptr, ptr [[M_ADDR]], align 4, !nonnull [[META3]], !align [[META4]] -// CHECK-NEXT: [[MATRIX_LOAD15:%.*]] = load <16 x i32>, ptr [[TMP11]], align 4 // CHECK-NEXT: [[TMP12:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS14]], i32 0 -// CHECK-NEXT: [[TMP13:%.*]] = insertelement <16 x i32> [[MATRIX_LOAD15]], i32 [[TMP12]], i32 1 +// CHECK-NEXT: [[TMP13:%.*]] = getelementptr <16 x i32>, ptr [[TMP11]], i32 0, i32 1 +// CHECK-NEXT: store i32 [[TMP12]], ptr [[TMP13]], align 4 // CHECK-NEXT: [[TMP14:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS14]], i32 1 -// CHECK-NEXT: [[TMP15:%.*]] = insertelement <16 x i32> [[TMP13]], i32 [[TMP14]], i32 5 +// CHECK-NEXT: [[TMP15:%.*]] = getelementptr <16 x i32>, ptr [[TMP11]], i32 0, i32 5 +// CHECK-NEXT: store i32 [[TMP14]], ptr [[TMP15]], align 4 // CHECK-NEXT: [[TMP16:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS14]], i32 2 -// CHECK-NEXT: [[TMP17:%.*]] = insertelement <16 x i32> [[TMP15]], i32 [[TMP16]], i32 9 +// CHECK-NEXT: [[TMP17:%.*]] = getelementptr <16 x i32>, ptr [[TMP11]], i32 0, i32 9 +// CHECK-NEXT: store i32 [[TMP16]], ptr [[TMP17]], align 4 // CHECK-NEXT: [[TMP18:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS14]], i32 3 -// CHECK-NEXT: [[TMP19:%.*]] = insertelement <16 x i32> [[TMP17]], i32 [[TMP18]], i32 13 -// CHECK-NEXT: store <16 x i32> [[TMP19]], ptr [[TMP11]], align 4 +// CHECK-NEXT: [[TMP19:%.*]] = getelementptr <16 x i32>, ptr [[TMP11]], i32 0, i32 13 +// CHECK-NEXT: store i32 [[TMP18]], ptr [[TMP19]], align 4 // CHECK-NEXT: [[TMP20:%.*]] = load <16 x i32>, ptr [[N_ADDR]], align 4 -// CHECK-NEXT: [[MATRIX_ELEM16:%.*]] = extractelement <16 x i32> [[TMP20]], i32 1 -// CHECK-NEXT: [[MATRIX_ROW_INS17:%.*]] = insertelement <4 x i32> poison, i32 [[MATRIX_ELEM16]], i32 0 -// CHECK-NEXT: [[MATRIX_ELEM18:%.*]] = extractelement <16 x i32> [[TMP20]], i32 5 -// CHECK-NEXT: [[MATRIX_ROW_INS19:%.*]] = insertelement <4 x i32> [[MATRIX_ROW_INS17]], i32 [[MATRIX_ELEM18]], i32 1 -// CHECK-NEXT: [[MATRIX_ELEM20:%.*]] = extractelement <16 x i32> [[TMP20]], i32 9 -// CHECK-NEXT: [[MATRIX_ROW_INS21:%.*]] = insertelement <4 x i32> [[MATRIX_ROW_INS19]], i32 [[MATRIX_ELEM20]], i32 2 -// CHECK-NEXT: [[MATRIX_ELEM22:%.*]] = extractelement <16 x i32> [[TMP20]], i32 13 -// CHECK-NEXT: [[MATRIX_ROW_INS23:%.*]] = insertelement <4 x i32> [[MATRIX_ROW_INS21]], i32 [[MATRIX_ELEM22]], i32 3 +// CHECK-NEXT: [[MATRIX_ELEM15:%.*]] = extractelement <16 x i32> [[TMP20]], i32 1 +// CHECK-NEXT: [[MATRIX_ROW_INS16:%.*]] = insertelement <4 x i32> poison, i32 [[MATRIX_ELEM15]], i32 0 +// CHECK-NEXT: [[MATRIX_ELEM17:%.*]] = extractelement <16 x i32> [[TMP20]], i32 5 +// CHECK-NEXT: [[MATRIX_ROW_INS18:%.*]] = insertelement <4 x i32> [[MATRIX_ROW_INS16]], i32 [[MATRIX_ELEM17]], i32 1 +// CHECK-NEXT: [[MATRIX_ELEM19:%.*]] = extractelement <16 x i32> [[TMP20]], i32 9 +// CHECK-NEXT: [[MATRIX_ROW_INS20:%.*]] = insertelement <4 x i32> [[MATRIX_ROW_INS18]], i32 [[MATRIX_ELEM19]], i32 2 +// CHECK-NEXT: [[MATRIX_ELEM21:%.*]] = extractelement <16 x i32> [[TMP20]], i32 13 +// CHECK-NEXT: [[MATRIX_ROW_INS22:%.*]] = insertelement <4 x i32> [[MATRIX_ROW_INS20]], i32 [[MATRIX_ELEM21]], i32 3 // CHECK-NEXT: [[TMP21:%.*]] = load ptr, ptr [[M_ADDR]], align 4, !nonnull [[META3]], !align [[META4]] -// CHECK-NEXT: [[MATRIX_LOAD24:%.*]] = load <16 x i32>, ptr [[TMP21]], align 4 -// CHECK-NEXT: [[TMP22:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS23]], i32 0 -// CHECK-NEXT: [[TMP23:%.*]] = insertelement <16 x i32> [[MATRIX_LOAD24]], i32 [[TMP22]], i32 2 -// CHECK-NEXT: [[TMP24:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS23]], i32 1 -// CHECK-NEXT: [[TMP25:%.*]] = insertelement <16 x i32> [[TMP23]], i32 [[TMP24]], i32 6 -// CHECK-NEXT: [[TMP26:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS23]], i32 2 -// CHECK-NEXT: [[TMP27:%.*]] = insertelement <16 x i32> [[TMP25]], i32 [[TMP26]], i32 10 -// CHECK-NEXT: [[TMP28:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS23]], i32 3 -// CHECK-NEXT: [[TMP29:%.*]] = insertelement <16 x i32> [[TMP27]], i32 [[TMP28]], i32 14 -// CHECK-NEXT: store <16 x i32> [[TMP29]], ptr [[TMP21]], align 4 +// CHECK-NEXT: [[TMP22:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS22]], i32 0 +// CHECK-NEXT: [[TMP23:%.*]] = getelementptr <16 x i32>, ptr [[TMP21]], i32 0, i32 2 +// CHECK-NEXT: store i32 [[TMP22]], ptr [[TMP23]], align 4 +// CHECK-NEXT: [[TMP24:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS22]], i32 1 +// CHECK-NEXT: [[TMP25:%.*]] = getelementptr <16 x i32>, ptr [[TMP21]], i32 0, i32 6 +// CHECK-NEXT: store i32 [[TMP24]], ptr [[TMP25]], align 4 +// CHECK-NEXT: [[TMP26:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS22]], i32 2 +// CHECK-NEXT: [[TMP27:%.*]] = getelementptr <16 x i32>, ptr [[TMP21]], i32 0, i32 10 +// CHECK-NEXT: store i32 [[TMP26]], ptr [[TMP27]], align 4 +// CHECK-NEXT: [[TMP28:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS22]], i32 3 +// CHECK-NEXT: [[TMP29:%.*]] = getelementptr <16 x i32>, ptr [[TMP21]], i32 0, i32 14 +// CHECK-NEXT: store i32 [[TMP28]], ptr [[TMP29]], align 4 // CHECK-NEXT: [[TMP30:%.*]] = load <16 x i32>, ptr [[N_ADDR]], align 4 -// CHECK-NEXT: [[MATRIX_ELEM25:%.*]] = extractelement <16 x i32> [[TMP30]], i32 0 -// CHECK-NEXT: [[MATRIX_ROW_INS26:%.*]] = insertelement <4 x i32> poison, i32 [[MATRIX_ELEM25]], i32 0 -// CHECK-NEXT: [[MATRIX_ELEM27:%.*]] = extractelement <16 x i32> [[TMP30]], i32 4 -// CHECK-NEXT: [[MATRIX_ROW_INS28:%.*]] = insertelement <4 x i32> [[MATRIX_ROW_INS26]], i32 [[MATRIX_ELEM27]], i32 1 -// CHECK-NEXT: [[MATRIX_ELEM29:%.*]] = extractelement <16 x i32> [[TMP30]], i32 8 -// CHECK-NEXT: [[MATRIX_ROW_INS30:%.*]] = insertelement <4 x i32> [[MATRIX_ROW_INS28]], i32 [[MATRIX_ELEM29]], i32 2 -// CHECK-NEXT: [[MATRIX_ELEM31:%.*]] = extractelement <16 x i32> [[TMP30]], i32 12 -// CHECK-NEXT: [[MATRIX_ROW_INS32:%.*]] = insertelement <4 x i32> [[MATRIX_ROW_INS30]], i32 [[MATRIX_ELEM31]], i32 3 +// CHECK-NEXT: [[MATRIX_ELEM23:%.*]] = extractelement <16 x i32> [[TMP30]], i32 0 +// CHECK-NEXT: [[MATRIX_ROW_INS24:%.*]] = insertelement <4 x i32> poison, i32 [[MATRIX_ELEM23]], i32 0 +// CHECK-NEXT: [[MATRIX_ELEM25:%.*]] = extractelement <16 x i32> [[TMP30]], i32 4 +// CHECK-NEXT: [[MATRIX_ROW_INS26:%.*]] = insertelement <4 x i32> [[MATRIX_ROW_INS24]], i32 [[MATRIX_ELEM25]], i32 1 +// CHECK-NEXT: [[MATRIX_ELEM27:%.*]] = extractelement <16 x i32> [[TMP30]], i32 8 +// CHECK-NEXT: [[MATRIX_ROW_INS28:%.*]] = insertelement <4 x i32> [[MATRIX_ROW_INS26]], i32 [[MATRIX_ELEM27]], i32 2 +// CHECK-NEXT: [[MATRIX_ELEM29:%.*]] = extractelement <16 x i32> [[TMP30]], i32 12 +// CHECK-NEXT: [[MATRIX_ROW_INS30:%.*]] = insertelement <4 x i32> [[MATRIX_ROW_INS28]], i32 [[MATRIX_ELEM29]], i32 3 // CHECK-NEXT: [[TMP31:%.*]] = load ptr, ptr [[M_ADDR]], align 4, !nonnull [[META3]], !align [[META4]] -// CHECK-NEXT: [[MATRIX_LOAD33:%.*]] = load <16 x i32>, ptr [[TMP31]], align 4 -// CHECK-NEXT: [[TMP32:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS32]], i32 0 -// CHECK-NEXT: [[TMP33:%.*]] = insertelement <16 x i32> [[MATRIX_LOAD33]], i32 [[TMP32]], i32 3 -// CHECK-NEXT: [[TMP34:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS32]], i32 1 -// CHECK-NEXT: [[TMP35:%.*]] = insertelement <16 x i32> [[TMP33]], i32 [[TMP34]], i32 7 -// CHECK-NEXT: [[TMP36:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS32]], i32 2 -// CHECK-NEXT: [[TMP37:%.*]] = insertelement <16 x i32> [[TMP35]], i32 [[TMP36]], i32 11 -// CHECK-NEXT: [[TMP38:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS32]], i32 3 -// CHECK-NEXT: [[TMP39:%.*]] = insertelement <16 x i32> [[TMP37]], i32 [[TMP38]], i32 15 -// CHECK-NEXT: store <16 x i32> [[TMP39]], ptr [[TMP31]], align 4 +// CHECK-NEXT: [[TMP32:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS30]], i32 0 +// CHECK-NEXT: [[TMP33:%.*]] = getelementptr <16 x i32>, ptr [[TMP31]], i32 0, i32 3 +// CHECK-NEXT: store i32 [[TMP32]], ptr [[TMP33]], align 4 +// CHECK-NEXT: [[TMP34:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS30]], i32 1 +// CHECK-NEXT: [[TMP35:%.*]] = getelementptr <16 x i32>, ptr [[TMP31]], i32 0, i32 7 +// CHECK-NEXT: store i32 [[TMP34]], ptr [[TMP35]], align 4 +// CHECK-NEXT: [[TMP36:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS30]], i32 2 +// CHECK-NEXT: [[TMP37:%.*]] = getelementptr <16 x i32>, ptr [[TMP31]], i32 0, i32 11 +// CHECK-NEXT: store i32 [[TMP36]], ptr [[TMP37]], align 4 +// CHECK-NEXT: [[TMP38:%.*]] = extractelement <4 x i32> [[MATRIX_ROW_INS30]], i32 3 +// CHECK-NEXT: [[TMP39:%.*]] = getelementptr <16 x i32>, ptr [[TMP31]], i32 0, i32 15 +// CHECK-NEXT: store i32 [[TMP38]], ptr [[TMP39]], align 4 // CHECK-NEXT: ret void // void setMatrixConstIndex(out int4x4 M, int4x4 N ) { diff --git a/clang/test/CodeGenHLSL/BasicFeatures/matrix-type-indexing.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/matrix-type-indexing.hlsl index 7a63bbb45ecf7..3fff4976a9387 100644 --- a/clang/test/CodeGenHLSL/BasicFeatures/matrix-type-indexing.hlsl +++ b/clang/test/CodeGenHLSL/BasicFeatures/matrix-type-indexing.hlsl @@ -44,9 +44,8 @@ void storeAtMatrixSubscriptExpr(int row, int col, half value) { // ROW-CHECK-NEXT: [[row_major_index:%.*]] = add i32 [[row_offset]], [[col_load:%.*]] // COL-CHECK: [[col_offset:%.*]] = mul i32 [[col_load:%.*]], 2 // COL-CHECK-NEXT: [[col_major_index:%.*]] = add i32 [[col_offset]], [[row_load:%.*]] - // CHECK-NEXT: [[matrix_as_vec:%.*]] = load <6 x half>, ptr addrspace(2) @gM, align 2 - // ROW-CHECK-NEXT: [[matrix_after_insert:%.*]] = insertelement <6 x half> [[matrix_as_vec]], half [[value_load]], i32 [[row_major_index]] - // COL-CHECK-NEXT: [[matrix_after_insert:%.*]] = insertelement <6 x half> [[matrix_as_vec]], half [[value_load]], i32 [[col_major_index]] - // CHECK-NEXT: store <6 x half> [[matrix_after_insert]], ptr addrspace(2) @gM, align 2 + // ROW-CHECK-NEXT: [[matrix_gep:%.*]] = getelementptr <6 x half>, ptr addrspace(2) @gM, i32 0, i32 [[row_major_index]] + // COL-CHECK-NEXT: [[matrix_gep:%.*]] = getelementptr <6 x half>, ptr addrspace(2) @gM, i32 0, i32 [[col_major_index]] + // CHECK-NEXT: store half [[value_load]], ptr addrspace(2) [[matrix_gep]], align 2 gM[row][col] = value; } diff --git a/clang/test/CodeGenHLSL/BoolMatrix.hlsl b/clang/test/CodeGenHLSL/BoolMatrix.hlsl index 05c9ad4b926e6..824b9656e6848 100644 --- a/clang/test/CodeGenHLSL/BoolMatrix.hlsl +++ b/clang/test/CodeGenHLSL/BoolMatrix.hlsl @@ -98,9 +98,8 @@ bool fn4() { // CHECK-NEXT: [[ENTRY:.*:]] // CHECK-NEXT: [[M:%.*]] = alloca [4 x i32], align 4 // CHECK-NEXT: store <4 x i32> splat (i32 1), ptr [[M]], align 4 -// CHECK-NEXT: [[TMP0:%.*]] = load <4 x i32>, ptr [[M]], align 4 -// CHECK-NEXT: [[MATINS:%.*]] = insertelement <4 x i32> [[TMP0]], i32 0, i32 3 -// CHECK-NEXT: store <4 x i32> [[MATINS]], ptr [[M]], align 4 +// CHECK-NEXT: [[TMP0:%.*]] = getelementptr <4 x i32>, ptr [[M]], i32 0, i32 3 +// CHECK-NEXT: store i32 0, ptr [[TMP0]], align 4 // CHECK-NEXT: ret void // void fn5() { @@ -121,10 +120,9 @@ void fn5() { // CHECK-NEXT: [[TMP0:%.*]] = load i32, ptr [[V]], align 4 // CHECK-NEXT: [[LOADEDV:%.*]] = trunc i32 [[TMP0]] to i1 // CHECK-NEXT: [[BM1:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0 -// CHECK-NEXT: [[TMP1:%.*]] = load <4 x i32>, ptr [[BM1]], align 1 -// CHECK-NEXT: [[TMP2:%.*]] = zext i1 [[LOADEDV]] to i32 -// CHECK-NEXT: [[MATINS:%.*]] = insertelement <4 x i32> [[TMP1]], i32 [[TMP2]], i32 1 -// CHECK-NEXT: store <4 x i32> [[MATINS]], ptr [[BM1]], align 1 +// CHECK-NEXT: [[TMP1:%.*]] = zext i1 [[LOADEDV]] to i32 +// CHECK-NEXT: [[TMP2:%.*]] = getelementptr <4 x i32>, ptr [[BM1]], i32 0, i32 1 +// CHECK-NEXT: store i32 [[TMP1]], ptr [[TMP2]], align 4 // CHECK-NEXT: ret void // void fn6() { @@ -141,9 +139,8 @@ void fn6() { // CHECK-NEXT: [[ARRAYINIT_ELEMENT:%.*]] = getelementptr inbounds [4 x i32], ptr [[ARR]], i32 1 // CHECK-NEXT: store <4 x i32> zeroinitializer, ptr [[ARRAYINIT_ELEMENT]], align 4 // CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds [2 x [4 x i32]], ptr [[ARR]], i32 0, i32 0 -// CHECK-NEXT: [[TMP0:%.*]] = load <4 x i32>, ptr [[ARRAYIDX]], align 4 -// CHECK-NEXT: [[MATINS:%.*]] = insertelement <4 x i32> [[TMP0]], i32 0, i32 1 -// CHECK-NEXT: store <4 x i32> [[MATINS]], ptr [[ARRAYIDX]], align 4 +// CHECK-NEXT: [[TMP0:%.*]] = getelementptr <4 x i32>, ptr [[ARRAYIDX]], i32 0, i32 1 +// CHECK-NEXT: store i32 0, ptr [[TMP0]], align 4 // CHECK-NEXT: ret void // void fn7() { >From 91caa8f2ad008f50cf95245f3361d2c0c1c2799b Mon Sep 17 00:00:00 2001 From: Deric Cheung <[email protected]> Date: Thu, 15 Jan 2026 10:41:38 -0800 Subject: [PATCH 2/4] Reword comment on store to matrix row --- clang/lib/CodeGen/CGExpr.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index 6f08d0ef1d3b1..391dca57f3d96 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -2764,8 +2764,8 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst, } if (Dst.isMatrixRow()) { // NOTE: Since there are no other languages that implement matrix single - // subscripting, the logic here is specific to HLSL which allows stores to - // indivdual rows of matrices. + // subscripting, the logic here is specific to HLSL which allows + // per-element stores to rows of matrices. assert(getLangOpts().HLSL && "Store through matrix row LValues is only implemented for HLSL!"); QualType MatTy = Dst.getType(); >From acfdc5352dd85919192c7406e0d8e59795eb2639 Mon Sep 17 00:00:00 2001 From: Deric Cheung <[email protected]> Date: Thu, 15 Jan 2026 14:04:19 -0800 Subject: [PATCH 3/4] Remove matrix index assumption from HLSL path --- clang/lib/CodeGen/CGExpr.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index 391dca57f3d96..696742b5ba228 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -2735,11 +2735,6 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst, Val = Builder.CreateZExt(Val, ElemTy->getScalarType()); llvm::Value *Idx = Dst.getMatrixIdx(); - if (CGM.getCodeGenOpts().OptimizationLevel > 0) { - const auto *const MatTy = Dst.getType()->castAs<ConstantMatrixType>(); - llvm::MatrixBuilder MB(Builder); - MB.CreateIndexAssumption(Idx, MatTy->getNumElementsFlattened()); - } llvm::Value *Zero = llvm::ConstantInt::get(Int32Ty, 0); Address DstElemAddr = Builder.CreateGEP(DstAddr, {Zero, Idx}, DestAddrTy, ElemAlign); >From 13d849da4710f37710e909699c6150d7a52ec851 Mon Sep 17 00:00:00 2001 From: Deric Cheung <[email protected]> Date: Thu, 15 Jan 2026 14:05:02 -0800 Subject: [PATCH 4/4] Apply clang-format --- clang/lib/CodeGen/CGExpr.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index 696742b5ba228..2a5ae8da72512 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -2773,8 +2773,8 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst, Address DstAddr = Dst.getMatrixAddress(); llvm::Type *DestAddrTy = DstAddr.getElementType(); llvm::Type *ElemTy = DestAddrTy->getScalarType(); - CharUnits ElemAlign = CharUnits::fromQuantity( - CGM.getDataLayout().getPrefTypeAlign(ElemTy)); + CharUnits ElemAlign = + CharUnits::fromQuantity(CGM.getDataLayout().getPrefTypeAlign(ElemTy)); assert(ElemTy->getScalarSizeInBits() >= 8 && "matrix element type must be at least byte-sized"); @@ -2783,9 +2783,8 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst, if (RowVal->getType()->getScalarType()->getPrimitiveSizeInBits() < ElemTy->getScalarSizeInBits()) { auto *RowValVecTy = cast<llvm::FixedVectorType>(RowVal->getType()); - llvm::Type *StorageElmTy = - llvm::FixedVectorType::get(ElemTy->getScalarType(), - RowValVecTy->getNumElements()); + llvm::Type *StorageElmTy = llvm::FixedVectorType::get( + ElemTy->getScalarType(), RowValVecTy->getNumElements()); RowVal = Builder.CreateZExt(RowVal, StorageElmTy); } _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
