================ @@ -5248,49 +5248,94 @@ AArch64TargetLowering::LowerLOOP_DEPENDENCE_MASK(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); uint64_t EltSize = Op.getConstantOperandVal(2); - EVT VT = Op.getValueType(); + EVT FullVT = Op.getValueType(); + unsigned NumElements = FullVT.getVectorMinNumElements(); + unsigned NumSplits = 0; + EVT EltVT; switch (EltSize) { case 1: - if (VT != MVT::v16i8 && VT != MVT::nxv16i1) - return SDValue(); + EltVT = MVT::i8; break; case 2: - if (VT != MVT::v8i8 && VT != MVT::nxv8i1) - return SDValue(); + if (NumElements >= 16) + NumSplits = NumElements / 16; + EltVT = MVT::i16; break; case 4: - if (VT != MVT::v4i16 && VT != MVT::nxv4i1) - return SDValue(); + if (NumElements >= 8) + NumSplits = NumElements / 8; + EltVT = MVT::i32; break; case 8: - if (VT != MVT::v2i32 && VT != MVT::nxv2i1) - return SDValue(); + if (NumElements >= 4) + NumSplits = NumElements / 4; + EltVT = MVT::i64; break; default: // Other element sizes are incompatible with whilewr/rw, so expand instead return SDValue(); } - SDValue PtrA = Op.getOperand(0); - SDValue PtrB = Op.getOperand(1); + auto LowerToWhile = [&](EVT VT, unsigned AddrScale) { + SDValue PtrA = Op.getOperand(0); + SDValue PtrB = Op.getOperand(1); - if (VT.isScalableVT()) - return DAG.getNode(Op.getOpcode(), DL, VT, PtrA, PtrB, Op.getOperand(2)); + EVT StoreVT = EVT::getVectorVT(*DAG.getContext(), EltVT, + VT.getVectorMinNumElements(), false); + if (AddrScale > 0) { + unsigned Offset = StoreVT.getStoreSizeInBits() / 8 * AddrScale; + SDValue Addend; - // We can use the SVE whilewr/whilerw instruction to lower this - // intrinsic by creating the appropriate sequence of scalable vector - // operations and then extracting a fixed-width subvector from the scalable - // vector. Scalable vector variants are already legal. - EVT ContainerVT = - EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(), - VT.getVectorNumElements(), true); - EVT WhileVT = ContainerVT.changeElementType(MVT::i1); + if (VT.isScalableVT()) + Addend = DAG.getVScale(DL, MVT::i64, APInt(64, Offset)); + else + Addend = DAG.getConstant(Offset, DL, MVT::i64); - SDValue Mask = - DAG.getNode(Op.getOpcode(), DL, WhileVT, PtrA, PtrB, Op.getOperand(2)); - SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, ContainerVT, Mask); - return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, MaskAsInt, - DAG.getVectorIdxConstant(0, DL)); + PtrA = DAG.getNode(ISD::ADD, DL, MVT::i64, PtrA, Addend); ---------------- sdesmalen-arm wrote:
As I had already pointed out in the other review, it is wrong to increment `PtrB`. https://github.com/llvm/llvm-project/pull/153187 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits