This is an automated email from the ASF dual-hosted git repository.

iuliana pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/brooklyn-server.git


The following commit(s) were added to refs/heads/master by this push:
     new b902d6a5fd better check for thread interruption with sshj connections
     new 58608dd79b Merge pull request #1348 from ahgittin/interrupt-sshj
b902d6a5fd is described below

commit b902d6a5fd9198fc9fee3d70fd8e76da72f1f1e3
Author: Alex Heneveld <[email protected]>
AuthorDate: Mon Aug 1 18:16:31 2022 +0100

    better check for thread interruption with sshj connections
    
    sshj clears the interrupted flag and hides the interruption exception 
inside a ConnectException;
    now we look inside it
    
    this also causes Thread.currentThread().interrupt() to be set when we 
propagate InterruptedException;
    previously it was not
---
 .../util/core/internal/ssh/ShellAbstractTool.java  |  4 +-
 .../util/core/internal/ssh/sshj/SshjTool.java      | 97 ++++++++++++----------
 .../internal/ssh/sshj/SshjToolIntegrationTest.java | 63 +++++++++++++-
 .../brooklyn/util/exceptions/Exceptions.java       | 20 ++++-
 4 files changed, 132 insertions(+), 52 deletions(-)

diff --git 
a/core/src/main/java/org/apache/brooklyn/util/core/internal/ssh/ShellAbstractTool.java
 
b/core/src/main/java/org/apache/brooklyn/util/core/internal/ssh/ShellAbstractTool.java
index eb1e9cbbb2..89f5f1c2e5 100644
--- 
a/core/src/main/java/org/apache/brooklyn/util/core/internal/ssh/ShellAbstractTool.java
+++ 
b/core/src/main/java/org/apache/brooklyn/util/core/internal/ssh/ShellAbstractTool.java
@@ -35,6 +35,7 @@ import java.util.Map.Entry;
 import org.apache.brooklyn.config.ConfigKey;
 import org.apache.brooklyn.util.collections.MutableList;
 import org.apache.brooklyn.util.core.flags.TypeCoercions;
+import org.apache.brooklyn.util.exceptions.Exceptions;
 import org.apache.brooklyn.util.os.Os;
 import org.apache.brooklyn.util.ssh.BashCommands;
 import org.apache.brooklyn.util.ssh.BashCommandsConfigurable;
@@ -135,13 +136,14 @@ public abstract class ShellAbstractTool implements 
ShellTool {
                 closeable.close();
             } catch (IOException e) {
                 if (LOG.isDebugEnabled()) {
-                    String msg = String.format("<< exception during close, for 
%s -> %s (%s); continuing.", 
+                    String msg = String.format("<< exception during close, for 
%s -> %s (%s); continuing.",
                             context1, context2, closeable);
                     if (LOG.isTraceEnabled())
                         LOG.debug(msg + ": " + e);
                     else
                         LOG.trace(msg, e);
                 }
+                Exceptions.handleRootCauseIsInterruption(e);
             }
         }
     }
diff --git 
a/core/src/main/java/org/apache/brooklyn/util/core/internal/ssh/sshj/SshjTool.java
 
b/core/src/main/java/org/apache/brooklyn/util/core/internal/ssh/sshj/SshjTool.java
index 1626c1861a..9c379cf6ed 100644
--- 
a/core/src/main/java/org/apache/brooklyn/util/core/internal/ssh/sshj/SshjTool.java
+++ 
b/core/src/main/java/org/apache/brooklyn/util/core/internal/ssh/sshj/SshjTool.java
@@ -18,33 +18,33 @@
  */
 package org.apache.brooklyn.util.core.internal.ssh.sshj;
 
-import static com.google.common.base.Preconditions.checkNotNull;
-import static com.google.common.base.Throwables.getCausalChain;
-import static com.google.common.collect.Iterables.any;
-
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileNotFoundException;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.OutputStream;
-import java.util.Collections;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.Callable;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.TimeoutException;
-import java.util.concurrent.atomic.AtomicReference;
-
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.*;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.io.CountingOutputStream;
+import com.google.common.net.HostAndPort;
+import com.google.common.primitives.Ints;
 import net.schmizz.sshj.common.SecurityUtils;
