From 30fede7ecea2ccffdd5d6c6d02ba5d5454a3348f Mon Sep 17 00:00:00 2001
From: Joseph Southwell <jsouthwell@serengeti.com>
Date: Tue, 23 May 2017 17:32:41 -0400
Subject: [PATCH 1/3] got external io working in libssh client

Signed-off-by: Joseph Southwell <jsouthwell@serengeti.com>
---
 include/libssh/callbacks.h  |  69 +++++++++++++++++++++++++
 include/libssh/libsshpp.hpp |   8 +--
 include/libssh/session.h    |   5 +-
 src/callbacks.c             |  17 +++++++
 src/client.c                |   1 +
 src/socket.c                |  22 +++++++-
 6 files changed, 101 insertions(+), 21 deletions(-)

diff --git a/include/libssh/callbacks.h b/include/libssh/callbacks.h
index 4e71b3b9..b382acb3 100644
--- a/include/libssh/callbacks.h
+++ b/include/libssh/callbacks.h
@@ -400,6 +400,52 @@ struct ssh_socket_callbacks_struct {
 };
 typedef struct ssh_socket_callbacks_struct *ssh_socket_callbacks;
 
+/** 
+* @brief callback to read data from the socket
+* If this function returns less than zero
+* it must update WSASetLastError or errno
+* @param socket_t to read from
+* @param userdata for callback
+* @param buffer for read data
+* @param size of read buffer
+* @return < 0 on error and >= 0 indicated how much data read into buffer
+*/
+typedef int(*ssh_callback_recv) (socket_t socket, void *userdata, char *buffer, int size);
+
+/**
+* @brief callback to write data to the socket
+* If this function returns less than zero
+* it must update WSASetLastError or errno
+* @param socket_t to write to
+* @param userdata for callback
+* @param buffer to write from
+* @param size how much data to write
+* @return < 0 on error and >= 0 indicated how much data consumed from buffer
+*/
+typedef int(*ssh_callback_send) (socket_t socket, void *userdata, const char *buffer, int size);
+
+/**
+* These are the callbacks exported by the socket structure
+* They are called by unbuffered read and write when socket data is sent or received
+*/
+struct ssh_socket_io_callbacks_struct {
+	/**
+	* User-provided data. User is free to set anything he wants here
+	*/
+	void *userdata;
+	/**
+	* This function will be called to write data to the socket. 
+	* The data not consumed will appear on the next send event.
+	*/
+	ssh_callback_send send;
+	/**
+	* This function will be called to read data from the socket.
+	*/
+	ssh_callback_recv recv;
+};
+
+typedef struct ssh_socket_io_callbacks_struct *ssh_socket_io_callbacks;
+
 #define SSH_SOCKET_FLOW_WRITEWILLBLOCK 1
 #define SSH_SOCKET_FLOW_WRITEWONTBLOCK 2
 
@@ -566,6 +612,29 @@ typedef struct ssh_packet_callbacks_struct *ssh_packet_callbacks;
 LIBSSH_API int ssh_set_callbacks(ssh_session session, ssh_callbacks cb);
 
 /**
+* @brief Set the session io callback functions.
+*
+* This functions sets the session io callback functions 
+* It allows you to use your own custom network interactions
+*
+* @code
+* struct ssh_socket_io_callbacks_struct io_cb = {
+*   .userdata = data,
+*   .send = my_send_function
+*   .recv = my_recv_function
+* };
+* ssh_set_io_callbacks(session, &io_cb);
+* @endcode
+*
+* @param  session      The session to set the callback structure.
+*
+* @param  io_cb           The callback structure itself.
+*
+* @return SSH_OK on success, SSH_ERROR on error.
+*/
+LIBSSH_API int ssh_set_io_callbacks(ssh_session session, ssh_socket_io_callbacks cb);
+
+/**
  * @brief SSH channel data callback. Called when data is available on a channel
  * @param session Current session handler
  * @param channel the actual channel
diff --git a/include/libssh/libsshpp.hpp b/include/libssh/libsshpp.hpp
index af08a914..1b058ad1 100644
--- a/include/libssh/libsshpp.hpp
+++ b/include/libssh/libsshpp.hpp
@@ -86,14 +86,14 @@ public:
    * @returns SSH_REQUEST_DENIED Request was denied by remote host
    * @see ssh_get_error_code
    */
