This is an automated email from the ASF dual-hosted git repository.
tzulitai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-statefun.git
The following commit(s) were added to refs/heads/master by this push:
new e1b0c29 [FLINK-21308][core] Support delayed message cancellation
e1b0c29 is described below
commit e1b0c2918939879112a297612fc429ad05bbe4a6
Author: Igal Shilman <[email protected]>
AuthorDate: Mon Jun 14 12:45:10 2021 +0200
[FLINK-21308][core] Support delayed message cancellation
This closes #241.
---
.../statefun/e2e/smoke/CommandInterpreterTest.java | 6 ++
.../core/functions/AsyncMessageDecorator.java | 6 ++
.../flink/core/functions/DelayMessageHandler.java | 60 +++++++++++
.../statefun/flink/core/functions/DelaySink.java | 56 ++++------
.../core/functions/DelayedMessagesBuffer.java | 14 ++-
.../functions/FlinkStateDelayedMessagesBuffer.java | 114 ++++++++++++++++++---
.../core/functions/FunctionGroupOperator.java | 6 ++
.../statefun/flink/core/functions/Reductions.java | 3 +
.../flink/core/functions/ReusableContext.java | 17 +++
.../flink/statefun/flink/core/message/Message.java | 3 +
.../flink/core/message/MessageFactory.java | 4 +
.../flink/core/message/ProtobufMessage.java | 10 ++
.../statefun/flink/core/message/SdkMessage.java | 28 ++++-
.../flink/core/reqreply/RequestReplyFunction.java | 28 ++++-
.../src/main/protobuf/stateful-functions.proto | 4 +
.../functions/LocalStatefulFunctionGroupTest.java | 6 ++
.../flink/core/functions/ReductionsTest.java | 23 +++--
.../core/reqreply/RequestReplyFunctionTest.java | 47 ++++++++-
.../org/apache/flink/statefun/sdk/Context.java | 27 +++++
.../apache/flink/statefun/sdk/java/Context.java | 20 ++++
.../sdk/java/handler/ConcurrentContext.java | 40 ++++++++
.../src/main/protobuf/sdk/request-reply.proto | 9 ++
statefun-sdk-python/statefun/context.py | 14 ++-
statefun-sdk-python/statefun/request_reply_v3.py | 62 +++++++++--
statefun-sdk-python/tests/request_reply_test.py | 20 ++++
.../statefun/testutils/function/TestContext.java | 21 +++-
26 files changed, 557 insertions(+), 91 deletions(-)
diff --git
a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java
b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java
index 226f418..db35de6 100644
---
a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java
+++
b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java
@@ -68,6 +68,12 @@ public class CommandInterpreterTest {
public void sendAfter(Duration duration, Address address, Object o) {}
@Override
+ public void sendAfter(Duration delay, Address to, Object message, String
cancellationToken) {}
+
+ @Override
+ public void cancelDelayedMessage(String cancellationToken) {}
+
+ @Override
public <M, T> void registerAsyncOperation(M m, CompletableFuture<T>
completableFuture) {}
}
}
diff --git
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/AsyncMessageDecorator.java
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/AsyncMessageDecorator.java
index c77adb7..eed001f 100644
---
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/AsyncMessageDecorator.java
+++
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/AsyncMessageDecorator.java
@@ -17,6 +17,7 @@
*/
package org.apache.flink.statefun.flink.core.functions;
+import java.util.Optional;
import java.util.OptionalLong;
import javax.annotation.Nullable;
import org.apache.flink.core.memory.DataOutputView;
@@ -93,6 +94,11 @@ final class AsyncMessageDecorator<T> implements Message {
}
@Override
+ public Optional<String> cancellationToken() {
+ return message.cancellationToken();
+ }
+
+ @Override
public void postApply() {
pendingAsyncOperations.remove(source(), futureId);
}
diff --git
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayMessageHandler.java
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayMessageHandler.java
new file mode 100644
index 0000000..1dfb66b
--- /dev/null
+++
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayMessageHandler.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.statefun.flink.core.functions;
+
+import java.util.Objects;
+import java.util.function.Consumer;
+import org.apache.flink.statefun.flink.core.di.Inject;
+import org.apache.flink.statefun.flink.core.di.Label;
+import org.apache.flink.statefun.flink.core.di.Lazy;
+import org.apache.flink.statefun.flink.core.message.Message;
+
+/**
+ * Handles any of the delayed message that needs to be fired at a specific
timestamp. This handler
+ * dispatches {@linkplain Message}s to either remotely (shuffle) or locally.
+ */
+final class DelayMessageHandler implements Consumer<Message> {
+ private final RemoteSink remoteSink;
+ private final Lazy<Reductions> reductions;
+ private final Partition thisPartition;
+
+ @Inject
+ public DelayMessageHandler(
+ RemoteSink remoteSink,
+ @Label("reductions") Lazy<Reductions> reductions,
+ Partition partition) {
+ this.remoteSink = Objects.requireNonNull(remoteSink);
+ this.reductions = Objects.requireNonNull(reductions);
+ this.thisPartition = Objects.requireNonNull(partition);
+ }
+
+ @Override
+ public void accept(Message message) {
+ if (thisPartition.contains(message.target())) {
+ reductions.get().enqueue(message);
+ } else {
+ remoteSink.accept(message);
+ }
+ }
+
+ public void onStart() {}
+
+ public void onComplete() {
+ reductions.get().processEnvelopes();
+ }
+}
diff --git
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelaySink.java
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelaySink.java
index 05cc212..ddf81c9 100644
---
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelaySink.java
+++
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelaySink.java
@@ -18,10 +18,10 @@
package org.apache.flink.statefun.flink.core.functions;
import java.util.Objects;
+import java.util.OptionalLong;
import org.apache.flink.runtime.state.VoidNamespace;
import org.apache.flink.statefun.flink.core.di.Inject;
import org.apache.flink.statefun.flink.core.di.Label;
-import org.apache.flink.statefun.flink.core.di.Lazy;
import org.apache.flink.statefun.flink.core.message.Message;
import org.apache.flink.streaming.api.operators.InternalTimer;
import org.apache.flink.streaming.api.operators.InternalTimerService;
@@ -32,25 +32,17 @@ final class DelaySink implements Triggerable<String,
VoidNamespace> {
private final InternalTimerService<VoidNamespace>
delayedMessagesTimerService;
private final DelayedMessagesBuffer delayedMessagesBuffer;
-
- private final Lazy<Reductions> reductionsSupplier;
- private final Partition thisPartition;
- private final RemoteSink remoteSink;
+ private final DelayMessageHandler delayMessageHandler;
@Inject
DelaySink(
@Label("delayed-messages-buffer") DelayedMessagesBuffer
delayedMessagesBuffer,
@Label("delayed-messages-timer-service-factory")
TimerServiceFactory delayedMessagesTimerServiceFactory,
- @Label("reductions") Lazy<Reductions> reductionsSupplier,
- Partition thisPartition,
- RemoteSink remoteSink) {
+ DelayMessageHandler delayMessageHandler) {
this.delayedMessagesBuffer = Objects.requireNonNull(delayedMessagesBuffer);
- this.reductionsSupplier = Objects.requireNonNull(reductionsSupplier);
- this.thisPartition = Objects.requireNonNull(thisPartition);
- this.remoteSink = Objects.requireNonNull(remoteSink);
-
this.delayedMessagesTimerService =
delayedMessagesTimerServiceFactory.createTimerService(this);
+ this.delayMessageHandler = Objects.requireNonNull(delayMessageHandler);
}
void accept(Message message, long delayMillis) {
@@ -64,35 +56,25 @@ final class DelaySink implements Triggerable<String,
VoidNamespace> {
}
@Override
- public void onProcessingTime(InternalTimer<String, VoidNamespace> timer)
throws Exception {
- final long triggerTimestamp = timer.getTimestamp();
- final Reductions reductions = reductionsSupplier.get();
-
- Iterable<Message> delayedMessages =
delayedMessagesBuffer.getForTimestamp(triggerTimestamp);
- if (delayedMessages == null) {
- throw new IllegalStateException(
- "A delayed message timer was triggered with timestamp "
- + triggerTimestamp
- + ", but no messages were buffered for it.");
- }
- for (Message delayedMessage : delayedMessages) {
- if (thisPartition.contains(delayedMessage.target())) {
- reductions.enqueue(delayedMessage);
- } else {
- remoteSink.accept(delayedMessage);
- }
- }
- // we clear the delayedMessageBuffer *before* we process the enqueued
local reductions, because
- // processing the envelops might actually trigger a delayed message to be
sent with the same
- // @triggerTimestamp
- // so it would be re-enqueued into the delayedMessageBuffer.
- delayedMessagesBuffer.clearForTimestamp(triggerTimestamp);
- reductions.processEnvelopes();
+ public void onProcessingTime(InternalTimer<String, VoidNamespace> timer) {
+ delayMessageHandler.onStart();
+ delayedMessagesBuffer.forEachMessageAt(timer.getTimestamp(),
delayMessageHandler);
+ delayMessageHandler.onComplete();
}
@Override
- public void onEventTime(InternalTimer<String, VoidNamespace> timer) throws
Exception {
+ public void onEventTime(InternalTimer<String, VoidNamespace> timer) {
throw new UnsupportedOperationException(
"Delayed messages with event time semantics is not supported.");
}
+
+ void removeMessageByCancellationToken(String cancellationToken) {
+ Objects.requireNonNull(cancellationToken);
+ OptionalLong timerToClear =
+
delayedMessagesBuffer.removeMessageByCancellationToken(cancellationToken);
+ if (timerToClear.isPresent()) {
+ long timestamp = timerToClear.getAsLong();
+
delayedMessagesTimerService.deleteProcessingTimeTimer(VoidNamespace.INSTANCE,
timestamp);
+ }
+ }
}
diff --git
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayedMessagesBuffer.java
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayedMessagesBuffer.java
index 1b68e3f..cf35389 100644
---
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayedMessagesBuffer.java
+++
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayedMessagesBuffer.java
@@ -17,13 +17,23 @@
*/
package org.apache.flink.statefun.flink.core.functions;
+import java.util.OptionalLong;
+import java.util.function.Consumer;
import org.apache.flink.statefun.flink.core.message.Message;
interface DelayedMessagesBuffer {
+ /** Add a message to be fired at a specific timestamp */
void add(Message message, long untilTimestamp);
- Iterable<Message> getForTimestamp(long timestamp);
+ /** Apply @fn for each delayed message that is meant to be fired at
@timestamp. */
+ void forEachMessageAt(long timestamp, Consumer<Message> fn);
- void clearForTimestamp(long timestamp);
+ /**
+ * @param token a message cancellation token to delete.
+ * @return an optional timestamp that this message was meant to be fired at.
The timestamp will be
+ * present only if this message was the last message registered to fire
at that timestamp.
+ * (hence: safe to clear any underlying timer)
+ */
+ OptionalLong removeMessageByCancellationToken(String token);
}
diff --git
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FlinkStateDelayedMessagesBuffer.java
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FlinkStateDelayedMessagesBuffer.java
index a451fd0..ac7b9a5 100644
---
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FlinkStateDelayedMessagesBuffer.java
+++
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FlinkStateDelayedMessagesBuffer.java
@@ -17,7 +17,14 @@
*/
package org.apache.flink.statefun.flink.core.functions;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Objects;
+import java.util.Optional;
+import java.util.OptionalLong;
+import java.util.function.Consumer;
+import javax.annotation.Nullable;
+import org.apache.flink.api.common.state.MapState;
import org.apache.flink.runtime.state.internal.InternalListState;
import org.apache.flink.statefun.flink.core.di.Inject;
import org.apache.flink.statefun.flink.core.di.Label;
@@ -26,41 +33,122 @@ import
org.apache.flink.statefun.flink.core.message.Message;
final class FlinkStateDelayedMessagesBuffer implements DelayedMessagesBuffer {
static final String BUFFER_STATE_NAME = "delayed-messages-buffer";
+ static final String INDEX_STATE_NAME = "delayed-message-index";
private final InternalListState<String, Long, Message> bufferState;
+ private final MapState<String, Long> cancellationTokenToTimestamp;
@Inject
FlinkStateDelayedMessagesBuffer(
- @Label("delayed-messages-buffer-state")
- InternalListState<String, Long, Message> bufferState) {
+ @Label("delayed-messages-buffer-state") InternalListState<String, Long,
Message> bufferState,
+ @Label("delayed-message-index") MapState<String, Long>
cancellationTokenToTimestamp) {
this.bufferState = Objects.requireNonNull(bufferState);
+ this.cancellationTokenToTimestamp =
Objects.requireNonNull(cancellationTokenToTimestamp);
}
@Override
- public void add(Message message, long untilTimestamp) {
- bufferState.setCurrentNamespace(untilTimestamp);
+ public void forEachMessageAt(long timestamp, Consumer<Message> fn) {
try {
- bufferState.add(message);
+ forEachMessageThrows(timestamp, fn);
} catch (Exception e) {
- throw new RuntimeException("Error adding delayed message to state
buffer: " + message, e);
+ throw new IllegalStateException(e);
}
}
@Override
- public Iterable<Message> getForTimestamp(long timestamp) {
- bufferState.setCurrentNamespace(timestamp);
-
+ public OptionalLong removeMessageByCancellationToken(String token) {
try {
- return bufferState.get();
+ return remove(token);
} catch (Exception e) {
- throw new RuntimeException(
- "Error accessing delayed message in state buffer for timestamp: " +
timestamp, e);
+ throw new IllegalStateException(
+ "Failed clearing a message with a cancellation token " + token, e);
}
}
@Override
- public void clearForTimestamp(long timestamp) {
+ public void add(Message message, long untilTimestamp) {
+ try {
+ addThrows(message, untilTimestamp);
+ } catch (Exception e) {
+ throw new RuntimeException("Error adding delayed message to state
buffer: " + message, e);
+ }
+ }
+
+ //
-----------------------------------------------------------------------------------------------------
+ // Internal
+ //
-----------------------------------------------------------------------------------------------------
+
+ private void forEachMessageThrows(long timestamp, Consumer<Message> fn)
throws Exception {
bufferState.setCurrentNamespace(timestamp);
+ for (Message message : bufferState.get()) {
+ removeMessageIdMapping(message);
+ fn.accept(message);
+ }
+ bufferState.clear();
+ }
+
+ private void addThrows(Message message, long untilTimestamp) throws
Exception {
+ bufferState.setCurrentNamespace(untilTimestamp);
+ bufferState.add(message);
+ Optional<String> maybeToken = message.cancellationToken();
+ if (!maybeToken.isPresent()) {
+ return;
+ }
+ String cancellationToken = maybeToken.get();
+ @Nullable Long previousTimestamp =
cancellationTokenToTimestamp.get(cancellationToken);
+ if (previousTimestamp != null) {
+ throw new IllegalStateException(
+ "Trying to associate a message with cancellation token "
+ + cancellationToken
+ + " and timestamp "
+ + untilTimestamp
+ + ", but a message with the same cancellation token exists and
with a timestamp "
+ + previousTimestamp);
+ }
+ cancellationTokenToTimestamp.put(cancellationToken, untilTimestamp);
+ }
+
+ private OptionalLong remove(String cancellationToken) throws Exception {
+ final @Nullable Long untilTimestamp =
cancellationTokenToTimestamp.get(cancellationToken);
+ if (untilTimestamp == null) {
+ // The message associated with @cancellationToken has already been
delivered, or previously
+ // removed.
+ return OptionalLong.empty();
+ }
+ cancellationTokenToTimestamp.remove(cancellationToken);
+ bufferState.setCurrentNamespace(untilTimestamp);
+ List<Message> newList = removeMessageByToken(bufferState.get(),
cancellationToken);
+ if (!newList.isEmpty()) {
+ // There are more messages to process, so we indicate to the caller that
+ // they should NOT cancel the timer.
+ bufferState.update(newList);
+ return OptionalLong.empty();
+ }
+ // There are no more message to remove, we clear the buffer and indicate
+ // to our caller to remove the timer for @untilTimestamp
bufferState.clear();
+ return OptionalLong.of(untilTimestamp);
+ }
+
+ //
---------------------------------------------------------------------------------------------------------
+ // Helpers
+ //
---------------------------------------------------------------------------------------------------------
+
+ private void removeMessageIdMapping(Message message) throws Exception {
+ Optional<String> maybeToken = message.cancellationToken();
+ if (maybeToken.isPresent()) {
+ cancellationTokenToTimestamp.remove(maybeToken.get());
+ }
+ }
+
+ private static List<Message> removeMessageByToken(Iterable<Message>
messages, String token) {
+ ArrayList<Message> newList = new ArrayList<>();
+ for (Message message : messages) {
+ Optional<String> thisMessageId = message.cancellationToken();
+ if (!thisMessageId.isPresent() || !Objects.equals(thisMessageId.get(),
token)) {
+ newList.add(message);
+ }
+ }
+ return newList;
}
}
diff --git
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java
index 741575b..8dcd01b 100644
---
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java
+++
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java
@@ -104,6 +104,11 @@ public class FunctionGroupOperator extends
AbstractStreamOperator<Message>
final ListStateDescriptor<Message> delayedMessageStateDescriptor =
new ListStateDescriptor<>(
FlinkStateDelayedMessagesBuffer.BUFFER_STATE_NAME,
envelopeSerializer.duplicate());
+ final MapStateDescriptor<String, Long> delayedMessageIndexDescriptor =
+ new MapStateDescriptor<>(
+ FlinkStateDelayedMessagesBuffer.INDEX_STATE_NAME, String.class,
Long.class);
+ final MapState<String, Long> delayedMessageIndex =
+ getRuntimeContext().getMapState(delayedMessageIndexDescriptor);
final MapState<Long, Message> asyncOperationState =
getRuntimeContext().getMapState(asyncOperationStateDescriptor);
@@ -130,6 +135,7 @@ public class FunctionGroupOperator extends
AbstractStreamOperator<Message>
new FlinkTimerServiceFactory(
super.getTimeServiceManager().orElseThrow(IllegalStateException::new)),
delayedMessagesBufferState(delayedMessageStateDescriptor),
+ delayedMessageIndex,
sideOutputs,
output,
MessageFactory.forKey(statefulFunctionsUniverse.messageFactoryKey()),
diff --git
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java
index 55b521f..b881a85 100644
---
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java
+++
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java
@@ -62,6 +62,7 @@ final class Reductions {
KeyedStateBackend<Object> keyedStateBackend,
TimerServiceFactory timerServiceFactory,
InternalListState<String, Long, Message> delayedMessagesBufferState,
+ MapState<String, Long> delayMessageIndex,
Map<EgressIdentifier<?>, OutputTag<Object>> sideOutputs,
Output<StreamRecord<Message>> output,
MessageFactory messageFactory,
@@ -117,6 +118,7 @@ final class Reductions {
// for delayed messages
container.add(
"delayed-messages-buffer-state", InternalListState.class,
delayedMessagesBufferState);
+ container.add("delayed-message-index", MapState.class, delayMessageIndex);
container.add(
"delayed-messages-buffer",
DelayedMessagesBuffer.class,
@@ -124,6 +126,7 @@ final class Reductions {
container.add(
"delayed-messages-timer-service-factory", TimerServiceFactory.class,
timerServiceFactory);
container.add(DelaySink.class);
+ container.add(DelayMessageHandler.class);
// lazy providers for the sinks
container.add("function-group", new Lazy<>(LocalFunctionGroup.class));
diff --git
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/ReusableContext.java
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/ReusableContext.java
index 77db7dc..e1a0c87 100644
---
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/ReusableContext.java
+++
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/ReusableContext.java
@@ -109,6 +109,23 @@ final class ReusableContext implements ApplyingContext,
InternalContext {
}
@Override
+ public void sendAfter(Duration delay, Address to, Object message, String
cancellationToken) {
+ Objects.requireNonNull(delay);
+ Objects.requireNonNull(to);
+ Objects.requireNonNull(message);
+ Objects.requireNonNull(cancellationToken);
+
+ Message envelope = messageFactory.from(self(), to, message,
cancellationToken);
+ delaySink.accept(envelope, delay.toMillis());
+ }
+
+ @Override
+ public void cancelDelayedMessage(String cancellationToken) {
+ Objects.requireNonNull(cancellationToken);
+ delaySink.removeMessageByCancellationToken(cancellationToken);
+ }
+
+ @Override
public <M, T> void registerAsyncOperation(M metadata, CompletableFuture<T>
future) {
Objects.requireNonNull(metadata);
Objects.requireNonNull(future);
diff --git
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/Message.java
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/Message.java
index 2278fa5..be10e3f 100644
---
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/Message.java
+++
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/Message.java
@@ -18,6 +18,7 @@
package org.apache.flink.statefun.flink.core.message;
import java.io.IOException;
+import java.util.Optional;
import java.util.OptionalLong;
import org.apache.flink.core.memory.DataOutputView;
@@ -35,6 +36,8 @@ public interface Message extends RoutableMessage {
*/
OptionalLong isBarrierMessage();
+ Optional<String> cancellationToken();
+
Message copy(MessageFactory context);
void writeTo(MessageFactory context, DataOutputView target) throws
IOException;
diff --git
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java
index e4d2d3a..780415e 100644
---
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java
+++
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java
@@ -54,6 +54,10 @@ public final class MessageFactory {
return new SdkMessage(from, to, payload);
}
+ public Message from(Address from, Address to, Object payload, String
messageId) {
+ return new SdkMessage(from, to, payload, messageId);
+ }
+
//
-------------------------------------------------------------------------------------------------------
void copy(DataInputView source, DataOutputView target) throws IOException {
diff --git
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/ProtobufMessage.java
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/ProtobufMessage.java
index dabda14..500958f 100644
---
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/ProtobufMessage.java
+++
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/ProtobufMessage.java
@@ -19,6 +19,7 @@ package org.apache.flink.statefun.flink.core.message;
import java.io.IOException;
import java.util.Objects;
+import java.util.Optional;
import java.util.OptionalLong;
import javax.annotation.Nullable;
import org.apache.flink.core.memory.DataOutputView;
@@ -82,6 +83,15 @@ final class ProtobufMessage implements Message {
}
@Override
+ public Optional<String> cancellationToken() {
+ String token = envelope.getCancellationToken();
+ if (token.isEmpty()) {
+ return Optional.empty();
+ }
+ return Optional.of(token);
+ }
+
+ @Override
public Message copy(MessageFactory unused) {
return new ProtobufMessage(envelope);
}
diff --git
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/SdkMessage.java
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/SdkMessage.java
index c10f2e9..ca5ee50 100644
---
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/SdkMessage.java
+++
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/SdkMessage.java
@@ -19,6 +19,7 @@ package org.apache.flink.statefun.flink.core.message;
import java.io.IOException;
import java.util.Objects;
+import java.util.Optional;
import java.util.OptionalLong;
import javax.annotation.Nullable;
import org.apache.flink.core.memory.DataOutputView;
@@ -29,18 +30,27 @@ import org.apache.flink.statefun.sdk.Address;
final class SdkMessage implements Message {
- @Nullable private final Address source;
-
private final Address target;
- private Object payload;
-
+ @Nullable private final Address source;
+ @Nullable private final String cancellationToken;
@Nullable private Envelope cachedEnvelope;
+ private Object payload;
+
SdkMessage(@Nullable Address source, Address target, Object payload) {
+ this(source, target, payload, null);
+ }
+
+ SdkMessage(
+ @Nullable Address source,
+ Address target,
+ Object payload,
+ @Nullable String cancellationToken) {
this.source = source;
this.target = Objects.requireNonNull(target);
this.payload = Objects.requireNonNull(payload);
+ this.cancellationToken = cancellationToken;
}
@Override
@@ -68,8 +78,13 @@ final class SdkMessage implements Message {
}
@Override
+ public Optional<String> cancellationToken() {
+ return Optional.ofNullable(cancellationToken);
+ }
+
+ @Override
public Message copy(MessageFactory factory) {
- return new SdkMessage(source, target, payload);
+ return new SdkMessage(source, target, payload, cancellationToken);
}
@Override
@@ -86,6 +101,9 @@ final class SdkMessage implements Message {
}
builder.setTarget(sdkAddressToProtobufAddress(target));
builder.setPayload(factory.serializeUserMessagePayload(payload));
+ if (cancellationToken != null) {
+ builder.setCancellationToken(cancellationToken);
+ }
cachedEnvelope = builder.build();
}
return cachedEnvelope;
diff --git
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
index d0ed196..b9fdc1a 100644
---
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
+++
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
@@ -233,14 +233,34 @@ public final class RequestReplyFunction implements
StatefulFunction {
private void handleOutgoingDelayedMessages(Context context,
InvocationResponse invocationResult) {
for (FromFunction.DelayedInvocation delayedInvokeCommand :
invocationResult.getDelayedInvocationsList()) {
- final Address to =
polyglotAddressToSdkAddress(delayedInvokeCommand.getTarget());
- final TypedValue message = delayedInvokeCommand.getArgument();
- final long delay = delayedInvokeCommand.getDelayInMs();
- context.sendAfter(Duration.ofMillis(delay), to, message);
+ if (delayedInvokeCommand.getIsCancellationRequest()) {
+ handleDelayedMessageCancellation(context, delayedInvokeCommand);
+ } else {
+ handleDelayedMessageSending(context, delayedInvokeCommand);
+ }
}
}
+ private void handleDelayedMessageSending(
+ Context context, FromFunction.DelayedInvocation delayedInvokeCommand) {
+ final Address to =
polyglotAddressToSdkAddress(delayedInvokeCommand.getTarget());
+ final TypedValue message = delayedInvokeCommand.getArgument();
+ final long delay = delayedInvokeCommand.getDelayInMs();
+
+ context.sendAfter(Duration.ofMillis(delay), to, message);
+ }
+
+ private void handleDelayedMessageCancellation(
+ Context context, FromFunction.DelayedInvocation delayedInvokeCommand) {
+ String token = delayedInvokeCommand.getCancellationToken();
+ if (token.isEmpty()) {
+ throw new IllegalArgumentException(
+ "Can not handle a cancellation request without a cancellation
token.");
+ }
+ context.cancelDelayedMessage(token);
+ }
+
//
--------------------------------------------------------------------------------
// Send Message to Remote Function
//
--------------------------------------------------------------------------------
diff --git
a/statefun-flink/statefun-flink-core/src/main/protobuf/stateful-functions.proto
b/statefun-flink/statefun-flink-core/src/main/protobuf/stateful-functions.proto
index 1b09239..1c17e0b 100644
---
a/statefun-flink/statefun-flink-core/src/main/protobuf/stateful-functions.proto
+++
b/statefun-flink/statefun-flink-core/src/main/protobuf/stateful-functions.proto
@@ -37,10 +37,14 @@ message Checkpoint {
int64 checkpoint_id = 1;
}
+
message Envelope {
EnvelopeAddress source = 1;
EnvelopeAddress target = 2;
+ // an optional token that can be used track delayed message cancellation.
+ string cancellation_token = 10;
+
oneof body {
Checkpoint checkpoint = 4;
Payload payload = 3;
diff --git
a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/LocalStatefulFunctionGroupTest.java
b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/LocalStatefulFunctionGroupTest.java
index 9c7ccd0..7cbd4b6 100644
---
a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/LocalStatefulFunctionGroupTest.java
+++
b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/LocalStatefulFunctionGroupTest.java
@@ -133,6 +133,12 @@ public class LocalStatefulFunctionGroupTest {
public void sendAfter(Duration duration, Address to, Object message) {}
@Override
+ public void sendAfter(Duration delay, Address to, Object message, String
cancellationToken) {}
+
+ @Override
+ public void cancelDelayedMessage(String cancellationToken) {}
+
+ @Override
public <M, T> void registerAsyncOperation(M metadata, CompletableFuture<T>
future) {}
@Override
diff --git
a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java
b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java
index 3f84bce..cf3b19a 100644
---
a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java
+++
b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java
@@ -100,12 +100,13 @@ public class ReductionsTest {
new FakeKeyedStateBackend(),
new FakeTimerServiceFactory(),
new FakeInternalListState(),
+ new FakeMapState<>(),
new HashMap<>(),
new FakeOutput(),
TestUtils.ENVELOPE_FACTORY,
MoreExecutors.directExecutor(),
new FakeMetricGroup(),
- new FakeMapState());
+ new FakeMapState<>());
assertThat(reductions, notNullValue());
}
@@ -517,44 +518,44 @@ public class ReductionsTest {
}
}
- private static final class FakeMapState implements MapState<Long, Message> {
+ private static final class FakeMapState<K, V> implements MapState<K, V> {
@Override
- public Message get(Long key) throws Exception {
+ public V get(K key) throws Exception {
return null;
}
@Override
- public void put(Long key, Message value) throws Exception {}
+ public void put(K key, V value) throws Exception {}
@Override
- public void putAll(Map<Long, Message> map) throws Exception {}
+ public void putAll(Map<K, V> map) throws Exception {}
@Override
- public void remove(Long key) throws Exception {}
+ public void remove(K key) throws Exception {}
@Override
- public boolean contains(Long key) throws Exception {
+ public boolean contains(K key) throws Exception {
return false;
}
@Override
- public Iterable<Entry<Long, Message>> entries() throws Exception {
+ public Iterable<Entry<K, V>> entries() throws Exception {
return null;
}
@Override
- public Iterable<Long> keys() throws Exception {
+ public Iterable<K> keys() throws Exception {
return null;
}
@Override
- public Iterable<Message> values() throws Exception {
+ public Iterable<V> values() throws Exception {
return null;
}
@Override
- public Iterator<Entry<Long, Message>> iterator() throws Exception {
+ public Iterator<Entry<K, V>> iterator() throws Exception {
return null;
}
diff --git
a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
index f88916f..5b7c536 100644
---
a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
+++
b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
@@ -37,6 +37,7 @@ import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.Supplier;
import java.util.stream.Collectors;
+import javax.annotation.Nullable;
import org.apache.flink.statefun.flink.core.backpressure.InternalContext;
import org.apache.flink.statefun.flink.core.metrics.FunctionTypeMetrics;
import org.apache.flink.statefun.flink.core.metrics.RemoteInvocationMetrics;
@@ -197,7 +198,7 @@ public class RequestReplyFunctionTest {
functionUnderTest.invoke(context, successfulAsyncOperation(response));
assertFalse(context.delayed.isEmpty());
- assertEquals(Duration.ofMillis(1), context.delayed.get(0).getKey());
+ assertEquals(Duration.ofMillis(1), context.delayed.get(0).delay());
}
@Test
@@ -376,6 +377,38 @@ public class RequestReplyFunctionTest {
}
}
+ private static final class DelayedMessage {
+ final Duration delay;
+ final @Nullable String messageId;
+ final Address target;
+ final Object message;
+
+ public DelayedMessage(
+ Duration delay, @Nullable String messageId, Address target, Object
message) {
+ this.delay = delay;
+ this.messageId = messageId;
+ this.target = target;
+ this.message = message;
+ }
+
+ public Duration delay() {
+ return delay;
+ }
+
+ @Nullable
+ public String messageId() {
+ return messageId;
+ }
+
+ public Address target() {
+ return target;
+ }
+
+ public Object message() {
+ return message;
+ }
+ }
+
private static final class FakeContext implements InternalContext {
private final BacklogTrackingMetrics fakeMetrics = new
BacklogTrackingMetrics();
@@ -385,7 +418,7 @@ public class RequestReplyFunctionTest {
// capture emitted messages
List<Map.Entry<EgressIdentifier<?>, ?>> egresses = new ArrayList<>();
- List<Map.Entry<Duration, ?>> delayed = new ArrayList<>();
+ List<DelayedMessage> delayed = new ArrayList<>();
@Override
public void awaitAsyncOperationComplete() {
@@ -417,10 +450,18 @@ public class RequestReplyFunctionTest {
@Override
public void sendAfter(Duration delay, Address to, Object message) {
- delayed.add(new SimpleImmutableEntry<>(delay, message));
+ delayed.add(new DelayedMessage(delay, null, to, message));
+ }
+
+ @Override
+ public void sendAfter(Duration delay, Address to, Object message, String
cancellationToken) {
+ delayed.add(new DelayedMessage(delay, cancellationToken, to, message));
}
@Override
+ public void cancelDelayedMessage(String cancellationToken) {}
+
+ @Override
public <M, T> void registerAsyncOperation(M metadata, CompletableFuture<T>
future) {}
}
diff --git
a/statefun-sdk-embedded/src/main/java/org/apache/flink/statefun/sdk/Context.java
b/statefun-sdk-embedded/src/main/java/org/apache/flink/statefun/sdk/Context.java
index d3a1231..ea9a32d 100644
---
a/statefun-sdk-embedded/src/main/java/org/apache/flink/statefun/sdk/Context.java
+++
b/statefun-sdk-embedded/src/main/java/org/apache/flink/statefun/sdk/Context.java
@@ -75,6 +75,33 @@ public interface Context {
void sendAfter(Duration delay, Address to, Object message);
/**
+ * Invokes another function with an input (associated with a {@code
cancellationToken}),
+ * identified by the target function's {@link Address}, after a given delay.
+ *
+ * <p>Providing an id to a message, allows "unsending" this message later.
({@link
+ * #cancelDelayedMessage(String)}).
+ *
+ * @param delay the amount of delay before invoking the target function.
Value needs to be >=
+ * 0.
+ * @param to the target function's address.
+ * @param message the input to provide for the delayed invocation.
+ * @param cancellationToken the non-empty, non-null, unique token to attach
to this message, to be
+ * used for message cancellation. (see {@link
#cancelDelayedMessage(String)}.)
+ */
+ void sendAfter(Duration delay, Address to, Object message, String
cancellationToken);
+
+ /**
+ * Cancel a delayed message (a message that was send via {@link
#sendAfter(Duration, Address,
+ * Object, String)}).
+ *
+ * <p>NOTE: this is a best-effort operation, since the message might have
been already delivered.
+ * If the message was delivered, this is a no-op operation.
+ *
+ * @param cancellationToken the id of the message to un-send.
+ */
+ void cancelDelayedMessage(String cancellationToken);
+
+ /**
* Invokes another function with an input, identified by the target
function's {@link
* FunctionType} and unique id.
*
diff --git
a/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/Context.java
b/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/Context.java
index 43ab6b8..8ef48e0 100644
---
a/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/Context.java
+++
b/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/Context.java
@@ -56,6 +56,26 @@ public interface Context {
void sendAfter(Duration duration, Message message);
/**
+ * Sends out a {@link Message} to another function, after a specified {@link
Duration} delay.
+ *
+ * @param duration the amount of time to delay the message delivery. *
@param cancellationToken
+ * @param cancellationToken the non-empty, non-null, unique token to attach
to this message, to be
+ * used for message cancellation. (see {@link
#cancelDelayedMessage(String)}.)
+ * @param message the message to send.
+ */
+ void sendAfter(Duration duration, String cancellationToken, Message message);
+
+ /**
+ * Cancel a delayed message (a message that was send via {@link
#sendAfter(Duration, Message)}).
+ *
+ * <p>NOTE: this is a best-effort operation, since the message might have
been already delivered.
+ * If the message was delivered, this is a no-op operation.
+ *
+ * @param cancellationToken the id of the message to un-send.
+ */
+ void cancelDelayedMessage(String cancellationToken);
+
+ /**
* Sends out a {@link EgressMessage} to an egress.
*
* @param message the message to send.
diff --git
a/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/handler/ConcurrentContext.java
b/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/handler/ConcurrentContext.java
index 49e9fcc..07e44ca 100644
---
a/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/handler/ConcurrentContext.java
+++
b/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/handler/ConcurrentContext.java
@@ -116,6 +116,46 @@ final class ConcurrentContext implements Context {
}
@Override
+ public void sendAfter(Duration duration, String cancellationToken, Message
message) {
+ Objects.requireNonNull(duration);
+ if (cancellationToken == null || cancellationToken.isEmpty()) {
+ throw new IllegalArgumentException("message cancellation token can not
be empty or null.");
+ }
+ Objects.requireNonNull(message);
+
+ FromFunction.DelayedInvocation outInvocation =
+ FromFunction.DelayedInvocation.newBuilder()
+ .setArgument(getTypedValue(message))
+ .setTarget(protoAddressFromSdk(message.targetAddress()))
+ .setDelayInMs(duration.toMillis())
+ .setCancellationToken(cancellationToken)
+ .build();
+
+ synchronized (responseBuilder) {
+ checkNotDone();
+ responseBuilder.addDelayedInvocations(outInvocation);
+ }
+ }
+
+ @Override
+ public void cancelDelayedMessage(String cancellationToken) {
+ if (cancellationToken == null || cancellationToken.isEmpty()) {
+ throw new IllegalArgumentException("message cancellation token can not
be empty or null.");
+ }
+
+ FromFunction.DelayedInvocation cancellation =
+ FromFunction.DelayedInvocation.newBuilder()
+ .setIsCancellationRequest(true)
+ .setCancellationToken(cancellationToken)
+ .build();
+
+ synchronized (responseBuilder) {
+ checkNotDone();
+ responseBuilder.addDelayedInvocations(cancellation);
+ }
+ }
+
+ @Override
public void send(EgressMessage message) {
Objects.requireNonNull(message);
diff --git a/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto
b/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto
index 19d9f2a..ac72d7c 100644
--- a/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto
+++ b/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto
@@ -115,6 +115,15 @@ message FromFunction {
// DelayedInvocation represents a delayed remote function call with a
target address, an argument
// and a delay in milliseconds, after which this message to be sent.
message DelayedInvocation {
+ // a boolean value (default false) that indicates rather this is a
regular delayed message, or (true) a message
+ // cancellation request.
+ // in case of a regular delayed message all other fields are expected
to be preset, otherwise only the
+ // cancellation_token is expected
+ bool is_cancellation_request = 10;
+
+ // an optional cancellation token that can be used to request the
"unsending" of a delayed message.
+ string cancellation_token = 11;
+
// the amount of milliseconds to wait before sending this message
int64 delay_in_ms = 1;
// the target address to send this message to
diff --git a/statefun-sdk-python/statefun/context.py
b/statefun-sdk-python/statefun/context.py
index b1692a5..578e020 100644
--- a/statefun-sdk-python/statefun/context.py
+++ b/statefun-sdk-python/statefun/context.py
@@ -24,7 +24,6 @@ from statefun.messages import Message, EgressMessage
class Context(abc.ABC):
-
__slots__ = ()
@property
@@ -62,13 +61,24 @@ class Context(abc.ABC):
"""
pass
- def send_after(self, duration: timedelta, message: Message):
+ def send_after(self, duration: timedelta, message: Message,
cancellation_token: str = ""):
"""
Send a message to a target function after a specified delay.
:param duration: the amount of time to wait before sending this
message out.
:param message: the message to send.
+ :param cancellation_token: an optional cancellation token to associate
with this message.
+ """
+ pass
+
+ def cancel_delayed_message(self, cancellation_token: str):
"""
+ Cancel a delayed message (message that was sent using send_after) with
a given token.
+
+ Please note that this is a best-effort operation, since the message
might have been already delivered.
+ If the message was delivered, this is a no-op operation.
+ """
+ pass
def send_egress(self, message: EgressMessage):
"""
diff --git a/statefun-sdk-python/statefun/request_reply_v3.py
b/statefun-sdk-python/statefun/request_reply_v3.py
index fa05253..f4db551 100644
--- a/statefun-sdk-python/statefun/request_reply_v3.py
+++ b/statefun-sdk-python/statefun/request_reply_v3.py
@@ -28,6 +28,16 @@ from statefun.statefun_builder import StatefulFunctions,
StatefulFunction
from statefun.request_reply_pb2 import ToFunction, FromFunction, Address,
TypedValue
from statefun.storage import resolve, Cell
+from dataclasses import dataclass
+
+
+@dataclass
+class DelayedMessage:
+ is_cancellation: bool = None
+ duration: int = None
+ message: Message = None,
+ cancellation_token: str = None
+
class UserFacingContext(statefun.context.Context):
__slots__ = (
@@ -37,7 +47,7 @@ class UserFacingContext(statefun.context.Context):
def __init__(self, address, storage):
self._self_address = address
self._outgoing_messages = []
- self._outgoing_delayed_messages = []
+ self._outgoing_delayed_messages: typing.List[DelayedMessage] = []
self._outgoing_egress_messages = []
self._storage = storage
self._caller = None
@@ -66,15 +76,28 @@ class UserFacingContext(statefun.context.Context):
"""
self._outgoing_messages.append(message)
- def send_after(self, duration: timedelta, message: Message):
+ def send_after(self, duration: timedelta, message: Message,
cancellation_token: str = ""):
"""
Send a message to a target function after a specified delay.
:param duration: the amount of time to wait before sending this
message out.
:param message: the message to send.
+ :param cancellation_token: an optional cancellation token to associate
with this message.
"""
ms = int(duration.total_seconds() * 1000.0)
- self._outgoing_delayed_messages.append((ms, message))
+ record = DelayedMessage(is_cancellation=False, duration=ms,
message=message,
+ cancellation_token=cancellation_token)
+ self._outgoing_delayed_messages.append(record)
+
+ def cancel_delayed_message(self, cancellation_token: str):
+ """
+ Cancel a delayed message (message that was sent using send_after) with
a given token.
+
+ Please note that this is a best-effort operation, since the message
might have been already delivered.
+ If the message was delivered, this is a no-op operation.
+ """
+ record = DelayedMessage(is_cancellation=True,
cancellation_token=cancellation_token)
+ self._outgoing_delayed_messages.append(record)
def send_egress(self, message: EgressMessage):
"""
@@ -145,17 +168,34 @@ def collect_messages(messages: typing.List[Message],
pb_invocation_result):
outgoing.argument.CopyFrom(message.typed_value)
-def collect_delayed(delayed_messages: typing.List[typing.Tuple[timedelta,
Message]], invocation_result):
+def collect_delayed(delayed_messages: typing.List[DelayedMessage],
invocation_result):
delayed_invocations = invocation_result.delayed_invocations
- for delay, message in delayed_messages:
+ for delayed_message in delayed_messages:
outgoing = delayed_invocations.add()
- namespace, type = parse_typename(message.target_typename)
- outgoing.target.namespace = namespace
- outgoing.target.type = type
- outgoing.target.id = message.target_id
- outgoing.delay_in_ms = delay
- outgoing.argument.CopyFrom(message.typed_value)
+ if delayed_message.is_cancellation:
+ # handle cancellation
+ outgoing.cancellation_token = delayed_message.cancellation_token
+ outgoing.is_cancellation_request = True
+ else:
+ message = delayed_message.message
+ namespace, type = parse_typename(message.target_typename)
+
+ outgoing.target.namespace = namespace
+ outgoing.target.type = type
+ outgoing.target.id = message.target_id
+ outgoing.delay_in_ms = delayed_message.duration
+ outgoing.argument.CopyFrom(message.typed_value)
+ if delayed_message.cancellation_token is not None:
+ outgoing.cancellation_token =
delayed_message.cancellation_token
+
+
+def collect_cancellations(tokens: typing.List[str], invocation_result):
+ outgoing_cancellations = invocation_result.outgoing_delay_cancellations
+ for token in tokens:
+ if token:
+ delay_cancelltion = outgoing_cancellations.add()
+ delay_cancelltion.cancellation_token = token
def collect_egress(egresses: typing.List[EgressMessage], invocation_result):
diff --git a/statefun-sdk-python/tests/request_reply_test.py
b/statefun-sdk-python/tests/request_reply_test.py
index 612750f..5bd783a 100644
--- a/statefun-sdk-python/tests/request_reply_test.py
+++ b/statefun-sdk-python/tests/request_reply_test.py
@@ -91,6 +91,7 @@ def nth(n):
NTH_OUTGOING_MESSAGE = lambda n: [key("invocation_result"),
key("outgoing_messages"), nth(n)]
NTH_STATE_MUTATION = lambda n: [key("invocation_result"),
key("state_mutations"), nth(n)]
NTH_DELAYED_MESSAGE = lambda n: [key("invocation_result"),
key("delayed_invocations"), nth(n)]
+NTH_CANCELLATION_MESSAGE = lambda n: [key("invocation_result"),
key("outgoing_delay_cancellations"), nth(n)]
NTH_EGRESS = lambda n: [key("invocation_result"), key("outgoing_egresses"),
nth(n)]
NTH_MISSING_STATE_SPEC = lambda n: [key("incomplete_invocation_context"),
key("missing_values"), nth(n)]
@@ -123,6 +124,16 @@ class RequestReplyTestCase(unittest.TestCase):
message_builder(target_typename="night/owl",
target_id="1",
str_value="hoo hoo"))
+
+ # delayed with cancellation
+ context.send_after(timedelta(hours=1),
+ message_builder(target_typename="night/owl",
+ target_id="1",
+ str_value="hoo hoo"),
+ cancellation_token="token-1234")
+
+ context.cancel_delayed_message("token-1234")
+
# kafka egresses
context.send_egress(
kafka_egress_message(typename="e/kafka",
@@ -165,6 +176,15 @@ class RequestReplyTestCase(unittest.TestCase):
first_delayed = json_at(result_json, NTH_DELAYED_MESSAGE(0))
self.assertEqual(int(first_delayed['delay_in_ms']), 1000 * 60 * 60)
+ # assert delayed with token
+ second_delayed = json_at(result_json, NTH_DELAYED_MESSAGE(1))
+ self.assertEqual(second_delayed['cancellation_token'], "token-1234")
+
+ # assert cancellation
+ first_cancellation = json_at(result_json, NTH_DELAYED_MESSAGE(2))
+ self.assertTrue(first_cancellation['is_cancellation_request'])
+ self.assertEqual(first_cancellation['cancellation_token'],
"token-1234")
+
# assert egresses
first_egress = json_at(result_json, NTH_EGRESS(0))
self.assertEqual(first_egress['egress_namespace'], 'e')
diff --git
a/statefun-testutil/src/main/java/org/apache/flink/statefun/testutils/function/TestContext.java
b/statefun-testutil/src/main/java/org/apache/flink/statefun/testutils/function/TestContext.java
index 1c39775..fdf6be8 100644
---
a/statefun-testutil/src/main/java/org/apache/flink/statefun/testutils/function/TestContext.java
+++
b/statefun-testutil/src/main/java/org/apache/flink/statefun/testutils/function/TestContext.java
@@ -102,7 +102,21 @@ class TestContext implements Context {
@Override
public void sendAfter(Duration delay, Address to, Object message) {
pendingMessage.add(
- new PendingMessage(new Envelope(self(), to, message), watermark +
delay.toMillis()));
+ new PendingMessage(new Envelope(self(), to, message), watermark +
delay.toMillis(), null));
+ }
+
+ @Override
+ public void sendAfter(Duration delay, Address to, Object message, String
cancellationToken) {
+ Objects.requireNonNull(cancellationToken);
+ pendingMessage.add(
+ new PendingMessage(
+ new Envelope(self(), to, message), watermark + delay.toMillis(),
cancellationToken));
+ }
+
+ @Override
+ public void cancelDelayedMessage(String cancellationToken) {
+ pendingMessage.removeIf(
+ pendingMessage -> Objects.equals(pendingMessage.cancellationToken,
cancellationToken));
}
@Override
@@ -186,12 +200,13 @@ class TestContext implements Context {
private static class PendingMessage {
Envelope envelope;
-
+ String cancellationToken;
long timer;
- PendingMessage(Envelope envelope, long timer) {
+ PendingMessage(Envelope envelope, long timer, String cancellationToken) {
this.envelope = envelope;
this.timer = timer;
+ this.cancellationToken = cancellationToken;
}
}
}