This is an automated email from the ASF dual-hosted git repository.

tzulitai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-statefun.git

commit df792ac50d9b09f47b66567bec4f3b36326b3dbc
Author: Tzu-Li (Gordon) Tai <[email protected]>
AuthorDate: Mon Oct 12 13:28:16 2020 +0800

    [FLINK-20265] [core] Implement new extended protocol in RequestReplyFunction
    
    This commit lets the RequestReplyFunction handle
    IncompleteInvocationContext responses, by retrying the original batch
    after registering states that were indicated to be missing in the
    original batch request.
    
    This closes #177.
---
 .../flink/core/reqreply/RequestReplyFunction.java  | 71 +++++++++++++++-------
 .../core/reqreply/RequestReplyFunctionTest.java    | 48 +++++++++++++++
 2 files changed, 98 insertions(+), 21 deletions(-)

diff --git 
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
 
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
index 66f3a24..01ee950 100644
--- 
a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
+++ 
b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
@@ -29,6 +29,7 @@ import 
org.apache.flink.statefun.flink.core.backpressure.InternalContext;
 import org.apache.flink.statefun.flink.core.metrics.RemoteInvocationMetrics;
 import org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction;
 import 
org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.EgressMessage;
+import 
org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.IncompleteInvocationContext;
 import 
org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.InvocationResponse;
 import org.apache.flink.statefun.flink.core.polyglot.generated.ToFunction;
 import 
org.apache.flink.statefun.flink.core.polyglot.generated.ToFunction.Invocation;
@@ -41,6 +42,7 @@ import org.apache.flink.statefun.sdk.annotations.Persisted;
 import org.apache.flink.statefun.sdk.io.EgressIdentifier;
 import org.apache.flink.statefun.sdk.state.PersistedAppendingBuffer;
 import org.apache.flink.statefun.sdk.state.PersistedValue;
