================ @@ -5247,50 +5247,85 @@ SDValue AArch64TargetLowering::LowerLOOP_DEPENDENCE_MASK(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); - uint64_t EltSize = Op.getConstantOperandVal(2); - EVT VT = Op.getValueType(); - switch (EltSize) { - case 1: - if (VT != MVT::v16i8 && VT != MVT::nxv16i1) - return SDValue(); - break; - case 2: - if (VT != MVT::v8i8 && VT != MVT::nxv8i1) - return SDValue(); - break; - case 4: - if (VT != MVT::v4i16 && VT != MVT::nxv4i1) - return SDValue(); - break; - case 8: - if (VT != MVT::v2i32 && VT != MVT::nxv2i1) - return SDValue(); - break; - default: - // Other element sizes are incompatible with whilewr/rw, so expand instead - return SDValue(); - } + assert((Subtarget->hasSVE2() || + (Subtarget->hasSME() && Subtarget->isStreaming())) && + "Lowering loop_dependence_raw_mask or loop_dependence_war_mask " + "requires SVE or SME"); + + uint64_t EltSizeInBytes = Op.getConstantOperandVal(2); + // Other element sizes are incompatible with whilewr/rw, so expand instead + if (!is_contained({1u, 2u, 4u, 8u}, EltSizeInBytes)) + return SDValue(); + + EVT FullVT = Op.getValueType(); + EVT ExtractVT = FullVT; + EVT EltVT = MVT::getIntegerVT(EltSizeInBytes * 8); + unsigned NumElements = FullVT.getVectorMinNumElements(); + unsigned PredElements = getPackedSVEVectorVT(EltVT).getVectorMinNumElements(); + bool Split = NumElements > PredElements; + + if (EltSizeInBytes * NumElements < 16) + // The element size and vector length combination must at least form a + // 128-bit vector. Shorter vector lengths can be widened then extracted + FullVT = FullVT.getDoubleNumVectorElementsVT(*DAG.getContext()); + + auto LowerToWhile = [&](EVT VT, unsigned AddrScale) { ---------------- SamTebbs33 wrote:
Done. https://github.com/llvm/llvm-project/pull/153187 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits