The following patch will reject stale connection requests.
It also fixes a couple of minor error handling cleanup issues
found during testing, and fixes a bug handling a stale
connection reject message where the cm_id would access an
invalid timewait pointer.

Signed-off-by: Sean Hefty <[EMAIL PROTECTED]>


Index: core/cm.c
===================================================================
--- core/cm.c   (revision 2346)
+++ core/cm.c   (working copy)
@@ -185,10 +185,35 @@ static int cm_alloc_msg(struct cm_id_pri
        return 0;
 }
 
+static int cm_alloc_response_msg(struct cm_port *port,
+                                struct ib_mad_recv_wc *mad_recv_wc,
+                                struct ib_mad_send_buf **msg)
+{
+       struct ib_mad_send_buf *m;
+       struct ib_ah *ah;
+
+       ah = ib_create_ah_from_wc(port->mad_agent->qp->pd, mad_recv_wc->wc,
+                                 mad_recv_wc->recv_buf.grh, port->port_num);
+       if (IS_ERR(ah))
+               return PTR_ERR(ah);
+
+       m = ib_create_send_mad(port->mad_agent, 1, mad_recv_wc->wc->pkey_index,
+                              ah, 0, sizeof(struct ib_mad_hdr),
+                              sizeof(struct ib_mad)-sizeof(struct ib_mad_hdr),
+                              GFP_ATOMIC);
+       if (IS_ERR(m)) {
+               ib_destroy_ah(ah);
+               return PTR_ERR(m);
+       }
+       *msg = m;
+       return 0;
+}
+
 static void cm_free_msg(struct ib_mad_send_buf *msg)
 {
        ib_destroy_ah(msg->send_wr.wr.ud.ah);
-       cm_deref_id(msg->context[0]);
+       if (msg->context[0])
+               cm_deref_id(msg->context[0]);
        ib_free_send_mad(msg);
 }
 
@@ -531,10 +556,7 @@ static void cm_cleanup_timewait(struct c
        spin_unlock_irqrestore(&cm.lock, flags);
 }
 