-  int getCode(){
+  int getCode() const {
     return code;
   }
   /** @brief returns the error message of the last exception
    * @returns pointer to a c string containing the description of error
    * @see ssh_get_error
    */
-  std::string getError(){
+  std::string getError() const {
     return description;
   }
 private:
@@ -378,7 +378,7 @@ public:
     return_throwable;
   }
 
-private:
+protected:
   ssh_session c_session;
   ssh_session getCSession(){
     return c_session;
@@ -584,7 +584,7 @@ public:
     ssh_throw(ret);
     return ret;
   }
-private:
+protected:
   ssh_session getCSession(){
     return session->getCSession();
   }
diff --git a/include/libssh/session.h b/include/libssh/session.h
index 60d78578..50a05b31 100644
--- a/include/libssh/session.h
+++ b/include/libssh/session.h
@@ -169,8 +169,9 @@ struct ssh_session_struct {
     void (*ssh_connection_callback)( struct ssh_session_struct *session);
     struct ssh_packet_callbacks_struct default_packet_callbacks;
     struct ssh_list *packet_callbacks;
-    struct ssh_socket_callbacks_struct socket_callbacks;
-    ssh_poll_ctx default_poll_ctx;
+	struct ssh_socket_callbacks_struct socket_callbacks;
+	struct ssh_socket_io_callbacks_struct socket_io_callbacks;
+	ssh_poll_ctx default_poll_ctx;
     /* options */
 #ifdef WITH_PCAP
     ssh_pcap_context pcap_ctx; /* pcap debugging context */
diff --git a/src/callbacks.c b/src/callbacks.c
index 3ed2f11c..c8ee9e7b 100644
--- a/src/callbacks.c
+++ b/src/callbacks.c
@@ -67,6 +67,23 @@ int ssh_set_callbacks(ssh_session session, ssh_callbacks cb) {
   return 0;
 }
 
+int ssh_set_io_callbacks(ssh_session session, ssh_socket_io_callbacks io_cb) {
+	if (session == NULL ) {
+		return SSH_ERROR;
+	}
+
+	if (io_cb == NULL) {
+		ssh_set_error(session,
+			SSH_FATAL,
+			"Invalid callback passed in (badly initialized)");
+		return SSH_ERROR;
+	};
+
+	session->socket_io_callbacks = *io_cb;
+	return SSH_OK;
+}
+
+
 static int ssh_add_set_channel_callbacks(ssh_channel channel,
                                          ssh_channel_callbacks cb,
                                          int prepend)
diff --git a/src/client.c b/src/client.c
index 3b120bbc..7c903403 100644
--- a/src/client.c
+++ b/src/client.c
@@ -576,6 +576,7 @@ int ssh_connect(ssh_session session) {
   session->socket_callbacks.userdata=session;
   if (session->opts.fd != SSH_INVALID_SOCKET) {
     session->session_state=SSH_SESSION_STATE_SOCKET_CONNECTED;
+	ssh_socket_set_io_callbacks(session->socket, &session->socket_io_callbacks);
     ssh_socket_set_fd(session->socket, session->opts.fd);
     ret=SSH_OK;
 #ifndef _WIN32
diff --git a/src/socket.c b/src/socket.c
index 76dc55e5..b6af36af 100644
--- a/src/socket.c
+++ b/src/socket.c
@@ -87,6 +87,7 @@ struct ssh_socket_struct {
   ssh_buffer in_buffer;
   ssh_session session;
   ssh_socket_callbacks callbacks;
+  ssh_socket_io_callbacks io_callbacks;
   ssh_poll_handle poll_in;
   ssh_poll_handle poll_out;
 };
@@ -204,6 +205,19 @@ void ssh_socket_set_callbacks(ssh_socket s, ssh_socket_callbacks callbacks){
 }
 
 /**
+* @internal
+* @brief the socket io callbacks, i.e. the callbacks to called
+* to read or write data to the socket.
+* Only useful when using an externally provided socket fd
+* @param s socket to set callbacks on.
+* @param callbacks a ssh_socket_io_callback object reference.
+*/
+
+void ssh_socket_set_io_callbacks(ssh_socket s, ssh_socket_io_callbacks io_callbacks) {
+	s->io_callbacks = io_callbacks;
+}
+
+/**
  * @brief               SSH poll callback. This callback will be used when an event
  *                      caught on the socket.
  *
@@ -530,7 +544,9 @@ static int ssh_socket_unbuffered_read(ssh_socket s, void *buffer, uint32_t len)
   if (s->data_except) {
     return -1;
   }
-  if(s->fd_is_socket)
+  if (s->io_callbacks && s->io_callbacks->recv)
+	rc = s->io_callbacks->recv(s->fd_in, s->io_callbacks->userdata, buffer, len);
+  else if(s->fd_is_socket)
     rc = recv(s->fd_in,buffer, len, 0);
   else
     rc = read(s->fd_in,buffer, len);
@@ -558,7 +574,9 @@ static int ssh_socket_unbuffered_write(ssh_socket s, const void *buffer,
   if (s->data_except) {
     return -1;
   }
-  if (s->fd_is_socket)
+  if (s->io_callbacks && s->io_callbacks->send)
+	w = s->io_callbacks->send(s->fd_in, s->io_callbacks->userdata, buffer, len);
+  else if (s->fd_is_socket)
     w = send(s->fd_out,buffer, len, 0);
   else
     w = write(s->fd_out, buffer, len);
-- 
2.11.0.windows.3

