Repository: mina-sshd Updated Branches: refs/heads/master f3d58b906 -> f57ad0da9
[SSHD-387] When using inverted streams on the client, the window should only be decreased when reading, not when writing to the pipe Project: http://git-wip-us.apache.org/repos/asf/mina-sshd/repo Commit: http://git-wip-us.apache.org/repos/asf/mina-sshd/commit/f57ad0da Tree: http://git-wip-us.apache.org/repos/asf/mina-sshd/tree/f57ad0da Diff: http://git-wip-us.apache.org/repos/asf/mina-sshd/diff/f57ad0da Branch: refs/heads/master Commit: f57ad0da99ca18da9cc98ddce23be572500da19d Parents: f3d58b9 Author: Guillaume Nodet <[email protected]> Authored: Fri Dec 12 16:03:37 2014 +0100 Committer: Guillaume Nodet <[email protected]> Committed: Fri Dec 12 16:03:37 2014 +0100 ---------------------------------------------------------------------- .../client/channel/AbstractClientChannel.java | 8 +- .../common/channel/ChannelPipedInputStream.java | 5 +- .../session/AbstractConnectionService.java | 5 + .../test/java/org/apache/sshd/WindowTest.java | 300 +++++++++++++++++++ 4 files changed, 312 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mina-sshd/blob/f57ad0da/sshd-core/src/main/java/org/apache/sshd/client/channel/AbstractClientChannel.java ---------------------------------------------------------------------- diff --git a/sshd-core/src/main/java/org/apache/sshd/client/channel/AbstractClientChannel.java b/sshd-core/src/main/java/org/apache/sshd/client/channel/AbstractClientChannel.java index 964ef3a..468ae3b 100644 --- a/sshd-core/src/main/java/org/apache/sshd/client/channel/AbstractClientChannel.java +++ b/sshd-core/src/main/java/org/apache/sshd/client/channel/AbstractClientChannel.java @@ -272,7 +272,9 @@ public abstract class AbstractClientChannel extends AbstractChannel implements C } else if (out != null) { out.write(data, off, len); out.flush(); - localWindow.consumeAndCheck(len); + if (invertedOut == null) { + localWindow.consumeAndCheck(len); + } } else { throw new IllegalStateException("No output stream for channel"); } @@ -288,7 +290,9 @@ public abstract class AbstractClientChannel extends AbstractChannel implements C } else if (err != null) { err.write(data, off, len); err.flush(); - localWindow.consumeAndCheck(len); + if (invertedErr == null) { + localWindow.consumeAndCheck(len); + } } else { throw new IllegalStateException("No error stream for channel"); } http://git-wip-us.apache.org/repos/asf/mina-sshd/blob/f57ad0da/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelPipedInputStream.java ---------------------------------------------------------------------- diff --git a/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelPipedInputStream.java b/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelPipedInputStream.java index 2257cd9..cb003b7 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelPipedInputStream.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelPipedInputStream.java @@ -92,7 +92,6 @@ public class ChannelPipedInputStream extends InputStream { @Override public int read(byte[] b, int off, int len) throws IOException { - int avail; long startTime = System.currentTimeMillis(); lock.lock(); try { @@ -128,11 +127,10 @@ public class ChannelPipedInputStream extends InputStream { if (buffer.rpos() > localWindow.getPacketSize() || buffer.available() == 0) { buffer.compact(); } - avail = localWindow.getMaxSize() - buffer.available(); } finally { lock.unlock(); } - localWindow.check(avail); + localWindow.consumeAndCheck(len); return len; } @@ -168,6 +166,5 @@ public class ChannelPipedInputStream extends InputStream { } finally { lock.unlock(); } - localWindow.consume(len); } } http://git-wip-us.apache.org/repos/asf/mina-sshd/blob/f57ad0da/sshd-core/src/main/java/org/apache/sshd/common/session/AbstractConnectionService.java ---------------------------------------------------------------------- diff --git a/sshd-core/src/main/java/org/apache/sshd/common/session/AbstractConnectionService.java b/sshd-core/src/main/java/org/apache/sshd/common/session/AbstractConnectionService.java index 0ffed33..31a5e48 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/session/AbstractConnectionService.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/session/AbstractConnectionService.java @@ -19,6 +19,7 @@ package org.apache.sshd.common.session; import java.io.IOException; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -82,6 +83,10 @@ public abstract class AbstractConnectionService extends CloseableUtils.AbstractI tcpipForwarder = session.getFactoryManager().getTcpipForwarderFactory().create(this); } + public Collection<Channel> getChannels() { + return channels.values(); + } + public AbstractSession getSession() { return (AbstractSession) session; } http://git-wip-us.apache.org/repos/asf/mina-sshd/blob/f57ad0da/sshd-core/src/test/java/org/apache/sshd/WindowTest.java ---------------------------------------------------------------------- diff --git a/sshd-core/src/test/java/org/apache/sshd/WindowTest.java b/sshd-core/src/test/java/org/apache/sshd/WindowTest.java new file mode 100644 index 0000000..8298b22 --- /dev/null +++ b/sshd-core/src/test/java/org/apache/sshd/WindowTest.java @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sshd; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.io.PipedInputStream; +import java.io.PipedOutputStream; +import java.util.Arrays; +import java.util.concurrent.CountDownLatch; + +import org.apache.sshd.client.channel.ChannelShell; +import org.apache.sshd.client.future.OpenFuture; +import org.apache.sshd.common.Channel; +import org.apache.sshd.common.NamedFactory; +import org.apache.sshd.common.RuntimeSshException; +import org.apache.sshd.common.Service; +import org.apache.sshd.common.Session; +import org.apache.sshd.common.channel.Window; +import org.apache.sshd.common.forward.TcpipServerChannel; +import org.apache.sshd.common.io.IoReadFuture; +import org.apache.sshd.common.util.Buffer; +import org.apache.sshd.server.Command; +import org.apache.sshd.server.CommandFactory; +import org.apache.sshd.server.channel.ChannelSession; +import org.apache.sshd.server.command.UnknownCommand; +import org.apache.sshd.server.session.ServerConnectionService; +import org.apache.sshd.server.session.ServerUserAuthService; +import org.apache.sshd.util.AsyncEchoShellFactory; +import org.apache.sshd.util.BaseTest; +import org.apache.sshd.util.BogusPasswordAuthenticator; +import org.apache.sshd.util.BogusPublickeyAuthenticator; +import org.apache.sshd.util.EchoShellFactory; +import org.apache.sshd.util.Utils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +/** + * TODO Add javadoc + * + * @author <a href="mailto:[email protected]">Apache MINA SSHD Project</a> + */ +public class WindowTest extends BaseTest { + + private SshServer sshd; + private SshClient client; + private int port; + private CountDownLatch authLatch; + private CountDownLatch channelLatch; + + @Before + public void setUp() throws Exception { + authLatch = new CountDownLatch(0); + channelLatch = new CountDownLatch(0); + + sshd = SshServer.setUpDefaultServer(); + sshd.setKeyPairProvider(Utils.createTestHostKeyProvider()); + sshd.setShellFactory(new TestEchoShellFactory()); + sshd.setCommandFactory(new CommandFactory() { + public Command createCommand(String command) { + return new UnknownCommand(command); + } + }); + sshd.setPasswordAuthenticator(new BogusPasswordAuthenticator()); + sshd.setPublickeyAuthenticator(new BogusPublickeyAuthenticator()); + sshd.setServiceFactories(Arrays.asList( + new ServerUserAuthService.Factory() { + @Override + public Service create(Session session) throws IOException { + return new ServerUserAuthService(session) { + @Override + public void process(byte cmd, Buffer buffer) throws Exception { + authLatch.await(); + super.process(cmd, buffer); + } + }; + } + }, + new ServerConnectionService.Factory() + )); + sshd.setChannelFactories(Arrays.<NamedFactory<Channel>>asList( + new ChannelSession.Factory() { + @Override + public Channel create() { + return new ChannelSession() { + @Override + public OpenFuture open(int recipient, int rwsize, int rmpsize, Buffer buffer) { + try { + channelLatch.await(); + } catch (InterruptedException e) { + throw new RuntimeSshException(e); + } + return super.open(recipient, rwsize, rmpsize, buffer); + } + + @Override + public String toString() { + return "ChannelSession" + "[id=" + id + ", recipient=" + recipient + "]"; + } + }; + } + }, + new TcpipServerChannel.DirectTcpipFactory())); + sshd.start(); + port = sshd.getPort(); + + client = SshClient.setUpDefaultClient(); + } + + @After + public void tearDown() throws Exception { + if (sshd != null) { + sshd.stop(true); + } + if (client != null) { + client.stop(); + } + } + + @Test + public void testWindowConsumptionWithInvertedStreams() throws Exception { + sshd.setShellFactory(new AsyncEchoShellFactory()); + sshd.getProperties().put(SshServer.WINDOW_SIZE, "1024"); + client.getProperties().put(SshClient.WINDOW_SIZE, "1024"); + client.start(); + ClientSession session = client.connect("smx", "localhost", port).await().getSession(); + session.addPasswordIdentity("smx"); + session.auth().verify(); + final ChannelShell channel = session.createShellChannel(); + channel.open().verify(); + + final Channel serverChannel = sshd.getActiveSessions().iterator().next().getService(ServerConnectionService.class) + .getChannels().iterator().next(); + + Window clientLocal = channel.getLocalWindow(); + Window clientRemote = channel.getRemoteWindow(); + Window serverLocal = serverChannel.getLocalWindow(); + Window serverRemote = serverChannel.getRemoteWindow(); + + final String message = "0123456789"; + final int nbMessages = 500; + + BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(channel.getInvertedIn())); + BufferedReader reader = new BufferedReader(new InputStreamReader(channel.getInvertedOut())); + for (int i = 0; i < nbMessages; i++) { + writer.write(message); + writer.write("\n"); + writer.flush(); + + Thread.sleep(5); + assertNotEquals("client local and server remote", clientLocal.getSize(), serverRemote.getSize()); + + String line = reader.readLine(); + assertEquals(message, line); + + Thread.sleep(5); + + assertEquals("client local and server remote", clientLocal.getSize(), serverRemote.getSize()); + assertEquals("client remote and server local", clientRemote.getSize(), serverLocal.getSize()); + } + } + + @Test + public void testWindowConsumptionWithDirectStreams() throws Exception { + sshd.setShellFactory(new AsyncEchoShellFactory()); + sshd.getProperties().put(SshServer.WINDOW_SIZE, "1024"); + client.getProperties().put(SshClient.WINDOW_SIZE, "1024"); + client.start(); + ClientSession session = client.connect("smx", "localhost", port).await().getSession(); + session.addPasswordIdentity("smx"); + session.auth().verify(); + final ChannelShell channel = session.createShellChannel(); + + PipedInputStream inPis = new PipedInputStream(); + PipedOutputStream inPos = new PipedOutputStream(inPis); + channel.setIn(inPis); + PipedInputStream outPis = new PipedInputStream(); + PipedOutputStream outPos = new PipedOutputStream(outPis); + channel.setOut(outPos); + channel.open().verify(); + + final Channel serverChannel = sshd.getActiveSessions().iterator().next().getService(ServerConnectionService.class) + .getChannels().iterator().next(); + + Window clientLocal = channel.getLocalWindow(); + Window clientRemote = channel.getRemoteWindow(); + Window serverLocal = serverChannel.getLocalWindow(); + Window serverRemote = serverChannel.getRemoteWindow(); + + final String message = "0123456789"; + final int nbMessages = 500; + + BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(inPos)); + BufferedReader reader = new BufferedReader(new InputStreamReader(outPis)); + for (int i = 0; i < nbMessages; i++) { + writer.write(message); + writer.write("\n"); + writer.flush(); + + Thread.sleep(5); + assertEquals("client local and server remote", clientLocal.getSize(), serverRemote.getSize()); + + String line = reader.readLine(); + assertEquals(message, line); + + Thread.sleep(5); + + assertEquals("client local and server remote", clientLocal.getSize(), serverRemote.getSize()); + assertEquals("client remote and server local", clientRemote.getSize(), serverLocal.getSize()); + } + } + + @Test + public void testWindowConsumptionWithAsyncStreams() throws Exception { + sshd.setShellFactory(new AsyncEchoShellFactory()); + sshd.getProperties().put(SshServer.WINDOW_SIZE, "1024"); + client.getProperties().put(SshClient.WINDOW_SIZE, "1024"); + client.start(); + ClientSession session = client.connect("smx", "localhost", port).await().getSession(); + session.addPasswordIdentity("smx"); + session.auth().verify(); + final ChannelShell channel = session.createShellChannel(); + channel.setStreaming(ClientChannel.Streaming.Async); + channel.open().verify(); + + final Channel serverChannel = sshd.getActiveSessions().iterator().next().getService(ServerConnectionService.class) + .getChannels().iterator().next(); + + Window clientLocal = channel.getLocalWindow(); + Window clientRemote = channel.getRemoteWindow(); + Window serverLocal = serverChannel.getLocalWindow(); + Window serverRemote = serverChannel.getRemoteWindow(); + + final String message = "0123456789"; + final int nbMessages = 500; + + for (int i = 0; i < nbMessages; i++) { + + Buffer buffer = new Buffer((message + "\n").getBytes()); + channel.getAsyncIn().write(buffer).verify(); + + Thread.sleep(5); + assertNotEquals("client local and server remote", clientLocal.getSize(), serverRemote.getSize()); + + Buffer buf = new Buffer(16); + IoReadFuture future = channel.getAsyncOut().read(buf); + future.verify(); + assertEquals(11, buf.available()); + assertEquals(message + "\n", new String(buf.array(), buf.rpos(), buf.available())); + + Thread.sleep(5); + + assertEquals("client local and server remote", clientLocal.getSize(), serverRemote.getSize()); + assertEquals("client remote and server local", clientRemote.getSize(), serverLocal.getSize()); + } + } + + public static class TestEchoShellFactory extends EchoShellFactory { + @Override + public Command create() { + return new TestEchoShell(); + } + public static class TestEchoShell extends EchoShell { + + public static CountDownLatch latch = new CountDownLatch(1); + + @Override + public void destroy() { + if (latch != null) { + latch.countDown(); + } + super.destroy(); + } + } + } + +}