+import net.schmizz.sshj.connection.ConnectionException;
+import net.schmizz.sshj.connection.channel.direct.PTYMode;
+import net.schmizz.sshj.connection.channel.direct.Session;
+import net.schmizz.sshj.connection.channel.direct.Session.Command;
+import net.schmizz.sshj.connection.channel.direct.Session.Shell;
+import net.schmizz.sshj.connection.channel.direct.SessionChannel;
+import net.schmizz.sshj.sftp.FileAttributes;
+import net.schmizz.sshj.sftp.SFTPClient;
+import net.schmizz.sshj.transport.TransportException;
+import net.schmizz.sshj.xfer.FileSystemFile;
+import net.schmizz.sshj.xfer.InMemorySourceFile;
+import net.schmizz.sshj.xfer.LocalDestFile;
 import org.apache.brooklyn.core.BrooklynFeatureEnablement;
 import org.apache.brooklyn.util.core.internal.ssh.BackoffLimitedRetryHandler;
 import org.apache.brooklyn.util.core.internal.ssh.ShellTool;
 import org.apache.brooklyn.util.core.internal.ssh.SshAbstractTool;
 import org.apache.brooklyn.util.core.internal.ssh.SshTool;
 import org.apache.brooklyn.util.exceptions.Exceptions;
+import org.apache.brooklyn.util.exceptions.RuntimeInterruptedException;
 import org.apache.brooklyn.util.exceptions.RuntimeTimeoutException;
 import org.apache.brooklyn.util.repeat.Repeater;
 import org.apache.brooklyn.util.stream.KnownSizeInputStream;
@@ -56,30 +56,19 @@ import org.apache.commons.io.input.ProxyInputStream;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Joiner;
-import com.google.common.base.Predicate;
-import com.google.common.base.Stopwatch;
-import com.google.common.base.Supplier;
-import com.google.common.base.Suppliers;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import com.google.common.io.CountingOutputStream;
-import com.google.common.net.HostAndPort;
-import com.google.common.primitives.Ints;
+import java.io.*;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.Callable;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicReference;
 
-import net.schmizz.sshj.connection.ConnectionException;
-import net.schmizz.sshj.connection.channel.direct.PTYMode;
-import net.schmizz.sshj.connection.channel.direct.Session;
-import net.schmizz.sshj.connection.channel.direct.Session.Command;
-import net.schmizz.sshj.connection.channel.direct.Session.Shell;
-import net.schmizz.sshj.connection.channel.direct.SessionChannel;
-import net.schmizz.sshj.sftp.FileAttributes;
-import net.schmizz.sshj.sftp.SFTPClient;
-import net.schmizz.sshj.transport.TransportException;
-import net.schmizz.sshj.xfer.FileSystemFile;
-import net.schmizz.sshj.xfer.InMemorySourceFile;
-import net.schmizz.sshj.xfer.LocalDestFile;
+import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Throwables.getCausalChain;
+import static com.google.common.base.Throwables.getRootCause;
+import static com.google.common.collect.Iterables.any;
 
 /**
  * For ssh and scp-style commands, using the sshj library.
@@ -657,6 +646,10 @@ public class SshjTool extends SshAbstractTool implements 
SshTool {
                 } catch (Exception e2) {
                     LOG.debug("<< ("+toString()+") error closing connection: 
"+e+" / "+e2, e);
                 }
+                if (Thread.currentThread().isInterrupted()) {
+                    LOG.debug("<< {} (rethrowing, interrupted): {}", 
fullMessage, e.getMessage());
+                    throw propagate(e, fullMessage + "; interrupted");
+                }
                 if (i + 1 == sshTries) {
                     LOG.debug("<< {} (rethrowing, out of retries): {}", 
fullMessage, e.getMessage());
                     throw propagate(e, fullMessage + "; out of retries");
@@ -1010,8 +1003,14 @@ public class SshjTool extends SshAbstractTool implements 
SshTool {
                         try {
                             shell.join(1000, TimeUnit.MILLISECONDS);
                         } catch (ConnectionException e) {
+                            LOG.debug("SshjTool exception joining shell", e);
+                            if (isNonRetryableException(e)) {
+                                throw e;
+                            }
+                            // don't automatically give up here, it might be a 
transient network failure
                             last = e;
                         }
+                        LOG.info("SshjTool looping waiting for shell; thread 
"+Thread.currentThread()+" interrupted? 
"+Thread.currentThread().isInterrupted());
                         if (endBecauseReturned) {
                             // shell is still open, ie some process is running
                             // but we have a result code, so main shell is 
finished
@@ -1047,7 +1046,7 @@ public class SshjTool extends SshAbstractTool implements 
SshTool {
                         }
                     } catch (InterruptedException e) {
                         LOG.warn("Interrupted gobbling streams from ssh: 
"+commands, e);
-                        Thread.currentThread().interrupt();
+                        throw Exceptions.propagate(e);
                     }
                 }
 
@@ -1062,6 +1061,16 @@ public class SshjTool extends SshAbstractTool implements 
SshTool {
         }
     }
 
+    protected boolean isNonRetryableException(ConnectionException e) throws 
ConnectionException {
+        if (Exceptions.isRootCauseIsInterruption(e)) {
+            // if we don't check for ^ wrapped in e then the interrupt is 
swallowed; that's how sshj works :(
+            Thread.currentThread().interrupt();
+            return true;
+        }
+        // anything else assume transient network failure until something else 
(eg shell) times out
+        return false;
+    }
+
     private byte[] toUTF8ByteArray(String string) {
         return org.bouncycastle.util.Strings.toUTF8ByteArray(string);
     }
diff --git 
a/core/src/test/java/org/apache/brooklyn/util/core/internal/ssh/sshj/SshjToolIntegrationTest.java
 
b/core/src/test/java/org/apache/brooklyn/util/core/internal/ssh/sshj/SshjToolIntegrationTest.java
index 2d7bfce5e1..22082c95eb 100644
--- 
a/core/src/test/java/org/apache/brooklyn/util/core/internal/ssh/sshj/SshjToolIntegrationTest.java
+++ 
b/core/src/test/java/org/apache/brooklyn/util/core/internal/ssh/sshj/SshjToolIntegrationTest.java
@@ -18,6 +18,10 @@
  */
 package org.apache.brooklyn.util.core.internal.ssh.sshj;
 
