This is an automated email from the ASF dual-hosted git repository.

dcapwell pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra.git


The following commit(s) were added to refs/heads/trunk by this push:
     new ffc0f01b0e Add a concept for retrying messages
ffc0f01b0e is described below

commit ffc0f01b0eede35518c2838d3c21f440d871c08a
Author: David Capwell <[email protected]>
AuthorDate: Thu Aug 29 13:45:48 2024 -0700

    Add a concept for retrying messages
    
    patch by David Capwell; reviewed by Alex Petrov for CASSANDRA-19856
---
 .../org/apache/cassandra/db/SystemKeyspace.java    |  13 +-
 .../org/apache/cassandra/net/MessageDelivery.java  | 160 ++++++++
 .../cassandra/repair/messages/RepairMessage.java   | 138 ++++---
 .../org/apache/cassandra/tcm/RemoteProcessor.java  |  86 +++--
 src/java/org/apache/cassandra/tcm/Retry.java       |   7 +-
 src/java/org/apache/cassandra/utils/Backoff.java   |  53 ++-
 .../org/apache/cassandra/utils/TriFunction.java    |  25 ++
 .../concurrent/SimulatedExecutorFactory.java       | 162 +++++++-
 .../apache/cassandra/net/MessageDeliveryTest.java  | 225 ++++++++++++
 .../cassandra/net/SimulatedMessageDelivery.java    | 408 +++++++++++++++++++++
 .../repair/messages/RepairMessageTest.java         |   5 +-
 11 files changed, 1139 insertions(+), 143 deletions(-)

diff --git a/src/java/org/apache/cassandra/db/SystemKeyspace.java 
b/src/java/org/apache/cassandra/db/SystemKeyspace.java
index 8709453280..05a2437546 100644
--- a/src/java/org/apache/cassandra/db/SystemKeyspace.java
+++ b/src/java/org/apache/cassandra/db/SystemKeyspace.java
@@ -120,6 +120,7 @@ import org.apache.cassandra.utils.FBUtilities;
 import org.apache.cassandra.utils.MD5Digest;
 import org.apache.cassandra.utils.Pair;
 import org.apache.cassandra.utils.TimeUUID;
+import org.apache.cassandra.utils.TriFunction;
 import org.apache.cassandra.utils.concurrent.Future;
 
 import static java.lang.String.format;
