The tx rekey operation creates a new psp_assoc with the same rx state,
but with new tx state. The new psp_assoc will reference the current
assocs 'prev' assoc as its own. The assoc referenced by
'sk->psp_assoc' will have its reference dropped and be freed.

Signed-off-by: Daniel Zahka <[email protected]>
---
 net/psp/psp_sock.c | 51 +++++++++++++++++++++++++++++++++++++++++++++------
 1 file changed, 45 insertions(+), 6 deletions(-)

diff --git a/net/psp/psp_sock.c b/net/psp/psp_sock.c
index 9b0ecce8350f..f429b8b2d8f2 100644
--- a/net/psp/psp_sock.c
+++ b/net/psp/psp_sock.c
@@ -215,6 +215,47 @@ psp_pas_set_tx_key(struct psp_dev *psd, struct psp_assoc 
*pas,
        return rc;
 }
 
+static int
+psp_sock_tx_rekey(struct sock *sk, struct psp_dev *psd, struct psp_assoc *pas,
+                 struct psp_key_parsed *key, struct netlink_ext_ack *extack)
+{
+       struct psp_assoc *new;
+       size_t pas_sz;
+       int err;
+
+       pas_sz = struct_size(new, drv_data, psd->caps->assoc_drv_spc);
+       new = kzalloc(pas_sz, GFP_KERNEL_ACCOUNT);
+       if (!new)
+               return -ENOMEM;
+
+       /* don't increase refcounts until we know we won't fail */
+       new->psd         = pas->psd;
+       new->dev_id      = pas->dev_id;
+       new->generation  = pas->generation;
+       new->version     = pas->version;
+       new->peer_tx     = pas->peer_tx;
+       new->upgrade_seq = pas->upgrade_seq;
+       new->prev        = pas->prev;
+       refcount_set(&new->refcnt, 1);
+       memcpy(&new->rx, &pas->rx, sizeof(new->rx));
+
+       err = psp_pas_set_tx_key(psd, new, key, extack);
+       if (err) {
+               kfree(new);
+               return err;
+       }
+
+       psp_dev_get(new->psd);
+       if (new->prev)
+               refcount_inc(&new->prev->refcnt);
+       list_add_tail(&new->assocs_list, &psd->active_assocs);
+
+       rcu_assign_pointer(sk->psp_assoc, new);
+       psp_assoc_put(pas);
+
+       return 0;
+}
+
 static int
 psp_sock_set_tx_key(struct sock *sk, struct psp_dev *psd, struct psp_assoc 
*pas,
                    struct psp_key_parsed *key, struct netlink_ext_ack *extack)
@@ -269,13 +310,11 @@ int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev 
*psd,
                err = -EINVAL;
                goto exit_unlock;
        }
-       if (pas->tx.spi) {
-               NL_SET_ERR_MSG(extack, "Tx key already set");
-               err = -EBUSY;
-               goto exit_unlock;
-       }
+       if (pas->tx.spi)
+               err = psp_sock_tx_rekey(sk, psd, pas, key, extack);
+       else
+               err = psp_sock_set_tx_key(sk, psd, pas, key, extack);
 
-       err = psp_sock_set_tx_key(sk, psd, pas, key, extack);
 exit_unlock:
        release_sock(sk);
        return err;

-- 
2.47.3


Reply via email to