-static struct cm_timewait_info * cm_create_timewait_info(u32 local_id,
-                                                        u32 remote_id,
-                                                        u64 remote_ca_guid,
-                                                        u32 remote_qpn)
+static struct cm_timewait_info * cm_create_timewait_info(u32 local_id)
 {
        struct cm_timewait_info *timewait_info;
 
@@ -544,10 +566,6 @@ static struct cm_timewait_info * cm_crea
        memset(timewait_info, 0, sizeof *timewait_info);
 
        timewait_info->work.local_id = local_id;
-       timewait_info->work.remote_id = remote_id;
-       timewait_info->remote_ca_guid = remote_ca_guid;
-       timewait_info->remote_qpn = remote_qpn;
-
        INIT_WORK(&timewait_info->work.work, cm_work_handler,
                  &timewait_info->work);
        timewait_info->work.cm_event.event = IB_CM_TIMEWAIT_EXIT;
@@ -674,30 +692,33 @@ int ib_cm_listen(struct ib_cm_id *cm_id,
 }
 EXPORT_SYMBOL(ib_cm_listen);
 
-static void cm_format_mad_hdr(struct ib_mad_hdr *hdr,
-                             struct cm_id_private *cm_id_priv,
-                             enum cm_msg_attr_id attr_id,
-                             enum cm_msg_sequence msg_seq)
+static u64 cm_form_tid(struct cm_id_private *cm_id_priv,
+                      enum cm_msg_sequence msg_seq)
 {
        u64 hi_tid, low_tid;
 
+       hi_tid   = ((u64) cm_id_priv->av.port->mad_agent->hi_tid) << 32;
+       low_tid  = (u64) (cm_id_priv->id.local_id | (msg_seq << 30));
+       return cpu_to_be64(hi_tid | low_tid);
+}
+
+static void cm_format_mad_hdr(struct ib_mad_hdr *hdr,
+                             enum cm_msg_attr_id attr_id, u64 tid)
+{
        hdr->base_version  = IB_MGMT_BASE_VERSION;
        hdr->mgmt_class    = IB_MGMT_CLASS_CM;
        hdr->class_version = IB_CM_CLASS_VERSION;
        hdr->method        = IB_MGMT_METHOD_SEND;
        hdr->attr_id       = attr_id;
-
-       hi_tid   = ((u64) cm_id_priv->av.port->mad_agent->hi_tid) << 32;
-       low_tid  = (u64) (cm_id_priv->id.local_id | (msg_seq << 30));
-       hdr->tid = cpu_to_be64(hi_tid | low_tid);
+       hdr->tid           = tid;
 }
 
 static void cm_format_req(struct cm_req_msg *req_msg,
                          struct cm_id_private *cm_id_priv,
                          struct ib_cm_req_param *param)
 {
-       cm_format_mad_hdr(&req_msg->hdr, cm_id_priv,
-                         CM_REQ_ATTR_ID, CM_MSG_SEQUENCE_REQ);
+       cm_format_mad_hdr(&req_msg->hdr, CM_REQ_ATTR_ID,
+                         cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_REQ));
 
        req_msg->local_comm_id = cm_id_priv->id.local_id;
        req_msg->service_id = param->service_id;
@@ -755,6 +776,10 @@ static void cm_format_req(struct cm_req_
 
 static inline int cm_validate_req_param(struct ib_cm_req_param *param)
 {
+       /* peer-to-peer not supported */
+       if (param->peer_to_peer)
+               return -EINVAL;
+
        if (!param->primary_path)
                return -EINVAL;
 
@@ -796,14 +821,19 @@ int ib_send_cm_req(struct ib_cm_id *cm_i
        }
        spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 
+       cm_id_priv->timewait_info = cm_create_timewait_info(cm_id_priv->
+                                                           id.local_id);
+       if (IS_ERR(cm_id_priv->timewait_info))
+               goto out;
+
        ret = cm_init_av_by_path(param->primary_path, &cm_id_priv->av);
        if (ret)
-               goto out;
+               goto error1;
        if (param->alternate_path) {
                ret = cm_init_av_by_path(param->alternate_path,
                                         &cm_id_priv->alt_av);
                if (ret)
-                       goto out;
+                       goto error1;
        }
        cm_id->service_id = param->service_id;
        cm_id->service_mask = ~0ULL;
@@ -819,7 +849,7 @@ int ib_send_cm_req(struct ib_cm_id *cm_i
 
        ret = cm_alloc_msg(cm_id_priv, &cm_id_priv->msg);
        if (ret)
-               goto out;
+               goto error1;
 
        req_msg = (struct cm_req_msg *) cm_id_priv->msg->mad;
        cm_format_req(req_msg, cm_id_priv, param);
@@ -831,35 +861,61 @@ int ib_send_cm_req(struct ib_cm_id *cm_i
        cm_id_priv->local_ack_timeout =
                                cm_req_get_primary_local_ack_timeout(req_msg);
 
-       /*
-        * Received REQs won't match until we're in REQ_SENT state.  This
-        * simplifies error recovery if the send fails.
-        */
-       if (param->peer_to_peer) {
-               ret = -EINVAL;
-               goto out;
-       }
-
        spin_lock_irqsave(&cm_id_priv->lock, flags);
        ret = ib_post_send_mad(cm_id_priv->av.port->mad_agent,
                                &cm_id_priv->msg->send_wr, &bad_send_wr);
-
        if (ret) {
                spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-               /* if (param->peer_to_peer) {
-                       cleanup peer_service_table
-               } */
-               cm_free_msg(cm_id_priv->msg);
-               goto out;
+               goto error2;
        }
        BUG_ON(cm_id->state != IB_CM_IDLE);
        cm_id->state = IB_CM_REQ_SENT;
        spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-out:
-       return ret;
+       return 0;
+
+error2:        cm_free_msg(cm_id_priv->msg);
+error1:        kfree(cm_id_priv->timewait_info);
+out:   return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_req);
 
+static int cm_issue_rej(struct cm_port *port,
+                       struct ib_mad_recv_wc *mad_recv_wc,
+                       enum ib_cm_rej_reason reason,
+                       enum cm_msg_response msg_rejected,
+                       void *ari, u8 ari_length)
+{
+       struct ib_mad_send_buf *msg;
+       struct ib_send_wr *bad_send_wr;
+       struct cm_rej_msg *rej_msg, *rcv_msg;
+       int ret;
+
+       ret = cm_alloc_response_msg(port, mad_recv_wc, &msg);
+       if (ret)
+               return ret;
+       
+       /* We just need common CM header information.  Cast to any message. */
+       rcv_msg = (struct cm_rej_msg *) mad_recv_wc->recv_buf.mad;
+       rej_msg = (struct cm_rej_msg *) msg->mad;
+
+       cm_format_mad_hdr(&rej_msg->hdr, CM_REJ_ATTR_ID, rcv_msg->hdr.tid);
+       rej_msg->remote_comm_id = rcv_msg->local_comm_id;
+       rej_msg->local_comm_id = rcv_msg->remote_comm_id;
+       cm_rej_set_msg_rejected(rej_msg, msg_rejected);
+       rej_msg->reason = reason;
+
+       if (ari && ari_length) {
+               cm_rej_set_reject_info_len(rej_msg, ari_length);
+               memcpy(rej_msg->ari, ari, ari_length);
+       }
+
+       ret = ib_post_send_mad(port->mad_agent, &msg->send_wr, &bad_send_wr);
+       if (ret)
+               cm_free_msg(msg);
+
+       return ret;
+}
+
 static inline int cm_is_active_peer(u64 local_ca_guid, u64 remote_ca_guid,
                                    u32 local_qpn, u32 remote_qpn)
 {
@@ -992,29 +1048,31 @@ static int cm_req_handler(struct cm_work
        cm_id_priv->id.remote_id = req_msg->local_comm_id;
        cm_init_av_for_response(work->port, work->mad_recv_wc->wc,
                                &cm_id_priv->av);
-       cm_id_priv->timewait_info = cm_create_timewait_info(
-                                               cm_id_priv->id.local_id,
-                                               cm_id_priv->id.remote_id,
-                                               req_msg->local_ca_guid,
-                                               cm_req_get_local_qpn(req_msg));
+       cm_id_priv->timewait_info = cm_create_timewait_info(cm_id_priv->
+                                                           id.local_id);
        if (IS_ERR(cm_id_priv->timewait_info)) {
                ret = PTR_ERR(cm_id_priv->timewait_info);
                goto error1;
        }
+       cm_id_priv->timewait_info->work.remote_id = req_msg->local_comm_id;
+       cm_id_priv->timewait_info->remote_ca_guid = req_msg->local_ca_guid;
+       cm_id_priv->timewait_info->remote_qpn = cm_req_get_local_qpn(req_msg);
 
        spin_lock_irqsave(&cm.lock, flags);
        /* Check for duplicate REQ. */
        if (cm_insert_remote_id(cm_id_priv->timewait_info)) {
                spin_unlock_irqrestore(&cm.lock, flags);
                ret = -EINVAL;
-               goto error1;
+               goto error2;
        }
        /* Check for a stale connection. */
        if (cm_insert_remote_qpn(cm_id_priv->timewait_info)) {
                spin_unlock_irqrestore(&cm.lock, flags);
-               /* todo: reject as stale */
+               cm_issue_rej(work->port, work->mad_recv_wc,
+                            IB_CM_REJ_STALE_CONN, CM_MSG_RESPONSE_REQ,
+                            NULL, 0);
                ret = -EINVAL;
-               goto error1;
+               goto error2;
        }
        /* Find matching listen request. */
        listen_cm_id_priv = cm_find_listen(req_msg->service_id);
@@ -1035,11 +1093,11 @@ static int cm_req_handler(struct cm_work
        cm_format_paths_from_req(req_msg, &work->path[0], &work->path[1]);
        ret = cm_init_av_by_path(&work->path[0], &cm_id_priv->av);
        if (ret)
-               goto error2;
+               goto error3;
        if (req_msg->alt_local_lid) {
                ret = cm_init_av_by_path(&work->path[1], &cm_id_priv->alt_av);
                if (ret)
-                       goto error2;
+                       goto error3;
        }
        cm_id_priv->timeout_ms = cm_convert_to_ms(
                                        cm_req_get_local_resp_timeout(req_msg));
@@ -1058,11 +1116,12 @@ static int cm_req_handler(struct cm_work
        cm_process_work(cm_id_priv, work);
        cm_deref_id(listen_cm_id_priv);
        return 0;
-error2:
-       atomic_dec(&cm_id_priv->refcount);
+
+error3:        atomic_dec(&cm_id_priv->refcount);
        cm_deref_id(listen_cm_id_priv);
-error1:
-       ib_destroy_cm_id(&cm_id_priv->id);
+error2:        cm_cleanup_timewait(cm_id_priv->timewait_info);
+       kfree(cm_id_priv->timewait_info);
+error1:        ib_destroy_cm_id(&cm_id_priv->id);
        return ret;
 }
 
@@ -1070,8 +1129,9 @@ static void cm_format_rep(struct cm_rep_
                          struct cm_id_private *cm_id_priv,
                          struct ib_cm_rep_param *param)
 {
-       cm_format_mad_hdr(&rep_msg->hdr, cm_id_priv,
-                         CM_REP_ATTR_ID, CM_MSG_SEQUENCE_REQ);
+       /* todo: TID should match received REQ */
+       cm_format_mad_hdr(&rep_msg->hdr, CM_REP_ATTR_ID,
+                         cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_REQ));
 
        rep_msg->local_comm_id = cm_id_priv->id.local_id;
        rep_msg->remote_comm_id = cm_id_priv->id.remote_id;
@@ -1148,8 +1208,8 @@ static void cm_format_rtu(struct cm_rtu_
                          void *private_data,
                          u8 private_data_len)
 {
-       cm_format_mad_hdr(&rtu_msg->hdr, cm_id_priv,
-                         CM_RTU_ATTR_ID, CM_MSG_SEQUENCE_REQ);
+       cm_format_mad_hdr(&rtu_msg->hdr, CM_RTU_ATTR_ID,
+                         cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_REQ));
 
        rtu_msg->local_comm_id = cm_id_priv->id.local_id;
        rtu_msg->remote_comm_id = cm_id_priv->id.remote_id;
@@ -1262,7 +1322,6 @@ static void cm_dup_rep_handler(struct cm
 static int cm_rep_handler(struct cm_work *work)
 {
        struct cm_id_private *cm_id_priv;
-       struct cm_timewait_info *timewait_info;
        struct cm_rep_msg *rep_msg;
        unsigned long flags;
        int ret;
@@ -1274,27 +1333,25 @@ static int cm_rep_handler(struct cm_work
                return -EINVAL;
        }
 
-       timewait_info = cm_create_timewait_info(cm_id_priv->id.local_id,
-                                               rep_msg->local_comm_id,
-                                               rep_msg->local_ca_guid,
-                                               cm_rep_get_local_qpn(rep_msg));
-       if (IS_ERR(timewait_info)) {
-               ret = PTR_ERR(timewait_info);
-               goto error1;
-       }
+       cm_id_priv->timewait_info->work.remote_id = rep_msg->local_comm_id;
+       cm_id_priv->timewait_info->remote_ca_guid = rep_msg->local_ca_guid;
+       cm_id_priv->timewait_info->remote_qpn = cm_rep_get_local_qpn(rep_msg);
+
        spin_lock_irqsave(&cm.lock, flags);
        /* Check for duplicate REP. */
-       if (cm_insert_remote_id(timewait_info)) {
+       if (cm_insert_remote_id(cm_id_priv->timewait_info)) {
                spin_unlock_irqrestore(&cm.lock, flags);
                ret = -EINVAL;
-               goto error2;
+               goto error;
        }
        /* Check for a stale connection. */
-       if (cm_insert_remote_qpn(timewait_info)) {
+       if (cm_insert_remote_qpn(cm_id_priv->timewait_info)) {
                spin_unlock_irqrestore(&cm.lock, flags);
-               /* todo: reject as stale */
+               cm_issue_rej(work->port, work->mad_recv_wc,
+                            IB_CM_REJ_STALE_CONN, CM_MSG_RESPONSE_REP,
+                            NULL, 0);
                ret = -EINVAL;
-               goto error2;
+               goto error;
        }
        spin_unlock_irqrestore(&cm.lock, flags);
 
@@ -1308,7 +1365,7 @@ static int cm_rep_handler(struct cm_work
        default:
                spin_unlock_irqrestore(&cm_id_priv->lock, flags);
                ret = -EINVAL;
-               goto error2;
+               goto error;
        }
        cm_id_priv->id.state = IB_CM_REP_RCVD;
        cm_id_priv->id.remote_id = rep_msg->local_comm_id;
@@ -1317,7 +1374,6 @@ static int cm_rep_handler(struct cm_work
        cm_id_priv->responder_resources = rep_msg->initiator_depth;
        cm_id_priv->sq_psn = cm_rep_get_starting_psn(rep_msg);
        cm_id_priv->rnr_retry_count = cm_rep_get_rnr_retry_count(rep_msg);
-       cm_id_priv->timewait_info = timewait_info;
 
        /* todo: handle peer_to_peer */
 
@@ -1333,10 +1389,8 @@ static int cm_rep_handler(struct cm_work
        else
                cm_deref_id(cm_id_priv);
        return 0;
-error2:
-       cm_cleanup_timewait(timewait_info);
-       kfree(timewait_info);
-error1:
+
+error: cm_cleanup_timewait(cm_id_priv->timewait_info);
        cm_deref_id(cm_id_priv);
        return ret;
 }
@@ -1420,8 +1474,8 @@ static void cm_format_dreq(struct cm_dre
                          void *private_data,
                          u8 private_data_len)
 {
-       cm_format_mad_hdr(&dreq_msg->hdr, cm_id_priv,
-                         CM_DREQ_ATTR_ID, CM_MSG_SEQUENCE_DREQ);
+       cm_format_mad_hdr(&dreq_msg->hdr, CM_DREQ_ATTR_ID,
+                         cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_DREQ));
 
        dreq_msg->local_comm_id = cm_id_priv->id.local_id;
        dreq_msg->remote_comm_id = cm_id_priv->id.remote_id;
@@ -1480,8 +1534,9 @@ static void cm_format_drep(struct cm_dre
                          void *private_data,
                          u8 private_data_len)
 {
-       cm_format_mad_hdr(&drep_msg->hdr, cm_id_priv,
-                         CM_DREP_ATTR_ID, CM_MSG_SEQUENCE_DREQ);
+       /* todo: TID should match received DREQ */
+       cm_format_mad_hdr(&drep_msg->hdr, CM_DREP_ATTR_ID,
+                         cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_DREQ));
 
        drep_msg->local_comm_id = cm_id_priv->id.local_id;
        drep_msg->remote_comm_id = cm_id_priv->id.remote_id;
@@ -1642,8 +1697,9 @@ static void cm_format_rej(struct cm_rej_
                          void *private_data,
                          u8 private_data_len)
 {
-       cm_format_mad_hdr(&rej_msg->hdr, cm_id_priv,
-                         CM_REJ_ATTR_ID, CM_MSG_SEQUENCE_REQ);
+       /* todo: TID should match received REQ */
+       cm_format_mad_hdr(&rej_msg->hdr, CM_REJ_ATTR_ID,
+                         cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_REQ));
 
        rej_msg->remote_comm_id = cm_id_priv->id.remote_id;
 
@@ -1861,8 +1917,9 @@ static void cm_format_mra(struct cm_mra_
                return;
        }
        spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-       cm_format_mad_hdr(&mra_msg->hdr, cm_id_priv,
-                         CM_MRA_ATTR_ID, msg_sequence);
+       /* todo: TID should matched REQ or LAP */
+       cm_format_mad_hdr(&mra_msg->hdr, CM_MRA_ATTR_ID,
+                         cm_form_tid(cm_id_priv, msg_sequence));
        cm_mra_set_msg_mraed(mra_msg, msg_mraed);
 
        mra_msg->local_comm_id = cm_id_priv->id.local_id;
@@ -1946,8 +2003,8 @@ static void cm_format_lap(struct cm_lap_
                          void *private_data,
                          u8 private_data_len)
 {
-       cm_format_mad_hdr(&lap_msg->hdr, cm_id_priv,
-                         CM_LAP_ATTR_ID, CM_MSG_SEQUENCE_LAP);
+       cm_format_mad_hdr(&lap_msg->hdr, CM_LAP_ATTR_ID,
+                         cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_LAP));
 
        lap_msg->local_comm_id = cm_id_priv->id.local_id;
        lap_msg->remote_comm_id = cm_id_priv->id.remote_id;
@@ -2089,8 +2146,9 @@ static void cm_format_apr(struct cm_apr_
                          void *private_data,
                          u8 private_data_len)
 {
-       cm_format_mad_hdr(&apr_msg->hdr, cm_id_priv,
-                         CM_APR_ATTR_ID, CM_MSG_SEQUENCE_LAP);
+       /* todo: TID should match received LAP */
+       cm_format_mad_hdr(&apr_msg->hdr, CM_APR_ATTR_ID,
+                         cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_LAP));
 
        apr_msg->local_comm_id = cm_id_priv->id.local_id;
        apr_msg->remote_comm_id = cm_id_priv->id.remote_id;
@@ -2237,8 +2295,8 @@ static void cm_format_sidr_req(struct cm
                               struct cm_id_private *cm_id_priv,
                               struct ib_cm_sidr_req_param *param)
 {
-       cm_format_mad_hdr(&sidr_req_msg->hdr, cm_id_priv,
-                         CM_SIDR_REQ_ATTR_ID, CM_MSG_SEQUENCE_SIDR);
+       cm_format_mad_hdr(&sidr_req_msg->hdr, CM_SIDR_REQ_ATTR_ID,
+                         cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_SIDR));
 
        sidr_req_msg->request_id = cm_id_priv->id.local_id;
        sidr_req_msg->pkey = param->pkey;
@@ -2351,7 +2409,7 @@ static int cm_sidr_req_handler(struct cm
        if (!cur_cm_id_priv) {
                rb_erase(&cm_id_priv->sidr_id_node, &cm.remote_sidr_table);
                spin_unlock_irqrestore(&cm.lock, flags);
-               /* todo: reject with no match */
+               /* todo: reply with no match */
                goto out; /* No match. */
        }
        atomic_inc(&cur_cm_id_priv->refcount);
@@ -2375,8 +2433,9 @@ static void cm_format_sidr_rep(struct cm
                               struct cm_id_private *cm_id_priv,
                               struct ib_cm_sidr_rep_param *param)
 {
-       cm_format_mad_hdr(&sidr_rep_msg->hdr, cm_id_priv,
-                         CM_SIDR_REP_ATTR_ID, CM_MSG_SEQUENCE_SIDR);
+       /* todo: TID should match received SIDR REQ */
+       cm_format_mad_hdr(&sidr_rep_msg->hdr, CM_SIDR_REP_ATTR_ID,
+                         cm_form_tid(cm_id_priv, CM_MSG_SEQUENCE_SIDR));
 
        sidr_rep_msg->request_id = cm_id_priv->id.remote_id;
        sidr_rep_msg->status = param->status;



_______________________________________________
openib-general mailing list
[email protected]
http://openib.org/mailman/listinfo/openib-general

To unsubscribe, please visit http://openib.org/mailman/listinfo/openib-general

Reply via email to