http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/92b8098c/proton-c/src/windows/driver.c ---------------------------------------------------------------------- diff --git a/proton-c/src/windows/driver.c b/proton-c/src/windows/driver.c index 3dadadb..ddccd82 100644 --- a/proton-c/src/windows/driver.c +++ b/proton-c/src/windows/driver.c @@ -19,59 +19,37 @@ * */ -/* - * Copy of posix poll-based driver with minimal changes to use - * select(). TODO: fully native implementaton with I/O completion - * ports. - * - * This implementation comments out the posix max_fds arg to select - * which has no meaning on windows. The number of fd_set slots are - * configured at compile time via FD_SETSIZE, chosen "large enough" - * for the limited scalability of select() at the expense of - * 2*N*sizeof(unsigned int) bytes per driver instance. select (and - * associated macros like FD_ZERO) are otherwise unaffected - * performance-wise by increasing FD_SETSIZE. - */ - -#define FD_SETSIZE 2048 -#ifndef _WIN32_WINNT -#define _WIN32_WINNT 0x0501 -#endif -#if _WIN32_WINNT < 0x0501 -#error "Proton requires Windows API support for XP or later." -#endif -#include <winsock2.h> -#include <Ws2tcpip.h> -#define PN_WINAPI - #include <assert.h> #include <stdio.h> #include <ctype.h> #include <sys/types.h> #include <fcntl.h> -#include "../platform.h" -#include <proton/io.h> #include <proton/driver.h> #include <proton/driver_extras.h> #include <proton/error.h> +#include <proton/io.h> #include <proton/sasl.h> #include <proton/ssl.h> -#include <proton/util.h> -#include "../util.h" -#include "../ssl/ssl-internal.h" - +#include <proton/object.h> +#include <proton/selector.h> #include <proton/types.h> +#include "selectable.h" +#include "util.h" +#include "platform.h" -/* Posix compatibility helpers */ - -static int pn_socket_pair(SOCKET sv[2]); -#define close(sock) closesocket(sock) -static int pn_i_error_from_errno_wrap(pn_error_t *error, const char *msg) { - errno = WSAGetLastError(); - return pn_i_error_from_errno(error, msg); -} -#define pn_i_error_from_errno(e,m) pn_i_error_from_errno_wrap(e,m) +/* + * This driver provides limited thread safety for some operations on pn_connector_t objects. + * + * These calls are: pn_connector_process(), pn_connector_activate(), pn_connector_activated(), + * pn_connector_close(), and others that only touch the connection object, i.e. + * pn_connector_context(). These calls provide limited safety in that simultaneous calls are + * not allowed to the same pn_connector_t object. + * + * The application must call pn_driver_wakeup() and resume its wait loop logic if a call to + * pn_wait() may have overlapped with any of the above calls that could affect a pn_wait() + * outcome. + */ /* Decls */ @@ -81,82 +59,137 @@ static int pn_i_error_from_errno_wrap(pn_error_t *error, const char *msg) { struct pn_driver_t { pn_error_t *error; pn_io_t *io; + pn_selector_t *selector; pn_listener_t *listener_head; pn_listener_t *listener_tail; pn_listener_t *listener_next; pn_connector_t *connector_head; pn_connector_t *connector_tail; - pn_connector_t *connector_next; + pn_listener_t *ready_listener_head; + pn_listener_t *ready_listener_tail; + pn_connector_t *ready_connector_head; + pn_connector_t *ready_connector_tail; + pn_selectable_t *ctrl_selectable; size_t listener_count; size_t connector_count; - size_t closed_count; - fd_set readfds; - fd_set writefds; - fd_set exceptfds; - // int max_fds; - bool overflow; pn_socket_t ctrl[2]; //pipe for updating selectable status - pn_trace_t trace; - pn_timestamp_t wakeup; }; +typedef enum {LISTENER, CONNECTOR} sel_type_t; + struct pn_listener_t { + sel_type_t type; pn_driver_t *driver; pn_listener_t *listener_next; pn_listener_t *listener_prev; - int idx; + pn_listener_t *ready_listener_next; + pn_listener_t *ready_listener_prev; + void *context; + pn_selectable_t *selectable; bool pending; - pn_socket_t fd; bool closed; - void *context; }; #define PN_NAME_MAX (256) struct pn_connector_t { + sel_type_t type; pn_driver_t *driver; pn_connector_t *connector_next; pn_connector_t *connector_prev; + pn_connector_t *ready_connector_next; + pn_connector_t *ready_connector_prev; char name[PN_NAME_MAX]; + pn_timestamp_t wakeup; + pn_timestamp_t posted_wakeup; + pn_connection_t *connection; + pn_transport_t *transport; + pn_sasl_t *sasl; + pn_listener_t *listener; + void *context; + pn_selectable_t *selectable; int idx; + int status; + int posted_status; + pn_trace_t trace; bool pending_tick; bool pending_read; bool pending_write; - pn_socket_t fd; - int status; - pn_trace_t trace; bool closed; - pn_timestamp_t wakeup; - pn_connection_t *connection; - pn_transport_t *transport; - pn_sasl_t *sasl; bool input_done; bool output_done; - pn_listener_t *listener; - void *context; }; +static void get_new_events(pn_driver_t *); + /* Impls */ // listener +static void driver_listener_readable(pn_selectable_t *sel) +{ + // do nothing +} + +static void driver_listener_writable(pn_selectable_t *sel) +{ + // do nothing +} + +static void driver_listener_expired(pn_selectable_t *sel) +{ + // do nothing +} + +static ssize_t driver_listener_capacity(pn_selectable_t *sel) +{ + return 1; +} + +static ssize_t driver_listener_pending(pn_selectable_t *sel) +{ + return 0; +} + +static pn_timestamp_t driver_listener_deadline(pn_selectable_t *sel) +{ + return 0; +} + +static void driver_listener_finalize(pn_selectable_t *sel) +{ + // do nothing +} + + static void pn_driver_add_listener(pn_driver_t *d, pn_listener_t *l) { if (!l->driver) return; LL_ADD(d, listener, l); l->driver = d; d->listener_count++; + pn_selector_add(d->selector, l->selectable); +} + +static void ready_listener_list_remove(pn_driver_t *d, pn_listener_t *l) +{ + LL_REMOVE(d, ready_listener, l); + l->ready_listener_next = NULL; + l->ready_listener_prev = NULL; } static void pn_driver_remove_listener(pn_driver_t *d, pn_listener_t *l) { if (!l->driver) return; + pn_selector_remove(d->selector, l->selectable); + if (l == d->ready_listener_head || l->ready_listener_prev) + ready_listener_list_remove(d, l); + if (l == d->listener_next) { d->listener_next = l->listener_next; } - LL_REMOVE(d, listener, l); l->driver = NULL; d->listener_count--; @@ -169,7 +202,7 @@ pn_listener_t *pn_listener(pn_driver_t *driver, const char *host, pn_socket_t sock = pn_listen(driver->io, host, port); - if (sock == INVALID_SOCKET) { + if (sock == PN_INVALID_SOCKET) { return NULL; } else { pn_listener_t *l = pn_listener_fd(driver, sock, context); @@ -186,15 +219,24 @@ pn_listener_t *pn_listener_fd(pn_driver_t *driver, pn_socket_t fd, void *context pn_listener_t *l = (pn_listener_t *) malloc(sizeof(pn_listener_t)); if (!l) return NULL; + l->type = LISTENER; l->driver = driver; l->listener_next = NULL; l->listener_prev = NULL; - l->idx = 0; + l->ready_listener_next = NULL; + l->ready_listener_prev = NULL; l->pending = false; - l->fd = fd; l->closed = false; l->context = context; - + l->selectable = pni_selectable(driver_listener_capacity, + driver_listener_pending, + driver_listener_deadline, + driver_listener_readable, + driver_listener_writable, + driver_listener_expired, + driver_listener_finalize); + pni_selectable_set_fd(l->selectable, fd); + pni_selectable_set_context(l->selectable, l); pn_driver_add_listener(driver, l); return l; } @@ -202,7 +244,7 @@ pn_listener_t *pn_listener_fd(pn_driver_t *driver, pn_socket_t fd, void *context pn_socket_t pn_listener_get_fd(pn_listener_t *listener) { assert(listener); - return listener->fd; + return pn_selectable_fd(listener->selectable); } pn_listener_t *pn_listener_head(pn_driver_t *driver) @@ -234,8 +276,8 @@ pn_connector_t *pn_listener_accept(pn_listener_t *l) if (!l || !l->pending) return NULL; char name[PN_NAME_MAX]; - pn_socket_t sock = pn_accept(l->driver->io, l->fd, name, PN_NAME_MAX); - if (sock == INVALID_SOCKET) { + pn_socket_t sock = pn_accept(l->driver->io, pn_selectable_fd(l->selectable), name, PN_NAME_MAX); + if (sock == PN_INVALID_SOCKET) { return NULL; } else { if (l->driver->trace & (PN_TRACE_FRM | PN_TRACE_RAW | PN_TRACE_DRV)) @@ -252,8 +294,7 @@ void pn_listener_close(pn_listener_t *l) if (!l) return; if (l->closed) return; - if (close(l->fd) == -1) - perror("close"); + pn_close(l->driver->io, pn_selectable_fd(l->selectable)); l->closed = true; } @@ -262,45 +303,85 @@ void pn_listener_free(pn_listener_t *l) if (!l) return; if (l->driver) pn_driver_remove_listener(l->driver, l); + pn_selectable_free(l->selectable); free(l); } // connector +static ssize_t driver_connection_capacity(pn_selectable_t *sel) +{ + pn_connector_t *c = (pn_connector_t *) pni_selectable_get_context(sel); + return c->posted_status & PN_SEL_RD ? 1 : 0; +} + +static ssize_t driver_connection_pending(pn_selectable_t *sel) +{ + pn_connector_t *c = (pn_connector_t *) pni_selectable_get_context(sel); + return c->posted_status & PN_SEL_WR ? 1 : 0; +} + +static pn_timestamp_t driver_connection_deadline(pn_selectable_t *sel) +{ + pn_connector_t *c = (pn_connector_t *) pni_selectable_get_context(sel); + return c->posted_wakeup; +} + +static void driver_connection_readable(pn_selectable_t *sel) +{ + // do nothing +} + +static void driver_connection_writable(pn_selectable_t *sel) +{ + // do nothing +} + +static void driver_connection_expired(pn_selectable_t *sel) +{ + // do nothing +} + +static void driver_connection_finalize(pn_selectable_t *sel) +{ + // do nothing +} + static void pn_driver_add_connector(pn_driver_t *d, pn_connector_t *c) { if (!c->driver) return; LL_ADD(d, connector, c); c->driver = d; d->connector_count++; + pn_selector_add(d->selector, c->selectable); +} + +static void ready_connector_list_remove(pn_driver_t *d, pn_connector_t *c) +{ + LL_REMOVE(d, ready_connector, c); + c->ready_connector_next = NULL; + c->ready_connector_prev = NULL; } static void pn_driver_remove_connector(pn_driver_t *d, pn_connector_t *c) { if (!c->driver) return; - if (c == d->connector_next) { - d->connector_next = c->connector_next; - } + pn_selector_remove(d->selector, c->selectable); + if (c == d->ready_connector_head || c->ready_connector_prev) + ready_connector_list_remove(d, c); LL_REMOVE(d, connector, c); c->driver = NULL; d->connector_count--; - if (c->closed) { - d->closed_count--; - } } -pn_connector_t *pn_connector(pn_driver_t *driver, const char *hostarg, +pn_connector_t *pn_connector(pn_driver_t *driver, const char *host, const char *port, void *context) { if (!driver) return NULL; - // convert "0.0.0.0" to "127.0.0.1" on Windows for outgoing sockets - const char *host = strcmp("0.0.0.0", hostarg) ? hostarg : "127.0.0.1"; - pn_socket_t sock = pn_connect(driver->io, host, port); - pn_connector_t *c = pn_connector_fd(driver, sock, context); snprintf(c->name, PN_NAME_MAX, "%s:%s", host, port); if (driver->trace & (PN_TRACE_FRM | PN_TRACE_RAW | PN_TRACE_DRV)) @@ -308,28 +389,28 @@ pn_connector_t *pn_connector(pn_driver_t *driver, const char *hostarg, return c; } -static void pn_connector_read(pn_connector_t *ctor); -static void pn_connector_write(pn_connector_t *ctor); - pn_connector_t *pn_connector_fd(pn_driver_t *driver, pn_socket_t fd, void *context) { if (!driver) return NULL; pn_connector_t *c = (pn_connector_t *) malloc(sizeof(pn_connector_t)); if (!c) return NULL; + c->type = CONNECTOR; c->driver = driver; c->connector_next = NULL; c->connector_prev = NULL; + c->ready_connector_next = NULL; + c->ready_connector_prev = NULL; c->pending_tick = false; c->pending_read = false; c->pending_write = false; c->name[0] = '\0'; - c->idx = 0; - c->fd = fd; c->status = PN_SEL_RD | PN_SEL_WR; + c->posted_status = -1; c->trace = driver->trace; c->closed = false; c->wakeup = 0; + c->posted_wakeup = 0; c->connection = NULL; c->transport = pn_transport(); c->sasl = pn_sasl(c->transport); @@ -337,7 +418,15 @@ pn_connector_t *pn_connector_fd(pn_driver_t *driver, pn_socket_t fd, void *conte c->output_done = false; c->context = context; c->listener = NULL; - + c->selectable = pni_selectable(driver_connection_capacity, + driver_connection_pending, + driver_connection_deadline, + driver_connection_readable, + driver_connection_writable, + driver_connection_expired, + driver_connection_finalize); + pni_selectable_set_fd(c->selectable, fd); + pni_selectable_set_context(c->selectable, c); pn_connector_trace(c, driver->trace); pn_driver_add_connector(driver, c); @@ -347,7 +436,7 @@ pn_connector_t *pn_connector_fd(pn_driver_t *driver, pn_socket_t fd, void *conte pn_socket_t pn_connector_get_fd(pn_connector_t *connector) { assert(connector); - return connector->fd; + return pn_selectable_fd(connector->selectable); } pn_connector_t *pn_connector_head(pn_driver_t *driver) @@ -380,8 +469,15 @@ pn_transport_t *pn_connector_transport(pn_connector_t *ctor) void pn_connector_set_connection(pn_connector_t *ctor, pn_connection_t *connection) { if (!ctor) return; + if (ctor->connection) { + pn_decref(ctor->connection); + pn_transport_unbind(ctor->transport); + } ctor->connection = connection; - pn_transport_bind(ctor->transport, connection); + if (ctor->connection) { + pn_incref(ctor->connection); + pn_transport_bind(ctor->transport, connection); + } if (ctor->transport) pn_transport_trace(ctor->transport, ctor->trace); } @@ -418,10 +514,8 @@ void pn_connector_close(pn_connector_t *ctor) if (!ctor) return; ctor->status = 0; - if (close(ctor->fd) == -1) - perror("close"); + pn_close(ctor->driver->io, pn_selectable_fd(ctor->selectable)); ctor->closed = true; - ctor->driver->closed_count++; } bool pn_connector_closed(pn_connector_t *ctor) @@ -434,9 +528,11 @@ void pn_connector_free(pn_connector_t *ctor) if (!ctor) return; if (ctor->driver) pn_driver_remove_connector(ctor->driver, ctor); - ctor->connection = NULL; pn_transport_free(ctor->transport); ctor->transport = NULL; + if (ctor->connection) pn_decref(ctor->connection); + ctor->connection = NULL; + pn_selectable_free(ctor->selectable); free(ctor); } @@ -487,6 +583,7 @@ void pn_connector_process(pn_connector_t *c) if (c->closed) return; pn_transport_t *transport = c->transport; + pn_socket_t sock = pn_selectable_fd(c->selectable); /// /// Socket read @@ -497,7 +594,7 @@ void pn_connector_process(pn_connector_t *c) c->status |= PN_SEL_RD; if (c->pending_read) { c->pending_read = false; - ssize_t n = pn_recv(c->driver->io, c->fd, pn_transport_tail(transport), capacity); + ssize_t n = pn_recv(c->driver->io, sock, pn_transport_tail(transport), capacity); if (n < 0) { if (errno != EAGAIN) { perror("read"); @@ -540,7 +637,7 @@ void pn_connector_process(pn_connector_t *c) c->status |= PN_SEL_WR; if (c->pending_write) { c->pending_write = false; - ssize_t n = pn_send(c->driver->io, c->fd, pn_transport_head(transport), pending); + ssize_t n = pn_send(c->driver->io, sock, pn_transport_head(transport), pending); if (n < 0) { // XXX if (errno != EAGAIN) { @@ -574,35 +671,41 @@ void pn_connector_process(pn_connector_t *c) // driver +static pn_selectable_t *create_ctrl_selectable(pn_socket_t fd); + pn_driver_t *pn_driver() { pn_driver_t *d = (pn_driver_t *) malloc(sizeof(pn_driver_t)); if (!d) return NULL; + d->error = pn_error(); d->io = pn_io(); + d->selector = pn_io_selector(d->io); d->listener_head = NULL; d->listener_tail = NULL; d->listener_next = NULL; + d->ready_listener_head = NULL; + d->ready_listener_tail = NULL; d->connector_head = NULL; d->connector_tail = NULL; - d->connector_next = NULL; + d->ready_connector_head = NULL; + d->ready_connector_tail = NULL; d->listener_count = 0; d->connector_count = 0; - d->closed_count = 0; - // d->max_fds = 0; d->ctrl[0] = 0; d->ctrl[1] = 0; d->trace = ((pn_env_bool("PN_TRACE_RAW") ? PN_TRACE_RAW : PN_TRACE_OFF) | (pn_env_bool("PN_TRACE_FRM") ? PN_TRACE_FRM : PN_TRACE_OFF) | (pn_env_bool("PN_TRACE_DRV") ? PN_TRACE_DRV : PN_TRACE_OFF)); - d->wakeup = 0; // XXX - if (pn_socket_pair(d->ctrl)) { + if (pn_pipe(d->io, d->ctrl)) { perror("Can't create control pipe"); free(d); return NULL; } + d->ctrl_selectable = create_ctrl_selectable(d->ctrl[0]); + pn_selector_add(d->selector, d->ctrl_selectable); return d; } @@ -626,8 +729,9 @@ void pn_driver_free(pn_driver_t *d) { if (!d) return; - close(d->ctrl[0]); - close(d->ctrl[1]); + pn_selectable_free(d->ctrl_selectable); + pn_close(d->io, d->ctrl[0]); + pn_close(d->io, d->ctrl[1]); while (d->connector_head) pn_connector_free(d->connector_head); while (d->listener_head) @@ -640,7 +744,7 @@ void pn_driver_free(pn_driver_t *d) int pn_driver_wakeup(pn_driver_t *d) { if (d) { - ssize_t count = send(d->ctrl[1], "x", 1, 0); + ssize_t count = pn_write(d->io, d->ctrl[1], "x", 1); if (count <= 0) { return count; } else { @@ -651,158 +755,57 @@ int pn_driver_wakeup(pn_driver_t *d) } } -static void pn_driver_rebuild(pn_driver_t *d) +void pn_driver_wait_1(pn_driver_t *d) { - d->wakeup = 0; - d->overflow = false; - int r_avail = FD_SETSIZE; - int w_avail = FD_SETSIZE; - // d->max_fds = -1; - FD_ZERO(&d->readfds); - FD_ZERO(&d->writefds); - FD_ZERO(&d->exceptfds); - - FD_SET(d->ctrl[0], &d->readfds); - // if (d->ctrl[0] > d->max_fds) d->max_fds = d->ctrl[0]; - - pn_listener_t *l = d->listener_head; - for (unsigned i = 0; i < d->listener_count; i++) { - if (r_avail) { - FD_SET(l->fd, &d->readfds); - // if (l->fd > d->max_fds) d->max_fds = l->fd; - r_avail--; - l = l->listener_next; - } - else { - d->overflow = true; - break; - } - } +} +int pn_driver_wait_2(pn_driver_t *d, int timeout) +{ + // These lists will normally be empty + while (d->ready_listener_head) + ready_listener_list_remove(d, d->ready_listener_head); + while (d->ready_connector_head) + ready_connector_list_remove(d, d->ready_connector_head); pn_connector_t *c = d->connector_head; for (unsigned i = 0; i < d->connector_count; i++) { - if (!c->closed) { - FD_SET(c->fd, &d->exceptfds); - d->wakeup = pn_timestamp_min(d->wakeup, c->wakeup); - if (c->status & PN_SEL_RD) { - if (r_avail) { - FD_SET(c->fd, &d->readfds); - r_avail--; - } - else { - d->overflow = true; - break; - } - } - if (c->status & PN_SEL_WR) { - if (w_avail) { - FD_SET(c->fd, &d->writefds); - w_avail--; - } - else { - d->overflow = true; - break; - } - } - // if (c->fd > d->max_fds) d->max_fds = c->fd; + // Optimistically use a snapshot of the non-threadsafe vars. + // If they are in flux, the app will guarantee progress with a pn_driver_wakeup(). + int current_status = c->status; + pn_timestamp_t current_wakeup = c->wakeup; + if (c->posted_status != current_status || c->posted_wakeup != current_wakeup) { + c->posted_status = current_status; + c->posted_wakeup = current_wakeup; + pn_selector_update(c->driver->selector, c->selectable); + } + if (c->closed) { + c->pending_read = false; + c->pending_write = false; + c->pending_tick = false; + LL_ADD(d, ready_connector, c); } c = c->connector_next; } -} -void pn_driver_wait_1(pn_driver_t *d) -{ - pn_driver_rebuild(d); -} - -int pn_driver_wait_2(pn_driver_t *d, int timeout) -{ - if (d->overflow) - return pn_error_set(d->error, PN_ERR, "maximum driver sockets exceeded"); - if (d->wakeup) { - pn_timestamp_t now = pn_i_now(); - if (now >= d->wakeup) - timeout = 0; - else - timeout = (timeout < 0) ? d->wakeup-now : pn_min(timeout, d->wakeup - now); - } + if (d->ready_connector_head) + timeout = 0; // We found closed connections - struct timeval to = {0}; - struct timeval *to_arg = &to; - // block only if (timeout == 0) and (closed_count == 0) - if (d->closed_count == 0) { - if (timeout > 0) { - // convert millisecs to sec and usec: - to.tv_sec = timeout/1000; - to.tv_usec = (timeout - (to.tv_sec * 1000)) * 1000; - } - else if (timeout < 0) { - to_arg = NULL; - } - } - int nfds = select(/* d->max_fds */ 0, &d->readfds, &d->writefds, &d->exceptfds, to_arg); - if (nfds == SOCKET_ERROR) { - errno = WSAGetLastError(); - pn_i_error_from_errno(d->error, "select"); + int code = pn_selector_select(d->selector, timeout); + if (code) { + pn_error_set(d->error, code, "select"); return -1; } + get_new_events(d); return 0; } int pn_driver_wait_3(pn_driver_t *d) { - bool woken = false; - if (FD_ISSET(d->ctrl[0], &d->readfds)) { - woken = true; - //clear the pipe - char buffer[512]; - while (recv(d->ctrl[0], buffer, 512, 0) == 512); - } - - pn_listener_t *l = d->listener_head; - while (l) { - l->pending = (FD_ISSET(l->fd, &d->readfds)); - l = l->listener_next; - } - - pn_timestamp_t now = pn_i_now(); - pn_connector_t *c = d->connector_head; - while (c) { - if (c->closed) { - c->pending_read = false; - c->pending_write = false; - c->pending_tick = false; - } else { - c->pending_read = FD_ISSET(c->fd, &d->readfds); - c->pending_write = FD_ISSET(c->fd, &d->writefds); - c->pending_tick = (c->wakeup && c->wakeup <= now); -// Unlike Posix no distinction of POLLERR and POLLHUP -// if (idx && d->fds[idx].revents & POLLERR) -// pn_connector_close(c); -// else if (idx && (d->fds[idx].revents & POLLHUP)) { -// [...] -// Strategy, defer error to a recv or send if read or write pending. -// Otherwise proclaim the connection dead. - if (!c->pending_read && !c->pending_write) { - if (FD_ISSET(c->fd, &d->exceptfds)) { - // can't defer error to a read or write, close now. - // How to get WSAlastError() equivalent info? - fprintf(stderr, "connector cleanup on unknown error %s\n", c->name); - pn_connector_close(c); - } - } - } - c = c->connector_next; - } - - d->listener_next = d->listener_head; - d->connector_next = d->connector_head; - - return woken ? PN_INTR : 0; + // no-op with new selector/selectables + return 0; } -// + // XXX - pn_driver_wait has been divided into three internal functions as a // temporary workaround for a multi-threading problem. A multi-threaded // application must hold a lock on parts 1 and 3, but not on part 2. @@ -821,103 +824,75 @@ int pn_driver_wait(pn_driver_t *d, int timeout) return pn_driver_wait_3(d); } +static void get_new_events(pn_driver_t *d) +{ + bool woken = false; + int events; + pn_selectable_t *sel; + while ((sel = pn_selector_next(d->selector, &events)) != NULL) { + if (sel == d->ctrl_selectable) { + woken = true; + //clear the pipe + char buffer[512]; + while (pn_read(d->io, d->ctrl[0], buffer, 512) == 512); + continue; + } + + void *ctx = pni_selectable_get_context(sel); + sel_type_t *type = (sel_type_t *) ctx; + if (*type == CONNECTOR) { + pn_connector_t *c = (pn_connector_t *) ctx; + if (!c->closed) { + LL_ADD(d, ready_connector, c); + c->pending_read = events & PN_READABLE; + c->pending_write = events & PN_WRITABLE; + c->pending_tick = events & PN_EXPIRED; + } + } else { + pn_listener_t *l = (pn_listener_t *) ctx; + LL_ADD(d, ready_listener, l); + l->pending = events & PN_READABLE; + } + } +} + pn_listener_t *pn_driver_listener(pn_driver_t *d) { if (!d) return NULL; - while (d->listener_next) { - pn_listener_t *l = d->listener_next; - d->listener_next = l->listener_next; - - if (l->pending) { + pn_listener_t *l = d->ready_listener_head; + while (l) { + ready_listener_list_remove(d, l); + if (l->pending) return l; - } + l = d->ready_listener_head; } - return NULL; } pn_connector_t *pn_driver_connector(pn_driver_t *d) { if (!d) return NULL; - while (d->connector_next) { - pn_connector_t *c = d->connector_next; - d->connector_next = c->connector_next; - + pn_connector_t *c = d->ready_connector_head; + while (c) { + ready_connector_list_remove(d, c); if (c->closed || c->pending_read || c->pending_write || c->pending_tick) { return c; } + c = d->ready_connector_head; } - return NULL; } -static int pn_socket_pair (SOCKET sv[2]) { - // no socketpair on windows. provide pipe() semantics using sockets - - SOCKET sock = socket(AF_INET, SOCK_STREAM, getprotobyname("tcp")->p_proto); - if (sock == INVALID_SOCKET) { - perror("socket"); - return -1; - } - - BOOL b = 1; - if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (const char *) &b, sizeof(b)) == -1) { - perror("setsockopt"); - closesocket(sock); - return -1; - } - else { - struct sockaddr_in addr = {0}; - addr.sin_family = AF_INET; - addr.sin_port = 0; - addr.sin_addr.s_addr = htonl (INADDR_LOOPBACK); - - if (bind(sock, (struct sockaddr *)&addr, sizeof(addr)) == -1) { - perror("bind"); - closesocket(sock); - return -1; - } - } - - if (listen(sock, 50) == -1) { - perror("listen"); - closesocket(sock); - return -1; - } - - if ((sv[1] = socket(AF_INET, SOCK_STREAM, getprotobyname("tcp")->p_proto)) == INVALID_SOCKET) { - perror("sock1"); - closesocket(sock); - return -1; - } - else { - struct sockaddr addr = {0}; - int l = sizeof(addr); - if (getsockname(sock, &addr, &l) == -1) { - perror("getsockname"); - closesocket(sock); - return -1; - } - - if (connect(sv[1], &addr, sizeof(addr)) == -1) { - int err = WSAGetLastError(); - fprintf(stderr, "connect wsaerrr %d\n", err); - closesocket(sock); - closesocket(sv[1]); - return -1; - } - - if ((sv[0] = accept(sock, &addr, &l)) == INVALID_SOCKET) { - perror("accept"); - closesocket(sock); - closesocket(sv[1]); - return -1; - } - } - - u_long v = 1; - ioctlsocket (sv[0], FIONBIO, &v); - ioctlsocket (sv[1], FIONBIO, &v); - closesocket(sock); - return 0; +static pn_selectable_t *create_ctrl_selectable(pn_socket_t fd) +{ + // ctrl input only needs to know about read events, just like a listener. + pn_selectable_t *sel = pni_selectable(driver_listener_capacity, + driver_listener_pending, + driver_listener_deadline, + driver_listener_readable, + driver_listener_writable, + driver_listener_expired, + driver_listener_finalize); + pni_selectable_set_fd(sel, fd); + return sel; }
http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/92b8098c/proton-c/src/windows/io.c ---------------------------------------------------------------------- diff --git a/proton-c/src/windows/io.c b/proton-c/src/windows/io.c index b5660be..b2d528a 100644 --- a/proton-c/src/windows/io.c +++ b/proton-c/src/windows/io.c @@ -103,7 +103,7 @@ void pn_io_finalize(void *obj) pn_io_t *pn_io(void) { static const pn_class_t clazz = PN_CLASS(pn_io); - pn_io_t *io = (pn_io_t *) pn_new(sizeof(pn_io_t), &clazz); + pn_io_t *io = (pn_io_t *) pn_class_new(&clazz, sizeof(pn_io_t)); return io; } @@ -210,14 +210,16 @@ pn_socket_t pn_listen(pn_io_t *io, const char *host, const char *port) return INVALID_SOCKET; } - iocpdesc_t *iocpd = pni_iocpdesc_create(io->iocp, sock, false); - if (!iocpd) { - pn_i_error_from_errno(io->error, "register"); - closesocket(sock); - return INVALID_SOCKET; + if (io->iocp->selector) { + iocpdesc_t *iocpd = pni_iocpdesc_create(io->iocp, sock, false); + if (!iocpd) { + pn_i_error_from_errno(io->error, "register"); + closesocket(sock); + return INVALID_SOCKET; + } + pni_iocpdesc_start(iocpd); } - pni_iocpdesc_start(iocpd); return sock; } @@ -242,7 +244,22 @@ pn_socket_t pn_connect(pn_io_t *io, const char *hostarg, const char *port) ensure_unique(io, sock); pn_configure_sock(io, sock); - return pni_iocp_begin_connect(io->iocp, sock, addr, io->error); + + if (io->iocp->selector) { + return pni_iocp_begin_connect(io->iocp, sock, addr, io->error); + } else { + if (connect(sock, addr->ai_addr, addr->ai_addrlen) != 0) { + if (WSAGetLastError() != WSAEWOULDBLOCK) { + pni_win32_error(io->error, "connect", WSAGetLastError()); + freeaddrinfo(addr); + closesocket(sock); + return INVALID_SOCKET; + } + } + + freeaddrinfo(addr); + return sock; + } } pn_socket_t pn_accept(pn_io_t *io, pn_socket_t listen_sock, char *name, size_t size) http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/92b8098c/proton-c/src/windows/iocp.c ---------------------------------------------------------------------- diff --git a/proton-c/src/windows/iocp.c b/proton-c/src/windows/iocp.c index 614b130..3c0451a 100644 --- a/proton-c/src/windows/iocp.c +++ b/proton-c/src/windows/iocp.c @@ -30,7 +30,7 @@ #include <Ws2tcpip.h> #define PN_WINAPI -#include "../platform.h" +#include "platform.h" #include <proton/object.h> #include <proton/io.h> #include <proton/selector.h> @@ -162,8 +162,7 @@ typedef struct { } accept_result_t; static accept_result_t *accept_result(iocpdesc_t *listen_sock) { - accept_result_t *result = (accept_result_t *) pn_new(sizeof(accept_result_t), 0); - memset(result, 0, sizeof(accept_result_t)); + accept_result_t *result = (accept_result_t *)calloc(1, sizeof(accept_result_t)); if (result) { result->base.type = IOCP_ACCEPT; result->base.iocpd = listen_sock; @@ -192,7 +191,7 @@ struct pni_acceptor_t { static void pni_acceptor_initialize(void *object) { pni_acceptor_t *acceptor = (pni_acceptor_t *) object; - acceptor->accepts = pn_list(IOCP_MAX_ACCEPTS, 0); + acceptor->accepts = pn_list(PN_VOID, IOCP_MAX_ACCEPTS); } static void pni_acceptor_finalize(void *object) @@ -200,14 +199,15 @@ static void pni_acceptor_finalize(void *object) pni_acceptor_t *acceptor = (pni_acceptor_t *) object; size_t len = pn_list_size(acceptor->accepts); for (size_t i = 0; i < len; i++) - pn_free(pn_list_get(acceptor->accepts, i)); + free(pn_list_get(acceptor->accepts, i)); pn_free(acceptor->accepts); } static pni_acceptor_t *pni_acceptor(iocpdesc_t *iocpd) { + static const pn_cid_t CID_pni_acceptor = CID_pn_void; static const pn_class_t clazz = PN_CLASS(pni_acceptor); - pni_acceptor_t *acceptor = (pni_acceptor_t *) pn_new(sizeof(pni_acceptor_t), &clazz); + pni_acceptor_t *acceptor = (pni_acceptor_t *) pn_class_new(&clazz, sizeof(pni_acceptor_t)); acceptor->listen_sock = iocpd; acceptor->accept_queue_size = 0; acceptor->signalled = false; @@ -221,7 +221,7 @@ static void begin_accept(pni_acceptor_t *acceptor, accept_result_t *result) { if (acceptor->listen_sock->closing) { if (result) { - pn_free(result); + free(result); acceptor->accept_queue_size--; } if (acceptor->accept_queue_size == 0) @@ -272,7 +272,7 @@ static void complete_accept(accept_result_t *result, HRESULT status) if (ld->read_closed) { if (!result->new_sock->closing) pni_iocp_begin_close(result->new_sock); - pn_free(result); // discard + free(result); // discard reap_check(ld); } else { result->base.status = status; @@ -364,8 +364,9 @@ static void connect_result_finalize(void *object) } static connect_result_t *connect_result(iocpdesc_t *iocpd, struct addrinfo *addr) { + static const pn_cid_t CID_connect_result = CID_pn_void; static const pn_class_t clazz = PN_CLASS(connect_result); - connect_result_t *result = (connect_result_t *) pn_new(sizeof(connect_result_t), &clazz); + connect_result_t *result = (connect_result_t *) pn_class_new(&clazz, sizeof(connect_result_t)); if (result) { memset(result, 0, sizeof(connect_result_t)); result->base.type = IOCP_CONNECT; @@ -599,7 +600,7 @@ static void begin_zero_byte_read(iocpdesc_t *iocpd) } static void drain_until_closed(iocpdesc_t *iocpd) { - int max_drain = 16 * 1024; + size_t max_drain = 16 * 1024; char buf[512]; read_result_t *result = iocpd->read_result; while (result->drain_count < max_drain) { @@ -730,9 +731,10 @@ static uintptr_t pni_iocpdesc_hashcode(void *object) // Reference counted in the iocpdesc map, zombie_list, selector. static iocpdesc_t *pni_iocpdesc(pn_socket_t s) { + static const pn_cid_t CID_pni_iocpdesc = CID_pn_void; static pn_class_t clazz = PN_CLASS(pni_iocpdesc); assert (s != INVALID_SOCKET); - iocpdesc_t *iocpd = (iocpdesc_t *) pn_new(sizeof(iocpdesc_t), &clazz); + iocpdesc_t *iocpd = (iocpdesc_t *) pn_class_new(&clazz, sizeof(iocpdesc_t)); assert(iocpd); iocpd->socket = s; return iocpd; @@ -983,7 +985,7 @@ static void drain_zombie_completions(iocp_t *iocp) static pn_list_t *iocp_map_close_all(iocp_t *iocp) { // Zombify stragglers, i.e. no pn_close() from the application. - pn_list_t *externals = pn_list(0, PN_REFCOUNT); + pn_list_t *externals = pn_list(PN_OBJECT, 0); for (pn_handle_t entry = pn_hash_head(iocp->iocpdesc_map); entry; entry = pn_hash_next(iocp->iocpdesc_map, entry)) { iocpdesc_t *iocpd = (iocpdesc_t *) pn_hash_value(iocp->iocpdesc_map, entry); @@ -1101,8 +1103,8 @@ void pni_iocp_initialize(void *obj) pni_shared_pool_create(iocp); iocp->completion_port = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0); assert(iocp->completion_port != NULL); - iocp->iocpdesc_map = pn_hash(0, 0.75, PN_REFCOUNT); - iocp->zombie_list = pn_list(0, PN_REFCOUNT); + iocp->iocpdesc_map = pn_hash(PN_OBJECT, 0, 0.75); + iocp->zombie_list = pn_list(PN_OBJECT, 0); iocp->iocp_trace = pn_env_bool("PN_TRACE_DRV"); iocp->selector = NULL; } @@ -1132,7 +1134,8 @@ void pni_iocp_finalize(void *obj) iocp_t *pni_iocp() { + static const pn_cid_t CID_pni_iocp = CID_pn_void; static const pn_class_t clazz = PN_CLASS(pni_iocp); - iocp_t *iocp = (iocp_t *) pn_new(sizeof(iocp_t), &clazz); + iocp_t *iocp = (iocp_t *) pn_class_new(&clazz, sizeof(iocp_t)); return iocp; } http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/92b8098c/proton-c/src/windows/iocp.h ---------------------------------------------------------------------- diff --git a/proton-c/src/windows/iocp.h b/proton-c/src/windows/iocp.h index bc64dd0..91ded50 100644 --- a/proton-c/src/windows/iocp.h +++ b/proton-c/src/windows/iocp.h @@ -47,7 +47,7 @@ struct iocp_t { char *shared_pool_memory; write_result_t **shared_results; write_result_t **available_results; - int shared_available_count; + size_t shared_available_count; size_t writer_count; int loopback_bufsize; bool iocp_trace; http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/92b8098c/proton-c/src/windows/schannel.c ---------------------------------------------------------------------- diff --git a/proton-c/src/windows/schannel.c b/proton-c/src/windows/schannel.c new file mode 100644 index 0000000..7aaf464 --- /dev/null +++ b/proton-c/src/windows/schannel.c @@ -0,0 +1,1320 @@ +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +/* + * SChannel is designed to encrypt and decrypt data in place. So a + * given buffer is expected to sometimes contain encrypted data, + * sometimes decrypted data, and occasionally both. Outgoing buffers + * need to reserve space for the TLS header and trailer. Read + * operations need to ignore the same headers and trailers from + * incoming buffers. Outgoing is simple because we choose record + * boundaries. Incoming is complicated by handling incomplete TLS + * records, and buffering contiguous data for the app layer that may + * span many records. A lazy double buffering system is used for + * the latter. + */ + +#include <proton/ssl.h> +#include <proton/engine.h> +#include "engine/engine-internal.h" +#include "platform.h" +#include "util.h" + +#include <assert.h> + +// security.h needs to see this to distinguish from kernel use. +#include <windows.h> +#define SECURITY_WIN32 +#include <security.h> +#include <Schnlsp.h> +#undef SECURITY_WIN32 + + +/** @file + * SSL/TLS support API. + * + * This file contains an SChannel-based implemention of the SSL/TLS API for Windows platforms. + */ + +#define SSL_DATA_SIZE 16384 +#define SSL_BUF_SIZE (SSL_DATA_SIZE + 5 + 2048 + 32) + +typedef enum { UNKNOWN_CONNECTION, SSL_CONNECTION, CLEAR_CONNECTION } connection_mode_t; +typedef struct pn_ssl_session_t pn_ssl_session_t; + +struct pn_ssl_domain_t { + int ref_count; + pn_ssl_mode_t mode; + bool has_ca_db; // true when CA database configured + bool has_certificate; // true when certificate configured + char *keyfile_pw; + + // settings used for all connections + pn_ssl_verify_mode_t verify_mode; + bool allow_unsecured; + + // SChannel + HCERTSTORE cert_store; + PCCERT_CONTEXT cert_context; + SCHANNEL_CRED credential; +}; + +typedef enum { CREATED, CLIENT_HELLO, NEGOTIATING, + RUNNING, SHUTTING_DOWN, SSL_CLOSED } ssl_state_t; + +struct pn_ssl_t { + pn_transport_t *transport; + pn_io_layer_t *io_layer; + pn_ssl_domain_t *domain; + const char *session_id; + const char *peer_hostname; + ssl_state_t state; + + bool queued_shutdown; + bool ssl_closed; // shutdown complete, or SSL error + ssize_t app_input_closed; // error code returned by upper layer process input + ssize_t app_output_closed; // error code returned by upper layer process output + + // OpenSSL hides the protocol envelope bytes, SChannel has them in-line. + char *sc_outbuf; // SChannel output buffer + size_t sc_out_size; + size_t sc_out_count; + char *network_outp; // network ready bytes within sc_outbuf + size_t network_out_pending; + + char *sc_inbuf; // SChannel input buffer + size_t sc_in_size; + size_t sc_in_count; + bool sc_in_incomplete; + + char *inbuf_extra; // Still encrypted data from following Record(s) + size_t extra_count; + + char *in_data; // Just the plaintext data part of sc_inbuf, decrypted in place + size_t in_data_size; + size_t in_data_count; + bool decrypting; + size_t max_data_size; // computed in the handshake + + pn_bytes_t app_inbytes; // Virtual decrypted datastream, presented to app layer + + pn_buffer_t *inbuf2; // Second input buf if longer contiguous bytes needed + bool double_buffered; + + bool sc_input_shutdown; + + pn_trace_t trace; + + CredHandle cred_handle; + CtxtHandle ctxt_handle; + SecPkgContext_StreamSizes sc_sizes; +}; + +struct pn_ssl_session_t { + const char *id; +// TODO + pn_ssl_session_t *ssn_cache_next; + pn_ssl_session_t *ssn_cache_prev; +}; + + +static ssize_t process_input_ssl( pn_io_layer_t *io_layer, const char *input_data, size_t len); +static ssize_t process_output_ssl( pn_io_layer_t *io_layer, char *input_data, size_t len); +static ssize_t process_input_unknown( pn_io_layer_t *io_layer, const char *input_data, size_t len); +static ssize_t process_output_unknown( pn_io_layer_t *io_layer, char *input_data, size_t len); +static ssize_t process_input_done(pn_io_layer_t *io_layer, const char *input_data, size_t len); +static ssize_t process_output_done(pn_io_layer_t *io_layer, char *input_data, size_t len); +static connection_mode_t check_for_ssl_connection( const char *data, size_t len ); +static pn_ssl_session_t *ssn_cache_find( pn_ssl_domain_t *, const char * ); +static void ssl_session_free( pn_ssl_session_t *); +static size_t buffered_output( pn_io_layer_t *io_layer ); +static size_t buffered_input( pn_io_layer_t *io_layer ); +static void start_ssl_shutdown(pn_ssl_t *ssl); +static void rewind_sc_inbuf(pn_ssl_t *ssl); +static bool grow_inbuf2(pn_ssl_t *ssl, size_t minimum_size); + + +// @todo: used to avoid littering the code with calls to printf... +static void ssl_log_error(const char *fmt, ...) +{ + va_list ap; + va_start(ap, fmt); + vfprintf(stderr, fmt, ap); + va_end(ap); + fflush(stderr); +} + +// @todo: used to avoid littering the code with calls to printf... +static void ssl_log(pn_ssl_t *ssl, const char *fmt, ...) +{ + if (PN_TRACE_DRV & ssl->trace) { + va_list ap; + va_start(ap, fmt); + vfprintf(stderr, fmt, ap); + va_end(ap); + fflush(stderr); + } +} + +static void ssl_log_error_status(HRESULT status, const char *fmt, ...) +{ + char buf[512]; + va_list ap; + + if (fmt) { + va_start(ap, fmt); + vfprintf(stderr, fmt, ap); + va_end(ap); + } + + if (FormatMessage(FORMAT_MESSAGE_MAX_WIDTH_MASK | FORMAT_MESSAGE_FROM_SYSTEM, + 0, status, 0, buf, sizeof(buf), 0)) + ssl_log_error("%s\n", buf); + else + fprintf(stderr, "pn internal Windows error: %lu\n", GetLastError()); + + fflush(stderr); +} + +static void ssl_log_clear_data(pn_ssl_t *ssl, const char *data, size_t len) +{ + if (PN_TRACE_RAW & ssl->trace) { + fprintf(stderr, "SSL decrypted data: \""); + pn_fprint_data( stderr, data, len ); + fprintf(stderr, "\"\n"); + } +} + +static size_t _pni_min(size_t a, size_t b) +{ + return (a < b) ? a : b; +} + +// unrecoverable SSL failure occured, notify transport and generate error code. +static int ssl_failed(pn_ssl_t *ssl, char *reason) +{ + char buf[512] = "Unknown error."; + if (!reason) { + HRESULT status = GetLastError(); + + FormatMessage(FORMAT_MESSAGE_MAX_WIDTH_MASK | FORMAT_MESSAGE_FROM_SYSTEM, + 0, status, 0, buf, sizeof(buf), 0); + reason = buf; + } + ssl->ssl_closed = true; + ssl->app_input_closed = ssl->app_output_closed = PN_ERR; + ssl->transport->tail_closed = true; + ssl->state = SSL_CLOSED; + pn_do_error(ssl->transport, "amqp:connection:framing-error", "SSL Failure: %s", reason); + return PN_EOS; +} + +/* match the DNS name pattern from the peer certificate against our configured peer + hostname */ +static bool match_dns_pattern( const char *hostname, + const char *pattern, int plen ) +{ + return false; // TODO +} + + +static pn_ssl_session_t *ssn_cache_find( pn_ssl_domain_t *domain, const char *id ) +{ +// TODO: + return NULL; +} + +static void ssl_session_free( pn_ssl_session_t *ssn) +{ + if (ssn) { + if (ssn->id) free( (void *)ssn->id ); + free( ssn ); + } +} + + +/** Public API - visible to application code */ + +pn_ssl_domain_t *pn_ssl_domain( pn_ssl_mode_t mode ) +{ + pn_ssl_domain_t *domain = (pn_ssl_domain_t *) calloc(1, sizeof(pn_ssl_domain_t)); + if (!domain) return NULL; + + memset(domain, 0, sizeof(domain)); + domain->credential.dwVersion = SCHANNEL_CRED_VERSION; + domain->credential.dwFlags = SCH_CRED_NO_DEFAULT_CREDS; + + domain->ref_count = 1; + domain->mode = mode; + switch(mode) { + case PN_SSL_MODE_CLIENT: + // TODO + break; + + case PN_SSL_MODE_SERVER: + // TODO + break; + + default: + ssl_log_error("Invalid mode for pn_ssl_mode_t: %d\n", mode); + free(domain); + return NULL; + } + + return domain; +} + +void pn_ssl_domain_free( pn_ssl_domain_t *domain ) +{ + if (--domain->ref_count == 0) { + + if (domain->cert_context) + CertFreeCertificateContext(domain->cert_context); + if (domain->cert_store) + CertCloseStore(domain->cert_store, CERT_CLOSE_STORE_FORCE_FLAG); + + if (domain->keyfile_pw) free(domain->keyfile_pw); + free(domain); + } +} + + +int pn_ssl_domain_set_credentials( pn_ssl_domain_t *domain, + const char *certificate_file, + const char *private_key_file, + const char *password) +{ + if (!domain) return -1; + + // TODO: + + return 0; +} + + +int pn_ssl_domain_set_trusted_ca_db(pn_ssl_domain_t *domain, + const char *certificate_db) +{ + if (!domain) return -1; + // TODO: support for alternate ca db? or just return -1 + domain->has_ca_db = true; + return 0; +} + + +int pn_ssl_domain_set_peer_authentication(pn_ssl_domain_t *domain, + const pn_ssl_verify_mode_t mode, + const char *trusted_CAs) +{ + if (!domain) return -1; + + switch (mode) { + case PN_SSL_VERIFY_PEER: + case PN_SSL_VERIFY_PEER_NAME: + // TODO + break; + + case PN_SSL_ANONYMOUS_PEER: // hippie free love mode... :) + // TODO + break; + + default: + ssl_log_error( "Invalid peer authentication mode given.\n" ); + return -1; + } + + domain->verify_mode = mode; + return 0; +} + +int pn_ssl_init(pn_ssl_t *ssl, pn_ssl_domain_t *domain, const char *session_id) +{ + if (!ssl || !domain || ssl->domain) return -1; + if (ssl->state != CREATED) return -1; + + ssl->domain = domain; + domain->ref_count++; + if (domain->allow_unsecured) { + ssl->io_layer->process_input = process_input_unknown; + ssl->io_layer->process_output = process_output_unknown; + } else { + ssl->io_layer->process_input = process_input_ssl; + ssl->io_layer->process_output = process_output_ssl; + } + + if (session_id && domain->mode == PN_SSL_MODE_CLIENT) + ssl->session_id = pn_strdup(session_id); + + TimeStamp cred_expiry; + SECURITY_STATUS status = AcquireCredentialsHandle(NULL, UNISP_NAME, SECPKG_CRED_OUTBOUND, + NULL, &domain->credential, NULL, NULL, &ssl->cred_handle, + &cred_expiry); + if (status != SEC_E_OK) { + ssl_log_error_status(status, "AcquireCredentialsHandle"); + return -1; + } + + ssl->state = (domain->mode == PN_SSL_MODE_CLIENT) ? CLIENT_HELLO : NEGOTIATING; + return 0; +} + + +int pn_ssl_domain_allow_unsecured_client(pn_ssl_domain_t *domain) +{ + if (!domain) return -1; + if (domain->mode != PN_SSL_MODE_SERVER) { + ssl_log_error("Cannot permit unsecured clients - not a server.\n"); + return -1; + } + domain->allow_unsecured = true; + return 0; +} + + +bool pn_ssl_get_cipher_name(pn_ssl_t *ssl, char *buffer, size_t size ) +{ + *buffer = '\0'; + snprintf( buffer, size, "%s", "TODO: cipher_name" ); + return true; +} + +bool pn_ssl_get_protocol_name(pn_ssl_t *ssl, char *buffer, size_t size ) +{ + *buffer = '\0'; + snprintf( buffer, size, "%s", "TODO: protocol name" ); + return true; +} + + +void pn_ssl_free( pn_ssl_t *ssl) +{ + if (!ssl) return; + ssl_log( ssl, "SSL socket freed.\n" ); + // clean up Windows per TLS session data before releasing the domain count + if (SecIsValidHandle(&ssl->ctxt_handle)) + DeleteSecurityContext(&ssl->ctxt_handle); + if (SecIsValidHandle(&ssl->cred_handle)) + FreeCredentialsHandle(&ssl->cred_handle); + + if (ssl->domain) pn_ssl_domain_free(ssl->domain); + if (ssl->session_id) free((void *)ssl->session_id); + if (ssl->peer_hostname) free((void *)ssl->peer_hostname); + if (ssl->sc_inbuf) free((void *)ssl->sc_inbuf); + if (ssl->sc_outbuf) free((void *)ssl->sc_outbuf); + if (ssl->inbuf2) pn_buffer_free(ssl->inbuf2); + free(ssl); +} + +pn_ssl_t *pn_ssl(pn_transport_t *transport) +{ + if (!transport) return NULL; + if (transport->ssl) return transport->ssl; + + pn_ssl_t *ssl = (pn_ssl_t *) calloc(1, sizeof(pn_ssl_t)); + if (!ssl) return NULL; + ssl->sc_out_size = ssl->sc_in_size = SSL_BUF_SIZE; + + ssl->sc_outbuf = (char *)malloc(ssl->sc_out_size); + if (!ssl->sc_outbuf) { + free(ssl); + return NULL; + } + ssl->sc_inbuf = (char *)malloc(ssl->sc_in_size); + if (!ssl->sc_inbuf) { + free(ssl->sc_outbuf); + free(ssl); + return NULL; + } + + ssl->inbuf2 = pn_buffer(0); + if (!ssl->inbuf2) { + free(ssl->sc_inbuf); + free(ssl->sc_outbuf); + free(ssl); + return NULL; + } + + ssl->transport = transport; + transport->ssl = ssl; + + ssl->io_layer = &transport->io_layers[PN_IO_SSL]; + ssl->io_layer->context = ssl; + ssl->io_layer->process_input = pn_io_layer_input_passthru; + ssl->io_layer->process_output = pn_io_layer_output_passthru; + ssl->io_layer->process_tick = pn_io_layer_tick_passthru; + ssl->io_layer->buffered_output = buffered_output; + ssl->io_layer->buffered_input = buffered_input; + + ssl->trace = (transport->disp) ? transport->disp->trace : PN_TRACE_OFF; + SecInvalidateHandle(&ssl->cred_handle); + SecInvalidateHandle(&ssl->ctxt_handle); + ssl->state = CREATED; + ssl->decrypting = true; + + return ssl; +} + +void pn_ssl_trace(pn_ssl_t *ssl, pn_trace_t trace) +{ + ssl->trace = trace; +} + + +pn_ssl_resume_status_t pn_ssl_resume_status( pn_ssl_t *ssl ) +{ + // TODO + return PN_SSL_RESUME_UNKNOWN; +} + + +int pn_ssl_set_peer_hostname( pn_ssl_t *ssl, const char *hostname ) +{ + if (!ssl) return -1; + + if (ssl->peer_hostname) free((void *)ssl->peer_hostname); + ssl->peer_hostname = NULL; + if (hostname) { + ssl->peer_hostname = pn_strdup(hostname); + if (!ssl->peer_hostname) return -2; + } + return 0; +} + +int pn_ssl_get_peer_hostname( pn_ssl_t *ssl, char *hostname, size_t *bufsize ) +{ + if (!ssl) return -1; + if (!ssl->peer_hostname) { + *bufsize = 0; + if (hostname) *hostname = '\0'; + return 0; + } + unsigned len = strlen(ssl->peer_hostname); + if (hostname) { + if (len >= *bufsize) return -1; + strcpy( hostname, ssl->peer_hostname ); + } + *bufsize = len; + return 0; +} + + +/** SChannel specific: */ + +static void ssl_encrypt(pn_ssl_t *ssl, char *app_data, size_t count) +{ + // Get SChannel to encrypt exactly one Record. + SecBuffer buffs[4]; + buffs[0].cbBuffer = ssl->sc_sizes.cbHeader; + buffs[0].BufferType = SECBUFFER_STREAM_HEADER; + buffs[0].pvBuffer = ssl->sc_outbuf; + buffs[1].cbBuffer = count; + buffs[1].BufferType = SECBUFFER_DATA; + buffs[1].pvBuffer = app_data; + buffs[2].cbBuffer = ssl->sc_sizes.cbTrailer; + buffs[2].BufferType = SECBUFFER_STREAM_TRAILER; + buffs[2].pvBuffer = &app_data[count]; + buffs[3].cbBuffer = 0; + buffs[3].BufferType = SECBUFFER_EMPTY; + buffs[3].pvBuffer = 0; + SecBufferDesc buff_desc; + buff_desc.ulVersion = SECBUFFER_VERSION; + buff_desc.cBuffers = 4; + buff_desc.pBuffers = buffs; + SECURITY_STATUS status = EncryptMessage(&ssl->ctxt_handle, 0, &buff_desc, 0); + assert(status == SEC_E_OK); + + // EncryptMessage encrypts the data in place. The header and trailer + // areas were reserved previously and must now be included in the updated + // count of bytes to write to the peer. + ssl->sc_out_count = buffs[0].cbBuffer + buffs[1].cbBuffer + buffs[2].cbBuffer; + ssl->network_outp = ssl->sc_outbuf; + ssl->network_out_pending = ssl->sc_out_count; + ssl_log(ssl, "ssl_encrypt %d network bytes\n", ssl->network_out_pending); +} + +// Returns true if decryption succeeded (even for empty content) +static bool ssl_decrypt(pn_ssl_t *ssl) +{ + // Get SChannel to decrypt input. May have an incomplete Record, + // exactly one, or more than one. Check also for session ending, + // session renegotiation. + + SecBuffer recv_buffs[4]; + recv_buffs[0].cbBuffer = ssl->sc_in_count; + recv_buffs[0].BufferType = SECBUFFER_DATA; + recv_buffs[0].pvBuffer = ssl->sc_inbuf; + recv_buffs[1].BufferType = SECBUFFER_EMPTY; + recv_buffs[2].BufferType = SECBUFFER_EMPTY; + recv_buffs[3].BufferType = SECBUFFER_EMPTY; + SecBufferDesc buff_desc; + buff_desc.ulVersion = SECBUFFER_VERSION; + buff_desc.cBuffers = 4; + buff_desc.pBuffers = recv_buffs; + SECURITY_STATUS status = ::DecryptMessage(&ssl->ctxt_handle, &buff_desc, 0, NULL); + + if (status == SEC_E_INCOMPLETE_MESSAGE) { + // Less than a full Record, come back later with more network data + ssl->sc_in_incomplete = true; + return false; + } + + ssl->decrypting = false; + + if (status != SEC_E_OK) { + rewind_sc_inbuf(ssl); + switch (status) { + case SEC_I_CONTEXT_EXPIRED: + // TLS shutdown alert record. Ignore all subsequent input. + ssl->state = SHUTTING_DOWN; + ssl->sc_input_shutdown = true; + return false; + + case SEC_I_RENEGOTIATE: + // TODO. Fall through for now. + default: + ssl_failed(ssl, 0); + return false; + } + } + + ssl->decrypting = false; + // have a decrypted Record and possible (still-encrypted) data of + // one (or more) later Recordss. Adjust pointers accordingly. + for (int i = 0; i < 4; i++) { + switch (recv_buffs[i].BufferType) { + case SECBUFFER_DATA: + ssl->in_data = (char *) recv_buffs[i].pvBuffer; + ssl->in_data_size = ssl->in_data_count = recv_buffs[i].cbBuffer; + break; + case SECBUFFER_EXTRA: + ssl->inbuf_extra = (char *)recv_buffs[i].pvBuffer; + ssl->extra_count = recv_buffs[i].cbBuffer; + break; + default: + // SECBUFFER_STREAM_HEADER: + // SECBUFFER_STREAM_TRAILER: + break; + } + } + return true; +} + +static void client_handshake_init(pn_ssl_t *ssl) +{ + // Tell SChannel to create the first handshake token (ClientHello) + // and place it in sc_outbuf + SEC_CHAR *host = const_cast<SEC_CHAR *>(ssl->peer_hostname); + ULONG ctxt_requested = ISC_REQ_STREAM | ISC_REQ_USE_SUPPLIED_CREDS; + ULONG ctxt_attrs; + + SecBuffer send_buffs[2]; + send_buffs[0].cbBuffer = ssl->sc_out_size; + send_buffs[0].BufferType = SECBUFFER_TOKEN; + send_buffs[0].pvBuffer = ssl->sc_outbuf; + send_buffs[1].cbBuffer = 0; + send_buffs[1].BufferType = SECBUFFER_EMPTY; + send_buffs[1].pvBuffer = 0; + SecBufferDesc send_buff_desc; + send_buff_desc.ulVersion = SECBUFFER_VERSION; + send_buff_desc.cBuffers = 2; + send_buff_desc.pBuffers = send_buffs; + SECURITY_STATUS status = InitializeSecurityContext(&ssl->cred_handle, + NULL, host, ctxt_requested, 0, 0, NULL, 0, + &ssl->ctxt_handle, &send_buff_desc, + &ctxt_attrs, NULL); + + if (status == SEC_I_CONTINUE_NEEDED) { + ssl->sc_out_count = send_buffs[0].cbBuffer; + ssl->network_out_pending = ssl->sc_out_count; + // the token is the whole quantity to send + ssl->network_outp = ssl->sc_outbuf; + ssl_log(ssl, "Sending client hello %d bytes\n", ssl->network_out_pending); + } else { + ssl_log_error_status(status, "InitializeSecurityContext failed"); + ssl_failed(ssl, 0); + } +} + +static void client_handshake( pn_ssl_t* ssl) { + // Feed SChannel ongoing responses from the server until the handshake is complete. + SEC_CHAR *host = const_cast<SEC_CHAR *>(ssl->peer_hostname); + ULONG ctxt_requested = ISC_REQ_STREAM | ISC_REQ_USE_SUPPLIED_CREDS; + ULONG ctxt_attrs; + size_t max = 0; + + // token_buffs describe the buffer that's coming in. It should have + // a token from the SSL server, or empty if sending final shutdown alert. + bool shutdown = ssl->state == SHUTTING_DOWN; + SecBuffer token_buffs[2]; + token_buffs[0].cbBuffer = shutdown ? 0 : ssl->sc_in_count; + token_buffs[0].BufferType = SECBUFFER_TOKEN; + token_buffs[0].pvBuffer = shutdown ? 0 : ssl->sc_inbuf; + token_buffs[1].cbBuffer = 0; + token_buffs[1].BufferType = SECBUFFER_EMPTY; + token_buffs[1].pvBuffer = 0; + SecBufferDesc token_buff_desc; + token_buff_desc.ulVersion = SECBUFFER_VERSION; + token_buff_desc.cBuffers = 2; + token_buff_desc.pBuffers = token_buffs; + + // send_buffs will hold information to forward to the peer. + SecBuffer send_buffs[2]; + send_buffs[0].cbBuffer = ssl->sc_out_size; + send_buffs[0].BufferType = SECBUFFER_TOKEN; + send_buffs[0].pvBuffer = ssl->sc_outbuf; + send_buffs[1].cbBuffer = 0; + send_buffs[1].BufferType = SECBUFFER_EMPTY; + send_buffs[1].pvBuffer = 0; + SecBufferDesc send_buff_desc; + send_buff_desc.ulVersion = SECBUFFER_VERSION; + send_buff_desc.cBuffers = 2; + send_buff_desc.pBuffers = send_buffs; + + SECURITY_STATUS status = InitializeSecurityContext(&ssl->cred_handle, + &ssl->ctxt_handle, host, ctxt_requested, 0, 0, + &token_buff_desc, 0, NULL, &send_buff_desc, + &ctxt_attrs, NULL); + switch (status) { + case SEC_E_INCOMPLETE_MESSAGE: + // Not enough - get more data from the server then try again. + // Leave input buffers untouched. + ssl_log(ssl, "client handshake: incomplete record\n"); + ssl->sc_in_incomplete = true; + return; + + case SEC_I_CONTINUE_NEEDED: + // Successful handshake step, requiring data to be sent to peer. + // TODO: check if server has requested a client certificate + ssl->sc_out_count = send_buffs[0].cbBuffer; + // the token is the whole quantity to send + ssl->network_out_pending = ssl->sc_out_count; + ssl->network_outp = ssl->sc_outbuf; + ssl_log(ssl, "client handshake token %d bytes\n", ssl->network_out_pending); + break; + + case SEC_E_OK: + // Handshake complete. + if (shutdown) { + if (send_buffs[0].cbBuffer > 0) { + ssl->sc_out_count = send_buffs[0].cbBuffer; + // the token is the whole quantity to send + ssl->network_out_pending = ssl->sc_out_count; + ssl->network_outp = ssl->sc_outbuf; + ssl_log(ssl, "client shutdown token %d bytes\n", ssl->network_out_pending); + } else { + ssl->state = SSL_CLOSED; + } + // we didn't touch sc_inbuf, no need to reset + return; + } + if (send_buffs[0].cbBuffer != 0) { + ssl_failed(ssl, "unexpected final server token"); + break; + } + if (token_buffs[1].BufferType == SECBUFFER_EXTRA && token_buffs[1].cbBuffer > 0) { + // This seems to work but not documented, plus logic differs from decrypt message + // since the pvBuffer value is not set. Grrr. + ssl->extra_count = token_buffs[1].cbBuffer; + ssl->inbuf_extra = ssl->sc_inbuf + (ssl->sc_in_count - ssl->extra_count); + } + + QueryContextAttributes(&ssl->ctxt_handle, + SECPKG_ATTR_STREAM_SIZES, &ssl->sc_sizes); + max = ssl->sc_sizes.cbMaximumMessage + ssl->sc_sizes.cbHeader + ssl->sc_sizes.cbTrailer; + if (max > ssl->sc_out_size) { + ssl_log_error("Buffer size mismatch have %d, need %d\n", (int) ssl->sc_out_size, (int) max); + ssl->state = SHUTTING_DOWN; + ssl->app_input_closed = ssl->app_output_closed = PN_ERR; + start_ssl_shutdown(ssl); + pn_do_error(ssl->transport, "amqp:connection:framing-error", "SSL Failure: buffer size"); + break; + } + + ssl->state = RUNNING; + ssl->max_data_size = max - ssl->sc_sizes.cbHeader - ssl->sc_sizes.cbTrailer; + ssl_log(ssl, "client handshake successful %d max record size\n", max); + break; + + case SEC_I_CONTEXT_EXPIRED: + // ended before we got going + default: + ssl_log(ssl, "client handshake failed %d\n", (int) status); + ssl_failed(ssl, 0); + break; + } + ssl->decrypting = false; + rewind_sc_inbuf(ssl); +} + + +static void ssl_handshake(pn_ssl_t* ssl) { + if (ssl->domain->mode == PN_SSL_MODE_CLIENT) + client_handshake(ssl); + else { + ssl_log( ssl, "TODO: server handshake.\n" ); + ssl_failed(ssl, "internal runtime error, not yet implemented"); + } +} + +static bool grow_inbuf2(pn_ssl_t *ssl, size_t minimum_size) { + size_t old_capacity = pn_buffer_capacity(ssl->inbuf2); + size_t new_capacity = old_capacity ? old_capacity * 2 : 1024; + + while (new_capacity < minimum_size) + new_capacity *= 2; + + uint32_t max_frame = pn_transport_get_max_frame(ssl->transport); + if (max_frame != 0) { + if (old_capacity >= max_frame) { + // already big enough + ssl_log(ssl, "Application expecting %d bytes (> negotiated maximum frame)\n", new_capacity); + ssl_failed(ssl, "TLS: transport maximimum frame size error"); + return false; + } + } + + size_t extra_bytes = new_capacity - pn_buffer_size(ssl->inbuf2); + int err = pn_buffer_ensure(ssl->inbuf2, extra_bytes); + if (err) { + ssl_log(ssl, "TLS memory allocation failed for %d bytes\n", max_frame); + ssl_failed(ssl, "TLS memory allocation failed"); + return false; + } + return true; +} + + +// Peer initiated a session end by sending us a shutdown alert (and we should politely +// reciprocate), or else we are initiating the session end (and will not bother to wait +// for the peer shutdown alert). Stop processing input immediately, and stop processing +// output once this is sent. + +static void start_ssl_shutdown(pn_ssl_t *ssl) +{ + assert(ssl->network_out_pending == 0); + if (ssl->queued_shutdown) + return; + ssl->queued_shutdown = true; + ssl_log(ssl, "Shutting down SSL connection...\n"); + + DWORD shutdown = SCHANNEL_SHUTDOWN; + SecBuffer shutBuff; + shutBuff.cbBuffer = sizeof(DWORD); + shutBuff.BufferType = SECBUFFER_TOKEN; + shutBuff.pvBuffer = &shutdown; + SecBufferDesc desc; + desc.ulVersion = SECBUFFER_VERSION; + desc.cBuffers = 1; + desc.pBuffers = &shutBuff; + ::ApplyControlToken(&ssl->ctxt_handle, &desc); + + // Next handshake will generate the shudown alert token + ssl_handshake(ssl); +} + +static int setup_ssl_connection(pn_ssl_t *ssl) +{ + ssl_log( ssl, "SSL connection detected.\n"); + ssl->io_layer->process_input = process_input_ssl; + ssl->io_layer->process_output = process_output_ssl; + return 0; +} + +static void rewind_sc_inbuf(pn_ssl_t *ssl) +{ + // Decrypted bytes have been drained or double buffered. Prepare for the next SSL Record. + assert(ssl->in_data_count == 0); + if (ssl->decrypting) + return; + ssl->decrypting = true; + if (ssl->inbuf_extra) { + // A previous read picked up more than one Record. Move it to the beginning. + memmove(ssl->sc_inbuf, ssl->inbuf_extra, ssl->extra_count); + ssl->sc_in_count = ssl->extra_count; + ssl->inbuf_extra = 0; + ssl->extra_count = 0; + } else { + ssl->sc_in_count = 0; + } +} + +static void app_inbytes_add(pn_ssl_t *ssl) +{ + if (!ssl->app_inbytes.start) { + ssl->app_inbytes.start = ssl->in_data; + ssl->app_inbytes.size = ssl->in_data_count; + return; + } + + if (ssl->double_buffered) { + if (pn_buffer_available(ssl->inbuf2) == 0) { + if (!grow_inbuf2(ssl, 1024)) + // could not add room + return; + } + size_t count = _pni_min(ssl->in_data_count, pn_buffer_available(ssl->inbuf2)); + pn_buffer_append(ssl->inbuf2, ssl->in_data, count); + ssl->in_data += count; + ssl->in_data_count -= count; + ssl->app_inbytes = pn_buffer_bytes(ssl->inbuf2); + } else { + assert(ssl->app_inbytes.size == 0); + ssl->app_inbytes.start = ssl->in_data; + ssl->app_inbytes.size = ssl->in_data_count; + } +} + + +static void app_inbytes_progress(pn_ssl_t *ssl, size_t minimum) +{ + // Make more decrypted data available, if possible. Otherwise, move + // unread bytes to front of inbuf2 to make room for next bulk decryption. + // SSL may have chopped up data that app layer expects to be + // contiguous. Start, continue or stop double buffering here. + if (ssl->double_buffered) { + if (ssl->app_inbytes.size == 0) { + // no straggler bytes, optimistically stop for now + ssl->double_buffered = false; + pn_buffer_clear(ssl->inbuf2); + ssl->app_inbytes.start = ssl->in_data; + ssl->app_inbytes.size = ssl->in_data_count; + } else { + pn_bytes_t ib2 = pn_buffer_bytes(ssl->inbuf2); + assert(ssl->app_inbytes.size <= ib2.size); + size_t consumed = ib2.size - ssl->app_inbytes.size; + if (consumed > 0) { + memmove((void *)ib2.start, ib2.start + consumed, consumed); + pn_buffer_trim(ssl->inbuf2, 0, consumed); + } + if (!pn_buffer_available(ssl->inbuf2)) { + if (!grow_inbuf2(ssl, minimum)) + // could not add room + return; + } + size_t count = _pni_min(ssl->in_data_count, pn_buffer_available(ssl->inbuf2)); + pn_buffer_append(ssl->inbuf2, ssl->in_data, count); + ssl->in_data += count; + ssl->in_data_count -= count; + ssl->app_inbytes = pn_buffer_bytes(ssl->inbuf2); + } + } else { + if (ssl->app_inbytes.size) { + // start double buffering the left over bytes + ssl->double_buffered = true; + pn_buffer_clear(ssl->inbuf2); + if (!pn_buffer_available(ssl->inbuf2)) { + if (!grow_inbuf2(ssl, minimum)) + // could not add room + return; + } + size_t count = _pni_min(ssl->in_data_count, pn_buffer_available(ssl->inbuf2)); + pn_buffer_append(ssl->inbuf2, ssl->in_data, count); + ssl->in_data += count; + ssl->in_data_count -= count; + ssl->app_inbytes = pn_buffer_bytes(ssl->inbuf2); + } else { + // already pointing at all available bytes until next decrypt + } + } + if (ssl->in_data_count == 0) + rewind_sc_inbuf(ssl); +} + + +static void app_inbytes_advance(pn_ssl_t *ssl, size_t consumed) +{ + if (consumed == 0) { + // more contiguous bytes required + app_inbytes_progress(ssl, ssl->app_inbytes.size + 1); + return; + } + assert(consumed <= ssl->app_inbytes.size); + ssl->app_inbytes.start += consumed; + ssl->app_inbytes.size -= consumed; + if (!ssl->double_buffered) { + ssl->in_data += consumed; + ssl->in_data_count -= consumed; + } + if (ssl->app_inbytes.size == 0) + app_inbytes_progress(ssl, 0); +} + +static void read_closed(pn_ssl_t *ssl, ssize_t error) +{ + if (ssl->app_input_closed) + return; + if (ssl->state == RUNNING && !error) { + pn_io_layer_t *io_next = ssl->io_layer->next; + // Signal end of stream + ssl->app_input_closed = io_next->process_input(io_next, ssl->app_inbytes.start, 0); + } + if (!ssl->app_input_closed) + ssl->app_input_closed = error ? error : PN_ERR; + + if (ssl->app_output_closed) { + // both sides of app closed, and no more app output pending: + ssl->state = SHUTTING_DOWN; + if (ssl->network_out_pending == 0 && !ssl->queued_shutdown) { + start_ssl_shutdown(ssl); + } + } +} + + +// Read up to "available" bytes from the network, decrypt it and pass plaintext to application. + +static ssize_t process_input_ssl(pn_io_layer_t *io_layer, const char *input_data, size_t available) +{ + pn_ssl_t *ssl = (pn_ssl_t *)io_layer->context; + ssl_log( ssl, "process_input_ssl( data size=%d )\n",available ); + ssize_t consumed = 0; + ssize_t forwarded = 0; + bool new_app_input; + + if (available == 0) { + // No more inbound network data + read_closed(ssl,0); + return 0; + } + + do { + if (ssl->sc_input_shutdown) { + // TLS protocol shutdown detected on input + read_closed(ssl,0); + return consumed; + } + + // sc_inbuf should be ready for new or additional network encrypted bytes. + // i.e. no straggling decrypted bytes pending. + assert(ssl->in_data_count == 0 && ssl->decrypting); + new_app_input = false; + size_t count; + + if (ssl->state != RUNNING) { + count = _pni_min(ssl->sc_in_size - ssl->sc_in_count, available); + } else { + // look for TLS record boundaries + if (ssl->sc_in_count < 5) { + ssl->sc_in_incomplete = true; + size_t hc = _pni_min(available, 5 - ssl->sc_in_count); + memmove(ssl->sc_inbuf + ssl->sc_in_count, input_data, hc); + ssl->sc_in_count += hc; + input_data += hc; + available -= hc; + consumed += hc; + if (ssl->sc_in_count < 5 || available == 0) + break; + } + + // Top up sc_inbuf from network input_data hoping for a complete TLS Record + // We try to guess the length as an optimization, but let SChannel + // ultimately decide if there is spoofing going on. + unsigned char low = (unsigned char) ssl->sc_inbuf[4]; + unsigned char high = (unsigned char) ssl->sc_inbuf[3]; + size_t rec_len = high * 256 + low + 5; + if (rec_len < 5 || rec_len == ssl->sc_in_count || rec_len > ssl->sc_in_size) + rec_len = ssl->sc_in_size; + + count = _pni_min(rec_len - ssl->sc_in_count, available); + } + + if (count > 0) { + memmove(ssl->sc_inbuf + ssl->sc_in_count, input_data, count); + ssl->sc_in_count += count; + input_data += count; + available -= count; + consumed += count; + ssl->sc_in_incomplete = false; + } + + // Try to decrypt another TLS Record. + + if (ssl->sc_in_count > 0 && ssl->state <= SHUTTING_DOWN) { + if (ssl->state == NEGOTIATING) { + ssl_handshake(ssl); + } else { + if (ssl_decrypt(ssl)) { + // Ignore TLS Record with 0 length data (does not mean EOS) + if (ssl->in_data_size > 0) { + new_app_input = true; + app_inbytes_add(ssl); + } else { + assert(ssl->decrypting == false); + rewind_sc_inbuf(ssl); + } + } + ssl_log(ssl, "Next decryption, %d left over\n", available); + } + } + + if (ssl->state == SHUTTING_DOWN) { + if (ssl->network_out_pending == 0 && !ssl->queued_shutdown) { + start_ssl_shutdown(ssl); + } + } else if (ssl->state == SSL_CLOSED) { + return consumed ? consumed : -1; + } + + // Consume or discard the decrypted bytes + if (new_app_input && (ssl->state == RUNNING || ssl->state == SHUTTING_DOWN)) { + // present app_inbytes to io_next only if it has new content + while (ssl->app_inbytes.size > 0) { + if (!ssl->app_input_closed) { + pn_io_layer_t *io_next = ssl->io_layer->next; + ssize_t count = io_next->process_input(io_next, ssl->app_inbytes.start, ssl->app_inbytes.size); + if (count > 0) { + forwarded += count; + // advance() can increase app_inbytes.size if double buffered + app_inbytes_advance(ssl, count); + ssl_log(ssl, "Application consumed %d bytes from peer\n", (int) count); + } else if (count == 0) { + size_t old_size = ssl->app_inbytes.size; + app_inbytes_advance(ssl, 0); + if (ssl->app_inbytes.size == old_size) { + break; // no additional contiguous decrypted data available, get more network data + } + } else { + // count < 0 + ssl_log(ssl, "Application layer closed its input, error=%d (discarding %d bytes)\n", + (int) count, (int)ssl->app_inbytes.size); + app_inbytes_advance(ssl, ssl->app_inbytes.size); // discard + read_closed(ssl, count); + } + } else { + ssl_log(ssl, "Input closed discard %d bytes\n", + (int)ssl->app_inbytes.size); + app_inbytes_advance(ssl, ssl->app_inbytes.size); // discard + } + } + } + } while (available || (ssl->sc_in_count && !ssl->sc_in_incomplete)); + + if (ssl->app_input_closed && ssl->state >= SHUTTING_DOWN) { + consumed = ssl->app_input_closed; + ssl->io_layer->process_input = process_input_done; + } + ssl_log(ssl, "process_input_ssl() returning %d, forwarded %d\n", (int) consumed, (int) forwarded); + return consumed; +} + +static ssize_t process_output_ssl( pn_io_layer_t *io_layer, char *buffer, size_t max_len) +{ + pn_ssl_t *ssl = (pn_ssl_t *)io_layer->context; + if (!ssl) return PN_ERR; + ssl_log( ssl, "process_output_ssl( max_len=%d )\n",max_len ); + + ssize_t written = 0; + ssize_t total_app_bytes = 0; + bool work_pending; + + if (ssl->state == CLIENT_HELLO) { + // output buffers eclusively for internal handshake use until negotiation complete + client_handshake_init(ssl); + if (ssl->state == SSL_CLOSED) + return PN_ERR; + ssl->state = NEGOTIATING; + } + + do { + work_pending = false; + + if (ssl->network_out_pending > 0) { + size_t wcount = _pni_min(ssl->network_out_pending, max_len); + memmove(buffer, ssl->network_outp, wcount); + ssl->network_outp += wcount; + ssl->network_out_pending -= wcount; + buffer += wcount; + max_len -= wcount; + written += wcount; + } + + if (ssl->network_out_pending == 0 && ssl->state == RUNNING && !ssl->app_output_closed) { + // refill the buffer with app data and encrypt it + + char *app_data = ssl->sc_outbuf + ssl->sc_sizes.cbHeader; + char *app_outp = app_data; + size_t remaining = ssl->max_data_size; + ssize_t app_bytes; + do { + pn_io_layer_t *io_next = ssl->io_layer->next; + app_bytes = io_next->process_output(io_next, app_outp, remaining); + if (app_bytes > 0) { + app_outp += app_bytes; + remaining -= app_bytes; + ssl_log( ssl, "Gathered %d bytes from app to send to peer\n", app_bytes ); + } else { + if (app_bytes < 0) { + ssl_log(ssl, "Application layer closed its output, error=%d (%d bytes pending send)\n", + (int) app_bytes, (int) ssl->network_out_pending); + ssl->app_output_closed = app_bytes; + if (ssl->app_input_closed) + ssl->state = SHUTTING_DOWN; + } else if (total_app_bytes == 0 && ssl->app_input_closed) { + // We've drained all the App layer can provide + ssl_log(ssl, "Application layer blocked on input, closing\n"); + ssl->state = SHUTTING_DOWN; + ssl->app_output_closed = PN_ERR; + } + } + } while (app_bytes > 0); + if (app_outp > app_data) { + work_pending = (max_len > 0); + ssl_encrypt(ssl, app_data, app_outp - app_data); + } + } + + if (ssl->network_out_pending == 0 && ssl->state == SHUTTING_DOWN) { + if (!ssl->queued_shutdown) { + start_ssl_shutdown(ssl); + work_pending = true; + } else { + ssl->state = SSL_CLOSED; + } + } + } while (work_pending); + + if (written == 0 && ssl->state == SSL_CLOSED) { + written = ssl->app_output_closed ? ssl->app_output_closed : PN_EOS; + ssl->io_layer->process_output = process_output_done; + } + ssl_log(ssl, "process_output_ssl() returning %d\n", (int) written); + return written; +} + + +static int setup_cleartext_connection( pn_ssl_t *ssl ) +{ + ssl_log( ssl, "Cleartext connection detected.\n"); + ssl->io_layer->process_input = pn_io_layer_input_passthru; + ssl->io_layer->process_output = pn_io_layer_output_passthru; + return 0; +} + + +// until we determine if the client is using SSL or not: + +static ssize_t process_input_unknown(pn_io_layer_t *io_layer, const char *input_data, size_t len) +{ + pn_ssl_t *ssl = (pn_ssl_t *)io_layer->context; + switch (check_for_ssl_connection( input_data, len )) { + case SSL_CONNECTION: + setup_ssl_connection( ssl ); + return ssl->io_layer->process_input( ssl->io_layer, input_data, len ); + case CLEAR_CONNECTION: + setup_cleartext_connection( ssl ); + return ssl->io_layer->process_input( ssl->io_layer, input_data, len ); + default: + return 0; + } +} + +static ssize_t process_output_unknown(pn_io_layer_t *io_layer, char *input_data, size_t len) +{ + // do not do output until we know if SSL is used or not + return 0; +} + +static connection_mode_t check_for_ssl_connection( const char *data, size_t len ) +{ + if (len >= 5) { + const unsigned char *buf = (unsigned char *)data; + /* + * SSLv2 Client Hello format + * http://www.mozilla.org/projects/security/pki/nss/ssl/draft02.html + * + * Bytes 0-1: RECORD-LENGTH + * Byte 2: MSG-CLIENT-HELLO (1) + * Byte 3: CLIENT-VERSION-MSB + * Byte 4: CLIENT-VERSION-LSB + * + * Allowed versions: + * 2.0 - SSLv2 + * 3.0 - SSLv3 + * 3.1 - TLS 1.0 + * 3.2 - TLS 1.1 + * 3.3 - TLS 1.2 + * + * The version sent in the Client-Hello is the latest version supported by + * the client. NSS may send version 3.x in an SSLv2 header for + * maximum compatibility. + */ + int isSSL2Handshake = buf[2] == 1 && // MSG-CLIENT-HELLO + ((buf[3] == 3 && buf[4] <= 3) || // SSL 3.0 & TLS 1.0-1.2 (v3.1-3.3) + (buf[3] == 2 && buf[4] == 0)); // SSL 2 + + /* + * SSLv3/TLS Client Hello format + * RFC 2246 + * + * Byte 0: ContentType (handshake - 22) + * Bytes 1-2: ProtocolVersion {major, minor} + * + * Allowed versions: + * 3.0 - SSLv3 + * 3.1 - TLS 1.0 + * 3.2 - TLS 1.1 + * 3.3 - TLS 1.2 + */ + int isSSL3Handshake = buf[0] == 22 && // handshake + (buf[1] == 3 && buf[2] <= 3); // SSL 3.0 & TLS 1.0-1.2 (v3.1-3.3) + + if (isSSL2Handshake || isSSL3Handshake) { + return SSL_CONNECTION; + } else { + return CLEAR_CONNECTION; + } + } + return UNKNOWN_CONNECTION; +} + +static ssize_t process_input_done(pn_io_layer_t *io_layer, const char *input_data, size_t len) +{ + return PN_EOS; +} + +static ssize_t process_output_done(pn_io_layer_t *io_layer, char *input_data, size_t len) +{ + return PN_EOS; +} + +// return # output bytes sitting in this layer +static size_t buffered_output(pn_io_layer_t *io_layer) +{ + size_t count = 0; + pn_ssl_t *ssl = (pn_ssl_t *)io_layer->context; + if (ssl) { + count += ssl->network_out_pending; + if (count == 0 && ssl->state == SHUTTING_DOWN && ssl->queued_shutdown) + count++; + } + return count; +} + +// return # input bytes sitting in this layer +static size_t buffered_input( pn_io_layer_t *io_layer ) +{ + size_t count = 0; + pn_ssl_t *ssl = (pn_ssl_t *)io_layer->context; + if (ssl) { + count += ssl->in_data_count; + } + return count; +} http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/92b8098c/proton-c/src/windows/selector.c ---------------------------------------------------------------------- diff --git a/proton-c/src/windows/selector.c b/proton-c/src/windows/selector.c index b01c27a..a7ee49f 100644 --- a/proton-c/src/windows/selector.c +++ b/proton-c/src/windows/selector.c @@ -65,8 +65,8 @@ void pn_selector_initialize(void *obj) selector->iocp = NULL; selector->deadlines = NULL; selector->capacity = 0; - selector->selectables = pn_list(0, 0); - selector->iocp_descriptors = pn_list(0, PN_REFCOUNT); + selector->selectables = pn_list(PN_WEAKREF, 0); + selector->iocp_descriptors = pn_list(PN_OBJECT, 0); selector->deadline = 0; selector->current = 0; selector->current_triggered = NULL; @@ -95,7 +95,7 @@ void pn_selector_finalize(void *obj) pn_selector_t *pni_selector() { static const pn_class_t clazz = PN_CLASS(pn_selector); - pn_selector_t *selector = (pn_selector_t *) pn_new(sizeof(pn_selector_t), &clazz); + pn_selector_t *selector = (pn_selector_t *) pn_class_new(&clazz, sizeof(pn_selector_t)); return selector; } http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/92b8098c/proton-c/src/windows/write_pipeline.c ---------------------------------------------------------------------- diff --git a/proton-c/src/windows/write_pipeline.c b/proton-c/src/windows/write_pipeline.c index 3160fa8..438ba88 100644 --- a/proton-c/src/windows/write_pipeline.c +++ b/proton-c/src/windows/write_pipeline.c @@ -168,8 +168,9 @@ static void write_pipeline_finalize(void *object) write_pipeline_t *pni_write_pipeline(iocpdesc_t *iocpd) { + static const pn_cid_t CID_write_pipeline = CID_pn_void; static const pn_class_t clazz = PN_CLASS(write_pipeline); - write_pipeline_t *pipeline = (write_pipeline_t *) pn_new(sizeof(write_pipeline_t), &clazz); + write_pipeline_t *pipeline = (write_pipeline_t *) pn_class_new(&clazz, sizeof(write_pipeline_t)); pipeline->iocpd = iocpd; pipeline->primary->base.iocpd = iocpd; return pipeline; @@ -243,15 +244,15 @@ size_t pni_write_pipeline_reserve(write_pipeline_t *pl, size_t count) iocp_t *iocp = pl->iocpd->iocp; confirm_as_writer(pl); - int wanted = (count / IOCP_WBUFSIZE); + size_t wanted = (count / IOCP_WBUFSIZE); if (count % IOCP_WBUFSIZE) wanted++; size_t pending = pl->pending_count; assert(pending < pl->depth); - int bufs = pn_min(wanted, pl->depth - pending); + size_t bufs = pn_min(wanted, pl->depth - pending); // Can draw from shared pool or the primary... but share with others. size_t writers = iocp->writer_count; - int shared_count = (iocp->shared_available_count + writers - 1) / writers; + size_t shared_count = (iocp->shared_available_count + writers - 1) / writers; bufs = pn_min(bufs, shared_count + 1); pl->reserved_count = pending + bufs; http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/92b8098c/proton-j/src/main/java/org/apache/qpid/proton/amqp/Binary.java ---------------------------------------------------------------------- diff --git a/proton-j/src/main/java/org/apache/qpid/proton/amqp/Binary.java b/proton-j/src/main/java/org/apache/qpid/proton/amqp/Binary.java index 31989d2..a416f3f 100644 --- a/proton-j/src/main/java/org/apache/qpid/proton/amqp/Binary.java +++ b/proton-j/src/main/java/org/apache/qpid/proton/amqp/Binary.java @@ -50,6 +50,7 @@ public final class Binary return ByteBuffer.wrap(_data, _offset, _length); } + @Override public final int hashCode() { int hc = _hashCode; @@ -64,13 +65,20 @@ public final class Binary return hc; } + @Override public final boolean equals(Object o) { - Binary buf = (Binary) o; - if(o == null) + if (this == o) + { + return true; + } + + if (o == null || getClass() != o.getClass()) { return false; } + + Binary buf = (Binary) o; final int size = _length; if (size != buf._length) { http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/92b8098c/proton-j/src/main/java/org/apache/qpid/proton/codec/EncoderImpl.java ---------------------------------------------------------------------- diff --git a/proton-j/src/main/java/org/apache/qpid/proton/codec/EncoderImpl.java b/proton-j/src/main/java/org/apache/qpid/proton/codec/EncoderImpl.java index 4b45e01..d681ffe 100644 --- a/proton-j/src/main/java/org/apache/qpid/proton/codec/EncoderImpl.java +++ b/proton-j/src/main/java/org/apache/qpid/proton/codec/EncoderImpl.java @@ -21,7 +21,11 @@ package org.apache.qpid.proton.codec; import java.nio.ByteBuffer; -import java.util.*; +import java.util.Date; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; import org.apache.qpid.proton.amqp.Binary; import org.apache.qpid.proton.amqp.Decimal128; @@ -770,29 +774,43 @@ public final class EncoderImpl implements ByteBufferEncoder void writeRaw(String string) { final int length = string.length(); - char c; + int c; for (int i = 0; i < length; i++) { c = string.charAt(i); - if ((c >= 0x0001) && (c <= 0x007F)) + if ((c & 0xFF80) == 0) /* U+0000..U+007F */ { _buffer.put((byte) c); - } - else if (c > 0x07FF) + else if ((c & 0xF800) == 0) /* U+0080..U+07FF */ + { + _buffer.put((byte)(0xC0 | ((c >> 6) & 0x1F))); + _buffer.put((byte)(0x80 | (c & 0x3F))); + } + else if ((c & 0xD800) != 0xD800) /* U+0800..U+FFFF - excluding surrogate pairs */ { - _buffer.put((byte) (0xE0 | ((c >> 12) & 0x0F))); - _buffer.put((byte) (0x80 | ((c >> 6) & 0x3F))); - _buffer.put((byte) (0x80 | (c & 0x3F))); + _buffer.put((byte)(0xE0 | ((c >> 12) & 0x0F))); + _buffer.put((byte)(0x80 | ((c >> 6) & 0x3F))); + _buffer.put((byte)(0x80 | (c & 0x3F))); } else { - _buffer.put((byte) (0xC0 | ((c >> 6) & 0x1F))); - _buffer.put((byte) (0x80 | (c & 0x3F))); + int low; + + if(((c & 0xDC00) == 0xDC00) || (++i == length) || ((low = string.charAt(i)) & 0xDC00) != 0xDC00) + { + throw new IllegalArgumentException("String contains invalid Unicode code points"); + } + + c = 0x010000 + ((c & 0x03FF) << 10) + (low & 0x03FF); + + _buffer.put((byte)(0xF0 | ((c >> 18) & 0x07))); + _buffer.put((byte)(0x80 | ((c >> 12) & 0x3F))); + _buffer.put((byte)(0x80 | ((c >> 6) & 0x3F))); + _buffer.put((byte)(0x80 | (c & 0x3F))); } } - } http://git-wip-us.apache.org/repos/asf/qpid-proton/blob/92b8098c/proton-j/src/main/java/org/apache/qpid/proton/codec/StringType.java ---------------------------------------------------------------------- diff --git a/proton-j/src/main/java/org/apache/qpid/proton/codec/StringType.java b/proton-j/src/main/java/org/apache/qpid/proton/codec/StringType.java index 3728a42..aa988f9 100644 --- a/proton-j/src/main/java/org/apache/qpid/proton/codec/StringType.java +++ b/proton-j/src/main/java/org/apache/qpid/proton/codec/StringType.java @@ -83,29 +83,22 @@ public class StringType extends AbstractPrimitiveType<String> return encoding; } - private static int calculateUTF8Length(final String s) + static int calculateUTF8Length(final String s) { int len = s.length(); - int i = 0; - final int length = s.length(); - while(i < length) + final int length = len; + for (int i = 0; i < length; i++) { - char c = s.charAt(i); - if(c > 127) + int c = s.charAt(i); + if ((c & 0xFF80) != 0) /* U+0080.. */ { len++; - if(c > 0x07ff) + // surrogate pairs should always combine to create a code point with a 4 octet representation + if(((c & 0xF800) != 0) && ((c & 0xD800) != 0xD800)) /* U+0800.. excluding surrogate pairs */ { len++; - if(c >= 0xD800 && c <= 0xDBFF) - { - i++; - len++; - } } } - i++; - } return len; } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
