Repository: mina-sshd
Updated Branches:
  refs/heads/master 885bbdbf0 -> 2d6fbc94a


[SSHD-797] Provide more flexible control of SSH client and session usage in 
GitSshdSession


Project: http://git-wip-us.apache.org/repos/asf/mina-sshd/repo
Commit: http://git-wip-us.apache.org/repos/asf/mina-sshd/commit/2d6fbc94
Tree: http://git-wip-us.apache.org/repos/asf/mina-sshd/tree/2d6fbc94
Diff: http://git-wip-us.apache.org/repos/asf/mina-sshd/diff/2d6fbc94

Branch: refs/heads/master
Commit: 2d6fbc94a098fd0d1c2ec77f59a7f23decb3f85c
Parents: 885bbdb
Author: Goldstein Lyor <[email protected]>
Authored: Wed Jan 31 17:32:22 2018 +0200
Committer: Goldstein Lyor <[email protected]>
Committed: Wed Jan 31 17:32:41 2018 +0200

----------------------------------------------------------------------
 .../java/org/apache/sshd/client/SshClient.java  |  19 +++
 .../java/org/apache/sshd/server/SshServer.java  |  17 +++
 .../java/org/apache/sshd/client/ClientTest.java |  12 ++
 .../java/org/apache/sshd/server/ServerTest.java |  12 ++
 .../sshd/git/transport/GitSshdSession.java      | 135 +++++++++++++++----
 .../git/transport/GitSshdSessionFactory.java    |  96 ++++++++++++-
 .../git/transport/GitSshdSessionProcess.java    |  24 +++-
 7 files changed, 283 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mina-sshd/blob/2d6fbc94/sshd-core/src/main/java/org/apache/sshd/client/SshClient.java
----------------------------------------------------------------------
diff --git a/sshd-core/src/main/java/org/apache/sshd/client/SshClient.java 
b/sshd-core/src/main/java/org/apache/sshd/client/SshClient.java
index f81dfc6..d8602e4 100644
--- a/sshd-core/src/main/java/org/apache/sshd/client/SshClient.java
+++ b/sshd-core/src/main/java/org/apache/sshd/client/SshClient.java
@@ -49,6 +49,7 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.TreeMap;
 import java.util.concurrent.CopyOnWriteArrayList;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.logging.ConsoleHandler;
 import java.util.logging.Formatter;
 import java.util.logging.Handler;
@@ -217,6 +218,7 @@ public class SshClient extends AbstractFactoryManager 
implements ClientFactoryMa
 
     private final List<Object> identities = new CopyOnWriteArrayList<>();
     private final AuthenticationIdentitiesProvider identitiesProvider;
+    private final AtomicBoolean started = new AtomicBoolean(false);
 
     public SshClient() {
         identitiesProvider = AuthenticationIdentitiesProvider.wrap(identities);
@@ -434,7 +436,19 @@ public class SshClient extends AbstractFactoryManager 
implements ClientFactoryMa
         }
     }
 
