================
@@ -5247,50 +5247,85 @@ SDValue
 AArch64TargetLowering::LowerLOOP_DEPENDENCE_MASK(SDValue Op,
                                                  SelectionDAG &DAG) const {
   SDLoc DL(Op);
-  uint64_t EltSize = Op.getConstantOperandVal(2);
-  EVT VT = Op.getValueType();
-  switch (EltSize) {
-  case 1:
-    if (VT != MVT::v16i8 && VT != MVT::nxv16i1)
-      return SDValue();
-    break;
-  case 2:
-    if (VT != MVT::v8i8 && VT != MVT::nxv8i1)
-      return SDValue();
-    break;
-  case 4:
-    if (VT != MVT::v4i16 && VT != MVT::nxv4i1)
-      return SDValue();
-    break;
-  case 8:
-    if (VT != MVT::v2i32 && VT != MVT::nxv2i1)
-      return SDValue();
-    break;
-  default:
-    // Other element sizes are incompatible with whilewr/rw, so expand instead
-    return SDValue();
-  }
+  assert((Subtarget->hasSVE2() ||
+          (Subtarget->hasSME() && Subtarget->isStreaming())) &&
+         "Lowering loop_dependence_raw_mask or loop_dependence_war_mask "
+         "requires SVE or SME");
+
+  uint64_t EltSizeInBytes = Op.getConstantOperandVal(2);
+  // Other element sizes are incompatible with whilewr/rw, so expand instead
+  if (!is_contained({1u, 2u, 4u, 8u}, EltSizeInBytes))
+    return SDValue();
+
+  EVT FullVT = Op.getValueType();
+  EVT ExtractVT = FullVT;
+  EVT EltVT = MVT::getIntegerVT(EltSizeInBytes * 8);
+  unsigned NumElements = FullVT.getVectorMinNumElements();
+  unsigned PredElements = 
getPackedSVEVectorVT(EltVT).getVectorMinNumElements();
+  bool Split = NumElements > PredElements;
+
+  if (EltSizeInBytes * NumElements < 16)
+    // The element size and vector length combination must at least form a
+    // 128-bit vector. Shorter vector lengths can be widened then extracted
+    FullVT = FullVT.getDoubleNumVectorElementsVT(*DAG.getContext());
+
+  auto LowerToWhile = [&](EVT VT, unsigned AddrScale) {
+    SDValue PtrA = Op.getOperand(0);
+    SDValue PtrB = Op.getOperand(1);
+
+    if (AddrScale > 0) {
+      unsigned Offset =
+          VT.getVectorMinNumElements() * EltSizeInBytes * AddrScale;
+      SDValue Addend;
+
+      if (VT.isScalableVT())
+        Addend = DAG.getVScale(DL, MVT::i64, APInt(64, Offset));
+      else
+        Addend = DAG.getConstant(Offset, DL, MVT::i64);
 
-  SDValue PtrA = Op.getOperand(0);
-  SDValue PtrB = Op.getOperand(1);
+      PtrA = DAG.getNode(ISD::ADD, DL, MVT::i64, PtrA, Addend);
+    }
 
-  if (VT.isScalableVT())
-    return DAG.getNode(Op.getOpcode(), DL, VT, PtrA, PtrB, Op.getOperand(2));
+    if (VT.isScalableVT())
+      return DAG.getNode(Op.getOpcode(), DL, VT, PtrA, PtrB, Op.getOperand(2));
 
-  // We can use the SVE whilewr/whilerw instruction to lower this
-  // intrinsic by creating the appropriate sequence of scalable vector
-  // operations and then extracting a fixed-width subvector from the scalable
-  // vector. Scalable vector variants are already legal.
-  EVT ContainerVT =
-      EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
-                       VT.getVectorNumElements(), true);
-  EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
+    // We can use the SVE whilewr/whilerw instruction to lower this
+    // intrinsic by creating the appropriate sequence of scalable vector
+    // operations and then extracting a fixed-width subvector from the scalable
+    // vector. Scalable vector variants are already legal.
+    EVT ContainerVT =
+        EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
+                         VT.getVectorNumElements(), true);
+    EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
 
-  SDValue Mask =
-      DAG.getNode(Op.getOpcode(), DL, WhileVT, PtrA, PtrB, Op.getOperand(2));
-  SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, ContainerVT, Mask);
-  return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, MaskAsInt,
-                     DAG.getVectorIdxConstant(0, DL));
+    SDValue Mask =
+        DAG.getNode(Op.getOpcode(), DL, WhileVT, PtrA, PtrB, Op.getOperand(2));
+    SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, ContainerVT, Mask);
+    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, MaskAsInt,
+                       DAG.getVectorIdxConstant(0, DL));
+  };
+
+  SDValue Result;
+  if (!Split) {
+    Result = LowerToWhile(FullVT, 0);
----------------
SamTebbs33 wrote:

It was being called to re-use the containerisation, not needed now with the new 
approach.

https://github.com/llvm/llvm-project/pull/153187
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to