@@ -1938,8 +1939,8 @@ public final class SystemKeyspace
         int counter = 0;
         for (UntypedResultSet.Row row : resultSet)
         {
-            if 
(onLoaded.accept(MD5Digest.wrap(row.getByteArray("prepared_id")),
-                                row.getString("query_string"),
+            if (onLoaded.apply(MD5Digest.wrap(row.getByteArray("prepared_id")),
+                               row.getString("query_string"),
                                 row.has("logged_keyspace") ? 
row.getString("logged_keyspace") : null))
                 counter++;
         }
@@ -1953,18 +1954,14 @@ public final class SystemKeyspace
         int counter = 0;
         for (UntypedResultSet.Row row : resultSet)
         {
-            if 
(onLoaded.accept(MD5Digest.wrap(row.getByteArray("prepared_id")),
-                                row.getString("query_string"),
+            if (onLoaded.apply(MD5Digest.wrap(row.getByteArray("prepared_id")),
+                               row.getString("query_string"),
                                 row.has("logged_keyspace") ? 
row.getString("logged_keyspace") : null))
                 counter++;
         }
         return counter;
     }
 
-    public static interface TriFunction<A, B, C, D> {
-        D accept(A var1, B var2, C var3);
-    }
-
     public static void saveTopPartitions(TableMetadata metadata, String 
topType, Collection<TopPartitionTracker.TopPartition> topPartitions, long 
lastUpdate)
     {
         String cql = String.format("INSERT INTO %s.%s (keyspace_name, 
table_name, top_type, top, last_update) values (?, ?, ?, ?, ?)", 
SchemaConstants.SYSTEM_KEYSPACE_NAME, TOP_PARTITIONS);
diff --git a/src/java/org/apache/cassandra/net/MessageDelivery.java 
b/src/java/org/apache/cassandra/net/MessageDelivery.java
index 0b7890c08d..0d052cb3d8 100644
--- a/src/java/org/apache/cassandra/net/MessageDelivery.java
+++ b/src/java/org/apache/cassandra/net/MessageDelivery.java
@@ -19,19 +19,27 @@
 package org.apache.cassandra.net;
 
 import java.util.Collection;
+import java.util.Iterator;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
 
+import javax.annotation.Nullable;
+
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.cassandra.config.DatabaseDescriptor;
 import org.apache.cassandra.exceptions.RequestFailureReason;
 import org.apache.cassandra.locator.InetAddressAndPort;
+import org.apache.cassandra.utils.Backoff;
 import org.apache.cassandra.utils.Pair;
 import org.apache.cassandra.utils.concurrent.Accumulator;
+import org.apache.cassandra.utils.concurrent.AsyncPromise;
 import org.apache.cassandra.utils.concurrent.CountDownLatch;
 import org.apache.cassandra.utils.concurrent.Future;
+import org.apache.cassandra.utils.concurrent.Promise;
+
+import static org.apache.cassandra.net.MessageFlag.CALL_BACK_ON_FAILURE;
 
 public interface MessageDelivery
 {
@@ -74,9 +82,161 @@ public interface MessageDelivery
     public <REQ, RSP> void sendWithCallback(Message<REQ> message, 
InetAddressAndPort to, RequestCallback<RSP> cb);
     public <REQ, RSP> void sendWithCallback(Message<REQ> message, 
InetAddressAndPort to, RequestCallback<RSP> cb, ConnectionType 
specifyConnection);
     public <REQ, RSP> Future<Message<RSP>> sendWithResult(Message<REQ> 
message, InetAddressAndPort to);
+
+    public default <REQ, RSP> Future<Message<RSP>> sendWithRetries(Backoff 
backoff, RetryScheduler retryThreads,
+                                                                   Verb verb, 
REQ request,
+                                                                   
Iterator<InetAddressAndPort> candidates,
+                                                                   
RetryPredicate shouldRetry,
+                                                                   
RetryErrorMessage errorMessage)
+    {
+        Promise<Message<RSP>> promise = new AsyncPromise<>();
+        this.<REQ, RSP>sendWithRetries(backoff, retryThreads, verb, request, 
candidates,
+                                       (attempt, success, failure) -> {
+                                           if (failure != null) 
promise.tryFailure(failure);
+                                           else promise.trySuccess(success);
+                                       },
+                                       shouldRetry, errorMessage);
+        return promise;
+    }
+
+    public default <REQ, RSP> void sendWithRetries(Backoff backoff, 
RetryScheduler retryThreads,
+                                                   Verb verb, REQ request,
+                                                   
Iterator<InetAddressAndPort> candidates,
+                                                   OnResult<RSP> onResult,
+                                                   RetryPredicate shouldRetry,
+                                                   RetryErrorMessage 
errorMessage)
+    {
+        sendWithRetries(this, backoff, retryThreads, verb, request, 
candidates, onResult, shouldRetry, errorMessage, 0);
+    }
     public <V> void respond(V response, Message<?> message);
     public default void respondWithFailure(RequestFailureReason reason, 
Message<?> message)
     {
         send(Message.failureResponse(message.id(), message.expiresAtNanos(), 
reason), message.respondTo());
     }
+
+    interface OnResult<T>
+    {
+        void result(int attempt, @Nullable Message<T> success, @Nullable 
Throwable failure);
+    }
+
+    interface RetryPredicate
+    {
+        boolean test(int attempt, InetAddressAndPort from, 
RequestFailureReason failure);
+    }
+
+    interface RetryErrorMessage
+    {
+        String apply(int attempt, ResponseFailureReason retryFailure, 
@Nullable InetAddressAndPort from, @Nullable RequestFailureReason reason);
+    }
+
+    private static <REQ, RSP> void sendWithRetries(MessageDelivery messaging,
+                                                   Backoff backoff, 
RetryScheduler retryThreads,
+                                                   Verb verb, REQ request,
+                                                   
Iterator<InetAddressAndPort> candidates,
+                                                   OnResult<RSP> onResult,
+                                                   RetryPredicate shouldRetry,
+                                                   RetryErrorMessage 
errorMessage,
+                                                   int attempt)
+    {
+        if (Thread.currentThread().isInterrupted())
+        {
+            onResult.result(attempt, null, new 
InterruptedException(errorMessage.apply(attempt, 
ResponseFailureReason.Interrupted, null, null)));
+            return;
+        }
+        if (!candidates.hasNext())
+        {
+            onResult.result(attempt, null, new 
NoMoreCandidatesException(errorMessage.apply(attempt, 
ResponseFailureReason.NoMoreCandidates, null, null)));
+            return;
+        }
+        class Request implements RequestCallbackWithFailure<RSP>
+        {
+            @Override
+            public void onResponse(Message<RSP> msg)
+            {
+                onResult.result(attempt, msg, null);
+            }
+
+            @Override
+            public void onFailure(InetAddressAndPort from, 
RequestFailureReason failure)
+            {
+                if (!backoff.mayRetry(attempt))
+                {
+                    onResult.result(attempt, null, new 
MaxRetriesException(attempt, errorMessage.apply(attempt, 
ResponseFailureReason.MaxRetries, from, failure)));
+                    return;
+                }
+                if (!shouldRetry.test(attempt, from, failure))
+                {
+                    onResult.result(attempt, null, new 
FailedResponseException(from, failure, errorMessage.apply(attempt, 
ResponseFailureReason.Rejected, from, failure)));
+                    return;
+                }
+                try
+                {
+                    retryThreads.schedule(() -> sendWithRetries(messaging, 
backoff, retryThreads, verb, request, candidates, onResult, shouldRetry, 
errorMessage, attempt + 1),
+                                          backoff.computeWaitTime(attempt), 
backoff.unit());
+                }
+                catch (Throwable t)
+                {
+                    onResult.result(attempt, null, new 
FailedScheduleException(errorMessage.apply(attempt, 
ResponseFailureReason.FailedSchedule, from, failure), t));
+                }
+            }
+        }
+        messaging.sendWithCallback(Message.outWithFlag(verb, request, 
CALL_BACK_ON_FAILURE), candidates.next(), new Request());
+    }
+
+    enum ResponseFailureReason { MaxRetries, Rejected, NoMoreCandidates, 
Interrupted, FailedSchedule }
+
+    interface RetryScheduler
+    {
+        void schedule(Runnable command, long delay, TimeUnit unit);
+    }
+
+    enum ImmediateRetryScheduler implements RetryScheduler
+    {
+        instance;
+
+        @Override
+        public void schedule(Runnable command, long delay, TimeUnit unit)
+        {
+            command.run();
+        }
+    }
+
+    class NoMoreCandidatesException extends IllegalStateException
+    {
+        public NoMoreCandidatesException(String s)
+        {
+            super(s);
+        }
+    }
+
+    class FailedResponseException extends IllegalStateException
+    {
+        public final InetAddressAndPort from;
+        public final RequestFailureReason failure;
+
+        public FailedResponseException(InetAddressAndPort from, 
RequestFailureReason failure, String message)
+        {
+            super(message);
+            this.from = from;
+            this.failure = failure;
+        }
+    }
+
+    class MaxRetriesException extends IllegalStateException
+    {
+        public final int attempts;
+        public MaxRetriesException(int attempts, String message)
+        {
+            super(message);
+            this.attempts = attempts;
+        }
+    }
+
+    class FailedScheduleException extends IllegalStateException
+    {
+        public FailedScheduleException(String message, Throwable cause)
+        {
+            super(message, cause);
+        }
+    }
 }
diff --git a/src/java/org/apache/cassandra/repair/messages/RepairMessage.java 
b/src/java/org/apache/cassandra/repair/messages/RepairMessage.java
index f0cbf78f38..835f90fc68 100644
--- a/src/java/org/apache/cassandra/repair/messages/RepairMessage.java
+++ b/src/java/org/apache/cassandra/repair/messages/RepairMessage.java
@@ -23,11 +23,13 @@ import java.util.EnumSet;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
+import java.util.function.BiConsumer;
 import java.util.function.Supplier;
 
 import javax.annotation.Nullable;
 
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Iterators;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -50,8 +52,6 @@ import org.apache.cassandra.utils.NoSpamLogger;
 import org.apache.cassandra.utils.TimeUUID;
 import org.apache.cassandra.utils.concurrent.Future;
 
-import static org.apache.cassandra.net.MessageFlag.CALL_BACK_ON_FAILURE;
-
 /**
  * Base class of all repair related request/response messages.
  *
@@ -138,9 +138,7 @@ public abstract class RepairMessage
     {
         RepairRetrySpec retrySpec = DatabaseDescriptor.getRepairRetrySpec();
         RetrySpec spec = verb == Verb.VALIDATION_RSP ? 
retrySpec.getMerkleTreeResponseSpec() : retrySpec;
-        if (!spec.isEnabled())
-            return Backoff.None.INSTANCE;
-        return new Backoff.ExponentialBackoff(spec.maxAttempts.value, 
spec.baseSleepTime.toMilliseconds(), spec.maxSleepTime.toMilliseconds(), 
ctx.random().get()::nextDouble);
+        return Backoff.fromConfig(ctx, spec);
     }
 
     public static Supplier<Boolean> notDone(Future<?> f)
@@ -155,98 +153,94 @@ public abstract class RepairMessage
 
     public static <T> void sendMessageWithRetries(SharedContext ctx, 
Supplier<Boolean> allowRetry, RepairMessage request, Verb verb, 
InetAddressAndPort endpoint, RequestCallback<T> finalCallback)
     {
-        sendMessageWithRetries(ctx, backoff(ctx, verb), allowRetry, request, 
verb, endpoint, finalCallback, 0);
+        sendMessageWithRetries(ctx, backoff(ctx, verb), allowRetry, request, 
verb, endpoint, finalCallback);
     }
 
     public static <T> void sendMessageWithRetries(SharedContext ctx, 
RepairMessage request, Verb verb, InetAddressAndPort endpoint, 
RequestCallback<T> finalCallback)
     {
-        sendMessageWithRetries(ctx, backoff(ctx, verb), always(), request, 
verb, endpoint, finalCallback, 0);
+        sendMessageWithRetries(ctx, backoff(ctx, verb), always(), request, 
verb, endpoint, finalCallback);
     }
 
     public static void sendMessageWithRetries(SharedContext ctx, RepairMessage 
request, Verb verb, InetAddressAndPort endpoint)
     {
-        sendMessageWithRetries(ctx, backoff(ctx, verb), always(), request, 
verb, endpoint, NOOP_CALLBACK, 0);
+        sendMessageWithRetries(ctx, backoff(ctx, verb), always(), request, 
verb, endpoint, NOOP_CALLBACK);
     }
 
     public static void sendMessageWithRetries(SharedContext ctx, 
Supplier<Boolean> allowRetry, RepairMessage request, Verb verb, 
InetAddressAndPort endpoint)
     {
-        sendMessageWithRetries(ctx, backoff(ctx, verb), allowRetry, request, 
verb, endpoint, NOOP_CALLBACK, 0);
+        sendMessageWithRetries(ctx, backoff(ctx, verb), allowRetry, request, 
verb, endpoint, NOOP_CALLBACK);
     }
 
     @VisibleForTesting
-    static <T> void sendMessageWithRetries(SharedContext ctx, Backoff backoff, 
Supplier<Boolean> allowRetry, RepairMessage request, Verb verb, 
InetAddressAndPort endpoint, RequestCallback<T> finalCallback, int attempt)
+    static <T> void sendMessageWithRetries(SharedContext ctx, Backoff backoff, 
Supplier<Boolean> allowRetry, RepairMessage request, Verb verb, 
InetAddressAndPort endpoint, RequestCallback<T> finalCallback)
     {
         if (!ALLOWS_RETRY.contains(verb))
             throw new AssertionError("Repair verb " + verb + " does not 
support retry, but a request to send with retry was given!");
-        RequestCallback<T> callback = new RequestCallback<>()
-        {
-            @Override
-            public void onResponse(Message<T> msg)
+        BiConsumer<Integer, RequestFailureReason > maybeRecordRetry = 
(attempt, reason) -> {
+            if (attempt <= 0)
+                return;
+            // we don't know what the prefix kind is... so use NONE... this 
impacts logPrefix as it will cause us to use "repair" rather than "preview 
repair" which may not be correct... but close enough...
+            String prefix = 
PreviewKind.NONE.logPrefix(request.parentRepairSession());
+            RepairMetrics.retry(verb, attempt);
+            if (reason == null)
             {
-                maybeRecordRetry(null);
-                finalCallback.onResponse(msg);
+                noSpam.info("{} Retry of repair verb " + verb + " was 
successful after {} attempts", prefix, attempt);
             }
-
-            @Override
-            public void onFailure(InetAddressAndPort from, 
RequestFailureReason failureReason)
+            else if (reason == RequestFailureReason.TIMEOUT)
             {
-                ErrorHandling allowed = errorHandlingSupported(ctx, endpoint, 
verb, request.parentRepairSession());
-                switch (allowed)
-                {
-                    case NONE:
-                        logger.error("[#{}] {} failed on {}: {}", 
request.parentRepairSession(), verb, from, failureReason);
-                        return;
-                    case TIMEOUT:
-                        finalCallback.onFailure(from, failureReason);
-                        return;
-                    case RETRY:
-                        int maxAttempts = backoff.maxAttempts();
-                        if (failureReason == RequestFailureReason.TIMEOUT && 
attempt < maxAttempts && allowRetry.get())
-                        {
-                            ctx.optionalTasks().schedule(() -> 
sendMessageWithRetries(ctx, backoff, allowRetry, request, verb, endpoint, 
finalCallback, attempt + 1),
-                                                         
backoff.computeWaitTime(attempt), backoff.unit());
-                            return;
-                        }
-                        maybeRecordRetry(failureReason);
-                        finalCallback.onFailure(from, failureReason);
-                        return;
-                    default:
-                        throw new AssertionError("Unknown error handler: " + 
allowed);
-                }
+                noSpam.warn("{} Timeout for repair verb " + verb + "; could 
not complete within {} attempts", prefix, attempt);
+                RepairMetrics.retryTimeout(verb);
             }
-
-            private void maybeRecordRetry(@Nullable RequestFailureReason 
reason)
+            else
             {
-                if (attempt <= 0)
-                    return;
-                // we don't know what the prefix kind is... so use NONE... 
this impacts logPrefix as it will cause us to use "repair" rather than "preview 
repair" which may not be correct... but close enough...
-                String prefix = 
PreviewKind.NONE.logPrefix(request.parentRepairSession());
-                RepairMetrics.retry(verb, attempt);
-                if (reason == null)
-                {
-                    noSpam.info("{} Retry of repair verb " + verb + " was 
successful after {} attempts", prefix, attempt);
-                }
-                else if (reason == RequestFailureReason.TIMEOUT)
-                {
-                    noSpam.warn("{} Timeout for repair verb " + verb + "; 
could not complete within {} attempts", prefix, attempt);
-                    RepairMetrics.retryTimeout(verb);
-                }
-                else
-                {
-                    noSpam.warn("{} {} failure for repair verb " + verb + "; 
could not complete within {} attempts", prefix, reason, attempt);
-                    RepairMetrics.retryFailure(verb);
-                }
-            }
-
-            @Override
-            public boolean invokeOnFailure()
-            {
-                return true;
+                noSpam.warn("{} {} failure for repair verb " + verb + "; could 
not complete within {} attempts", prefix, reason, attempt);
+                RepairMetrics.retryFailure(verb);
             }
         };
-        ctx.messaging().sendWithCallback(Message.outWithFlag(verb, request, 
CALL_BACK_ON_FAILURE),
-                                         endpoint,
-                                         callback);
+        ctx.messaging().sendWithRetries(backoff, ctx.optionalTasks()::schedule,
+                                        verb, request, 
Iterators.cycle(endpoint),
+                                        (int attempt, Message<T> msg, 
Throwable failure) -> {
+                                            if (failure == null)
+                                            {
+                                                
maybeRecordRetry.accept(attempt, null);
+                                                finalCallback.onResponse(msg);
+                                            }
+                                        },
+                                        (attempt, from, failure) -> {
+                                            ErrorHandling allowed = 
errorHandlingSupported(ctx, endpoint, verb, request.parentRepairSession());
+                                            switch (allowed)
+                                            {
+                                                case NONE:
+                                                    logger.error("[#{}] {} 
failed on {}: {}", request.parentRepairSession(), verb, from, failure);
+                                                    return false;
+                                                case TIMEOUT:
+                                                    
finalCallback.onFailure(from, failure);
+                                                    return false;
+                                                case RETRY:
+                                                    if (failure == 
RequestFailureReason.TIMEOUT && allowRetry.get())
+                                                        return true;
+                                                    
maybeRecordRetry.accept(attempt, failure);
+                                                    
finalCallback.onFailure(from, failure);
+                                                    return false;
+                                                default:
+                                                    throw new 
AssertionError("Unknown error handler: " + allowed);
+                                            }
+                                        },
+                                        (attempt, retryReason, from, failure) 
-> {
+                                            switch (retryReason)
+                                            {
+                                                case MaxRetries:
+                                                    
maybeRecordRetry.accept(attempt, failure);
+                                                    
finalCallback.onFailure(from, failure);
+                                                    return null;
+                                                case Interrupted:
+                                                case Rejected:
+                                                case FailedSchedule:
+                                                    return null;
+                                                default:
+                                                    throw new 
UnsupportedOperationException(retryReason.name());
+                                            }
+                                        });
     }
 
     public static void sendMessageWithFailureCB(SharedContext ctx, 
Supplier<Boolean> allowRetry, RepairMessage request, Verb verb, 
InetAddressAndPort endpoint, RepairFailureCallback failureCallback)
diff --git a/src/java/org/apache/cassandra/tcm/RemoteProcessor.java 
b/src/java/org/apache/cassandra/tcm/RemoteProcessor.java
index 79f0b7cf7d..0ea055b908 100644
--- a/src/java/org/apache/cassandra/tcm/RemoteProcessor.java
+++ b/src/java/org/apache/cassandra/tcm/RemoteProcessor.java
@@ -39,6 +39,7 @@ import org.apache.cassandra.gms.FailureDetector;
 import org.apache.cassandra.locator.InetAddressAndPort;
 import org.apache.cassandra.metrics.TCMMetrics;
 import org.apache.cassandra.net.Message;
+import org.apache.cassandra.net.MessageDelivery;
 import org.apache.cassandra.net.MessagingService;
 import org.apache.cassandra.net.RequestCallbackWithFailure;
 import org.apache.cassandra.net.Verb;
@@ -47,6 +48,7 @@ import org.apache.cassandra.tcm.log.Entry;
 import org.apache.cassandra.tcm.log.LocalLog;
 import org.apache.cassandra.tcm.log.LogState;
 import org.apache.cassandra.utils.AbstractIterator;
+import org.apache.cassandra.utils.Backoff;
 import org.apache.cassandra.utils.FBUtilities;
 import org.apache.cassandra.utils.concurrent.AsyncPromise;
 import org.apache.cassandra.utils.concurrent.Future;
@@ -201,51 +203,45 @@ public final class RemoteProcessor implements Processor
 
     public static <REQ, RSP> void sendWithCallbackAsync(Promise<RSP> promise, 
Verb verb, REQ request, CandidateIterator candidates, Retry retryPolicy)
     {
-        class Request implements RequestCallbackWithFailure<RSP>
-        {
-            void retry()
-            {
-                if (promise.isCancelled() || promise.isDone())
-                    return;
-                if (Thread.currentThread().isInterrupted())
-                    promise.setFailure(new InterruptedException());
-                if (!candidates.hasNext())
-                    promise.tryFailure(new 
IllegalStateException(String.format("Ran out of candidates while sending %s: 
%s", verb, candidates)));
-
-                MessagingService.instance().sendWithCallback(Message.out(verb, 
request), candidates.next(), this);
-            }
-
-            @Override
-            public void onResponse(Message<RSP> msg)
-            {
-                promise.trySuccess(msg.payload);
-            }
-
-            @Override
-            public void onFailure(InetAddressAndPort from, 
RequestFailureReason reason)
-            {
-                if (reason == RequestFailureReason.NOT_CMS)
-                {
-                    logger.debug("{} is not a member of the CMS, querying it 
to discover current membership", from);
-                    DiscoveredNodes cms = tryDiscover(from);
-                    candidates.addCandidates(cms);
-                    candidates.timeout(from);
-                    logger.debug("Got CMS from {}: {}, retrying on: {}", from, 
cms, candidates);
-                }
-                else
-                {
-                    candidates.timeout(from);
-                    logger.warn("Got error from {}: {} when sending {}, 
retrying on {}", from, reason, verb, candidates);
-                }
-
-                if (retryPolicy.reachedMax())
-                    promise.tryFailure(new 
IllegalStateException(String.format("Could not succeed sending %s to %s after 
%d tries", verb, candidates, retryPolicy.tries)));
-                else
-                    retry();
-            }
-        }
-
-        new Request().retry();
+        //TODO (now): the retry defines how long to wait for a retry, but the 
old behavior scheduled the message right away... should this be delayed as well?
+        MessagingService.instance().<REQ, 
RSP>sendWithRetries(Backoff.fromRetry(retryPolicy), 
MessageDelivery.ImmediateRetryScheduler.instance,
+                                                              verb, request, 
candidates,
+                                                              (attempt, 
success, failure) -> {
+                                                                  if (failure 
!= null) promise.tryFailure(failure);
+                                                                  else 
promise.trySuccess(success.payload);
+                                                              },
+                                                              (attempt, from, 
failure) -> {
+                                                                  if 
(promise.isDone() || promise.isCancelled())
+                                                                      return 
false;
+                                                                  if (failure 
== RequestFailureReason.NOT_CMS)
+                                                                  {
+                                                                      
logger.debug("{} is not a member of the CMS, querying it to discover current 
membership", from);
+                                                                      
DiscoveredNodes cms = tryDiscover(from);
+                                                                      
candidates.addCandidates(cms);
+                                                                      
candidates.timeout(from);
+                                                                      
logger.debug("Got CMS from {}: {}, retrying on: {}", from, cms, candidates);
+                                                                  }
+                                                                  else
+                                                                  {
+                                                                      
candidates.timeout(from);
+                                                                      
logger.warn("Got error from {}: {} when sending {}, retrying on {}", from, 
failure, verb, candidates);
+                                                                  }
+                                                                  return true;
+                                                              },
+                                                              (attempt, 
reason, from, failure) -> {
+                                                                  switch 
(reason)
+                                                                  {
+                                                                      case 
NoMoreCandidates:
+                                                                          
return String.format("Ran out of candidates while sending %s: %s", verb, 
candidates);
+                                                                      case 
MaxRetries:
+                                                                          
return String.format("Could not succeed sending %s to %s after %d tries", verb, 
candidates, retryPolicy.tries);
+                                                                      case 
Interrupted:
+                                                                      case 
FailedSchedule:
+                                                                          
return null;
+                                                                      default:
+                                                                          
throw new UnsupportedOperationException(reason.name());
+                                                                  }
+                                                              });
     }
 
     private static DiscoveredNodes tryDiscover(InetAddressAndPort ep)
diff --git a/src/java/org/apache/cassandra/tcm/Retry.java 
b/src/java/org/apache/cassandra/tcm/Retry.java
index a1215fd649..703e590466 100644
--- a/src/java/org/apache/cassandra/tcm/Retry.java
+++ b/src/java/org/apache/cassandra/tcm/Retry.java
@@ -58,10 +58,15 @@ public abstract class Retry
     }
 
     public void maybeSleep()
+    {
+        sleepUninterruptibly(computeSleepFor(), TimeUnit.MILLISECONDS);
+    }
+
+    public long computeSleepFor()
     {
         tries++;
         retryMeter.mark();
-        sleepUninterruptibly(sleepFor(), TimeUnit.MILLISECONDS);
+        return sleepFor();
     }
 
     protected abstract long sleepFor();
diff --git a/src/java/org/apache/cassandra/utils/Backoff.java 
b/src/java/org/apache/cassandra/utils/Backoff.java
index 2f0b7e2c9c..7974dbf346 100644
--- a/src/java/org/apache/cassandra/utils/Backoff.java
+++ b/src/java/org/apache/cassandra/utils/Backoff.java
@@ -21,23 +21,55 @@ package org.apache.cassandra.utils;
 import java.util.concurrent.TimeUnit;
 import java.util.function.DoubleSupplier;
 
+import org.apache.cassandra.config.RetrySpec;
+import org.apache.cassandra.repair.SharedContext;
+import org.apache.cassandra.tcm.Retry;
+
 public interface Backoff
 {
-    /**
-     * @return max attempts allowed, {@code == 0} implies no retries are 
allowed
-     */
-    int maxAttempts();
-    long computeWaitTime(int retryCount);
+    boolean mayRetry(int attempt);
+    long computeWaitTime(int attempt);
     TimeUnit unit();
 
+    static Backoff fromRetry(Retry retry)
+    {
+        return new Backoff()
+        {
+            @Override
+            public boolean mayRetry(int attempt)
+            {
+                return !retry.reachedMax();
+            }
+
+            @Override
+            public long computeWaitTime(int retryCount)
+            {
+                return retry.computeSleepFor();
+            }
+
+            @Override
+            public TimeUnit unit()
+            {
+                return TimeUnit.MILLISECONDS;
+            }
+        };
+    }
+
+    static Backoff fromConfig(SharedContext ctx, RetrySpec spec)
+    {
+        if (!spec.isEnabled())
+            return Backoff.None.INSTANCE;
+        return new Backoff.ExponentialBackoff(spec.maxAttempts.value, 
spec.baseSleepTime.toMilliseconds(), spec.maxSleepTime.toMilliseconds(), 
ctx.random().get()::nextDouble);
+    }
+
     enum None implements Backoff
     {
         INSTANCE;
 
         @Override
-        public int maxAttempts()
+        public boolean mayRetry(int attempt)
         {
-            return 0;
+            return false;
         }
 
         @Override
@@ -68,12 +100,17 @@ public interface Backoff
             this.randomSource = randomSource;
         }
 
-        @Override
         public int maxAttempts()
         {
             return maxAttempts;
         }
 
+        @Override
+        public boolean mayRetry(int attempt)
+        {
+            return attempt < maxAttempts;
+        }
+
         @Override
         public long computeWaitTime(int retryCount)
         {
diff --git a/src/java/org/apache/cassandra/utils/TriFunction.java 
b/src/java/org/apache/cassandra/utils/TriFunction.java
new file mode 100644
index 0000000000..c280850dad
--- /dev/null
+++ b/src/java/org/apache/cassandra/utils/TriFunction.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.utils;
+
+@FunctionalInterface
+public interface TriFunction<A, B, C, D>
+{
+    D apply(A var1, B var2, C var3);
+}
diff --git 
a/test/unit/org/apache/cassandra/concurrent/SimulatedExecutorFactory.java 
b/test/unit/org/apache/cassandra/concurrent/SimulatedExecutorFactory.java
index ffb1a0fabd..e206af390b 100644
--- a/test/unit/org/apache/cassandra/concurrent/SimulatedExecutorFactory.java
+++ b/test/unit/org/apache/cassandra/concurrent/SimulatedExecutorFactory.java
@@ -18,6 +18,7 @@
 
 package org.apache.cassandra.concurrent;
 
+import java.sql.Timestamp;
 import java.util.Collections;
 import java.util.LinkedList;
 import java.util.List;
@@ -33,13 +34,25 @@ import java.util.concurrent.RejectedExecutionHandler;
 import java.util.concurrent.RunnableFuture;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
 import java.util.function.LongSupplier;
 
+import javax.annotation.Nullable;
+
 import accord.utils.Gens;
 import accord.utils.RandomSource;
 import org.apache.cassandra.utils.Clock;
+import org.apache.cassandra.utils.Generators;
+import org.apache.cassandra.utils.concurrent.Future;
+import org.apache.cassandra.utils.concurrent.UncheckedInterruptedException;
 
 import static java.util.concurrent.TimeUnit.NANOSECONDS;
+import static 
org.apache.cassandra.concurrent.InfiniteLoopExecutor.InternalState.SHUTTING_DOWN_NOW;
+import static 
org.apache.cassandra.concurrent.InfiniteLoopExecutor.InternalState.TERMINATED;
+import static org.apache.cassandra.concurrent.Interruptible.State.INTERRUPTED;
+import static org.apache.cassandra.concurrent.Interruptible.State.NORMAL;
+import static 
org.apache.cassandra.concurrent.Interruptible.State.SHUTTING_DOWN;
+import static org.apache.cassandra.utils.Generators.toGen;
 
 public class SimulatedExecutorFactory implements ExecutorFactory, Clock
 {
@@ -79,22 +92,51 @@ public class SimulatedExecutorFactory implements 
ExecutorFactory, Clock
 
     private final RandomSource rs;
     private final long startTimeNanos;
+    @Nullable
+    private final Consumer<Throwable> onError;
     private final PriorityQueue<Item> queue = new PriorityQueue<>();
     private long seq = 0;
     private long nowNanos;
     private int repeatedTasks = 0;
 
+    public SimulatedExecutorFactory(RandomSource rs, Consumer<Throwable> 
onError)
+    {
+        this(rs, 
toGen(Generators.TIMESTAMP_GEN.map(Timestamp::getTime)).mapToLong(TimeUnit.MILLISECONDS::toNanos).next(rs),
 onError);
+    }
+
+    public SimulatedExecutorFactory(RandomSource rs)
+    {
+        this(rs, null);
+    }
+
     public SimulatedExecutorFactory(RandomSource rs, long startTimeNanos)
+    {
+        this(rs, startTimeNanos, null);
+    }
+
+    public SimulatedExecutorFactory(RandomSource rs, long startTimeNanos, 
Consumer<Throwable> onError)
     {
         this.rs = rs;
         this.startTimeNanos = startTimeNanos;
+        this.onError = onError;
     }
 
-    public boolean processOne()
+    private void maybeAddFailureListener(Future<?> task)
+    {
+        if (onError == null) return;
+        task.addCallback((s, f) -> {
+            if (f != null)
+                onError.accept(f);
+        });
+    }
+
+    public boolean hasWork()
+    {
+        return queue.size() > repeatedTasks;
+    }
+
+    public boolean processAny()
     {
-        // if we count the repeated tasks, then processAll will never complete
-        if (queue.size() == repeatedTasks)
-            return false;
         Item item = queue.poll();
         if (item == null)
             return false;
@@ -103,6 +145,21 @@ public class SimulatedExecutorFactory implements 
ExecutorFactory, Clock
         return true;
     }
 
+    public boolean processOne()
+    {
+        // if we count the repeated tasks, then processAll will never complete
+        if (queue.size() == repeatedTasks)
+            return false;
+        return processAny();
+    }
+
+    public void processAll()
+    {
+        while (processOne())
+        {
+        }
+    }
+
     @Override
     public long nanoTime()
     {
@@ -153,9 +210,92 @@ public class SimulatedExecutorFactory implements 
ExecutorFactory, Clock
     }
 
     @Override
-    public Interruptible infiniteLoop(String name, Interruptible.Task task, 
InfiniteLoopExecutor.SimulatorSafe simulatorSafe, InfiniteLoopExecutor.Daemon 
daemon, InfiniteLoopExecutor.Interrupts interrupts)
+    public Interruptible infiniteLoop(String name,
+                                      Interruptible.Task task,
+                                      InfiniteLoopExecutor.SimulatorSafe 
simulatorSafe,
+                                      InfiniteLoopExecutor.Daemon daemon,
+                                      InfiniteLoopExecutor.Interrupts 
interrupts)
     {
-        throw new UnsupportedOperationException("TODO");
+        var delegate = new UnorderedScheduledExecutorService();
+        class Capture { UnorderedScheduledExecutorService.ScheduledFuture<?> 
f;}
+        Capture c = new Capture();
+        class I implements Interruptible
+        {
+            private Object state = NORMAL;
+            private boolean interrupted = false;
+            private void runOne()
+            {
+                Object cur = state;
+                if (cur == SHUTTING_DOWN_NOW || cur == SHUTTING_DOWN)
+                {
+                    state = TERMINATED;
+                    if (c.f != null)
+                        c.f.cancel(false);
+                    return;
+                }
+
+                if (cur == NORMAL && interrupted) cur = INTERRUPTED;
+                try
+                {
+                    task.run((State) cur);
+                    interrupted = false;
+                }
+                catch (TerminateException ignore)
+                {
+                    state = TERMINATED;
+                    if (c.f != null)
+                        c.f.cancel(false);
+                }
+                catch (UncheckedInterruptedException | InterruptedException e)
+                {
+                    interrupted = false;
+                    state = TERMINATED;
+                    if (c.f != null)
+                        c.f.cancel(false);
+                }
+                catch (Throwable t)
+                {
+                    if (onError != null)
+                        onError.accept(t);
+                }
+            }
+
+            @Override
+            public void interrupt()
+            {
+                interrupted = true;
+            }
+
+            @Override
+            public boolean isTerminated()
+            {
+                return state == TERMINATED;
+            }
+
+            @Override
+            public void shutdown()
+            {
+                if (state != TERMINATED && state != SHUTTING_DOWN_NOW)
+                    state = SHUTTING_DOWN;
+            }
+
+            @Override
+            public Object shutdownNow()
+            {
+                if (state != TERMINATED)
+                    state = SHUTTING_DOWN_NOW;
+                return null;
+            }
+
+            @Override
+            public boolean awaitTermination(long timeout, TimeUnit units)
+            {
+                return isTerminated();
+            }
+        }
+        I i = new I();
+        c.f = delegate.scheduleAtFixedRate(i::runOne, 0, 0, NANOSECONDS);
+        return i;
     }
 
     @Override
@@ -329,7 +469,9 @@ public class SimulatedExecutorFactory implements 
ExecutorFactory, Clock
         public void execute(Runnable command)
         {
             checkNotShutdown();
-            queue.add(new Item(nowWithJitter(), 
SimulatedExecutorFactory.this.seq++, taskFor(command)));
+            var action = taskFor(command);
+            maybeAddFailureListener(action);
+            queue.add(new Item(nowWithJitter(), 
SimulatedExecutorFactory.this.seq++, action));
         }
 
         protected void checkNotShutdown()
@@ -365,6 +507,7 @@ public class SimulatedExecutorFactory implements 
ExecutorFactory, Clock
             if (next == null)
                 return;
 
+            maybeAddFailureListener(next.action);
             next.action.addCallback((s, f) -> afterExecution());
             queue.add(next);
         }
@@ -461,6 +604,8 @@ public class SimulatedExecutorFactory implements 
ExecutorFactory, Clock
                     catch (Throwable t)
                     {
                         tryFailure(t);
+                        if (onError != null)
+                            onError.accept(t);
                     }
                 }
             }
@@ -477,6 +622,7 @@ public class SimulatedExecutorFactory implements 
ExecutorFactory, Clock
         {
             checkNotShutdown();
             ScheduledFuture<V> task = new ScheduledFuture<>(seq++, delay, 0, 
NANOSECONDS, callable);
+            maybeAddFailureListener(task);
             queue.add(new Item(nowWithJitter() + unit.toNanos(delay), 
task.sequenceNumber, task));
             return task;
         }
@@ -486,6 +632,7 @@ public class SimulatedExecutorFactory implements 
ExecutorFactory, Clock
         {
             checkNotShutdown();
             ScheduledFuture<?> task = new ScheduledFuture<>(seq++, 
initialDelay, period, unit, Executors.callable(command));
+            maybeAddFailureListener(task);
             repeatedTasks++;
             task.addCallback((s, f) -> repeatedTasks--);
             queue.add(new Item(nowWithJitter() + unit.toNanos(initialDelay), 
task.sequenceNumber, task));
@@ -497,6 +644,7 @@ public class SimulatedExecutorFactory implements 
ExecutorFactory, Clock
         {
             checkNotShutdown();
             ScheduledFuture<?> task = new ScheduledFuture<>(seq++, 
initialDelay, -delay, unit, Executors.callable(command));
+            maybeAddFailureListener(task);
             repeatedTasks++;
             task.addCallback((s, f) -> repeatedTasks--);
             queue.add(new Item(nowWithJitter() + unit.toNanos(initialDelay), 
task.sequenceNumber, task));
diff --git a/test/unit/org/apache/cassandra/net/MessageDeliveryTest.java 
b/test/unit/org/apache/cassandra/net/MessageDeliveryTest.java
new file mode 100644
index 0000000000..59d7106506
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/MessageDeliveryTest.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import com.google.common.collect.Iterators;
+import org.junit.Assert;
+import org.junit.Test;
+
+import accord.utils.RandomSource;
+import org.apache.cassandra.concurrent.ScheduledExecutorPlus;
+import org.apache.cassandra.concurrent.SimulatedExecutorFactory;
+import org.apache.cassandra.config.DatabaseDescriptor;
+import org.apache.cassandra.dht.Murmur3Partitioner;
+import org.apache.cassandra.exceptions.RequestFailureReason;
+import org.apache.cassandra.locator.InetAddressAndPort;
+import org.apache.cassandra.net.MessageDelivery.FailedResponseException;
+import org.apache.cassandra.net.MessageDelivery.MaxRetriesException;
+import org.apache.cassandra.net.SimulatedMessageDelivery.Action;
+import 
org.apache.cassandra.net.SimulatedMessageDelivery.SimulatedMessageReceiver;
+import org.apache.cassandra.tcm.ClusterMetadataService;
+import org.apache.cassandra.tcm.StubClusterMetadataService;
+import org.apache.cassandra.utils.Backoff;
+import org.mockito.Mockito;
+
+import static accord.utils.Property.qt;
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class MessageDeliveryTest
+{
+    private static final InetAddressAndPort ID1 = 
InetAddressAndPort.getByNameUnchecked("127.0.0.1");
+    private static final MessageDelivery.RetryErrorMessage RETRY_ERROR_MESSAGE 
= (i1, i2, i3, i4) -> null;
+    private static final MessageDelivery.RetryPredicate ALWAYS_RETRY = (i1, 
i2, i3) -> true;
+    private static final MessageDelivery.RetryPredicate ALWAYS_REJECT = (i1, 
i2, i3) -> false;
+
+    static
+    {
+        DatabaseDescriptor.clientInitialization();
+        DatabaseDescriptor.setPartitionerUnsafe(Murmur3Partitioner.instance);
+        
ClusterMetadataService.setInstance(StubClusterMetadataService.forTesting());
+    }
+
+    @Test
+    public void sendWithRetryFailsAfterMaxAttempts()
+    {
+        qt().check(rs -> {
+            List<Throwable> failures = new ArrayList<>();
+            SimulatedExecutorFactory factory = new 
SimulatedExecutorFactory(rs.fork(), failures::add);
+            ScheduledExecutorPlus scheduler = factory.scheduled("ignored");
+            MessageDelivery messaging = simulatedMessages(rs, scheduler, 
failures, (i1, i2, i3) -> Action.DROP);
+
+            int expectedRetries = 3;
+            Backoff backoff = new Backoff.ExponentialBackoff(expectedRetries, 
200, 1000, rs.fork()::nextDouble);
+
+            Future<Message<Void>> result = messaging.sendWithRetries(backoff,
+                                                                     
scheduler::schedule,
+                                                                     
Verb.ECHO_REQ, NoPayload.noPayload,
+                                                                     
Iterators.cycle(ID1),
+                                                                     
ALWAYS_RETRY,
+                                                                     
RETRY_ERROR_MESSAGE);
+            assertThat(result).isNotDone();
+            factory.processAll();
+            assertThat(result).isDone();
+
+            
assertThat(getMaxRetriesException(result).attempts).isEqualTo(expectedRetries);
+        });
+    }
+
+    @Test
+    public void sendWithRetryFirstAttempt()
+    {
+        qt().check(rs -> {
+            List<Throwable> failures = new ArrayList<>();
+            SimulatedExecutorFactory factory = new 
SimulatedExecutorFactory(rs.fork(), failures::add);
+            ScheduledExecutorPlus scheduler = factory.scheduled("ignored");
+            MessageDelivery messaging = simulatedMessages(rs, scheduler, 
failures, (i1, i2, i3) -> Action.DELIVER);
+
+            Backoff backoff = Mockito.mock(Backoff.class);
+
+            Future<Message<Void>> result = messaging.sendWithRetries(backoff,
+                                                                     
scheduler::schedule,
+                                                                     
Verb.ECHO_REQ, NoPayload.noPayload,
+                                                                     
Iterators.cycle(ID1),
+                                                                     
ALWAYS_RETRY,
+                                                                     
RETRY_ERROR_MESSAGE);
+            assertThat(result).isNotDone();
+            factory.processAll();
+            assertThat(result).isDone();
+            assertThat(result.get().header.verb).isEqualTo(Verb.ECHO_RSP);
+            Mockito.verify(backoff, 
Mockito.never()).mayRetry(Mockito.anyInt());
+            Mockito.verify(backoff, 
Mockito.never()).computeWaitTime(Mockito.anyInt());
+            Mockito.verify(backoff, Mockito.never()).unit();
+        });
+    }
+
+    @Test
+    public void sendWithRetry()
+    {
+        qt().check(rs -> {
+            List<Throwable> failures = new ArrayList<>();
+            SimulatedExecutorFactory factory = new 
SimulatedExecutorFactory(rs.fork(), failures::add);
+            ScheduledExecutorPlus scheduler = factory.scheduled("ignored");
+
+            int maxAttempts = 3;
+            int expectedAttempts = 1;
+            AtomicInteger attempts = new AtomicInteger(0);
+            MessageDelivery messaging = simulatedMessages(rs, scheduler, 
failures, (i1, i2, i3) -> attempts.incrementAndGet() >= (expectedAttempts + 1) 
? Action.DELIVER : Action.DROP);
+
+            Backoff backoff = Mockito.spy(new 
Backoff.ExponentialBackoff(maxAttempts, 200, 1000, rs.fork()::nextDouble));
+
+            Future<Message<Void>> result = messaging.sendWithRetries(backoff,
+                                                                     
scheduler::schedule,
+                                                                     
Verb.ECHO_REQ, NoPayload.noPayload,
+                                                                     
Iterators.cycle(ID1),
+                                                                     
ALWAYS_RETRY,
+                                                                     
RETRY_ERROR_MESSAGE);
+            assertThat(result).isNotDone();
+            factory.processAll();
+            assertThat(result).isDone();
+            assertThat(result.get().header.verb).isEqualTo(Verb.ECHO_RSP);
+            Mockito.verify(backoff, 
Mockito.times(expectedAttempts)).mayRetry(Mockito.anyInt());
+            Mockito.verify(backoff, 
Mockito.times(expectedAttempts)).computeWaitTime(Mockito.anyInt());
+            Mockito.verify(backoff, Mockito.times(expectedAttempts)).unit();
+        });
+    }
+
+    @Test
+    public void sendWithRetryDontAllowRetry()
+    {
+        qt().check(rs -> {
+            List<Throwable> failures = new ArrayList<>();
+            SimulatedExecutorFactory factory = new 
SimulatedExecutorFactory(rs.fork(), failures::add);
+            ScheduledExecutorPlus scheduler = factory.scheduled("ignored");
+
+            MessageDelivery messaging = simulatedMessages(rs, scheduler, 
failures, (i1, i2, i3) -> Action.DROP);
+
+            Backoff backoff = Mockito.spy(new Backoff.ExponentialBackoff(3, 
200, 1000, rs.fork()::nextDouble));
+
+            Future<Message<Void>> result = messaging.sendWithRetries(backoff,
+                                                                     
scheduler::schedule,
+                                                                     
Verb.ECHO_REQ, NoPayload.noPayload,
+                                                                     
Iterators.cycle(ID1),
+                                                                     
ALWAYS_REJECT,
+                                                                     
RETRY_ERROR_MESSAGE);
+            assertThat(result).isNotDone();
+            factory.processAll();
+            assertThat(result).isDone();
+            FailedResponseException e = getFailedResponseException(result);
+            assertThat(e.from).isEqualTo(ID1);
+            assertThat(e.failure).isEqualTo(RequestFailureReason.TIMEOUT);
+            Mockito.verify(backoff, 
Mockito.times(1)).mayRetry(Mockito.anyInt());
+            Mockito.verify(backoff, 
Mockito.never()).computeWaitTime(Mockito.anyInt());
+            Mockito.verify(backoff, Mockito.never()).unit();
+        });
+    }
+
+    private static MessageDelivery simulatedMessages(RandomSource rs, 
ScheduledExecutorPlus scheduler, List<Throwable> failures, 
SimulatedMessageDelivery.ActionSupplier actionSupplier)
+    {
+        Map<InetAddressAndPort, SimulatedMessageReceiver> receivers = new 
HashMap<>();
+        SimulatedMessageDelivery messaging = new SimulatedMessageDelivery(ID1,
+                                                                          
actionSupplier,
+                                                                          
SimulatedMessageDelivery.randomDelay(rs),
+                                                                          (to, 
message) -> scheduler.execute(() -> receivers.get(to).recieve(message)),
+                                                                          (i1, 
i2, i3) -> {},
+                                                                          
scheduler::schedule,
+                                                                          
failures::add);
+        receivers.put(ID1, messaging.receiver(m -> 
messaging.respond(NoPayload.noPayload, m)));
+        return messaging;
+    }
+
+    private static FailedResponseException 
getFailedResponseException(Future<Message<Void>> result) throws 
InterruptedException
+    {
+        FailedResponseException ex;
+        try
+        {
+            result.get();
+            Assert.fail("Should have failed");
+            throw new AssertionError("Not Reachable");
+        }
+        catch (ExecutionException e)
+        {
+            ex = (FailedResponseException) e.getCause();
+        }
+        return ex;
+    }
+
+    private static MaxRetriesException 
getMaxRetriesException(Future<Message<Void>> result) throws InterruptedException
+    {
+        MaxRetriesException ex;
+        try
+        {
+            result.get();
+            Assert.fail("Should have failed");
+            throw new AssertionError("Not Reachable");
+        }
+        catch (ExecutionException e)
+        {
+            ex = (MaxRetriesException) e.getCause();
+        }
+        return ex;
+    }
+}
\ No newline at end of file
diff --git a/test/unit/org/apache/cassandra/net/SimulatedMessageDelivery.java 
b/test/unit/org/apache/cassandra/net/SimulatedMessageDelivery.java
new file mode 100644
index 0000000000..f36d04d8a5
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/SimulatedMessageDelivery.java
@@ -0,0 +1,408 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.TimeUnit;
+import java.util.function.BiConsumer;
+import java.util.function.Consumer;
+import java.util.function.LongSupplier;
+import javax.annotation.Nullable;
+
+import accord.utils.Gens;
+import accord.utils.RandomSource;
+import org.apache.cassandra.exceptions.RequestFailureReason;
+import org.apache.cassandra.locator.InetAddressAndPort;
+import org.apache.cassandra.utils.concurrent.AsyncPromise;
+import org.apache.cassandra.utils.concurrent.Future;
+
+public class SimulatedMessageDelivery implements MessageDelivery
+{
+    public enum Action { DELIVER, DELIVER_WITH_FAILURE, DROP, 
DROP_PARTITIONED, FAILURE }
+
+    public interface ActionSupplier
+    {
+        Action get(InetAddressAndPort self, Message<?> message, 
InetAddressAndPort to);
+    }
+
+    public interface NetworkDelaySupplier
+    {
+        @Nullable
+        Duration jitter(Message<?> message, InetAddressAndPort to);
+    }
+
+    public static NetworkDelaySupplier noDelay()
+    {
+        return (i1, i2) -> null;
+    }
+
+    public static NetworkDelaySupplier randomDelay(RandomSource rs)
+    {
+        class Connection
+        {
+            final InetAddressAndPort from, to;
+
+            private Connection(InetAddressAndPort from, InetAddressAndPort to)
+            {
+                this.from = from;
+                this.to = to;
+            }
+
+            @Override
+            public boolean equals(Object o)
+            {
+                if (this == o) return true;
+                if (o == null || getClass() != o.getClass()) return false;
+                Connection that = (Connection) o;
+                return from.equals(that.from) && to.equals(that.to);
+            }
+
+            @Override
+            public int hashCode()
+            {
+                return Objects.hash(from, to);
+            }
+
+            @Override
+            public String toString()
+            {
+                return "Connection{" + "from=" + from + ", to=" + to + '}';
+            }
+        }
+        final Map<Connection, LongSupplier> networkLatencies = new HashMap<>();
+        return (msg, to) -> {
+            InetAddressAndPort from = msg.from();
+            long delayNanos = networkLatencies.computeIfAbsent(new 
Connection(from, to), ignore -> {
+                long min = TimeUnit.MICROSECONDS.toNanos(500);
+                long maxSmall = TimeUnit.MILLISECONDS.toNanos(5);
+                long max = TimeUnit.SECONDS.toNanos(5);
+                LongSupplier small = () -> rs.nextLong(min, maxSmall);
+                LongSupplier large = () -> rs.nextLong(maxSmall, max);
+                return Gens.bools().runs(rs.nextInt(1, 11) / 100.0D, 
rs.nextInt(3, 15))
+                           .mapToLong(b -> b ? large.getAsLong() : 
small.getAsLong())
+                           .asLongSupplier(rs.fork());
+            }).getAsLong();
+            return Duration.ofNanos(delayNanos);
+        };
+    }
+
+    public interface Scheduler
+    {
+        void schedule(Runnable command, long delay, TimeUnit unit);
+    }
+
+    public interface DropListener
+    {
+        void onDrop(Action action, InetAddressAndPort from, Message<?> msg);
+    }
+
+    private final InetAddressAndPort self;
+    private final ActionSupplier actions;
+    private final NetworkDelaySupplier networkDelay;
+    private final BiConsumer<InetAddressAndPort, Message<?>> reciever;
+    private final DropListener onDropped;
+    private final Scheduler scheduler;
+    private final Consumer<Throwable> onError;
+    private final Map<CallbackKey, CallbackContext> callbacks = new 
HashMap<>();
+    private enum Status { Up, Down }
+    private Status status = Status.Up;
+
+    public SimulatedMessageDelivery(InetAddressAndPort self,
+                                    ActionSupplier actions,
+                                    NetworkDelaySupplier networkDelay,
+                                    BiConsumer<InetAddressAndPort, Message<?>> 
reciever,
+                                    DropListener onDropped,
+                                    Scheduler scheduler,
+                                    Consumer<Throwable> onError)
+    {
+        this.self = self;
+        this.actions = actions;
+        this.networkDelay = networkDelay;
+        this.reciever = reciever;
+        this.onDropped = onDropped;
+        this.scheduler = scheduler;
+        this.onError = onError;
+    }
+
+    public void stop()
+    {
+        callbacks.clear();
+        status = Status.Down;
+    }
+
+    @Override
+    public <REQ> void send(Message<REQ> message, InetAddressAndPort to)
+    {
+        message = message.withFrom(self);
+        maybeEnqueue(message, to, null);
+    }
+
+    @Override
+    public <REQ, RSP> void sendWithCallback(Message<REQ> message, 
InetAddressAndPort to, RequestCallback<RSP> cb)
+    {
+        message = message.withFrom(self);
+        maybeEnqueue(message, to, cb);
+    }
+
+    @Override
+    public <REQ, RSP> void sendWithCallback(Message<REQ> message, 
InetAddressAndPort to, RequestCallback<RSP> cb, ConnectionType 
specifyConnection)
+    {
+        message = message.withFrom(self);
+        maybeEnqueue(message, to, cb);
+    }
+
+    @Override
+    public <REQ, RSP> Future<Message<RSP>> sendWithResult(Message<REQ> 
message, InetAddressAndPort to)
+    {
+        AsyncPromise<Message<RSP>> promise = new AsyncPromise<>();
+        sendWithCallback(message, to, new RequestCallback<RSP>()
+        {
+            @Override
+            public void onResponse(Message<RSP> msg)
+            {
+                promise.trySuccess(msg);
+            }
+
+            @Override
+            public void onFailure(InetAddressAndPort from, 
RequestFailureReason failure)
+            {
+                promise.tryFailure(new 
MessagingService.FailureResponseException(from, failure));
+            }
+
+            @Override
+            public boolean invokeOnFailure()
+            {
+                return true;
+            }
+        });
+        return promise;
+    }
+
+    @Override
+    public <V> void respond(V response, Message<?> message)
+    {
+        send(message.responseWith(response), message.respondTo());
+    }
+
+    private <REQ, RSP> void maybeEnqueue(Message<REQ> message, 
InetAddressAndPort to, @Nullable RequestCallback<RSP> callback)
+    {
+        if (status != Status.Up)
+            return;
+        CallbackContext cb;
+        if (callback != null)
+        {
+            CallbackKey key = new CallbackKey(message.id(), to);
+            if (callbacks.containsKey(key))
+                throw new AssertionError("Message id " + message.id() + " to " 
+ to + " already has a callback");
+            cb = new CallbackContext(callback);
+            callbacks.put(key, cb);
+        }
+        else
+        {
+            cb = null;
+        }
+        Action action = actions.get(self, message, to);
+        switch (action)
+        {
+            case DELIVER:
+                deliver(message, to);
+                break;
+            case DROP:
+            case DROP_PARTITIONED:
+                onDropped.onDrop(action, to, message);
+                break;
+            case DELIVER_WITH_FAILURE:
+                deliver(message, to);
+            case FAILURE:
+                if (action == Action.FAILURE)
+                    onDropped.onDrop(action, to, message);
+                if (callback != null)
+                    scheduler.schedule(() -> callback.onFailure(to, 
RequestFailureReason.UNKNOWN),
+                                       message.verb().expiresAfterNanos(), 
TimeUnit.NANOSECONDS);
+                return;
+            default:
+                throw new UnsupportedOperationException("Unknown action type: 
" + action);
+        }
+        if (cb != null)
+        {
+            scheduler.schedule(() -> {
+                CallbackContext ctx = callbacks.remove(new 
CallbackKey(message.id(), to));
+                if (ctx != null)
+                {
+                    assert ctx == cb;
+                    try
+                    {
+                        ctx.onFailure(to, RequestFailureReason.TIMEOUT);
+                    }
+                    catch (Throwable t)
+                    {
+                        onError.accept(t);
+                    }
+                }
+            }, message.verb().expiresAfterNanos(), TimeUnit.NANOSECONDS);
+        }
+    }
+
+    private void deliver(Message<?> message, InetAddressAndPort to)
+    {
+        Duration delay = networkDelay.jitter(message, to);
+        if (delay == null) reciever.accept(to, message);
+        else               scheduler.schedule(() -> reciever.accept(to, 
message), delay.toNanos(), TimeUnit.NANOSECONDS);
+    }
+
+    @SuppressWarnings("rawtypes")
+    public SimulatedMessageReceiver receiver(IVerbHandler onMessage)
+    {
+        return new SimulatedMessageReceiver(onMessage);
+    }
+
+    public class SimulatedMessageReceiver
+    {
+        @SuppressWarnings("rawtypes")
+        final IVerbHandler onMessage;
+
+        @SuppressWarnings("rawtypes")
+        public SimulatedMessageReceiver(IVerbHandler onMessage)
+        {
+            this.onMessage = onMessage;
+        }
+
+        public void recieve(Message<?> msg)
+        {
+            if (status != Status.Up)
+                return;
+            if (msg.verb().isResponse())
+            {
+                CallbackKey key = new CallbackKey(msg.id(), msg.from());
+                if (callbacks.containsKey(key))
+                {
+                    CallbackContext callback = callbacks.remove(key);
+                    if (callback == null)
+                        return;
+                    try
+                    {
+                        if (msg.isFailureResponse())
+                            callback.onFailure(msg.from(), 
(RequestFailureReason) msg.payload);
+                        else callback.onResponse(msg);
+                    }
+                    catch (Throwable t)
+                    {
+                        onError.accept(t);
+                    }
+                }
+            }
+            else
+            {
+                try
+                {
+                    //noinspection unchecked
+                    onMessage.doVerb(msg);
+                }
+                catch (Throwable t)
+                {
+                    onError.accept(t);
+                }
+            }
+        }
+    }
+
+    @SuppressWarnings("rawtypes")
+    public static class SimpleVerbHandler implements IVerbHandler
+    {
+        private final Map<Verb, IVerbHandler<?>> handlers;
+
+        public SimpleVerbHandler(Map<Verb, IVerbHandler<?>> handlers)
+        {
+            this.handlers = handlers;
+        }
+
+        @Override
+        public void doVerb(Message msg) throws IOException
+        {
+            IVerbHandler<?> handler = handlers.get(msg.verb());
+            if (handler == null)
+                throw new AssertionError("Unexpected verb: " + msg.verb());
+            //noinspection unchecked
+            handler.doVerb(msg);
+        }
+    }
+
+    private static class CallbackContext
+    {
+        @SuppressWarnings("rawtypes")
+        final RequestCallback callback;
+
+        @SuppressWarnings("rawtypes")
+        private CallbackContext(RequestCallback callback)
+        {
+            this.callback = Objects.requireNonNull(callback);
+        }
+
+        @SuppressWarnings({ "rawtypes", "unchecked" })
+        public void onResponse(Message msg)
+        {
+            callback.onResponse(msg);
+        }
+
+        public void onFailure(InetAddressAndPort from, RequestFailureReason 
failure)
+        {
+            if (callback.invokeOnFailure()) callback.onFailure(from, failure);
+        }
+    }
+
+    private static class CallbackKey
+    {
+        private final long id;
+        private final InetAddressAndPort peer;
+
+        private CallbackKey(long id, InetAddressAndPort peer)
+        {
+            this.id = id;
+            this.peer = peer;
+        }
+
+        @Override
+        public boolean equals(Object o)
+        {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            CallbackKey that = (CallbackKey) o;
+            return id == that.id && peer.equals(that.peer);
+        }
+
+        @Override
+        public int hashCode()
+        {
+            return Objects.hash(id, peer);
+        }
+
+        @Override
+        public String toString()
+        {
+            return "CallbackKey{" +
+                   "id=" + id +
+                   ", peer=" + peer +
+                   '}';
+        }
+    }
+}
diff --git 
a/test/unit/org/apache/cassandra/repair/messages/RepairMessageTest.java 
b/test/unit/org/apache/cassandra/repair/messages/RepairMessageTest.java
index b01a9fcbbd..fb3ce470f5 100644
--- a/test/unit/org/apache/cassandra/repair/messages/RepairMessageTest.java
+++ b/test/unit/org/apache/cassandra/repair/messages/RepairMessageTest.java
@@ -155,7 +155,8 @@ public class RepairMessageTest
     {
         SharedContext ctx = Mockito.mock(SharedContext.class, REJECT_ALL);
         MessageDelivery messaging = Mockito.mock(MessageDelivery.class, 
REJECT_ALL);
-        // allow the single method under test
+        // allow all retry methods and send with callback
+        
Mockito.doCallRealMethod().when(messaging).sendWithRetries(Mockito.any(), 
Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), 
Mockito.any(), Mockito.any());
         Mockito.doNothing().when(messaging).sendWithCallback(Mockito.any(), 
Mockito.any(), Mockito.any());
         IGossiper gossiper = Mockito.mock(IGossiper.class, REJECT_ALL);
         
Mockito.doReturn(RepairMessage.SUPPORTS_RETRY).when(gossiper).getReleaseVersion(Mockito.any());
@@ -205,7 +206,7 @@ public class RepairMessageTest
         {
             before();
 
-            sendMessageWithRetries(ctx, backoff(maxAttempts), always(), 
PAYLOAD, VERB, ADDRESS, RepairMessage.NOOP_CALLBACK, 0);
+            sendMessageWithRetries(ctx, backoff(maxAttempts), always(), 
PAYLOAD, VERB, ADDRESS, RepairMessage.NOOP_CALLBACK);
             for (int i = 0; i < maxAttempts; i++)
                 callback(messaging).onFailure(ADDRESS, 
RequestFailureReason.TIMEOUT);
             fn.test(maxAttempts, callback(messaging));


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to