+    public boolean isStarted() {
+        return started.get();
+    }
+
+    /**
+     * Starts the SSH client and can start creating sessions using it.
+     * Ignored if already {@link #isStarted() started}.
+     */
     public void start() {
+        if (isStarted()) {
+            return;
+        }
+
         checkConfig();
         if (sessionFactory == null) {
             sessionFactory = createSessionFactory();
@@ -443,9 +457,14 @@ public class SshClient extends AbstractFactoryManager 
implements ClientFactoryMa
         setupSessionTimeout(sessionFactory);
 
         connector = createConnector();
+        started.set(true);
     }
 
     public void stop() {
+        if (!started.getAndSet(false)) {
+            return;
+        }
+
         try {
             long maxWait = this.getLongProperty(STOP_WAIT_TIME, 
DEFAULT_STOP_WAIT_TIME);
             boolean successful = close(true).await(maxWait);

http://git-wip-us.apache.org/repos/asf/mina-sshd/blob/2d6fbc94/sshd-core/src/main/java/org/apache/sshd/server/SshServer.java
----------------------------------------------------------------------
diff --git a/sshd-core/src/main/java/org/apache/sshd/server/SshServer.java 
b/sshd-core/src/main/java/org/apache/sshd/server/SshServer.java
index 5085283..e8dc4bf 100644
--- a/sshd-core/src/main/java/org/apache/sshd/server/SshServer.java
+++ b/sshd-core/src/main/java/org/apache/sshd/server/SshServer.java
@@ -40,6 +40,7 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
 import java.util.TreeMap;
+import java.util.concurrent.atomic.AtomicBoolean;
 
 import org.apache.sshd.common.Closeable;
 import org.apache.sshd.common.Factory;
@@ -136,6 +137,7 @@ public class SshServer extends AbstractFactoryManager 
implements ServerFactoryMa
     private KeyboardInteractiveAuthenticator interactiveAuthenticator;
     private HostBasedAuthenticator hostBasedAuthenticator;
     private GSSAuthenticator gssAuthenticator;
+    private final AtomicBoolean started = new AtomicBoolean(false);
 
     public SshServer() {
         super();
@@ -293,12 +295,21 @@ public class SshServer extends AbstractFactoryManager 
implements ServerFactoryMa
         }
     }
 
+    public boolean isStarted() {
+        return started.get();
+    }
+
     /**
      * Start the SSH server and accept incoming exceptions on the configured 
port.
+     * Ignored if already {@link #isStarted() started}
      *
      * @throws IOException If failed to start
      */
     public void start() throws IOException {
+        if (isStarted()) {
+            return;
+        }
+
         checkConfig();
         if (sessionFactory == null) {
             sessionFactory = createSessionFactory();
@@ -335,6 +346,8 @@ public class SshServer extends AbstractFactoryManager 
implements ServerFactoryMa
                 log.info("start() listen on auto-allocated port=" + port);
             }
         }
+
+        started.set(true);
     }
 
     /**
@@ -346,6 +359,10 @@ public class SshServer extends AbstractFactoryManager 
implements ServerFactoryMa
     }
 
     public void stop(boolean immediately) throws IOException {
+        if (!started.getAndSet(false)) {
+            return;
+        }
+
         long maxWait = immediately ? this.getLongProperty(STOP_WAIT_TIME, 
DEFAULT_STOP_WAIT_TIME) : Long.MAX_VALUE;
         boolean successful = close(immediately).await(maxWait);
         if (!successful) {

http://git-wip-us.apache.org/repos/asf/mina-sshd/blob/2d6fbc94/sshd-core/src/test/java/org/apache/sshd/client/ClientTest.java
----------------------------------------------------------------------
diff --git a/sshd-core/src/test/java/org/apache/sshd/client/ClientTest.java 
b/sshd-core/src/test/java/org/apache/sshd/client/ClientTest.java
index 0e429cc..8d41661 100644
--- a/sshd-core/src/test/java/org/apache/sshd/client/ClientTest.java
+++ b/sshd-core/src/test/java/org/apache/sshd/client/ClientTest.java
@@ -239,6 +239,18 @@ public class ClientTest extends BaseTestSupport {
     }
 
     @Test
+    public void testClientStartedIndicator() throws Exception {
+        client.start();
+        try {
+            assertTrue("Client not marked as started", client.isStarted());
+        } finally {
+            client.stop();
+        }
+
+        assertFalse("Client not marked as stopped", client.isStarted());
+    }
+
+    @Test
     public void testPropertyResolutionHierarchy() throws Exception {
         String sessionPropName = getCurrentTestName() + "-session";
         AtomicReference<Object> sessionConfigValueHolder = new 
AtomicReference<>(null);

http://git-wip-us.apache.org/repos/asf/mina-sshd/blob/2d6fbc94/sshd-core/src/test/java/org/apache/sshd/server/ServerTest.java
----------------------------------------------------------------------
diff --git a/sshd-core/src/test/java/org/apache/sshd/server/ServerTest.java 
b/sshd-core/src/test/java/org/apache/sshd/server/ServerTest.java
index 84aee1f..2f07cab 100644
--- a/sshd-core/src/test/java/org/apache/sshd/server/ServerTest.java
+++ b/sshd-core/src/test/java/org/apache/sshd/server/ServerTest.java
@@ -114,6 +114,18 @@ public class ServerTest extends BaseTestSupport {
         }
     }
 
+    @Test
+    public void testServerStartedIndicator() throws Exception {
+        sshd.start();
+        try {
+            assertTrue("Server not marked as started", sshd.isStarted());
+        } finally {
+            sshd.stop();
+        }
+
+        assertFalse("Server not marked as stopped", sshd.isStarted());
+    }
+
     /*
      * Send bad password.  The server should disconnect after a few attempts
      */

http://git-wip-us.apache.org/repos/asf/mina-sshd/blob/2d6fbc94/sshd-git/src/main/java/org/apache/sshd/git/transport/GitSshdSession.java
----------------------------------------------------------------------
diff --git 
a/sshd-git/src/main/java/org/apache/sshd/git/transport/GitSshdSession.java 
b/sshd-git/src/main/java/org/apache/sshd/git/transport/GitSshdSession.java
index cc4236a..2bac013 100644
--- a/sshd-git/src/main/java/org/apache/sshd/git/transport/GitSshdSession.java
+++ b/sshd-git/src/main/java/org/apache/sshd/git/transport/GitSshdSession.java
@@ -24,6 +24,7 @@ import java.util.concurrent.TimeUnit;
 import org.apache.sshd.client.SshClient;
 import org.apache.sshd.client.channel.ChannelExec;
 import org.apache.sshd.client.session.ClientSession;
+import org.apache.sshd.common.util.GenericUtils;
 import org.apache.sshd.common.util.logging.AbstractLoggingBean;
 import org.eclipse.jgit.transport.CredentialItem;
 import org.eclipse.jgit.transport.CredentialsProvider;
@@ -64,7 +65,7 @@ public class GitSshdSession extends AbstractLoggingBean 
implements RemoteSession
 
     public GitSshdSession(URIish uri, CredentialsProvider credentialsProvider, 
FS fs, int tms) throws IOException, InterruptedException {
         String user = uri.getUser();
-        final String pass = uri.getPass();
+        final String pass1 = uri.getPass();
         String host = uri.getHost();
         int port = uri.getPort();
         char[] pass2 = null;
@@ -82,47 +83,135 @@ public class GitSshdSession extends AbstractLoggingBean 
implements RemoteSession
         }
 
         client = createClient();
+        try {
+            if (!client.isStarted()) {
+                client.start();
+            }
 
-        client.start();
-
-        session = client.connect(user, host, port)
-                        .verify(client.getLongProperty(CONNECT_TIMEOUT_PROP, 
DEFAULT_CONNECT_TIMEOUT))
-                        .getSession();
-        if (log.isDebugEnabled()) {
-            log.debug("Connected to {}:{}", host, port);
+            session = createClientSession(client, host, user, port, pass1, 
(pass2 != null) ? new String(pass2) : null);
+        } catch (IOException | InterruptedException e) {
+            disconnectClient(client);
+            throw e;
         }
-        if (pass != null) {
-            session.addPasswordIdentity(pass);
+    }
+
+    protected ClientSession createClientSession(
+            SshClient clientInstance, String host, String username, int port, 
String... passwords)
+                throws IOException, InterruptedException {
+        boolean debugEnabled = log.isDebugEnabled();
+        if (debugEnabled) {
+            log.debug("Connecting to {}:{}", host, port);
         }
-        if (pass2 != null) {
-            session.addPasswordIdentity(new String(pass2));
+
+        ClientSession s = clientInstance.connect(username, host, port)
+            .verify(clientInstance.getLongProperty(CONNECT_TIMEOUT_PROP, 
DEFAULT_CONNECT_TIMEOUT))
+            .getSession();
+
+        if (debugEnabled) {
+            log.debug("Connected to {}:{}", host, port);
         }
-        session.auth().verify(session.getLongProperty(AUTH_TIMEOUT_PROP, 
DEFAULT_AUTH_TIMEOUT));
-        if (log.isDebugEnabled()) {
-            log.debug("Authenticated: {}", session);
+
+        try {
+            if (passwords == null) {
+                passwords = GenericUtils.EMPTY_STRING_ARRAY;
+            }
+
+            for (String p : passwords) {
+                if (p == null) {
+                    continue;
+                }
+                s.addPasswordIdentity(p);
+            }
+
+            if (debugEnabled) {
+                log.debug("Authenticating: {}", s);
+            }
+
+            s.auth().verify(s.getLongProperty(AUTH_TIMEOUT_PROP, 
DEFAULT_AUTH_TIMEOUT));
+
+            if (debugEnabled) {
+                log.debug("Authenticated: {}", s);
+            }
+
+            ClientSession result = s;
+            s = null;   // avoid auto-close at finally clause
+            return result;
+        } finally {
+            if (s != null) {
+                s.close(true);
+            }
         }
     }
 
     @Override
     public Process exec(String commandName, int timeout) throws IOException {
-        if (log.isTraceEnabled()) {
+        boolean traceEnabled = log.isTraceEnabled();
+        if (traceEnabled) {
             log.trace("exec({}) session={}, timeout={} sec.", commandName, 
session, timeout);
         }
 
         ChannelExec channel = session.createExecChannel(commandName);
-        
channel.open().verify(channel.getLongProperty(CHANNEL_OPEN_TIMEOUT_PROPT, 
DEFAULT_CHANNEL_OPEN_TIMEOUT));
-        return new GitSshdSessionProcess(channel, commandName, timeout);
+        if (traceEnabled) {
+            log.trace("exec({}) session={} - open channel", commandName, 
session);
+        }
+
+        try {
+            
channel.open().verify(channel.getLongProperty(CHANNEL_OPEN_TIMEOUT_PROPT, 
DEFAULT_CHANNEL_OPEN_TIMEOUT));
+            if (traceEnabled) {
+                log.trace("exec({}) session={} - channel open", commandName, 
session);
+            }
+
+            GitSshdSessionProcess process = new GitSshdSessionProcess(channel, 
commandName, timeout);
+            channel = null; // disable auto-close on finally clause
+            return process;
+        } finally {
+            if (channel != null) {
+                channel.close(true);
+            }
+        }
     }
 
     @Override
     public void disconnect() {
-        if (session.isOpen()) {
-            if (log.isDebugEnabled()) {
-                log.debug("Disconnecting from {}", session);
-            }
+        try {
+            disconnectSession(session);
+        } finally {
+            disconnectClient(client);
+        }
+    }
+
+    protected void disconnectSession(ClientSession sessionInstance) {
+        if ((sessionInstance == null) || (!sessionInstance.isOpen())) {
+            return; // debug breakpoint
+        }
+
+        boolean debugEnabled = log.isDebugEnabled();
+        if (debugEnabled) {
+            log.debug("Disconnecting from {}", sessionInstance);
         }
 
-        client.close(true);
+        sessionInstance.close(true);
+
+        if (debugEnabled) {
+            log.debug("Disconnected from {}", sessionInstance);
+        }
+    }
+
+    protected void disconnectClient(SshClient clientInstance) {
+        if ((clientInstance == null) || (!clientInstance.isStarted())) {
+            return; // debug breakpoint
+        }
+
+        boolean debugEnabled = log.isDebugEnabled();
+        if (debugEnabled) {
+            log.debug("Stopping {}", clientInstance);
+        }
+
+        clientInstance.stop();
+
+        if (debugEnabled) {
+            log.debug("Stopped {}", clientInstance);
+        }
     }
 
     protected SshClient createClient() {

http://git-wip-us.apache.org/repos/asf/mina-sshd/blob/2d6fbc94/sshd-git/src/main/java/org/apache/sshd/git/transport/GitSshdSessionFactory.java
----------------------------------------------------------------------
diff --git 
a/sshd-git/src/main/java/org/apache/sshd/git/transport/GitSshdSessionFactory.java
 
b/sshd-git/src/main/java/org/apache/sshd/git/transport/GitSshdSessionFactory.java
index 545657d..bd429e5 100644
--- 
a/sshd-git/src/main/java/org/apache/sshd/git/transport/GitSshdSessionFactory.java
+++ 
b/sshd-git/src/main/java/org/apache/sshd/git/transport/GitSshdSessionFactory.java
@@ -18,6 +18,13 @@
  */
 package org.apache.sshd.git.transport;
 
+import java.io.IOException;
+import java.util.Objects;
+
+import org.apache.sshd.client.SshClient;
+import org.apache.sshd.client.session.ClientSession;
+import org.apache.sshd.client.session.ClientSessionHolder;
+import org.apache.sshd.common.util.GenericUtils;
 import org.eclipse.jgit.errors.TransportException;
 import org.eclipse.jgit.transport.CredentialsProvider;
 import org.eclipse.jgit.transport.RemoteSession;
@@ -30,17 +37,100 @@ import org.eclipse.jgit.util.FS;
  *
  * @author <a href="mailto:[email protected]";>Apache MINA SSHD Project</a>
  */
-public class GitSshdSessionFactory extends SshSessionFactory {
+public class GitSshdSessionFactory extends SshSessionFactory implements 
ClientSessionHolder {
+    public static final GitSshdSessionFactory INSTANCE = new 
GitSshdSessionFactory();
+
+    private final SshClient client;
+    private final ClientSession session;
+
     public GitSshdSessionFactory() {
-        super();
+        this(null, null);
+    }
+
+    /**
+     * Used to provide an externally managed {@link SshClient} instance. In 
this case, the
+     * caller is responsible for start/stop-ing the client once no longer 
needed.
+     *
+     * @param client The (never {@code null}) client instance
+     */
+    public GitSshdSessionFactory(SshClient client) {
+        this(Objects.requireNonNull(client, "No client instance provided"), 
null);
+    }
+
+    /**
+     * Used to provide an externally managed {@link ClientSession} instance. 
In this case, the
+     * caller is responsible for connecting and disconnecting the session once 
no longer needed.
+     * <B>Note:</B> in this case, the connection and authentication phase are 
<U>skipped</U> - i.e.,
+     * any specific host/port/user/password(s) specified in the GIT URI are 
<U>not used</U>.
+     *
+     * @param client The (never {@code null}) client instance
+     */
+    public GitSshdSessionFactory(ClientSession session) {
+        this(null, session);
+    }
+
+    protected GitSshdSessionFactory(SshClient client, ClientSession session) {
+        this.client = client;
+        this.session = session;
     }
 
     @Override
     public RemoteSession getSession(URIish uri, CredentialsProvider 
credentialsProvider, FS fs, int tms) throws TransportException {
         try {
-            return new GitSshdSession(uri, credentialsProvider, fs, tms);
+            return new GitSshdSession(uri, credentialsProvider, fs, tms) {
+                @Override
+                protected SshClient createClient() {
+                    SshClient thisClient = getClient();
+                    if (thisClient != null) {
+                        return thisClient;
+                    }
+
+                    return super.createClient();
+                }
+
+                @Override
+                protected ClientSession createClientSession(
+                        SshClient clientInstance, String host, String 
username, int port, String... passwords)
+                            throws IOException, InterruptedException {
+                    ClientSession thisSession = getClientSession();
+                    if (thisSession != null) {
+                        return thisSession;
+                    }
+
+                    return super.createClientSession(clientInstance, host, 
username, port, passwords);
+                }
+
+                @Override
+                protected void disconnectSession(ClientSession 
sessionInstance) {
+                    ClientSession thisSession = getClientSession();
+                    if (GenericUtils.isSameReference(thisSession, 
sessionInstance)) {
+                        return; // do not use the session instance we were 
given
+                    }
+
+                    super.disconnectSession(sessionInstance);
+                }
+
+                @Override
+                protected void disconnectClient(SshClient clientInstance) {
+                    SshClient thisClient = getClient();
+                    if (GenericUtils.isSameReference(thisClient, 
clientInstance)) {
+                        return; // do not close the client the user gave us
+                    }
+
+                    super.disconnectClient(clientInstance);
+                }
+            };
         } catch (Exception e) {
             throw new TransportException("Unable to connect", e);
         }
     }
+
+    protected SshClient getClient() {
+        return client;
+    }
+
+    @Override
+    public ClientSession getClientSession() {
+        return session;
+    }
 }

http://git-wip-us.apache.org/repos/asf/mina-sshd/blob/2d6fbc94/sshd-git/src/main/java/org/apache/sshd/git/transport/GitSshdSessionProcess.java
----------------------------------------------------------------------
diff --git 
a/sshd-git/src/main/java/org/apache/sshd/git/transport/GitSshdSessionProcess.java
 
b/sshd-git/src/main/java/org/apache/sshd/git/transport/GitSshdSessionProcess.java
index 1f2b4db..dc36df0 100644
--- 
a/sshd-git/src/main/java/org/apache/sshd/git/transport/GitSshdSessionProcess.java
+++ 
b/sshd-git/src/main/java/org/apache/sshd/git/transport/GitSshdSessionProcess.java
@@ -22,8 +22,10 @@ package org.apache.sshd.git.transport;
 import java.io.InputStream;
 import java.io.OutputStream;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.EnumSet;
 import java.util.Objects;
+import java.util.Set;
 import java.util.concurrent.TimeUnit;
 
 import org.apache.sshd.client.channel.ChannelExec;
@@ -35,6 +37,9 @@ import org.slf4j.LoggerFactory;
  * @author <a href="mailto:[email protected]";>Apache MINA SSHD Project</a>
  */
 public class GitSshdSessionProcess extends Process {
+    public static final Set<ClientChannelEvent> CLOSE_WAIT_EVENTS =
+        Collections.unmodifiableSet(EnumSet.of(ClientChannelEvent.CLOSED));
+
     protected final ChannelExec channel;
     protected final String commandName;
     protected final long waitTimeout;
@@ -64,12 +69,17 @@ public class GitSshdSessionProcess extends Process {
 
     @Override   // TODO in Java-8 implement also waitFor(long, TimeUnit)
     public int waitFor() throws InterruptedException {
-        Collection<ClientChannelEvent> res =
-                channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), 
waitTimeout);
-        if (log.isTraceEnabled()) {
-            log.trace("waitFor({}) channel={}, timeout={} millis.: {}",
-                      commandName, channel, waitTimeout, res);
+        boolean traceEnabled = log.isTraceEnabled();
+        if (traceEnabled) {
+            log.trace("waitFor({}) channel={} waiting {} millis", commandName, 
channel, waitTimeout);
         }
+
+        Collection<ClientChannelEvent> res = 
channel.waitFor(CLOSE_WAIT_EVENTS, waitTimeout);
+
+        if (traceEnabled) {
+            log.trace("waitFor({}) channel={} events={}", commandName, 
channel, res);
+        }
+
         if (res.contains(ClientChannelEvent.CLOSED)) {
             return 0;
         } else {
@@ -92,7 +102,9 @@ public class GitSshdSessionProcess extends Process {
 
     @Override
     public void destroy() {
-        channel.close(true);
+        if (channel.isOpen()) {
+            channel.close(true);
+        }
     }
 
     @Override

Reply via email to