This is an automated email from the ASF dual-hosted git repository.
dcapwell pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra.git
The following commit(s) were added to refs/heads/trunk by this push:
new ffc0f01b0e Add a concept for retrying messages
ffc0f01b0e is described below
commit ffc0f01b0eede35518c2838d3c21f440d871c08a
Author: David Capwell <[email protected]>
AuthorDate: Thu Aug 29 13:45:48 2024 -0700
Add a concept for retrying messages
patch by David Capwell; reviewed by Alex Petrov for CASSANDRA-19856
---
.../org/apache/cassandra/db/SystemKeyspace.java | 13 +-
.../org/apache/cassandra/net/MessageDelivery.java | 160 ++++++++
.../cassandra/repair/messages/RepairMessage.java | 138 ++++---
.../org/apache/cassandra/tcm/RemoteProcessor.java | 86 +++--
src/java/org/apache/cassandra/tcm/Retry.java | 7 +-
src/java/org/apache/cassandra/utils/Backoff.java | 53 ++-
.../org/apache/cassandra/utils/TriFunction.java | 25 ++
.../concurrent/SimulatedExecutorFactory.java | 162 +++++++-
.../apache/cassandra/net/MessageDeliveryTest.java | 225 ++++++++++++
.../cassandra/net/SimulatedMessageDelivery.java | 408 +++++++++++++++++++++
.../repair/messages/RepairMessageTest.java | 5 +-
11 files changed, 1139 insertions(+), 143 deletions(-)
diff --git a/src/java/org/apache/cassandra/db/SystemKeyspace.java
b/src/java/org/apache/cassandra/db/SystemKeyspace.java
index 8709453280..05a2437546 100644
--- a/src/java/org/apache/cassandra/db/SystemKeyspace.java
+++ b/src/java/org/apache/cassandra/db/SystemKeyspace.java
@@ -120,6 +120,7 @@ import org.apache.cassandra.utils.FBUtilities;
import org.apache.cassandra.utils.MD5Digest;
import org.apache.cassandra.utils.Pair;
import org.apache.cassandra.utils.TimeUUID;
+import org.apache.cassandra.utils.TriFunction;
import org.apache.cassandra.utils.concurrent.Future;
import static java.lang.String.format;
@@ -1938,8 +1939,8 @@ public final class SystemKeyspace
int counter = 0;
for (UntypedResultSet.Row row : resultSet)
{
- if
(onLoaded.accept(MD5Digest.wrap(row.getByteArray("prepared_id")),
- row.getString("query_string"),
+ if (onLoaded.apply(MD5Digest.wrap(row.getByteArray("prepared_id")),
+ row.getString("query_string"),
row.has("logged_keyspace") ?
row.getString("logged_keyspace") : null))
counter++;
}
@@ -1953,18 +1954,14 @@ public final class SystemKeyspace
int counter = 0;
for (UntypedResultSet.Row row : resultSet)
{
- if
(onLoaded.accept(MD5Digest.wrap(row.getByteArray("prepared_id")),
- row.getString("query_string"),
+ if (onLoaded.apply(MD5Digest.wrap(row.getByteArray("prepared_id")),
+ row.getString("query_string"),
row.has("logged_keyspace") ?
row.getString("logged_keyspace") : null))
counter++;
}
return counter;
}
- public static interface TriFunction<A, B, C, D> {
- D accept(A var1, B var2, C var3);
- }
-
public static void saveTopPartitions(TableMetadata metadata, String
topType, Collection<TopPartitionTracker.TopPartition> topPartitions, long
lastUpdate)
{
String cql = String.format("INSERT INTO %s.%s (keyspace_name,
table_name, top_type, top, last_update) values (?, ?, ?, ?, ?)",
SchemaConstants.SYSTEM_KEYSPACE_NAME, TOP_PARTITIONS);
diff --git a/src/java/org/apache/cassandra/net/MessageDelivery.java
b/src/java/org/apache/cassandra/net/MessageDelivery.java
index 0b7890c08d..0d052cb3d8 100644
--- a/src/java/org/apache/cassandra/net/MessageDelivery.java
+++ b/src/java/org/apache/cassandra/net/MessageDelivery.java
@@ -19,19 +19,27 @@
package org.apache.cassandra.net;
import java.util.Collection;
+import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.exceptions.RequestFailureReason;
import org.apache.cassandra.locator.InetAddressAndPort;
+import org.apache.cassandra.utils.Backoff;
import org.apache.cassandra.utils.Pair;
import org.apache.cassandra.utils.concurrent.Accumulator;
+import org.apache.cassandra.utils.concurrent.AsyncPromise;
import org.apache.cassandra.utils.concurrent.CountDownLatch;
import org.apache.cassandra.utils.concurrent.Future;
+import org.apache.cassandra.utils.concurrent.Promise;
+
+import static org.apache.cassandra.net.MessageFlag.CALL_BACK_ON_FAILURE;
public interface MessageDelivery
{
@@ -74,9 +82,161 @@ public interface MessageDelivery
public <REQ, RSP> void sendWithCallback(Message<REQ> message,
InetAddressAndPort to, RequestCallback<RSP> cb);
public <REQ, RSP> void sendWithCallback(Message<REQ> message,
InetAddressAndPort to, RequestCallback<RSP> cb, ConnectionType
specifyConnection);
public <REQ, RSP> Future<Message<RSP>> sendWithResult(Message<REQ>
message, InetAddressAndPort to);
+
+ public default <REQ, RSP> Future<Message<RSP>> sendWithRetries(Backoff
backoff, RetryScheduler retryThreads,
+ Verb verb,
REQ request,
+
Iterator<InetAddressAndPort> candidates,
+
RetryPredicate shouldRetry,
+
RetryErrorMessage errorMessage)
+ {
+ Promise<Message<RSP>> promise = new AsyncPromise<>();
+ this.<REQ, RSP>sendWithRetries(backoff, retryThreads, verb, request,
candidates,
+ (attempt, success, failure) -> {
+ if (failure != null)
promise.tryFailure(failure);
+ else promise.trySuccess(success);
+ },
+ shouldRetry, errorMessage);
+ return promise;
+ }
+
+ public default <REQ, RSP> void sendWithRetries(Backoff backoff,
RetryScheduler retryThreads,
+ Verb verb, REQ request,
+
Iterator<InetAddressAndPort> candidates,
+ OnResult<RSP> onResult,
+ RetryPredicate shouldRetry,
+ RetryErrorMessage
errorMessage)
+ {
+ sendWithRetries(this, backoff, retryThreads, verb, request,
candidates, onResult, shouldRetry, errorMessage, 0);
+ }
public <V> void respond(V response, Message<?> message);
public default void respondWithFailure(RequestFailureReason reason,
Message<?> message)
{
send(Message.failureResponse(message.id(), message.expiresAtNanos(),
reason), message.respondTo());
}
+
+ interface OnResult<T>
+ {
+ void result(int attempt, @Nullable Message<T> success, @Nullable
Throwable failure);
+ }
+
+ interface RetryPredicate
+ {
+ boolean test(int attempt, InetAddressAndPort from,
RequestFailureReason failure);
+ }
+
+ interface RetryErrorMessage
+ {
+ String apply(int attempt, ResponseFailureReason retryFailure,
@Nullable InetAddressAndPort from, @Nullable RequestFailureReason reason);
+ }
+
+ private static <REQ, RSP> void sendWithRetries(MessageDelivery messaging,
+ Backoff backoff,
RetryScheduler retryThreads,
+ Verb verb, REQ request,
+
Iterator<InetAddressAndPort> candidates,
+ OnResult<RSP> onResult,
+ RetryPredicate shouldRetry,
+ RetryErrorMessage
errorMessage,
+ int attempt)
+ {
+ if (Thread.currentThread().isInterrupted())
+ {
+ onResult.result(attempt, null, new
InterruptedException(errorMessage.apply(attempt,
ResponseFailureReason.Interrupted, null, null)));
+ return;
+ }
+ if (!candidates.hasNext())
+ {
+ onResult.result(attempt, null, new
NoMoreCandidatesException(errorMessage.apply(attempt,
ResponseFailureReason.NoMoreCandidates, null, null)));
+ return;
+ }
+ class Request implements RequestCallbackWithFailure<RSP>
+ {
+ @Override
+ public void onResponse(Message<RSP> msg)
+ {
+ onResult.result(attempt, msg, null);
+ }
+
+ @Override
+ public void onFailure(InetAddressAndPort from,
RequestFailureReason failure)
+ {
+ if (!backoff.mayRetry(attempt))
+ {
+ onResult.result(attempt, null, new
MaxRetriesException(attempt, errorMessage.apply(attempt,
ResponseFailureReason.MaxRetries, from, failure)));
+ return;
+ }
+ if (!shouldRetry.test(attempt, from, failure))
+ {
+ onResult.result(attempt, null, new
FailedResponseException(from, failure, errorMessage.apply(attempt,
ResponseFailureReason.Rejected, from, failure)));
+ return;
+ }
+ try
+ {
+ retryThreads.schedule(() -> sendWithRetries(messaging,
backoff, retryThreads, verb, request, candidates, onResult, shouldRetry,
errorMessage, attempt + 1),
+ backoff.computeWaitTime(attempt),
backoff.unit());
+ }
+ catch (Throwable t)
+ {
+ onResult.result(attempt, null, new
FailedScheduleException(errorMessage.apply(attempt,
ResponseFailureReason.FailedSchedule, from, failure), t));
+ }
+ }
+ }
+ messaging.sendWithCallback(Message.outWithFlag(verb, request,
CALL_BACK_ON_FAILURE), candidates.next(), new Request());
+ }
+
+ enum ResponseFailureReason { MaxRetries, Rejected, NoMoreCandidates,
Interrupted, FailedSchedule }
+
+ interface RetryScheduler
+ {
+ void schedule(Runnable command, long delay, TimeUnit unit);
+ }
+
+ enum ImmediateRetryScheduler implements RetryScheduler
+ {
+ instance;
+
+ @Override
+ public void schedule(Runnable command, long delay, TimeUnit unit)
+ {
+ command.run();
+ }
+ }
+
+ class NoMoreCandidatesException extends IllegalStateException
+ {
+ public NoMoreCandidatesException(String s)
+ {
+ super(s);
+ }
+ }
+
+ class FailedResponseException extends IllegalStateException
+ {
+ public final InetAddressAndPort from;
+ public final RequestFailureReason failure;
+
+ public FailedResponseException(InetAddressAndPort from,
RequestFailureReason failure, String message)
+ {
+ super(message);
+ this.from = from;
+ this.failure = failure;
+ }
+ }
+
+ class MaxRetriesException extends IllegalStateException
+ {
+ public final int attempts;
+ public MaxRetriesException(int attempts, String message)
+ {
+ super(message);
+ this.attempts = attempts;
+ }
+ }
+
+ class FailedScheduleException extends IllegalStateException
+ {
+ public FailedScheduleException(String message, Throwable cause)
+ {
+ super(message, cause);
+ }
+ }
}
diff --git a/src/java/org/apache/cassandra/repair/messages/RepairMessage.java
b/src/java/org/apache/cassandra/repair/messages/RepairMessage.java
index f0cbf78f38..835f90fc68 100644
--- a/src/java/org/apache/cassandra/repair/messages/RepairMessage.java
+++ b/src/java/org/apache/cassandra/repair/messages/RepairMessage.java
@@ -23,11 +23,13 @@ import java.util.EnumSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
+import java.util.function.BiConsumer;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Iterators;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -50,8 +52,6 @@ import org.apache.cassandra.utils.NoSpamLogger;
import org.apache.cassandra.utils.TimeUUID;
import org.apache.cassandra.utils.concurrent.Future;
-import static org.apache.cassandra.net.MessageFlag.CALL_BACK_ON_FAILURE;
-
/**
* Base class of all repair related request/response messages.
*
@@ -138,9 +138,7 @@ public abstract class RepairMessage
{
RepairRetrySpec retrySpec = DatabaseDescriptor.getRepairRetrySpec();
RetrySpec spec = verb == Verb.VALIDATION_RSP ?
retrySpec.getMerkleTreeResponseSpec() : retrySpec;
- if (!spec.isEnabled())
- return Backoff.None.INSTANCE;
- return new Backoff.ExponentialBackoff(spec.maxAttempts.value,
spec.baseSleepTime.toMilliseconds(), spec.maxSleepTime.toMilliseconds(),
ctx.random().get()::nextDouble);
+ return Backoff.fromConfig(ctx, spec);
}
public static Supplier<Boolean> notDone(Future<?> f)
@@ -155,98 +153,94 @@ public abstract class RepairMessage
public static <T> void sendMessageWithRetries(SharedContext ctx,
Supplier<Boolean> allowRetry, RepairMessage request, Verb verb,
InetAddressAndPort endpoint, RequestCallback<T> finalCallback)
{
- sendMessageWithRetries(ctx, backoff(ctx, verb), allowRetry, request,
verb, endpoint, finalCallback, 0);
+ sendMessageWithRetries(ctx, backoff(ctx, verb), allowRetry, request,
verb, endpoint, finalCallback);
}
public static <T> void sendMessageWithRetries(SharedContext ctx,
RepairMessage request, Verb verb, InetAddressAndPort endpoint,
RequestCallback<T> finalCallback)
{
- sendMessageWithRetries(ctx, backoff(ctx, verb), always(), request,
verb, endpoint, finalCallback, 0);
+ sendMessageWithRetries(ctx, backoff(ctx, verb), always(), request,
verb, endpoint, finalCallback);
}
public static void sendMessageWithRetries(SharedContext ctx, RepairMessage
request, Verb verb, InetAddressAndPort endpoint)
{
- sendMessageWithRetries(ctx, backoff(ctx, verb), always(), request,
verb, endpoint, NOOP_CALLBACK, 0);
+ sendMessageWithRetries(ctx, backoff(ctx, verb), always(), request,
verb, endpoint, NOOP_CALLBACK);
}
public static void sendMessageWithRetries(SharedContext ctx,
Supplier<Boolean> allowRetry, RepairMessage request, Verb verb,
InetAddressAndPort endpoint)
{
- sendMessageWithRetries(ctx, backoff(ctx, verb), allowRetry, request,
verb, endpoint, NOOP_CALLBACK, 0);
+ sendMessageWithRetries(ctx, backoff(ctx, verb), allowRetry, request,
verb, endpoint, NOOP_CALLBACK);
}
@VisibleForTesting
- static <T> void sendMessageWithRetries(SharedContext ctx, Backoff backoff,
Supplier<Boolean> allowRetry, RepairMessage request, Verb verb,
InetAddressAndPort endpoint, RequestCallback<T> finalCallback, int attempt)
+ static <T> void sendMessageWithRetries(SharedContext ctx, Backoff backoff,
Supplier<Boolean> allowRetry, RepairMessage request, Verb verb,
InetAddressAndPort endpoint, RequestCallback<T> finalCallback)
{
if (!ALLOWS_RETRY.contains(verb))
throw new AssertionError("Repair verb " + verb + " does not
support retry, but a request to send with retry was given!");
- RequestCallback<T> callback = new RequestCallback<>()
- {
- @Override
- public void onResponse(Message<T> msg)
+ BiConsumer<Integer, RequestFailureReason > maybeRecordRetry =
(attempt, reason) -> {
+ if (attempt <= 0)
+ return;
+ // we don't know what the prefix kind is... so use NONE... this
impacts logPrefix as it will cause us to use "repair" rather than "preview
repair" which may not be correct... but close enough...
+ String prefix =
PreviewKind.NONE.logPrefix(request.parentRepairSession());
+ RepairMetrics.retry(verb, attempt);
+ if (reason == null)
{
- maybeRecordRetry(null);
- finalCallback.onResponse(msg);
+ noSpam.info("{} Retry of repair verb " + verb + " was
successful after {} attempts", prefix, attempt);
}
-
- @Override
- public void onFailure(InetAddressAndPort from,
RequestFailureReason failureReason)
+ else if (reason == RequestFailureReason.TIMEOUT)
{
- ErrorHandling allowed = errorHandlingSupported(ctx, endpoint,
verb, request.parentRepairSession());
- switch (allowed)
- {
- case NONE:
- logger.error("[#{}] {} failed on {}: {}",
request.parentRepairSession(), verb, from, failureReason);
- return;
- case TIMEOUT:
- finalCallback.onFailure(from, failureReason);
- return;
- case RETRY:
- int maxAttempts = backoff.maxAttempts();
- if (failureReason == RequestFailureReason.TIMEOUT &&
attempt < maxAttempts && allowRetry.get())
- {
- ctx.optionalTasks().schedule(() ->
sendMessageWithRetries(ctx, backoff, allowRetry, request, verb, endpoint,
finalCallback, attempt + 1),
-
backoff.computeWaitTime(attempt), backoff.unit());
- return;
- }
- maybeRecordRetry(failureReason);
- finalCallback.onFailure(from, failureReason);
- return;
- default:
- throw new AssertionError("Unknown error handler: " +
allowed);
- }
+ noSpam.warn("{} Timeout for repair verb " + verb + "; could
not complete within {} attempts", prefix, attempt);
+ RepairMetrics.retryTimeout(verb);
}
-
- private void maybeRecordRetry(@Nullable RequestFailureReason
reason)
+ else
{
- if (attempt <= 0)
- return;
- // we don't know what the prefix kind is... so use NONE...
this impacts logPrefix as it will cause us to use "repair" rather than "preview
repair" which may not be correct... but close enough...
- String prefix =
PreviewKind.NONE.logPrefix(request.parentRepairSession());
- RepairMetrics.retry(verb, attempt);
- if (reason == null)
- {
- noSpam.info("{} Retry of repair verb " + verb + " was
successful after {} attempts", prefix, attempt);
- }
- else if (reason == RequestFailureReason.TIMEOUT)
- {
- noSpam.warn("{} Timeout for repair verb " + verb + ";
could not complete within {} attempts", prefix, attempt);
- RepairMetrics.retryTimeout(verb);
- }
- else
- {
- noSpam.warn("{} {} failure for repair verb " + verb + ";
could not complete within {} attempts", prefix, reason, attempt);
- RepairMetrics.retryFailure(verb);
- }
- }
-
- @Override
- public boolean invokeOnFailure()
- {
- return true;
+ noSpam.warn("{} {} failure for repair verb " + verb + "; could
not complete within {} attempts", prefix, reason, attempt);
+ RepairMetrics.retryFailure(verb);
}
};
- ctx.messaging().sendWithCallback(Message.outWithFlag(verb, request,
CALL_BACK_ON_FAILURE),
- endpoint,
- callback);
+ ctx.messaging().sendWithRetries(backoff, ctx.optionalTasks()::schedule,
+ verb, request,
Iterators.cycle(endpoint),
+ (int attempt, Message<T> msg,
Throwable failure) -> {
+ if (failure == null)
+ {
+
maybeRecordRetry.accept(attempt, null);
+ finalCallback.onResponse(msg);
+ }
+ },
+ (attempt, from, failure) -> {
+ ErrorHandling allowed =
errorHandlingSupported(ctx, endpoint, verb, request.parentRepairSession());
+ switch (allowed)
+ {
+ case NONE:
+ logger.error("[#{}] {}
failed on {}: {}", request.parentRepairSession(), verb, from, failure);
+ return false;
+ case TIMEOUT:
+
finalCallback.onFailure(from, failure);
+ return false;
+ case RETRY:
+ if (failure ==
RequestFailureReason.TIMEOUT && allowRetry.get())
+ return true;
+
maybeRecordRetry.accept(attempt, failure);
+
finalCallback.onFailure(from, failure);
+ return false;
+ default:
+ throw new
AssertionError("Unknown error handler: " + allowed);
+ }
+ },
+ (attempt, retryReason, from, failure)
-> {
+ switch (retryReason)
+ {
+ case MaxRetries:
+
maybeRecordRetry.accept(attempt, failure);
+
finalCallback.onFailure(from, failure);
+ return null;
+ case Interrupted:
+ case Rejected:
+ case FailedSchedule:
+ return null;
+ default:
+ throw new
UnsupportedOperationException(retryReason.name());
+ }
+ });
}
public static void sendMessageWithFailureCB(SharedContext ctx,
Supplier<Boolean> allowRetry, RepairMessage request, Verb verb,
InetAddressAndPort endpoint, RepairFailureCallback failureCallback)
diff --git a/src/java/org/apache/cassandra/tcm/RemoteProcessor.java
b/src/java/org/apache/cassandra/tcm/RemoteProcessor.java
index 79f0b7cf7d..0ea055b908 100644
--- a/src/java/org/apache/cassandra/tcm/RemoteProcessor.java
+++ b/src/java/org/apache/cassandra/tcm/RemoteProcessor.java
@@ -39,6 +39,7 @@ import org.apache.cassandra.gms.FailureDetector;
import org.apache.cassandra.locator.InetAddressAndPort;
import org.apache.cassandra.metrics.TCMMetrics;
import org.apache.cassandra.net.Message;
+import org.apache.cassandra.net.MessageDelivery;
import org.apache.cassandra.net.MessagingService;
import org.apache.cassandra.net.RequestCallbackWithFailure;
import org.apache.cassandra.net.Verb;
@@ -47,6 +48,7 @@ import org.apache.cassandra.tcm.log.Entry;
import org.apache.cassandra.tcm.log.LocalLog;
import org.apache.cassandra.tcm.log.LogState;
import org.apache.cassandra.utils.AbstractIterator;
+import org.apache.cassandra.utils.Backoff;
import org.apache.cassandra.utils.FBUtilities;
import org.apache.cassandra.utils.concurrent.AsyncPromise;
import org.apache.cassandra.utils.concurrent.Future;
@@ -201,51 +203,45 @@ public final class RemoteProcessor implements Processor
public static <REQ, RSP> void sendWithCallbackAsync(Promise<RSP> promise,
Verb verb, REQ request, CandidateIterator candidates, Retry retryPolicy)
{
- class Request implements RequestCallbackWithFailure<RSP>
- {
- void retry()
- {
- if (promise.isCancelled() || promise.isDone())
- return;
- if (Thread.currentThread().isInterrupted())
- promise.setFailure(new InterruptedException());
- if (!candidates.hasNext())
- promise.tryFailure(new
IllegalStateException(String.format("Ran out of candidates while sending %s:
%s", verb, candidates)));
-
- MessagingService.instance().sendWithCallback(Message.out(verb,
request), candidates.next(), this);
- }
-
- @Override
- public void onResponse(Message<RSP> msg)
- {
- promise.trySuccess(msg.payload);
- }
-
- @Override
- public void onFailure(InetAddressAndPort from,
RequestFailureReason reason)
- {
- if (reason == RequestFailureReason.NOT_CMS)
- {
- logger.debug("{} is not a member of the CMS, querying it
to discover current membership", from);
- DiscoveredNodes cms = tryDiscover(from);
- candidates.addCandidates(cms);
- candidates.timeout(from);
- logger.debug("Got CMS from {}: {}, retrying on: {}", from,
cms, candidates);
- }
- else
- {
- candidates.timeout(from);
- logger.warn("Got error from {}: {} when sending {},
retrying on {}", from, reason, verb, candidates);
- }
-
- if (retryPolicy.reachedMax())
- promise.tryFailure(new
IllegalStateException(String.format("Could not succeed sending %s to %s after
%d tries", verb, candidates, retryPolicy.tries)));
- else
- retry();
- }
- }
-
- new Request().retry();
+ //TODO (now): the retry defines how long to wait for a retry, but the
old behavior scheduled the message right away... should this be delayed as well?
+ MessagingService.instance().<REQ,
RSP>sendWithRetries(Backoff.fromRetry(retryPolicy),
MessageDelivery.ImmediateRetryScheduler.instance,
+ verb, request,
candidates,
+ (attempt,
success, failure) -> {
+ if (failure
!= null) promise.tryFailure(failure);
+ else
promise.trySuccess(success.payload);
+ },
+ (attempt, from,
failure) -> {
+ if
(promise.isDone() || promise.isCancelled())
+ return
false;
+ if (failure
== RequestFailureReason.NOT_CMS)
+ {
+
logger.debug("{} is not a member of the CMS, querying it to discover current
membership", from);
+
DiscoveredNodes cms = tryDiscover(from);
+
candidates.addCandidates(cms);
+
candidates.timeout(from);
+
logger.debug("Got CMS from {}: {}, retrying on: {}", from, cms, candidates);
+ }
+ else
+ {
+
candidates.timeout(from);
+
logger.warn("Got error from {}: {} when sending {}, retrying on {}", from,
failure, verb, candidates);
+ }
+ return true;
+ },
+ (attempt,
reason, from, failure) -> {
+ switch
(reason)
+ {
+ case
NoMoreCandidates:
+
return String.format("Ran out of candidates while sending %s: %s", verb,
candidates);
+ case
MaxRetries:
+
return String.format("Could not succeed sending %s to %s after %d tries", verb,
candidates, retryPolicy.tries);
+ case
Interrupted:
+ case
FailedSchedule:
+
return null;
+ default:
+
throw new UnsupportedOperationException(reason.name());
+ }
+ });
}
private static DiscoveredNodes tryDiscover(InetAddressAndPort ep)
diff --git a/src/java/org/apache/cassandra/tcm/Retry.java
b/src/java/org/apache/cassandra/tcm/Retry.java
index a1215fd649..703e590466 100644
--- a/src/java/org/apache/cassandra/tcm/Retry.java
+++ b/src/java/org/apache/cassandra/tcm/Retry.java
@@ -58,10 +58,15 @@ public abstract class Retry
}
public void maybeSleep()
+ {
+ sleepUninterruptibly(computeSleepFor(), TimeUnit.MILLISECONDS);
+ }
+
+ public long computeSleepFor()
{
tries++;
retryMeter.mark();
- sleepUninterruptibly(sleepFor(), TimeUnit.MILLISECONDS);
+ return sleepFor();
}
protected abstract long sleepFor();
diff --git a/src/java/org/apache/cassandra/utils/Backoff.java
b/src/java/org/apache/cassandra/utils/Backoff.java
index 2f0b7e2c9c..7974dbf346 100644
--- a/src/java/org/apache/cassandra/utils/Backoff.java
+++ b/src/java/org/apache/cassandra/utils/Backoff.java
@@ -21,23 +21,55 @@ package org.apache.cassandra.utils;
import java.util.concurrent.TimeUnit;
import java.util.function.DoubleSupplier;
+import org.apache.cassandra.config.RetrySpec;
+import org.apache.cassandra.repair.SharedContext;
+import org.apache.cassandra.tcm.Retry;
+
public interface Backoff
{
- /**
- * @return max attempts allowed, {@code == 0} implies no retries are
allowed
- */
- int maxAttempts();
- long computeWaitTime(int retryCount);
+ boolean mayRetry(int attempt);
+ long computeWaitTime(int attempt);
TimeUnit unit();
+ static Backoff fromRetry(Retry retry)
+ {
+ return new Backoff()
+ {
+ @Override
+ public boolean mayRetry(int attempt)
+ {
+ return !retry.reachedMax();
+ }
+
+ @Override
+ public long computeWaitTime(int retryCount)
+ {
+ return retry.computeSleepFor();
+ }
+
+ @Override
+ public TimeUnit unit()
+ {
+ return TimeUnit.MILLISECONDS;
+ }
+ };
+ }
+
+ static Backoff fromConfig(SharedContext ctx, RetrySpec spec)
+ {
+ if (!spec.isEnabled())
+ return Backoff.None.INSTANCE;
+ return new Backoff.ExponentialBackoff(spec.maxAttempts.value,
spec.baseSleepTime.toMilliseconds(), spec.maxSleepTime.toMilliseconds(),
ctx.random().get()::nextDouble);
+ }
+
enum None implements Backoff
{
INSTANCE;
@Override
- public int maxAttempts()
+ public boolean mayRetry(int attempt)
{
- return 0;
+ return false;
}
@Override
@@ -68,12 +100,17 @@ public interface Backoff
this.randomSource = randomSource;
}
- @Override
public int maxAttempts()
{
return maxAttempts;
}
+ @Override
+ public boolean mayRetry(int attempt)
+ {
+ return attempt < maxAttempts;
+ }
+
@Override
public long computeWaitTime(int retryCount)
{
diff --git a/src/java/org/apache/cassandra/utils/TriFunction.java
b/src/java/org/apache/cassandra/utils/TriFunction.java
new file mode 100644
index 0000000000..c280850dad
--- /dev/null
+++ b/src/java/org/apache/cassandra/utils/TriFunction.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.utils;
+
+@FunctionalInterface
+public interface TriFunction<A, B, C, D>
+{
+ D apply(A var1, B var2, C var3);
+}
diff --git
a/test/unit/org/apache/cassandra/concurrent/SimulatedExecutorFactory.java
b/test/unit/org/apache/cassandra/concurrent/SimulatedExecutorFactory.java
index ffb1a0fabd..e206af390b 100644
--- a/test/unit/org/apache/cassandra/concurrent/SimulatedExecutorFactory.java
+++ b/test/unit/org/apache/cassandra/concurrent/SimulatedExecutorFactory.java
@@ -18,6 +18,7 @@
package org.apache.cassandra.concurrent;
+import java.sql.Timestamp;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
@@ -33,13 +34,25 @@ import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
import java.util.function.LongSupplier;
+import javax.annotation.Nullable;
+
import accord.utils.Gens;
import accord.utils.RandomSource;
import org.apache.cassandra.utils.Clock;
+import org.apache.cassandra.utils.Generators;
+import org.apache.cassandra.utils.concurrent.Future;
+import org.apache.cassandra.utils.concurrent.UncheckedInterruptedException;
import static java.util.concurrent.TimeUnit.NANOSECONDS;
+import static
org.apache.cassandra.concurrent.InfiniteLoopExecutor.InternalState.SHUTTING_DOWN_NOW;
+import static
org.apache.cassandra.concurrent.InfiniteLoopExecutor.InternalState.TERMINATED;
+import static org.apache.cassandra.concurrent.Interruptible.State.INTERRUPTED;
+import static org.apache.cassandra.concurrent.Interruptible.State.NORMAL;
+import static
org.apache.cassandra.concurrent.Interruptible.State.SHUTTING_DOWN;
+import static org.apache.cassandra.utils.Generators.toGen;
public class SimulatedExecutorFactory implements ExecutorFactory, Clock
{
@@ -79,22 +92,51 @@ public class SimulatedExecutorFactory implements
ExecutorFactory, Clock
private final RandomSource rs;
private final long startTimeNanos;
+ @Nullable
+ private final Consumer<Throwable> onError;
private final PriorityQueue<Item> queue = new PriorityQueue<>();
private long seq = 0;
private long nowNanos;
private int repeatedTasks = 0;
+ public SimulatedExecutorFactory(RandomSource rs, Consumer<Throwable>
onError)
+ {
+ this(rs,
toGen(Generators.TIMESTAMP_GEN.map(Timestamp::getTime)).mapToLong(TimeUnit.MILLISECONDS::toNanos).next(rs),
onError);
+ }
+
+ public SimulatedExecutorFactory(RandomSource rs)
+ {
+ this(rs, null);
+ }
+
public SimulatedExecutorFactory(RandomSource rs, long startTimeNanos)
+ {
+ this(rs, startTimeNanos, null);
+ }
+
+ public SimulatedExecutorFactory(RandomSource rs, long startTimeNanos,
Consumer<Throwable> onError)
{
this.rs = rs;
this.startTimeNanos = startTimeNanos;
+ this.onError = onError;
}
- public boolean processOne()
+ private void maybeAddFailureListener(Future<?> task)
+ {
+ if (onError == null) return;
+ task.addCallback((s, f) -> {
+ if (f != null)
+ onError.accept(f);
+ });
+ }
+
+ public boolean hasWork()
+ {
+ return queue.size() > repeatedTasks;
+ }
+
+ public boolean processAny()
{
- // if we count the repeated tasks, then processAll will never complete
- if (queue.size() == repeatedTasks)
- return false;
Item item = queue.poll();
if (item == null)
return false;
@@ -103,6 +145,21 @@ public class SimulatedExecutorFactory implements
ExecutorFactory, Clock
return true;
}
+ public boolean processOne()
+ {
+ // if we count the repeated tasks, then processAll will never complete
+ if (queue.size() == repeatedTasks)
+ return false;
+ return processAny();
+ }
+
+ public void processAll()
+ {
+ while (processOne())
+ {
+ }
+ }
+
@Override
public long nanoTime()
{
@@ -153,9 +210,92 @@ public class SimulatedExecutorFactory implements
ExecutorFactory, Clock
}
@Override
- public Interruptible infiniteLoop(String name, Interruptible.Task task,
InfiniteLoopExecutor.SimulatorSafe simulatorSafe, InfiniteLoopExecutor.Daemon
daemon, InfiniteLoopExecutor.Interrupts interrupts)
+ public Interruptible infiniteLoop(String name,
+ Interruptible.Task task,
+ InfiniteLoopExecutor.SimulatorSafe
simulatorSafe,
+ InfiniteLoopExecutor.Daemon daemon,
+ InfiniteLoopExecutor.Interrupts
interrupts)
{
- throw new UnsupportedOperationException("TODO");
+ var delegate = new UnorderedScheduledExecutorService();
+ class Capture { UnorderedScheduledExecutorService.ScheduledFuture<?>
f;}
+ Capture c = new Capture();
+ class I implements Interruptible
+ {
+ private Object state = NORMAL;
+ private boolean interrupted = false;
+ private void runOne()
+ {
+ Object cur = state;
+ if (cur == SHUTTING_DOWN_NOW || cur == SHUTTING_DOWN)
+ {
+ state = TERMINATED;
+ if (c.f != null)
+ c.f.cancel(false);
+ return;
+ }
+
+ if (cur == NORMAL && interrupted) cur = INTERRUPTED;
+ try
+ {
+ task.run((State) cur);
+ interrupted = false;
+ }
+ catch (TerminateException ignore)
+ {
+ state = TERMINATED;
+ if (c.f != null)
+ c.f.cancel(false);
+ }
+ catch (UncheckedInterruptedException | InterruptedException e)
+ {
+ interrupted = false;
+ state = TERMINATED;
+ if (c.f != null)
+ c.f.cancel(false);
+ }
+ catch (Throwable t)
+ {
+ if (onError != null)
+ onError.accept(t);
+ }
+ }
+
+ @Override
+ public void interrupt()
+ {
+ interrupted = true;
+ }
+
+ @Override
+ public boolean isTerminated()
+ {
+ return state == TERMINATED;
+ }
+
+ @Override
+ public void shutdown()
+ {
+ if (state != TERMINATED && state != SHUTTING_DOWN_NOW)
+ state = SHUTTING_DOWN;
+ }
+
+ @Override
+ public Object shutdownNow()
+ {
+ if (state != TERMINATED)
+ state = SHUTTING_DOWN_NOW;
+ return null;
+ }
+
+ @Override
+ public boolean awaitTermination(long timeout, TimeUnit units)
+ {
+ return isTerminated();
+ }
+ }
+ I i = new I();
+ c.f = delegate.scheduleAtFixedRate(i::runOne, 0, 0, NANOSECONDS);
+ return i;
}
@Override
@@ -329,7 +469,9 @@ public class SimulatedExecutorFactory implements
ExecutorFactory, Clock
public void execute(Runnable command)
{
checkNotShutdown();
- queue.add(new Item(nowWithJitter(),
SimulatedExecutorFactory.this.seq++, taskFor(command)));
+ var action = taskFor(command);
+ maybeAddFailureListener(action);
+ queue.add(new Item(nowWithJitter(),
SimulatedExecutorFactory.this.seq++, action));
}
protected void checkNotShutdown()
@@ -365,6 +507,7 @@ public class SimulatedExecutorFactory implements
ExecutorFactory, Clock
if (next == null)
return;
+ maybeAddFailureListener(next.action);
next.action.addCallback((s, f) -> afterExecution());
queue.add(next);
}
@@ -461,6 +604,8 @@ public class SimulatedExecutorFactory implements
ExecutorFactory, Clock
catch (Throwable t)
{
tryFailure(t);
+ if (onError != null)
+ onError.accept(t);
}
}
}
@@ -477,6 +622,7 @@ public class SimulatedExecutorFactory implements
ExecutorFactory, Clock
{
checkNotShutdown();
ScheduledFuture<V> task = new ScheduledFuture<>(seq++, delay, 0,
NANOSECONDS, callable);
+ maybeAddFailureListener(task);
queue.add(new Item(nowWithJitter() + unit.toNanos(delay),
task.sequenceNumber, task));
return task;
}
@@ -486,6 +632,7 @@ public class SimulatedExecutorFactory implements
ExecutorFactory, Clock
{
checkNotShutdown();
ScheduledFuture<?> task = new ScheduledFuture<>(seq++,
initialDelay, period, unit, Executors.callable(command));
+ maybeAddFailureListener(task);
repeatedTasks++;
task.addCallback((s, f) -> repeatedTasks--);
queue.add(new Item(nowWithJitter() + unit.toNanos(initialDelay),
task.sequenceNumber, task));
@@ -497,6 +644,7 @@ public class SimulatedExecutorFactory implements
ExecutorFactory, Clock
{
checkNotShutdown();
ScheduledFuture<?> task = new ScheduledFuture<>(seq++,
initialDelay, -delay, unit, Executors.callable(command));
+ maybeAddFailureListener(task);
repeatedTasks++;
task.addCallback((s, f) -> repeatedTasks--);
queue.add(new Item(nowWithJitter() + unit.toNanos(initialDelay),
task.sequenceNumber, task));
diff --git a/test/unit/org/apache/cassandra/net/MessageDeliveryTest.java
b/test/unit/org/apache/cassandra/net/MessageDeliveryTest.java
new file mode 100644
index 0000000000..59d7106506
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/MessageDeliveryTest.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import com.google.common.collect.Iterators;
+import org.junit.Assert;
+import org.junit.Test;
+
+import accord.utils.RandomSource;
+import org.apache.cassandra.concurrent.ScheduledExecutorPlus;
+import org.apache.cassandra.concurrent.SimulatedExecutorFactory;
+import org.apache.cassandra.config.DatabaseDescriptor;
+import org.apache.cassandra.dht.Murmur3Partitioner;
+import org.apache.cassandra.exceptions.RequestFailureReason;
+import org.apache.cassandra.locator.InetAddressAndPort;
+import org.apache.cassandra.net.MessageDelivery.FailedResponseException;
+import org.apache.cassandra.net.MessageDelivery.MaxRetriesException;
+import org.apache.cassandra.net.SimulatedMessageDelivery.Action;
+import
org.apache.cassandra.net.SimulatedMessageDelivery.SimulatedMessageReceiver;
+import org.apache.cassandra.tcm.ClusterMetadataService;
+import org.apache.cassandra.tcm.StubClusterMetadataService;
+import org.apache.cassandra.utils.Backoff;
+import org.mockito.Mockito;
+
+import static accord.utils.Property.qt;
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class MessageDeliveryTest
+{
+ private static final InetAddressAndPort ID1 =
InetAddressAndPort.getByNameUnchecked("127.0.0.1");
+ private static final MessageDelivery.RetryErrorMessage RETRY_ERROR_MESSAGE
= (i1, i2, i3, i4) -> null;
+ private static final MessageDelivery.RetryPredicate ALWAYS_RETRY = (i1,
i2, i3) -> true;
+ private static final MessageDelivery.RetryPredicate ALWAYS_REJECT = (i1,
i2, i3) -> false;
+
+ static
+ {
+ DatabaseDescriptor.clientInitialization();
+ DatabaseDescriptor.setPartitionerUnsafe(Murmur3Partitioner.instance);
+
ClusterMetadataService.setInstance(StubClusterMetadataService.forTesting());
+ }
+
+ @Test
+ public void sendWithRetryFailsAfterMaxAttempts()
+ {
+ qt().check(rs -> {
+ List<Throwable> failures = new ArrayList<>();
+ SimulatedExecutorFactory factory = new
SimulatedExecutorFactory(rs.fork(), failures::add);
+ ScheduledExecutorPlus scheduler = factory.scheduled("ignored");
+ MessageDelivery messaging = simulatedMessages(rs, scheduler,
failures, (i1, i2, i3) -> Action.DROP);
+
+ int expectedRetries = 3;
+ Backoff backoff = new Backoff.ExponentialBackoff(expectedRetries,
200, 1000, rs.fork()::nextDouble);
+
+ Future<Message<Void>> result = messaging.sendWithRetries(backoff,
+
scheduler::schedule,
+
Verb.ECHO_REQ, NoPayload.noPayload,
+
Iterators.cycle(ID1),
+
ALWAYS_RETRY,
+
RETRY_ERROR_MESSAGE);
+ assertThat(result).isNotDone();
+ factory.processAll();
+ assertThat(result).isDone();
+
+
assertThat(getMaxRetriesException(result).attempts).isEqualTo(expectedRetries);
+ });
+ }
+
+ @Test
+ public void sendWithRetryFirstAttempt()
+ {
+ qt().check(rs -> {
+ List<Throwable> failures = new ArrayList<>();
+ SimulatedExecutorFactory factory = new
SimulatedExecutorFactory(rs.fork(), failures::add);
+ ScheduledExecutorPlus scheduler = factory.scheduled("ignored");
+ MessageDelivery messaging = simulatedMessages(rs, scheduler,
failures, (i1, i2, i3) -> Action.DELIVER);
+
+ Backoff backoff = Mockito.mock(Backoff.class);
+
+ Future<Message<Void>> result = messaging.sendWithRetries(backoff,
+
scheduler::schedule,
+
Verb.ECHO_REQ, NoPayload.noPayload,
+
Iterators.cycle(ID1),
+
ALWAYS_RETRY,
+
RETRY_ERROR_MESSAGE);
+ assertThat(result).isNotDone();
+ factory.processAll();
+ assertThat(result).isDone();
+ assertThat(result.get().header.verb).isEqualTo(Verb.ECHO_RSP);
+ Mockito.verify(backoff,
Mockito.never()).mayRetry(Mockito.anyInt());
+ Mockito.verify(backoff,
Mockito.never()).computeWaitTime(Mockito.anyInt());
+ Mockito.verify(backoff, Mockito.never()).unit();
+ });
+ }
+
+ @Test
+ public void sendWithRetry()
+ {
+ qt().check(rs -> {
+ List<Throwable> failures = new ArrayList<>();
+ SimulatedExecutorFactory factory = new
SimulatedExecutorFactory(rs.fork(), failures::add);
+ ScheduledExecutorPlus scheduler = factory.scheduled("ignored");
+
+ int maxAttempts = 3;
+ int expectedAttempts = 1;
+ AtomicInteger attempts = new AtomicInteger(0);
+ MessageDelivery messaging = simulatedMessages(rs, scheduler,
failures, (i1, i2, i3) -> attempts.incrementAndGet() >= (expectedAttempts + 1)
? Action.DELIVER : Action.DROP);
+
+ Backoff backoff = Mockito.spy(new
Backoff.ExponentialBackoff(maxAttempts, 200, 1000, rs.fork()::nextDouble));
+
+ Future<Message<Void>> result = messaging.sendWithRetries(backoff,
+
scheduler::schedule,
+
Verb.ECHO_REQ, NoPayload.noPayload,
+
Iterators.cycle(ID1),
+
ALWAYS_RETRY,
+
RETRY_ERROR_MESSAGE);
+ assertThat(result).isNotDone();
+ factory.processAll();
+ assertThat(result).isDone();
+ assertThat(result.get().header.verb).isEqualTo(Verb.ECHO_RSP);
+ Mockito.verify(backoff,
Mockito.times(expectedAttempts)).mayRetry(Mockito.anyInt());
+ Mockito.verify(backoff,
Mockito.times(expectedAttempts)).computeWaitTime(Mockito.anyInt());
+ Mockito.verify(backoff, Mockito.times(expectedAttempts)).unit();
+ });
+ }
+
+ @Test
+ public void sendWithRetryDontAllowRetry()
+ {
+ qt().check(rs -> {
+ List<Throwable> failures = new ArrayList<>();
+ SimulatedExecutorFactory factory = new
SimulatedExecutorFactory(rs.fork(), failures::add);
+ ScheduledExecutorPlus scheduler = factory.scheduled("ignored");
+
+ MessageDelivery messaging = simulatedMessages(rs, scheduler,
failures, (i1, i2, i3) -> Action.DROP);
+
+ Backoff backoff = Mockito.spy(new Backoff.ExponentialBackoff(3,
200, 1000, rs.fork()::nextDouble));
+
+ Future<Message<Void>> result = messaging.sendWithRetries(backoff,
+
scheduler::schedule,
+
Verb.ECHO_REQ, NoPayload.noPayload,
+
Iterators.cycle(ID1),
+
ALWAYS_REJECT,
+
RETRY_ERROR_MESSAGE);
+ assertThat(result).isNotDone();
+ factory.processAll();
+ assertThat(result).isDone();
+ FailedResponseException e = getFailedResponseException(result);
+ assertThat(e.from).isEqualTo(ID1);
+ assertThat(e.failure).isEqualTo(RequestFailureReason.TIMEOUT);
+ Mockito.verify(backoff,
Mockito.times(1)).mayRetry(Mockito.anyInt());
+ Mockito.verify(backoff,
Mockito.never()).computeWaitTime(Mockito.anyInt());
+ Mockito.verify(backoff, Mockito.never()).unit();
+ });
+ }
+
+ private static MessageDelivery simulatedMessages(RandomSource rs,
ScheduledExecutorPlus scheduler, List<Throwable> failures,
SimulatedMessageDelivery.ActionSupplier actionSupplier)
+ {
+ Map<InetAddressAndPort, SimulatedMessageReceiver> receivers = new
HashMap<>();
+ SimulatedMessageDelivery messaging = new SimulatedMessageDelivery(ID1,
+
actionSupplier,
+
SimulatedMessageDelivery.randomDelay(rs),
+ (to,
message) -> scheduler.execute(() -> receivers.get(to).recieve(message)),
+ (i1,
i2, i3) -> {},
+
scheduler::schedule,
+
failures::add);
+ receivers.put(ID1, messaging.receiver(m ->
messaging.respond(NoPayload.noPayload, m)));
+ return messaging;
+ }
+
+ private static FailedResponseException
getFailedResponseException(Future<Message<Void>> result) throws
InterruptedException
+ {
+ FailedResponseException ex;
+ try
+ {
+ result.get();
+ Assert.fail("Should have failed");
+ throw new AssertionError("Not Reachable");
+ }
+ catch (ExecutionException e)
+ {
+ ex = (FailedResponseException) e.getCause();
+ }
+ return ex;
+ }
+
+ private static MaxRetriesException
getMaxRetriesException(Future<Message<Void>> result) throws InterruptedException
+ {
+ MaxRetriesException ex;
+ try
+ {
+ result.get();
+ Assert.fail("Should have failed");
+ throw new AssertionError("Not Reachable");
+ }
+ catch (ExecutionException e)
+ {
+ ex = (MaxRetriesException) e.getCause();
+ }
+ return ex;
+ }
+}
\ No newline at end of file
diff --git a/test/unit/org/apache/cassandra/net/SimulatedMessageDelivery.java
b/test/unit/org/apache/cassandra/net/SimulatedMessageDelivery.java
new file mode 100644
index 0000000000..f36d04d8a5
--- /dev/null
+++ b/test/unit/org/apache/cassandra/net/SimulatedMessageDelivery.java
@@ -0,0 +1,408 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.net;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.TimeUnit;
+import java.util.function.BiConsumer;
+import java.util.function.Consumer;
+import java.util.function.LongSupplier;
+import javax.annotation.Nullable;
+
+import accord.utils.Gens;
+import accord.utils.RandomSource;
+import org.apache.cassandra.exceptions.RequestFailureReason;
+import org.apache.cassandra.locator.InetAddressAndPort;
+import org.apache.cassandra.utils.concurrent.AsyncPromise;
+import org.apache.cassandra.utils.concurrent.Future;
+
+public class SimulatedMessageDelivery implements MessageDelivery
+{
+ public enum Action { DELIVER, DELIVER_WITH_FAILURE, DROP,
DROP_PARTITIONED, FAILURE }
+
+ public interface ActionSupplier
+ {
+ Action get(InetAddressAndPort self, Message<?> message,
InetAddressAndPort to);
+ }
+
+ public interface NetworkDelaySupplier
+ {
+ @Nullable
+ Duration jitter(Message<?> message, InetAddressAndPort to);
+ }
+
+ public static NetworkDelaySupplier noDelay()
+ {
+ return (i1, i2) -> null;
+ }
+
+ public static NetworkDelaySupplier randomDelay(RandomSource rs)
+ {
+ class Connection
+ {
+ final InetAddressAndPort from, to;
+
+ private Connection(InetAddressAndPort from, InetAddressAndPort to)
+ {
+ this.from = from;
+ this.to = to;
+ }
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Connection that = (Connection) o;
+ return from.equals(that.from) && to.equals(that.to);
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(from, to);
+ }
+
+ @Override
+ public String toString()
+ {
+ return "Connection{" + "from=" + from + ", to=" + to + '}';
+ }
+ }
+ final Map<Connection, LongSupplier> networkLatencies = new HashMap<>();
+ return (msg, to) -> {
+ InetAddressAndPort from = msg.from();
+ long delayNanos = networkLatencies.computeIfAbsent(new
Connection(from, to), ignore -> {
+ long min = TimeUnit.MICROSECONDS.toNanos(500);
+ long maxSmall = TimeUnit.MILLISECONDS.toNanos(5);
+ long max = TimeUnit.SECONDS.toNanos(5);
+ LongSupplier small = () -> rs.nextLong(min, maxSmall);
+ LongSupplier large = () -> rs.nextLong(maxSmall, max);
+ return Gens.bools().runs(rs.nextInt(1, 11) / 100.0D,
rs.nextInt(3, 15))
+ .mapToLong(b -> b ? large.getAsLong() :
small.getAsLong())
+ .asLongSupplier(rs.fork());
+ }).getAsLong();
+ return Duration.ofNanos(delayNanos);
+ };
+ }
+
+ public interface Scheduler
+ {
+ void schedule(Runnable command, long delay, TimeUnit unit);
+ }
+
+ public interface DropListener
+ {
+ void onDrop(Action action, InetAddressAndPort from, Message<?> msg);
+ }
+
+ private final InetAddressAndPort self;
+ private final ActionSupplier actions;
+ private final NetworkDelaySupplier networkDelay;
+ private final BiConsumer<InetAddressAndPort, Message<?>> reciever;
+ private final DropListener onDropped;
+ private final Scheduler scheduler;
+ private final Consumer<Throwable> onError;
+ private final Map<CallbackKey, CallbackContext> callbacks = new
HashMap<>();
+ private enum Status { Up, Down }
+ private Status status = Status.Up;
+
+ public SimulatedMessageDelivery(InetAddressAndPort self,
+ ActionSupplier actions,
+ NetworkDelaySupplier networkDelay,
+ BiConsumer<InetAddressAndPort, Message<?>>
reciever,
+ DropListener onDropped,
+ Scheduler scheduler,
+ Consumer<Throwable> onError)
+ {
+ this.self = self;
+ this.actions = actions;
+ this.networkDelay = networkDelay;
+ this.reciever = reciever;
+ this.onDropped = onDropped;
+ this.scheduler = scheduler;
+ this.onError = onError;
+ }
+
+ public void stop()
+ {
+ callbacks.clear();
+ status = Status.Down;
+ }
+
+ @Override
+ public <REQ> void send(Message<REQ> message, InetAddressAndPort to)
+ {
+ message = message.withFrom(self);
+ maybeEnqueue(message, to, null);
+ }
+
+ @Override
+ public <REQ, RSP> void sendWithCallback(Message<REQ> message,
InetAddressAndPort to, RequestCallback<RSP> cb)
+ {
+ message = message.withFrom(self);
+ maybeEnqueue(message, to, cb);
+ }
+
+ @Override
+ public <REQ, RSP> void sendWithCallback(Message<REQ> message,
InetAddressAndPort to, RequestCallback<RSP> cb, ConnectionType
specifyConnection)
+ {
+ message = message.withFrom(self);
+ maybeEnqueue(message, to, cb);
+ }
+
+ @Override
+ public <REQ, RSP> Future<Message<RSP>> sendWithResult(Message<REQ>
message, InetAddressAndPort to)
+ {
+ AsyncPromise<Message<RSP>> promise = new AsyncPromise<>();
+ sendWithCallback(message, to, new RequestCallback<RSP>()
+ {
+ @Override
+ public void onResponse(Message<RSP> msg)
+ {
+ promise.trySuccess(msg);
+ }
+
+ @Override
+ public void onFailure(InetAddressAndPort from,
RequestFailureReason failure)
+ {
+ promise.tryFailure(new
MessagingService.FailureResponseException(from, failure));
+ }
+
+ @Override
+ public boolean invokeOnFailure()
+ {
+ return true;
+ }
+ });
+ return promise;
+ }
+
+ @Override
+ public <V> void respond(V response, Message<?> message)
+ {
+ send(message.responseWith(response), message.respondTo());
+ }
+
+ private <REQ, RSP> void maybeEnqueue(Message<REQ> message,
InetAddressAndPort to, @Nullable RequestCallback<RSP> callback)
+ {
+ if (status != Status.Up)
+ return;
+ CallbackContext cb;
+ if (callback != null)
+ {
+ CallbackKey key = new CallbackKey(message.id(), to);
+ if (callbacks.containsKey(key))
+ throw new AssertionError("Message id " + message.id() + " to "
+ to + " already has a callback");
+ cb = new CallbackContext(callback);
+ callbacks.put(key, cb);
+ }
+ else
+ {
+ cb = null;
+ }
+ Action action = actions.get(self, message, to);
+ switch (action)
+ {
+ case DELIVER:
+ deliver(message, to);
+ break;
+ case DROP:
+ case DROP_PARTITIONED:
+ onDropped.onDrop(action, to, message);
+ break;
+ case DELIVER_WITH_FAILURE:
+ deliver(message, to);
+ case FAILURE:
+ if (action == Action.FAILURE)
+ onDropped.onDrop(action, to, message);
+ if (callback != null)
+ scheduler.schedule(() -> callback.onFailure(to,
RequestFailureReason.UNKNOWN),
+ message.verb().expiresAfterNanos(),
TimeUnit.NANOSECONDS);
+ return;
+ default:
+ throw new UnsupportedOperationException("Unknown action type:
" + action);
+ }
+ if (cb != null)
+ {
+ scheduler.schedule(() -> {
+ CallbackContext ctx = callbacks.remove(new
CallbackKey(message.id(), to));
+ if (ctx != null)
+ {
+ assert ctx == cb;
+ try
+ {
+ ctx.onFailure(to, RequestFailureReason.TIMEOUT);
+ }
+ catch (Throwable t)
+ {
+ onError.accept(t);
+ }
+ }
+ }, message.verb().expiresAfterNanos(), TimeUnit.NANOSECONDS);
+ }
+ }
+
+ private void deliver(Message<?> message, InetAddressAndPort to)
+ {
+ Duration delay = networkDelay.jitter(message, to);
+ if (delay == null) reciever.accept(to, message);
+ else scheduler.schedule(() -> reciever.accept(to,
message), delay.toNanos(), TimeUnit.NANOSECONDS);
+ }
+
+ @SuppressWarnings("rawtypes")
+ public SimulatedMessageReceiver receiver(IVerbHandler onMessage)
+ {
+ return new SimulatedMessageReceiver(onMessage);
+ }
+
+ public class SimulatedMessageReceiver
+ {
+ @SuppressWarnings("rawtypes")
+ final IVerbHandler onMessage;
+
+ @SuppressWarnings("rawtypes")
+ public SimulatedMessageReceiver(IVerbHandler onMessage)
+ {
+ this.onMessage = onMessage;
+ }
+
+ public void recieve(Message<?> msg)
+ {
+ if (status != Status.Up)
+ return;
+ if (msg.verb().isResponse())
+ {
+ CallbackKey key = new CallbackKey(msg.id(), msg.from());
+ if (callbacks.containsKey(key))
+ {
+ CallbackContext callback = callbacks.remove(key);
+ if (callback == null)
+ return;
+ try
+ {
+ if (msg.isFailureResponse())
+ callback.onFailure(msg.from(),
(RequestFailureReason) msg.payload);
+ else callback.onResponse(msg);
+ }
+ catch (Throwable t)
+ {
+ onError.accept(t);
+ }
+ }
+ }
+ else
+ {
+ try
+ {
+ //noinspection unchecked
+ onMessage.doVerb(msg);
+ }
+ catch (Throwable t)
+ {
+ onError.accept(t);
+ }
+ }
+ }
+ }
+
+ @SuppressWarnings("rawtypes")
+ public static class SimpleVerbHandler implements IVerbHandler
+ {
+ private final Map<Verb, IVerbHandler<?>> handlers;
+
+ public SimpleVerbHandler(Map<Verb, IVerbHandler<?>> handlers)
+ {
+ this.handlers = handlers;
+ }
+
+ @Override
+ public void doVerb(Message msg) throws IOException
+ {
+ IVerbHandler<?> handler = handlers.get(msg.verb());
+ if (handler == null)
+ throw new AssertionError("Unexpected verb: " + msg.verb());
+ //noinspection unchecked
+ handler.doVerb(msg);
+ }
+ }
+
+ private static class CallbackContext
+ {
+ @SuppressWarnings("rawtypes")
+ final RequestCallback callback;
+
+ @SuppressWarnings("rawtypes")
+ private CallbackContext(RequestCallback callback)
+ {
+ this.callback = Objects.requireNonNull(callback);
+ }
+
+ @SuppressWarnings({ "rawtypes", "unchecked" })
+ public void onResponse(Message msg)
+ {
+ callback.onResponse(msg);
+ }
+
+ public void onFailure(InetAddressAndPort from, RequestFailureReason
failure)
+ {
+ if (callback.invokeOnFailure()) callback.onFailure(from, failure);
+ }
+ }
+
+ private static class CallbackKey
+ {
+ private final long id;
+ private final InetAddressAndPort peer;
+
+ private CallbackKey(long id, InetAddressAndPort peer)
+ {
+ this.id = id;
+ this.peer = peer;
+ }
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ CallbackKey that = (CallbackKey) o;
+ return id == that.id && peer.equals(that.peer);
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(id, peer);
+ }
+
+ @Override
+ public String toString()
+ {
+ return "CallbackKey{" +
+ "id=" + id +
+ ", peer=" + peer +
+ '}';
+ }
+ }
+}
diff --git
a/test/unit/org/apache/cassandra/repair/messages/RepairMessageTest.java
b/test/unit/org/apache/cassandra/repair/messages/RepairMessageTest.java
index b01a9fcbbd..fb3ce470f5 100644
--- a/test/unit/org/apache/cassandra/repair/messages/RepairMessageTest.java
+++ b/test/unit/org/apache/cassandra/repair/messages/RepairMessageTest.java
@@ -155,7 +155,8 @@ public class RepairMessageTest
{
SharedContext ctx = Mockito.mock(SharedContext.class, REJECT_ALL);
MessageDelivery messaging = Mockito.mock(MessageDelivery.class,
REJECT_ALL);
- // allow the single method under test
+ // allow all retry methods and send with callback
+
Mockito.doCallRealMethod().when(messaging).sendWithRetries(Mockito.any(),
Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(),
Mockito.any(), Mockito.any());
Mockito.doNothing().when(messaging).sendWithCallback(Mockito.any(),
Mockito.any(), Mockito.any());
IGossiper gossiper = Mockito.mock(IGossiper.class, REJECT_ALL);
Mockito.doReturn(RepairMessage.SUPPORTS_RETRY).when(gossiper).getReleaseVersion(Mockito.any());
@@ -205,7 +206,7 @@ public class RepairMessageTest
{
before();
- sendMessageWithRetries(ctx, backoff(maxAttempts), always(),
PAYLOAD, VERB, ADDRESS, RepairMessage.NOOP_CALLBACK, 0);
+ sendMessageWithRetries(ctx, backoff(maxAttempts), always(),
PAYLOAD, VERB, ADDRESS, RepairMessage.NOOP_CALLBACK);
for (int i = 0; i < maxAttempts; i++)
callback(messaging).onFailure(ADDRESS,
RequestFailureReason.TIMEOUT);
fn.test(maxAttempts, callback(messaging));
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]