https://github.com/MacDue created 
https://github.com/llvm/llvm-project/pull/166361

In the MachineSMEABIPass, if we have a function with ZT0 state, then there are 
some additional cases where we need to zero ZA and ZT0.

If the function has a private ZA interface, i.e., new ZT0 (and new ZA if 
present). Then ZT0/ZA must be zeroed when committing the incoming ZA save.

If the function has a shared ZA interface, e.g. new ZA and shared ZT0. Then ZA 
must be zeroed on function entry (without a ZA save commit).

The logic in the ABI pass has been reworked to use an "ENTRY" state to handle 
this (rather than the more specific "CALLER_DORMANT" state).

>From ae3ec416aaed38c254f0bbcef1c5b6671d1ce2a6 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <[email protected]>
Date: Mon, 3 Nov 2025 15:55:59 +0000
Subject: [PATCH] [AArch64][SME] Handle zeroing ZA and ZT0 in functions with
 ZT0 state

In the MachineSMEABIPass, if we have a function with ZT0 state, then
there are some additional cases where we need to zero ZA and ZT0.

If the function has a private ZA interface, i.e., new ZT0 (and new ZA if
present). Then ZT0/ZA must be zeroed when committing the incoming ZA
save.

If the function has a shared ZA interface, e.g. new ZA and shared ZT0.
Then ZA must be zeroed on function entry (without a ZA save commit).

The logic in the ABI pass has been reworked to use an "ENTRY" state to
handle this (rather than the more specific "CALLER_DORMANT" state).

Change-Id: Ib91e9b13ffa4752320fe6a7a720afe919cf00198
---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  9 --
 llvm/lib/Target/AArch64/MachineSMEABIPass.cpp | 99 +++++++++++--------
 llvm/test/CodeGen/AArch64/sme-zt0-state.ll    | 29 +++---
 3 files changed, 68 insertions(+), 69 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp 
b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 60aa61e993b26..30f961043e78b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -8735,15 +8735,6 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
     }
   }
 
-  if (getTM().useNewSMEABILowering()) {
-    // Clear new ZT0 state. TODO: Move this to the SME ABI pass.
-    if (Attrs.isNewZT0())
-      Chain = DAG.getNode(
-          ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
-          DAG.getConstant(Intrinsic::aarch64_sme_zero_zt, DL, MVT::i32),
-          DAG.getTargetConstant(0, DL, MVT::i32));
-  }
-
   return Chain;
 }
 
diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp 
b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
index 8f9aae944ad6d..bb4dfe8c60904 100644
--- a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
@@ -82,8 +82,8 @@ enum ZAState {
   // A ZA save has been set up or committed (i.e. ZA is dormant or off)
   LOCAL_SAVED,
 
-  // ZA is off or a lazy save has been set up by the caller
-  CALLER_DORMANT,
+  // The ZA/ZT0 state on entry to the function.
+  ENTRY,
 
   // ZA is off
   OFF,
@@ -200,7 +200,7 @@ StringRef getZAStateString(ZAState State) {
     MAKE_CASE(ZAState::ANY)
     MAKE_CASE(ZAState::ACTIVE)
     MAKE_CASE(ZAState::LOCAL_SAVED)
-    MAKE_CASE(ZAState::CALLER_DORMANT)
+    MAKE_CASE(ZAState::ENTRY)
     MAKE_CASE(ZAState::OFF)
   default:
     llvm_unreachable("Unexpected ZAState");
@@ -281,8 +281,8 @@ struct MachineSMEABI : public MachineFunctionPass {
   void propagateDesiredStates(FunctionInfo &FnInfo, bool Forwards = true);
 
   // Emission routines for private and shared ZA functions (using lazy saves).
-  void emitNewZAPrologue(MachineBasicBlock &MBB,
-                         MachineBasicBlock::iterator MBBI);
+  void emitSMEPrologue(MachineBasicBlock &MBB,
+                       MachineBasicBlock::iterator MBBI);
   void emitRestoreLazySave(EmitContext &, MachineBasicBlock &MBB,
                            MachineBasicBlock::iterator MBBI,
                            LiveRegs PhysLiveRegs);
@@ -395,9 +395,7 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs 
SMEFnAttrs) {
 
     if (MBB.isEntryBlock()) {
       // Entry block:
-      Block.FixedEntryState = SMEFnAttrs.hasPrivateZAInterface()
-                                  ? ZAState::CALLER_DORMANT
-                                  : ZAState::ACTIVE;
+      Block.FixedEntryState = ZAState::ENTRY;
     } else if (MBB.isEHPad()) {
       // EH entry block:
       Block.FixedEntryState = ZAState::LOCAL_SAVED;
@@ -815,32 +813,49 @@ void MachineSMEABI::emitAllocateLazySaveBuffer(
   }
 }
 
-void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
-                                      MachineBasicBlock::iterator MBBI) {
+static constexpr unsigned ZERO_ALL_ZA_MASK = 0b11111111;
+
+void MachineSMEABI::emitSMEPrologue(MachineBasicBlock &MBB,
+                                    MachineBasicBlock::iterator MBBI) {
   auto *TLI = Subtarget->getTargetLowering();
   DebugLoc DL = getDebugLoc(MBB, MBBI);
 
-  // Get current TPIDR2_EL0.
-  Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
-  BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS))
-      .addReg(TPIDR2EL0, RegState::Define)
-      .addImm(AArch64SysReg::TPIDR2_EL0);
-  // If TPIDR2_EL0 is non-zero, commit the lazy save.
-  // NOTE: Functions that only use ZT0 don't need to zero ZA.
-  bool ZeroZA = AFI->getSMEFnAttrs().hasZAState();
-  auto CommitZASave =
-      BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo))
-          .addReg(TPIDR2EL0)
-          .addImm(ZeroZA ? 1 : 0)
-          .addImm(/*ZeroZT0=*/false)
-          .addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_TPIDR2_SAVE))
-          .addRegMask(TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
-  if (ZeroZA)
-    CommitZASave.addDef(AArch64::ZAB0, RegState::ImplicitDefine);
-  // Enable ZA (as ZA could have previously been in the OFF state).
-  BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
-      .addImm(AArch64SVCR::SVCRZA)
-      .addImm(1);
+  bool ZeroZA = AFI->getSMEFnAttrs().isNewZA();
+  bool ZeroZT0 = AFI->getSMEFnAttrs().isNewZT0();
+  if (AFI->getSMEFnAttrs().hasPrivateZAInterface()) {
+    // Get current TPIDR2_EL0.
+    Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
+    BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS))
+        .addReg(TPIDR2EL0, RegState::Define)
+        .addImm(AArch64SysReg::TPIDR2_EL0);
+    // If TPIDR2_EL0 is non-zero, commit the lazy save.
+    // NOTE: Functions that only use ZT0 don't need to zero ZA.
+    auto CommitZASave =
+        BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo))
+            .addReg(TPIDR2EL0)
+            .addImm(ZeroZA)
+            .addImm(ZeroZT0)
+            .addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_TPIDR2_SAVE))
+            .addRegMask(TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
+    if (ZeroZA)
+      CommitZASave.addDef(AArch64::ZAB0, RegState::ImplicitDefine);
+    if (ZeroZT0)
+      CommitZASave.addDef(AArch64::ZT0, RegState::ImplicitDefine);
+    // Enable ZA (as ZA could have previously been in the OFF state).
+    BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
+        .addImm(AArch64SVCR::SVCRZA)
+        .addImm(1);
+  } else if (AFI->getSMEFnAttrs().hasSharedZAInterface()) {
+    if (ZeroZA) {
+      BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_M))
+          .addImm(ZERO_ALL_ZA_MASK)
+          .addDef(AArch64::ZAB0, RegState::ImplicitDefine);
+    }
+    if (ZeroZT0) {
+      DebugLoc DL = getDebugLoc(MBB, MBBI);
+      BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_T)).addDef(AArch64::ZT0);
+    }
+  }
 }
 
 void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context,