+import static 
org.apache.brooklyn.util.core.internal.ssh.ShellTool.PROP_ERR_STREAM;
+import static 
org.apache.brooklyn.util.core.internal.ssh.ShellTool.PROP_OUT_STREAM;
+import static org.apache.brooklyn.util.time.Duration.FIVE_SECONDS;
+import static org.apache.brooklyn.util.time.Duration.ONE_SECOND;
 import static org.testng.Assert.assertEquals;
 import static org.testng.Assert.assertFalse;
 import static org.testng.Assert.assertNotNull;
@@ -41,10 +45,14 @@ import org.apache.brooklyn.util.core.internal.ssh.ShellTool;
 import org.apache.brooklyn.util.core.internal.ssh.SshException;
 import org.apache.brooklyn.util.core.internal.ssh.SshTool;
 import 
org.apache.brooklyn.util.core.internal.ssh.SshToolAbstractIntegrationTest;
+import org.apache.brooklyn.util.core.task.ssh.SshPutTaskFactory;
 import org.apache.brooklyn.util.exceptions.Exceptions;
 import org.apache.brooklyn.util.exceptions.RuntimeTimeoutException;
 import org.apache.brooklyn.util.os.Os;
 import org.apache.brooklyn.util.time.Duration;
+import org.apache.brooklyn.util.time.Time;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 import org.testng.annotations.Test;
 
 import com.google.common.base.Stopwatch;