+import org.apache.flink.types.Either;
 
 public final class RequestReplyFunction implements StatefulFunction {
 
@@ -124,8 +126,48 @@ public final class RequestReplyFunction implements 
StatefulFunction {
       sendToFunction(context, batch);
       return;
     }
-    InvocationResponse invocationResult = 
unpackInvocationOrThrow(context.self(), asyncResult);
-    handleInvocationResponse(context, invocationResult);
+    if (asyncResult.failure()) {
+      throw new IllegalStateException(
+          "Failure forwarding a message to a remote function " + 
context.self(),
+          asyncResult.throwable());
+    }
+
+    final Either<InvocationResponse, IncompleteInvocationContext> response =
+        unpackResponse(asyncResult.value());
+    if (response.isRight()) {
+      handleIncompleteInvocationContextResponse(context, response.right(), 
asyncResult.metadata());
+    } else {
+      handleInvocationResultResponse(context, response.left());
+    }
+  }
+
+  private static Either<InvocationResponse, IncompleteInvocationContext> 
unpackResponse(
+      FromFunction fromFunction) {
+    if (fromFunction.hasIncompleteInvocationContext()) {
+      return Either.Right(fromFunction.getIncompleteInvocationContext());
+    }
+    if (fromFunction.hasInvocationResult()) {
+      return Either.Left(fromFunction.getInvocationResult());
+    }
+    // function had no side effects
+    return Either.Left(InvocationResponse.getDefaultInstance());
+  }
+
+  private void handleIncompleteInvocationContextResponse(
+      InternalContext context,
+      IncompleteInvocationContext incompleteContext,
+      ToFunction originalBatch) {
+    managedStates.registerStates(incompleteContext.getMissingValuesList());
+
+    final InvocationBatchRequest.Builder retryBatch = 
createRetryBatch(originalBatch);
+    sendToFunction(context, retryBatch);
+  }
+
+  private void handleInvocationResultResponse(InternalContext context, 
InvocationResponse result) {
+    handleOutgoingMessages(context, result);
+    handleOutgoingDelayedMessages(context, result);
+    handleEgressMessages(context, result);
+    managedStates.updateStateValues(result.getStateMutationsList());
 
     final int numBatched = requestState.getOrDefault(-1);
     if (numBatched < 0) {
@@ -136,7 +178,8 @@ public final class RequestReplyFunction implements 
StatefulFunction {
       final InvocationBatchRequest.Builder nextBatch = getNextBatch();
       // an async request was just completed, but while it was in flight we 
have
       // accumulated a batch, we now proceed with:
-      // a) clearing the batch from our own persisted state (the batch moves 
to the async operation
+      // a) clearing the batch from our own persisted state (the batch moves 
to the async
+      // operation
       // state)
       // b) sending the accumulated batch to the remote function.
       requestState.set(0);
@@ -146,19 +189,6 @@ public final class RequestReplyFunction implements 
StatefulFunction {
     }
   }
 
-  private InvocationResponse unpackInvocationOrThrow(
-      Address self, AsyncOperationResult<ToFunction, FromFunction> result) {
-    if (result.failure()) {
-      throw new IllegalStateException(
-          "Failure forwarding a message to a remote function " + self, 
result.throwable());
-    }
-    FromFunction fromFunction = result.value();
-    if (fromFunction.hasInvocationResult()) {
-      return fromFunction.getInvocationResult();
-    }
-    return InvocationResponse.getDefaultInstance();
-  }
-
   private InvocationBatchRequest.Builder getNextBatch() {
     InvocationBatchRequest.Builder builder = 
InvocationBatchRequest.newBuilder();
     Iterable<Invocation> view = batch.view();
@@ -166,11 +196,10 @@ public final class RequestReplyFunction implements 
StatefulFunction {
     return builder;
   }
 
-  private void handleInvocationResponse(Context context, InvocationResponse 
invocationResult) {
-    handleOutgoingMessages(context, invocationResult);
-    handleOutgoingDelayedMessages(context, invocationResult);
-    handleEgressMessages(context, invocationResult);
-    managedStates.updateStateValues(invocationResult.getStateMutationsList());
+  private InvocationBatchRequest.Builder createRetryBatch(ToFunction 
toFunction) {
+    InvocationBatchRequest.Builder builder = 
InvocationBatchRequest.newBuilder();
+    builder.addAllInvocations(toFunction.getInvocation().getInvocationsList());
+    return builder;
   }
 
   private void handleEgressMessages(Context context, InvocationResponse 
invocationResult) {
diff --git 
a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
 
b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
index 8ca1fc9..c4eb85a 100644
--- 
a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
+++ 
b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
@@ -19,6 +19,7 @@ package org.apache.flink.statefun.flink.core.reqreply;
 
 import static org.apache.flink.statefun.flink.core.TestUtils.FUNCTION_1_ADDR;
 import static 
org.apache.flink.statefun.flink.core.common.PolyglotUtil.polyglotAddressToSdkAddress;
+import static org.hamcrest.CoreMatchers.hasItems;
 import static org.hamcrest.CoreMatchers.is;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -33,8 +34,10 @@ import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 import java.util.function.Supplier;
+import java.util.stream.Collectors;
 import org.apache.flink.statefun.flink.core.TestUtils;
 import org.apache.flink.statefun.flink.core.backpressure.InternalContext;
 import org.apache.flink.statefun.flink.core.httpfn.StateSpec;
@@ -43,9 +46,12 @@ import 
org.apache.flink.statefun.flink.core.metrics.RemoteInvocationMetrics;
 import org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction;
 import 
org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.DelayedInvocation;
 import 
org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.EgressMessage;
+import 
org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.ExpirationSpec;
+import 
org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.IncompleteInvocationContext;
 import 
org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.InvocationResponse;
 import 
org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.PersistedValueMutation;
 import 
org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.PersistedValueMutation.MutationType;
+import 
org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.PersistedValueSpec;
 import org.apache.flink.statefun.flink.core.polyglot.generated.ToFunction;
 import 
org.apache.flink.statefun.flink.core.polyglot.generated.ToFunction.Invocation;
 import org.apache.flink.statefun.sdk.Address;
@@ -211,6 +217,37 @@ public class RequestReplyFunctionTest {
   }
 
   @Test
+  public void retryBatchOnIncompleteInvocationContextResponse() {
+    Any any = Any.pack(TestUtils.DUMMY_PAYLOAD);
+    functionUnderTest.invoke(context, any);
+
+    FromFunction response =
+        FromFunction.newBuilder()
+            .setIncompleteInvocationContext(
+                IncompleteInvocationContext.newBuilder()
+                    .addMissingValues(
+                        PersistedValueSpec.newBuilder()
+                            .setStateName("new-state")
+                            .setExpirationSpec(
+                                ExpirationSpec.newBuilder()
+                                    
.setMode(ExpirationSpec.ExpireMode.AFTER_INVOKE)
+                                    .setExpireAfterMillis(5000)
+                                    .build())))
+            .build();
+
+    functionUnderTest.invoke(context, 
successfulAsyncOperation(client.wasSentToFunction, response));
+
+    // re-sent batch should have identical invocation input messages
+    assertTrue(client.wasSentToFunction.hasInvocation());
+    assertThat(client.capturedInvocationBatchSize(), is(1));
+    assertThat(client.capturedInvocation(0).getArgument(), is(any));
+
+    // re-sent batch should have new state as well as originally registered 
state
+    assertThat(client.capturedStateNames().size(), is(2));
+    assertThat(client.capturedStateNames(), hasItems("session", "new-state"));
+  }
+
+  @Test
   public void backlogMetricsIncreasedOnInvoke() {
     functionUnderTest.invoke(context, Any.getDefaultInstance());
 
@@ -246,6 +283,11 @@ public class RequestReplyFunctionTest {
     return new AsyncOperationResult<>(new Object(), Status.SUCCESS, 
fromFunction, null);
   }
 
+  private static AsyncOperationResult<ToFunction, FromFunction> 
successfulAsyncOperation(
+      ToFunction toFunction, FromFunction fromFunction) {
+    return new AsyncOperationResult<>(toFunction, Status.SUCCESS, 
fromFunction, null);
+  }
+
   private static final class FakeClient implements RequestReplyClient {
     ToFunction wasSentToFunction;
     Supplier<FromFunction> fromFunction = FromFunction::getDefaultInstance;
@@ -274,6 +316,12 @@ public class RequestReplyFunctionTest {
       return wasSentToFunction.getInvocation().getState(n).getStateValue();
     }
 
+    Set<String> capturedStateNames() {
+      return wasSentToFunction.getInvocation().getStateList().stream()
+          .map(ToFunction.PersistedValue::getStateName)
+          .collect(Collectors.toSet());
+    }
+
     public int capturedInvocationBatchSize() {
       return wasSentToFunction.getInvocation().getInvocationsCount();
     }

Reply via email to