@@ -922,19 +937,19 @@ void MachineSMEABI::emitStateChange(EmitContext &Context,
   if (From == ZAState::ANY || To == ZAState::ANY)
     return;
 
-  // If we're exiting from the CALLER_DORMANT state that means this new ZA
-  // function did not touch ZA (so ZA was never turned on).
-  if (From == ZAState::CALLER_DORMANT && To == ZAState::OFF)
+  // If we're exiting from the ENTRY state that means that the function has not
+  // used ZA, so in the case of private ZA/ZT0 functions we can omit any set 
up.
+  if (From == ZAState::ENTRY && To == ZAState::OFF)
     return;
 
+  SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
+
   // TODO: Avoid setting up the save buffer if there's no transition to
   // LOCAL_SAVED.
-  if (From == ZAState::CALLER_DORMANT) {
-    assert(AFI->getSMEFnAttrs().hasPrivateZAInterface() &&
-           "CALLER_DORMANT state requires private ZA interface");
+  if (From == ZAState::ENTRY) {
     assert(&MBB == &MBB.getParent()->front() &&
-           "CALLER_DORMANT state only valid in entry block");
-    emitNewZAPrologue(MBB, MBB.getFirstNonPHI());
+           "ENTRY state only valid in entry block");
+    emitSMEPrologue(MBB, MBB.getFirstNonPHI());
     if (To == ZAState::ACTIVE)
       return; // Nothing more to do (ZA is active after the prologue).
 
@@ -949,9 +964,9 @@ void MachineSMEABI::emitStateChange(EmitContext &Context,
   else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
     emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
   else if (To == ZAState::OFF) {
-    assert(From != ZAState::CALLER_DORMANT &&
-           "CALLER_DORMANT to OFF should have already been handled");
-    assert(!AFI->getSMEFnAttrs().hasAgnosticZAInterface() &&
+    assert(From != ZAState::ENTRY &&
+           "ENTRY to OFF should have already been handled");
+    assert(!SMEFnAttrs.hasAgnosticZAInterface() &&
            "Should not turn ZA off in agnostic ZA function");
     emitZAOff(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED);
   } else {
diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll 
b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
index 5b81f5dafe421..4c48e41294a3a 100644
--- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -199,9 +199,9 @@ define void @zt0_new_caller_zt0_new_callee(ptr %callee) 
"aarch64_new_zt0" nounwi
 ; CHECK-NEWLOWERING-NEXT:  // %bb.1:
 ; CHECK-NEWLOWERING-NEXT:    bl __arm_tpidr2_save
 ; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEWLOWERING-NEXT:    zero { zt0 }
 ; CHECK-NEWLOWERING-NEXT:  .LBB6_2:
 ; CHECK-NEWLOWERING-NEXT:    smstart za
-; CHECK-NEWLOWERING-NEXT:    zero { zt0 }
 ; CHECK-NEWLOWERING-NEXT:    mov x19, sp
 ; CHECK-NEWLOWERING-NEXT:    str zt0, [x19]
 ; CHECK-NEWLOWERING-NEXT:    smstop za
@@ -252,9 +252,9 @@ define i64 @zt0_new_caller_abi_routine_callee() 
"aarch64_new_zt0" nounwind {
 ; CHECK-NEWLOWERING-NEXT:  // %bb.1:
 ; CHECK-NEWLOWERING-NEXT:    bl __arm_tpidr2_save
 ; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEWLOWERING-NEXT:    zero { zt0 }
 ; CHECK-NEWLOWERING-NEXT:  .LBB7_2:
 ; CHECK-NEWLOWERING-NEXT:    smstart za
-; CHECK-NEWLOWERING-NEXT:    zero { zt0 }
 ; CHECK-NEWLOWERING-NEXT:    mov x19, sp
 ; CHECK-NEWLOWERING-NEXT:    str zt0, [x19]
 ; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_state
@@ -302,9 +302,9 @@ define void @zt0_new_caller(ptr %callee) "aarch64_new_zt0" 
nounwind {
 ; CHECK-NEWLOWERING-NEXT:  // %bb.1:
 ; CHECK-NEWLOWERING-NEXT:    bl __arm_tpidr2_save
 ; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEWLOWERING-NEXT:    zero { zt0 }
 ; CHECK-NEWLOWERING-NEXT:  .LBB8_2:
 ; CHECK-NEWLOWERING-NEXT:    smstart za
-; CHECK-NEWLOWERING-NEXT:    zero { zt0 }
 ; CHECK-NEWLOWERING-NEXT:    blr x0
 ; CHECK-NEWLOWERING-NEXT:    smstop za
 ; CHECK-NEWLOWERING-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
@@ -343,9 +343,9 @@ define void @new_za_zt0_caller(ptr %callee) 
"aarch64_new_za" "aarch64_new_zt0" n
 ; CHECK-NEWLOWERING-NEXT:    bl __arm_tpidr2_save
 ; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, xzr
 ; CHECK-NEWLOWERING-NEXT:    zero {za}
+; CHECK-NEWLOWERING-NEXT:    zero { zt0 }
 ; CHECK-NEWLOWERING-NEXT:  .LBB9_2:
 ; CHECK-NEWLOWERING-NEXT:    smstart za
-; CHECK-NEWLOWERING-NEXT:    zero { zt0 }
 ; CHECK-NEWLOWERING-NEXT:    blr x0
 ; CHECK-NEWLOWERING-NEXT:    smstop za
 ; CHECK-NEWLOWERING-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
@@ -356,20 +356,13 @@ define void @new_za_zt0_caller(ptr %callee) 
"aarch64_new_za" "aarch64_new_zt0" n
 
 ; Expect clear ZA on entry
 define void @new_za_shared_zt0_caller(ptr %callee) "aarch64_new_za" 
"aarch64_in_zt0" nounwind {
-; CHECK-LABEL: new_za_shared_zt0_caller:
-; CHECK:       // %bb.0:
-; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEXT:    zero {za}
-; CHECK-NEXT:    blr x0
-; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
-; CHECK-NEXT:    ret
-;
-; CHECK-NEWLOWERING-LABEL: new_za_shared_zt0_caller:
-; CHECK-NEWLOWERING:       // %bb.0:
-; CHECK-NEWLOWERING-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT:    blr x0
-; CHECK-NEWLOWERING-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-COMMON-LABEL: new_za_shared_zt0_caller:
+; CHECK-COMMON:       // %bb.0:
+; CHECK-COMMON-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-COMMON-NEXT:    zero {za}
+; CHECK-COMMON-NEXT:    blr x0
+; CHECK-COMMON-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-COMMON-NEXT:    ret
   call void %callee() "aarch64_inout_za" "aarch64_in_zt0";
   ret void;
 }

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to