https://github.com/llvmbot created https://github.com/llvm/llvm-project/pull/159891
Backport 6119d1f115625cd1b8a2b9d331609eb9e9f676ce Requested by: @topperc >From 6b036105fdfcde8b38e99fa68ccd820c5652fc3a Mon Sep 17 00:00:00 2001 From: Craig Topper <[email protected]> Date: Fri, 19 Sep 2025 09:19:57 -0700 Subject: [PATCH] [RISCV] Re-work how VWADD_W_VL and similar _W_VL nodes are handled in combineOp_VLToVWOp_VL. (#159205) These instructions have one already narrow operand. Previously, we pretended like this operand was a supported extension. This could cause problems when we called getOrCreateExtendedOp on this narrow operand when creating the the VWADD_VL. If the narrow operand happened to be an extend of the opposite type, we would peek through it and then rebuild it with the wrong extension type. So (vwadd_w_vl (i32 (sext X)), (i16 (zext Y))) would become (vwadd_vl (i16 (sext X)), (i16 (sext Y))). To prevent this, we ignore the operand instead and pass std::nullopt for SupportsExt to getOrCreateExtendedOp so it won't peek through any extends on the narrow source. Fixes #159152. (cherry picked from commit 6119d1f115625cd1b8a2b9d331609eb9e9f676ce) --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 86 +++++++++++-------- .../fixed-vectors-vw-web-simplification.ll | 23 +++++ 2 files changed, 72 insertions(+), 37 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 5fb16f5ac6b9e..347f6c99852e7 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16936,18 +16936,9 @@ struct NodeExtensionHelper { case RISCVISD::VWSUBU_W_VL: case RISCVISD::VFWADD_W_VL: case RISCVISD::VFWSUB_W_VL: - if (OperandIdx == 1) { - SupportsZExt = - Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL; - SupportsSExt = - Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL; - SupportsFPExt = - Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL; - // There's no existing extension here, so we don't have to worry about - // making sure it gets removed. - EnforceOneUse = false; + // Operand 1 can't be changed. + if (OperandIdx == 1) break; - } [[fallthrough]]; default: fillUpExtensionSupport(Root, DAG, Subtarget); @@ -16985,20 +16976,20 @@ struct NodeExtensionHelper { case RISCVISD::ADD_VL: case RISCVISD::MUL_VL: case RISCVISD::OR_VL: - case RISCVISD::VWADD_W_VL: - case RISCVISD::VWADDU_W_VL: case RISCVISD::FADD_VL: case RISCVISD::FMUL_VL: - case RISCVISD::VFWADD_W_VL: case RISCVISD::VFMADD_VL: case RISCVISD::VFNMSUB_VL: case RISCVISD::VFNMADD_VL: case RISCVISD::VFMSUB_VL: return true; + case RISCVISD::VWADD_W_VL: + case RISCVISD::VWADDU_W_VL: case ISD::SUB: case RISCVISD::SUB_VL: case RISCVISD::VWSUB_W_VL: case RISCVISD::VWSUBU_W_VL: + case RISCVISD::VFWADD_W_VL: case RISCVISD::FSUB_VL: case RISCVISD::VFWSUB_W_VL: case ISD::SHL: @@ -17117,6 +17108,30 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS, Subtarget); } +/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS)) +/// +/// \returns std::nullopt if the pattern doesn't match or a CombineResult that +/// can be used to apply the pattern. +static std::optional<CombineResult> +canFoldToVWWithSameExtZEXT(SDNode *Root, const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG, + Subtarget); +} + +/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS)) +/// +/// \returns std::nullopt if the pattern doesn't match or a CombineResult that +/// can be used to apply the pattern. +static std::optional<CombineResult> +canFoldToVWWithSameExtBF16(SDNode *Root, const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG, + Subtarget); +} + /// Check if \p Root follows a pattern Root(LHS, ext(RHS)) /// /// \returns std::nullopt if the pattern doesn't match or a CombineResult that @@ -17145,7 +17160,7 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS, return std::nullopt; } -/// Check if \p Root follows a pattern Root(sext(LHS), sext(RHS)) +/// Check if \p Root follows a pattern Root(sext(LHS), RHS) /// /// \returns std::nullopt if the pattern doesn't match or a CombineResult that /// can be used to apply the pattern. @@ -17153,11 +17168,14 @@ static std::optional<CombineResult> canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS, const NodeExtensionHelper &RHS, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG, - Subtarget); + if (LHS.SupportsSExt) + return CombineResult(NodeExtensionHelper::getSExtOpcode(Root->getOpcode()), + Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS, + /*RHSExt=*/std::nullopt); + return std::nullopt; } -/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS)) +/// Check if \p Root follows a pattern Root(zext(LHS), RHS) /// /// \returns std::nullopt if the pattern doesn't match or a CombineResult that /// can be used to apply the pattern. @@ -17165,11 +17183,14 @@ static std::optional<CombineResult> canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS, const NodeExtensionHelper &RHS, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG, - Subtarget); + if (LHS.SupportsZExt) + return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()), + Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS, + /*RHSExt=*/std::nullopt); + return std::nullopt; } -/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS)) +/// Check if \p Root follows a pattern Root(fpext(LHS), RHS) /// /// \returns std::nullopt if the pattern doesn't match or a CombineResult that /// can be used to apply the pattern. @@ -17177,20 +17198,11 @@ static std::optional<CombineResult> canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS, const NodeExtensionHelper &RHS, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG, - Subtarget); -} - -/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS)) -/// -/// \returns std::nullopt if the pattern doesn't match or a CombineResult that -/// can be used to apply the pattern. -static std::optional<CombineResult> -canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS, - const NodeExtensionHelper &RHS, SelectionDAG &DAG, - const RISCVSubtarget &Subtarget) { - return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG, - Subtarget); + if (LHS.SupportsFPExt) + return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()), + Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS, + /*RHSExt=*/std::nullopt); + return std::nullopt; } /// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS)) @@ -17233,7 +17245,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { case RISCVISD::VFNMSUB_VL: Strategies.push_back(canFoldToVWWithSameExtension); if (Root->getOpcode() == RISCVISD::VFMADD_VL) - Strategies.push_back(canFoldToVWWithBF16EXT); + Strategies.push_back(canFoldToVWWithSameExtBF16); break; case ISD::MUL: case RISCVISD::MUL_VL: @@ -17245,7 +17257,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { case ISD::SHL: case RISCVISD::SHL_VL: // shl -> vwsll - Strategies.push_back(canFoldToVWWithZEXT); + Strategies.push_back(canFoldToVWWithSameExtZEXT); break; case RISCVISD::VWADD_W_VL: case RISCVISD::VWSUB_W_VL: diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll index 227a428831b60..ea4add2da5ebc 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll @@ -58,3 +58,26 @@ define <2 x i16> @vwmul_v2i16_multiple_users(ptr %x, ptr %y, ptr %z) { %i = or <2 x i16> %h, %g ret <2 x i16> %i } + +; Make sure we have a vsext.vl and a vwaddu.vx. +define <4 x i32> @pr159152(<4 x i8> %x) { +; NO_FOLDING-LABEL: pr159152: +; NO_FOLDING: # %bb.0: +; NO_FOLDING-NEXT: vsetivli zero, 4, e16, mf2, ta, ma +; NO_FOLDING-NEXT: vsext.vf2 v9, v8 +; NO_FOLDING-NEXT: li a0, 9 +; NO_FOLDING-NEXT: vwaddu.vx v8, v9, a0 +; NO_FOLDING-NEXT: ret +; +; FOLDING-LABEL: pr159152: +; FOLDING: # %bb.0: +; FOLDING-NEXT: vsetivli zero, 4, e16, mf2, ta, ma +; FOLDING-NEXT: vsext.vf2 v9, v8 +; FOLDING-NEXT: li a0, 9 +; FOLDING-NEXT: vwaddu.vx v8, v9, a0 +; FOLDING-NEXT: ret + %a = sext <4 x i8> %x to <4 x i16> + %b = zext <4 x i16> %a to <4 x i32> + %c = add <4 x i32> %b, <i32 9, i32 9, i32 9, i32 9> + ret <4 x i32> %c +} _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