@@ -58,6 +66,8 @@ import net.schmizz.sshj.connection.channel.direct.Session;
  */
 public class SshjToolIntegrationTest extends SshToolAbstractIntegrationTest {
 
+    private static final Logger log = 
LoggerFactory.getLogger(SshjToolIntegrationTest.class);
+
     @Override
     protected SshTool newUnregisteredTool(Map<String,?> flags) {
         return new SshjTool(flags);
@@ -76,7 +86,6 @@ public class SshjToolIntegrationTest extends 
SshToolAbstractIntegrationTest {
     @Test(groups = {"Integration"})
     public void testGivesUpAfterMaxRetries() throws Exception {
         final AtomicInteger callCount = new AtomicInteger();
-        
         final SshTool localtool = new SshjTool(ImmutableMap.of("sshTries", 3, 
"host", "localhost", "privateKeyFile", "~/.ssh/id_rsa")) {
             @Override
             protected SshAction<Session> newSessionAction() {
@@ -212,7 +221,7 @@ public class SshjToolIntegrationTest extends 
SshToolAbstractIntegrationTest {
                             "err", err, 
                             SshjTool.PROP_EXEC_ASYNC.getName(), true, 
                             SshjTool.PROP_NO_EXTRA_OUTPUT.getName(), true,
-                            
SshjTool.PROP_EXEC_ASYNC_POLLING_TIMEOUT.getName(), Duration.ONE_SECOND), 
+                            
SshjTool.PROP_EXEC_ASYNC_POLLING_TIMEOUT.getName(), ONE_SECOND),
                     cmds, 
                     ImmutableMap.<String,String>of());
             String outStr = new String(out.toByteArray());
@@ -335,5 +344,53 @@ public class SshjToolIntegrationTest extends 
SshToolAbstractIntegrationTest {
         assertEquals(exitcode, 0, outstr);
         return outstr;
     }
-    
+
+    @Test(groups = {"Integration"})
+    public void testSshIsInterrupted() {
+        log.info("STARTING");
+        final SshTool localTool = new SshjTool(ImmutableMap.of(
+                //  "user", "amp",
+                  "sshTries", 3,
+                "host", "localhost",
+                "privateKeyFile", "~/.ssh/id_rsa"));
+        try {
+            Thread t = new Thread(() -> {
+                try {
+                    log.info("T2 starting - "+Thread.currentThread());
+                    localTool.connect();
+                    log.info("T2 executing");
+                    //localTool.connect();
+                    
localTool.execScript(ImmutableMap.of(PROP_OUT_STREAM.getName(), System.out, 
PROP_ERR_STREAM.getName(), System.err),
+                            ImmutableList.of(
+                                    "echo hello world",
+                                    "ls /path/to/does-not-exist || echo no ls",
+                                    "sleep 10",
+                                    "echo slept")
+                    );
+                } catch (Exception e) {
+                    log.info("T2 error", e);
+                } finally {
+                    log.info("T2 ending - 
"+Thread.currentThread().isInterrupted());
+                }
+            });
+            log.info("STARTING");
+            t.start();
+            Time.sleep(FIVE_SECONDS);
+            log.info("INTERRUPTING");
+            t.interrupt();
+            Time.sleep(ONE_SECOND);
+            Arrays.asList(t.getStackTrace()).forEach(traceElement -> 
System.out.println(traceElement));
+            log.info("JOINING");
+            Stopwatch s = Stopwatch.createStarted();
+            t.join();
+            if (Duration.of(s.elapsed()).isLongerThan(ONE_SECOND)) {
+                Asserts.fail("Join should have been immediate as other thread 
was interrupted, but instead took "+Duration.of(s.elapsed()));
+            }
+        } catch (Exception e) {
+            log.info("FAILED", e);
+            Asserts.fail("Shouldn't throw");
+        }
+        log.info("ENDING");
+    }
+
 }
diff --git 
a/utils/common/src/main/java/org/apache/brooklyn/util/exceptions/Exceptions.java
 
b/utils/common/src/main/java/org/apache/brooklyn/util/exceptions/Exceptions.java
index b91279b51e..29a9300b63 100644
--- 
a/utils/common/src/main/java/org/apache/brooklyn/util/exceptions/Exceptions.java
+++ 
b/utils/common/src/main/java/org/apache/brooklyn/util/exceptions/Exceptions.java
@@ -20,6 +20,7 @@ package org.apache.brooklyn.util.exceptions;
 
 import static com.google.common.base.Preconditions.checkNotNull;
 import static com.google.common.base.Predicates.instanceOf;
+import static com.google.common.base.Throwables.getRootCause;
 
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.UndeclaredThrowableException;
@@ -119,11 +120,10 @@ public class Exceptions {
      * <li> wraps as PropagatedRuntimeException for easier filtering
      */
     public static RuntimeException propagate(Throwable throwable) {
-        if (throwable instanceof InterruptedException) {
-            throw new RuntimeInterruptedException((InterruptedException) 
throwable);
-        } else if (throwable instanceof RuntimeInterruptedException) {
+        if (throwable instanceof InterruptedException || throwable instanceof 
RuntimeInterruptedException || Exceptions.isRootCauseIsInterruption(throwable)) 
{
+            // previously only interrupted if we caught RuntimeInterrupted; 
but best seems to be to always set the interrupted bit
             Thread.currentThread().interrupt();
-            throw (RuntimeInterruptedException) throwable;
+            throw new RuntimeInterruptedException(throwable);
         }
         Throwables.propagateIfPossible(checkNotNull(throwable));
         throw new PropagatedRuntimeException(throwable);
@@ -221,6 +221,18 @@ public class Exceptions {
         return IsFatalPredicate.INSTANCE;
     }
 
+    public static void handleRootCauseIsInterruption(Throwable e) {
+        if (isRootCauseIsInterruption(e)) {
+            Thread.currentThread().interrupt();
+            throw new RuntimeInterruptedException(e);
+        }
+    }
+
+    public static boolean isRootCauseIsInterruption(Throwable e) {
+        Throwable root = getRootCause(e);
+        return (root instanceof InterruptedException || root instanceof 
RuntimeInterruptedException);
+    }
+
     private static class IsFatalPredicate implements Predicate<Throwable> {
         private static final IsFatalPredicate INSTANCE = new 
IsFatalPredicate();
         

Reply via email to