This is an automated email from the ASF dual-hosted git repository. tzulitai pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-statefun.git
commit 69c658ec361682fb3bef50bd810f0646d7332a0c Author: Tzu-Li (Gordon) Tai <[email protected]> AuthorDate: Thu Jan 28 16:40:16 2021 +0800 [FLINK-21171] Wire in TypedValue throughout the runtime as state values and message payloads This closes #195. --- statefun-e2e-tests/statefun-smoke-e2e/pom.xml | 59 +++++++++++++ .../statefun/e2e/smoke/CommandFlinkSource.java | 15 ++-- .../statefun/e2e/smoke/CommandInterpreter.java | 28 ++++--- .../flink/statefun/e2e/smoke/CommandRouter.java | 12 +-- .../apache/flink/statefun/e2e/smoke/Constants.java | 12 +-- .../apache/flink/statefun/e2e/smoke/Module.java | 9 +- .../flink/statefun/e2e/smoke/ProtobufUtils.java | 34 -------- .../statefun/e2e/smoke/CommandInterpreterTest.java | 4 +- .../flink/statefun/e2e/smoke/HarnessTest.java | 4 +- .../flink/statefun/e2e/smoke/SmokeRunner.java | 4 +- .../org/apache/flink/statefun/e2e/smoke/Utils.java | 15 ++-- .../run-example.py | 26 ++++-- statefun-flink/statefun-flink-common/pom.xml | 57 +++++++++++++ .../flink/common/types/TypedValueUtil.java | 55 ++++++++++++ .../flink/core/jsonmodule/EgressJsonEntity.java | 6 +- .../protorouter/AutoRoutableProtobufRouter.java | 15 ++-- .../reqreply/PersistedRemoteFunctionValues.java | 37 +++++---- .../flink/core/reqreply/RequestReplyFunction.java | 16 ++-- .../flink/core/jsonmodule/JsonModuleTest.java | 5 +- .../PersistedRemoteFunctionValuesTest.java | 51 +++++++++--- .../core/reqreply/RequestReplyFunctionTest.java | 97 +++++++++++++--------- statefun-flink/statefun-flink-io-bundle/pom.xml | 10 +++ .../io/kafka/GenericKafkaEgressSerializer.java | 15 ++-- .../flink/io/kafka/GenericKafkaSinkProvider.java | 6 +- .../polyglot/GenericKinesisEgressSerializer.java | 13 +-- .../polyglot/GenericKinesisSinkProvider.java | 6 +- .../io/kafka/GenericKafkaSinkProviderTest.java | 4 +- .../io/kinesis/GenericKinesisSinkProviderTest.java | 4 +- statefun-python-sdk/statefun/core.py | 7 ++ statefun-python-sdk/statefun/request_reply.py | 27 +++--- statefun-python-sdk/statefun/typed_value_utils.py | 49 +++++++++++ statefun-python-sdk/tests/request_reply_test.py | 34 ++++++-- .../src/main/protobuf/sdk/request-reply.proto | 21 +++-- 33 files changed, 537 insertions(+), 220 deletions(-) diff --git a/statefun-e2e-tests/statefun-smoke-e2e/pom.xml b/statefun-e2e-tests/statefun-smoke-e2e/pom.xml index 71bb3c3..26318c2 100644 --- a/statefun-e2e-tests/statefun-smoke-e2e/pom.xml +++ b/statefun-e2e-tests/statefun-smoke-e2e/pom.xml @@ -30,6 +30,7 @@ under the License. <properties> <testcontainers.version>1.12.5</testcontainers.version> <commons-math3.version>3.5</commons-math3.version> + <additional-sources.dir>target/additional-sources</additional-sources.dir> </properties> <dependencies> @@ -41,6 +42,11 @@ under the License. </dependency> <dependency> <groupId>org.apache.flink</groupId> + <artifactId>statefun-sdk-protos</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> + <groupId>org.apache.flink</groupId> <artifactId>statefun-flink-io</artifactId> <version>${project.version}</version> </dependency> @@ -132,10 +138,63 @@ under the License. <build> <plugins> + <!-- + The following plugin is executed in the generated-sources phase, + and is responsible to extract the additional *.proto files located + at statefun-sdk-protos.jar. + --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-dependency-plugin</artifactId> + <executions> + <execution> + <id>unpack</id> + <phase>generate-sources</phase> + <goals> + <goal>unpack</goal> + </goals> + <configuration> + <artifactItems> + <artifactItem> + <groupId>org.apache.flink</groupId> + <artifactId>statefun-sdk-protos</artifactId> + <version>${project.version}</version> + <type>jar</type> + <outputDirectory>${additional-sources.dir}</outputDirectory> + <includes>sdk/*.proto</includes> + </artifactItem> + </artifactItems> + </configuration> + </execution> + </executions> + </plugin> + <!-- + The following plugin invokes protoc to generate Java classes out of the *.proto + definitions located at: (1) src/main/protobuf (2) ${additional-sources.dir}. + --> <plugin> <groupId>com.github.os72</groupId> <artifactId>protoc-jar-maven-plugin</artifactId> <version>${protoc-jar-maven-plugin.version}</version> + <executions> + <execution> + <id>generate-protobuf-sources</id> + <phase>generate-sources</phase> + <goals> + <goal>run</goal> + </goals> + <configuration> + <includeStdTypes>true</includeStdTypes> + <protocVersion>${protobuf.version}</protocVersion> + <cleanOutputFolder>true</cleanOutputFolder> + <inputDirectories> + <inputDirectory>src/main/protobuf</inputDirectory> + <inputDirectory>${additional-sources.dir}</inputDirectory> + </inputDirectories> + <outputDirectory>${basedir}/target/generated-sources/protoc-jar</outputDirectory> + </configuration> + </execution> + </executions> </plugin> </plugins> </build> diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandFlinkSource.java b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandFlinkSource.java index ea4ed39..374d9e8 100644 --- a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandFlinkSource.java +++ b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandFlinkSource.java @@ -20,7 +20,6 @@ package org.apache.flink.statefun.e2e.smoke; import static org.apache.flink.statefun.e2e.smoke.generated.Command.Verify; import static org.apache.flink.statefun.e2e.smoke.generated.Command.newBuilder; -import com.google.protobuf.Any; import java.util.Iterator; import java.util.Objects; import java.util.OptionalInt; @@ -38,6 +37,8 @@ import org.apache.flink.statefun.e2e.smoke.generated.Command; import org.apache.flink.statefun.e2e.smoke.generated.Commands; import org.apache.flink.statefun.e2e.smoke.generated.SourceCommand; import org.apache.flink.statefun.e2e.smoke.generated.SourceSnapshot; +import org.apache.flink.statefun.flink.common.types.TypedValueUtil; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.functions.source.RichSourceFunction; import org.slf4j.Logger; @@ -54,7 +55,7 @@ import org.slf4j.LoggerFactory; * to {@code verification} step. At this step, it would keep sending (every 2 seconds) a {@link * Verify} command to every function indefinitely. */ -final class CommandFlinkSource extends RichSourceFunction<Any> +final class CommandFlinkSource extends RichSourceFunction<TypedValue> implements CheckpointedFunction, CheckpointListener { private static final Logger LOG = LoggerFactory.getLogger(CommandFlinkSource.class); @@ -132,7 +133,7 @@ final class CommandFlinkSource extends RichSourceFunction<Any> // ------------------------------------------------------------------------------------------------------------ @Override - public void run(SourceContext<Any> ctx) { + public void run(SourceContext<TypedValue> ctx) { generate(ctx); do { verify(ctx); @@ -145,7 +146,7 @@ final class CommandFlinkSource extends RichSourceFunction<Any> } while (true); } - private void generate(SourceContext<Any> ctx) { + private void generate(SourceContext<TypedValue> ctx) { final int startPosition = this.commandsSentSoFar; final OptionalInt kaboomIndex = computeFailureIndex(startPosition, failuresSoFar, moduleParameters.getMaxFailures()); @@ -170,13 +171,13 @@ final class CommandFlinkSource extends RichSourceFunction<Any> return; } functionStateTracker.apply(command); - ctx.collect(Any.pack(command)); + ctx.collect(TypedValueUtil.packProtobufMessage(command)); this.commandsSentSoFar = i; } } } - private void verify(SourceContext<Any> ctx) { + private void verify(SourceContext<TypedValue> ctx) { FunctionStateTracker functionStateTracker = this.functionStateTracker; for (int i = 0; i < moduleParameters.getNumberOfFunctionInstances(); i++) { @@ -190,7 +191,7 @@ final class CommandFlinkSource extends RichSourceFunction<Any> .setCommands(Commands.newBuilder().addCommand(verify)) .build(); synchronized (ctx.getCheckpointLock()) { - ctx.collect(Any.pack(command)); + ctx.collect(TypedValueUtil.packProtobufMessage(command)); } } } diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreter.java b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreter.java index 343c8f2..036e6e0 100644 --- a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreter.java +++ b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreter.java @@ -17,9 +17,10 @@ */ package org.apache.flink.statefun.e2e.smoke; -import static org.apache.flink.statefun.e2e.smoke.ProtobufUtils.unpack; +import static org.apache.flink.statefun.flink.common.types.TypedValueUtil.isProtobufTypeOf; +import static org.apache.flink.statefun.flink.common.types.TypedValueUtil.packProtobufMessage; +import static org.apache.flink.statefun.flink.common.types.TypedValueUtil.unpackProtobufMessage; -import com.google.protobuf.Any; import java.time.Duration; import java.util.Objects; import java.util.concurrent.CompletableFuture; @@ -30,6 +31,7 @@ import org.apache.flink.statefun.e2e.smoke.generated.VerificationResult; import org.apache.flink.statefun.sdk.AsyncOperationResult; import org.apache.flink.statefun.sdk.Context; import org.apache.flink.statefun.sdk.FunctionType; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; import org.apache.flink.statefun.sdk.state.PersistedValue; public final class CommandInterpreter { @@ -50,18 +52,18 @@ public final class CommandInterpreter { interpret(state, context, res.metadata()); return; } - if (!(message instanceof Any)) { + if (!(message instanceof TypedValue)) { throw new IllegalArgumentException("wtf " + message); } - Any any = (Any) message; - if (any.is(SourceCommand.class)) { - SourceCommand sourceCommand = unpack(any, SourceCommand.class); + TypedValue typedValue = (TypedValue) message; + if (isProtobufTypeOf(typedValue, SourceCommand.getDescriptor())) { + SourceCommand sourceCommand = unpackProtobufMessage(typedValue, SourceCommand.parser()); interpret(state, context, sourceCommand.getCommands()); - } else if (any.is(Commands.class)) { - Commands commands = unpack(any, Commands.class); + } else if (isProtobufTypeOf(typedValue, Commands.getDescriptor())) { + Commands commands = unpackProtobufMessage(typedValue, Commands.parser()); interpret(state, context, commands); } else { - throw new IllegalArgumentException("Unknown message type " + any.getTypeUrl()); + throw new IllegalArgumentException("Unknown message type " + typedValue.getTypename()); } } @@ -96,14 +98,14 @@ public final class CommandInterpreter { .setActual(actual) .setExpected(expected) .build(); - context.send(Constants.VERIFICATION_RESULT, Any.pack(verificationResult)); + context.send(Constants.VERIFICATION_RESULT, packProtobufMessage(verificationResult)); } private void sendEgress( @SuppressWarnings("unused") PersistedValue<Long> state, Context context, @SuppressWarnings("unused") Command.SendEgress sendEgress) { - context.send(Constants.OUT, Any.getDefaultInstance()); + context.send(Constants.OUT, TypedValue.getDefaultInstance()); } private void sendAfter( @@ -112,14 +114,14 @@ public final class CommandInterpreter { Command.SendAfter send) { FunctionType functionType = Constants.FN_TYPE; String id = ids.idOf(send.getTarget()); - context.sendAfter(sendAfterDelay, functionType, id, Any.pack(send.getCommands())); + context.sendAfter(sendAfterDelay, functionType, id, packProtobufMessage(send.getCommands())); } private void send( @SuppressWarnings("unused") PersistedValue<Long> state, Context context, Command.Send send) { FunctionType functionType = Constants.FN_TYPE; String id = ids.idOf(send.getTarget()); - context.send(functionType, id, Any.pack(send.getCommands())); + context.send(functionType, id, packProtobufMessage(send.getCommands())); } private void registerAsyncOps( diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandRouter.java b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandRouter.java index e08ae8d..00af145 100644 --- a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandRouter.java +++ b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/CommandRouter.java @@ -17,13 +17,14 @@ */ package org.apache.flink.statefun.e2e.smoke; -import com.google.protobuf.Any; import java.util.Objects; import org.apache.flink.statefun.e2e.smoke.generated.SourceCommand; +import org.apache.flink.statefun.flink.common.types.TypedValueUtil; import org.apache.flink.statefun.sdk.FunctionType; import org.apache.flink.statefun.sdk.io.Router; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; -public class CommandRouter implements Router<Any> { +public class CommandRouter implements Router<TypedValue> { private final Ids ids; public CommandRouter(Ids ids) { @@ -31,10 +32,11 @@ public class CommandRouter implements Router<Any> { } @Override - public void route(Any any, Downstream<Any> downstream) { - SourceCommand sourceCommand = ProtobufUtils.unpack(any, SourceCommand.class); + public void route(TypedValue command, Downstream<TypedValue> downstream) { + SourceCommand sourceCommand = + TypedValueUtil.unpackProtobufMessage(command, SourceCommand.parser()); FunctionType type = Constants.FN_TYPE; String id = ids.idOf(sourceCommand.getTarget()); - downstream.forward(type, id, any); + downstream.forward(type, id, command); } } diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/Constants.java b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/Constants.java index f5cf262..8f1c222 100644 --- a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/Constants.java +++ b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/Constants.java @@ -17,19 +17,21 @@ */ package org.apache.flink.statefun.e2e.smoke; -import com.google.protobuf.Any; import org.apache.flink.statefun.sdk.FunctionType; import org.apache.flink.statefun.sdk.io.EgressIdentifier; import org.apache.flink.statefun.sdk.io.IngressIdentifier; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; public class Constants { - public static final IngressIdentifier<Any> IN = new IngressIdentifier<>(Any.class, "", "source"); + public static final IngressIdentifier<TypedValue> IN = + new IngressIdentifier<>(TypedValue.class, "", "source"); - public static final EgressIdentifier<Any> OUT = new EgressIdentifier<>("", "sink", Any.class); + public static final EgressIdentifier<TypedValue> OUT = + new EgressIdentifier<>("", "sink", TypedValue.class); public static final FunctionType FN_TYPE = new FunctionType("v", "f1"); - public static final EgressIdentifier<Any> VERIFICATION_RESULT = - new EgressIdentifier<>("", "verification", Any.class); + public static final EgressIdentifier<TypedValue> VERIFICATION_RESULT = + new EgressIdentifier<>("", "verification", TypedValue.class); } diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/Module.java b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/Module.java index 21db25b..2673ac5 100644 --- a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/Module.java +++ b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/Module.java @@ -20,13 +20,13 @@ package org.apache.flink.statefun.e2e.smoke; import static org.apache.flink.statefun.e2e.smoke.Constants.IN; import com.google.auto.service.AutoService; -import com.google.protobuf.Any; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.Map; import org.apache.flink.api.common.serialization.SerializationSchema; import org.apache.flink.statefun.flink.io.datastream.SinkFunctionSpec; import org.apache.flink.statefun.flink.io.datastream.SourceFunctionSpec; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; import org.apache.flink.statefun.sdk.spi.StatefulFunctionModule; import org.apache.flink.streaming.api.functions.sink.DiscardingSink; import org.apache.flink.streaming.api.functions.sink.SocketClientSink; @@ -51,7 +51,7 @@ public class Module implements StatefulFunctionModule { FunctionProvider provider = new FunctionProvider(ids); binder.bindFunctionProvider(Constants.FN_TYPE, provider); - SocketClientSink<Any> client = + SocketClientSink<TypedValue> client = new SocketClientSink<>( moduleParameters.getVerificationServerHost(), moduleParameters.getVerificationServerPort(), @@ -62,10 +62,11 @@ public class Module implements StatefulFunctionModule { binder.bindEgress(new SinkFunctionSpec<>(Constants.VERIFICATION_RESULT, client)); } - private static final class VerificationResultSerializer implements SerializationSchema<Any> { + private static final class VerificationResultSerializer + implements SerializationSchema<TypedValue> { @Override - public byte[] serialize(Any element) { + public byte[] serialize(TypedValue element) { try { ByteArrayOutputStream out = new ByteArrayOutputStream(element.getSerializedSize() + 8); element.writeDelimitedTo(out); diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/ProtobufUtils.java b/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/ProtobufUtils.java deleted file mode 100644 index 25aec2a..0000000 --- a/statefun-e2e-tests/statefun-smoke-e2e/src/main/java/org/apache/flink/statefun/e2e/smoke/ProtobufUtils.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.statefun.e2e.smoke; - -import com.google.protobuf.Any; -import com.google.protobuf.InvalidProtocolBufferException; -import com.google.protobuf.Message; - -final class ProtobufUtils { - private ProtobufUtils() {} - - public static <T extends Message> T unpack(Any any, Class<T> messageType) { - try { - return any.unpack(messageType); - } catch (InvalidProtocolBufferException e) { - throw new IllegalStateException(e); - } - } -} diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java index 1010666..226f418 100644 --- a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java +++ b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java @@ -21,10 +21,10 @@ import static org.apache.flink.statefun.e2e.smoke.Utils.aStateModificationComman import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; -import com.google.protobuf.Any; import java.time.Duration; import java.util.concurrent.CompletableFuture; import org.apache.flink.statefun.e2e.smoke.generated.SourceCommand; +import org.apache.flink.statefun.flink.common.types.TypedValueUtil; import org.apache.flink.statefun.sdk.Address; import org.apache.flink.statefun.sdk.Context; import org.apache.flink.statefun.sdk.io.EgressIdentifier; @@ -41,7 +41,7 @@ public class CommandInterpreterTest { Context context = new MockContext(); SourceCommand sourceCommand = aStateModificationCommand(); - interpreter.interpret(state, context, Any.pack(sourceCommand)); + interpreter.interpret(state, context, TypedValueUtil.packProtobufMessage(sourceCommand)); assertThat(state.get(), is(1L)); } diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/HarnessTest.java b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/HarnessTest.java index 88864f8..382eefe 100644 --- a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/HarnessTest.java +++ b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/HarnessTest.java @@ -21,8 +21,8 @@ package org.apache.flink.statefun.e2e.smoke; import static org.apache.flink.statefun.e2e.smoke.Utils.awaitVerificationSuccess; import static org.apache.flink.statefun.e2e.smoke.Utils.startProtobufServer; -import com.google.protobuf.Any; import org.apache.flink.statefun.flink.harness.Harness; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; import org.junit.Ignore; import org.junit.Test; import org.slf4j.Logger; @@ -51,7 +51,7 @@ public class HarnessTest { harness.withConfiguration("state.checkpoints.dir", "file:///tmp/checkpoints"); // start the Protobuf server - SimpleProtobufServer.StartedServer<Any> started = startProtobufServer(); + SimpleProtobufServer.StartedServer<TypedValue> started = startProtobufServer(); // configure test parameters. ModuleParameters parameters = new ModuleParameters(); diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/SmokeRunner.java b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/SmokeRunner.java index 9f2065e..55c857c 100644 --- a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/SmokeRunner.java +++ b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/SmokeRunner.java @@ -21,8 +21,8 @@ package org.apache.flink.statefun.e2e.smoke; import static org.apache.flink.statefun.e2e.smoke.Utils.awaitVerificationSuccess; import static org.apache.flink.statefun.e2e.smoke.Utils.startProtobufServer; -import com.google.protobuf.Any; import org.apache.flink.statefun.e2e.common.StatefulFunctionsAppContainers; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; import org.apache.flink.util.function.ThrowingRunnable; import org.junit.runner.Description; import org.junit.runners.model.Statement; @@ -34,7 +34,7 @@ public final class SmokeRunner { private static final Logger LOG = LoggerFactory.getLogger(SmokeRunner.class); public static void run(ModuleParameters parameters) throws Throwable { - SimpleProtobufServer.StartedServer<Any> server = startProtobufServer(); + SimpleProtobufServer.StartedServer<TypedValue> server = startProtobufServer(); parameters.setVerificationServerHost("host.testcontainers.internal"); parameters.setVerificationServerPort(server.port()); diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/Utils.java b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/Utils.java index 85f527d..ffbd57c 100644 --- a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/Utils.java +++ b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/Utils.java @@ -17,7 +17,6 @@ */ package org.apache.flink.statefun.e2e.smoke; -import com.google.protobuf.Any; import java.util.HashSet; import java.util.Set; import java.util.function.Supplier; @@ -25,6 +24,8 @@ import org.apache.flink.statefun.e2e.smoke.generated.Command; import org.apache.flink.statefun.e2e.smoke.generated.Commands; import org.apache.flink.statefun.e2e.smoke.generated.SourceCommand; import org.apache.flink.statefun.e2e.smoke.generated.VerificationResult; +import org.apache.flink.statefun.flink.common.types.TypedValueUtil; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; class Utils { @@ -60,11 +61,13 @@ class Utils { } /** Blocks the currently executing thread until enough successful verification results supply. */ - static void awaitVerificationSuccess(Supplier<Any> results, final int numberOfFunctionInstances) { + static void awaitVerificationSuccess( + Supplier<TypedValue> results, final int numberOfFunctionInstances) { Set<Integer> successfullyVerified = new HashSet<>(); while (successfullyVerified.size() != numberOfFunctionInstances) { - Any any = results.get(); - VerificationResult result = ProtobufUtils.unpack(any, VerificationResult.class); + TypedValue typedValue = results.get(); + VerificationResult result = + TypedValueUtil.unpackProtobufMessage(typedValue, VerificationResult.parser()); if (result.getActual() == result.getExpected()) { successfullyVerified.add(result.getId()); } else if (result.getActual() > result.getExpected()) { @@ -80,8 +83,8 @@ class Utils { } /** starts a simple Protobuf TCP server that accepts {@link com.google.protobuf.Any}. */ - static SimpleProtobufServer.StartedServer<Any> startProtobufServer() { - SimpleProtobufServer<Any> server = new SimpleProtobufServer<>(Any.parser()); + static SimpleProtobufServer.StartedServer<TypedValue> startProtobufServer() { + SimpleProtobufServer<TypedValue> server = new SimpleProtobufServer<>(TypedValue.parser()); return server.start(); } } diff --git a/statefun-examples/statefun-python-walkthrough-example/run-example.py b/statefun-examples/statefun-python-walkthrough-example/run-example.py index 3795e8f..8cee3b5 100644 --- a/statefun-examples/statefun-python-walkthrough-example/run-example.py +++ b/statefun-examples/statefun-python-walkthrough-example/run-example.py @@ -22,7 +22,7 @@ import requests from google.protobuf.json_format import MessageToDict from google.protobuf.any_pb2 import Any -from statefun.request_reply_pb2 import ToFunction, FromFunction +from statefun.request_reply_pb2 import ToFunction, FromFunction, TypedValue from walkthrough_pb2 import Hello, AnotherHello, Counter @@ -41,9 +41,7 @@ class InvocationBuilder(object): state = self.to_function.invocation.state.add() state.state_name = name if value: - any = Any() - any.Pack(value) - state.state_value = any.SerializeToString() + state.state_value.CopyFrom(self.to_typed_value_any_state(value)) return self def with_invocation(self, arg, caller=None): @@ -51,13 +49,31 @@ class InvocationBuilder(object): if caller: (ns, type, id) = caller InvocationBuilder.set_address(ns, type, id, invocation.caller) - invocation.argument.Pack(arg) + invocation.argument.CopyFrom(self.to_typed_value(arg)) return self def SerializeToString(self): return self.to_function.SerializeToString() @staticmethod + def to_typed_value(proto_msg): + any = Any() + any.Pack(proto_msg) + typed_value = TypedValue() + typed_value.typename = any.type_url + typed_value.value = any.value + return typed_value + + @staticmethod + def to_typed_value_any_state(proto_msg): + any = Any() + any.Pack(proto_msg) + typed_value = TypedValue() + typed_value.typename = "type.googleapis.com/google.protobuf.Any" + typed_value.value = any.SerializeToString() + return typed_value + + @staticmethod def set_address(namespace, type, id, address): address.namespace = namespace address.type = type diff --git a/statefun-flink/statefun-flink-common/pom.xml b/statefun-flink/statefun-flink-common/pom.xml index f4ef3f5..8063972 100644 --- a/statefun-flink/statefun-flink-common/pom.xml +++ b/statefun-flink/statefun-flink-common/pom.xml @@ -29,6 +29,10 @@ under the License. <artifactId>statefun-flink-common</artifactId> + <properties> + <additional-sources.dir>target/additional-sources</additional-sources.dir> + </properties> + <dependencies> <!-- flink runtime --> <dependency> @@ -84,10 +88,63 @@ under the License. <build> <plugins> + <!-- + The following plugin is executed in the generated-sources phase, + and is responsible to extract the additional *.proto files located + at statefun-sdk-protos.jar. + --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-dependency-plugin</artifactId> + <executions> + <execution> + <id>unpack</id> + <phase>generate-sources</phase> + <goals> + <goal>unpack</goal> + </goals> + <configuration> + <artifactItems> + <artifactItem> + <groupId>org.apache.flink</groupId> + <artifactId>statefun-sdk-protos</artifactId> + <version>${project.version}</version> + <type>jar</type> + <outputDirectory>${additional-sources.dir}</outputDirectory> + <includes>sdk/*.proto</includes> + </artifactItem> + </artifactItems> + </configuration> + </execution> + </executions> + </plugin> + <!-- + The following plugin invokes protoc to generate Java classes out of the *.proto + definitions located at: (1) src/main/protobuf (2) ${additional-sources.dir}. + --> <plugin> <groupId>com.github.os72</groupId> <artifactId>protoc-jar-maven-plugin</artifactId> <version>${protoc-jar-maven-plugin.version}</version> + <executions> + <execution> + <id>generate-protobuf-sources</id> + <phase>generate-sources</phase> + <goals> + <goal>run</goal> + </goals> + <configuration> + <includeStdTypes>true</includeStdTypes> + <protocVersion>${protobuf.version}</protocVersion> + <cleanOutputFolder>true</cleanOutputFolder> + <inputDirectories> + <inputDirectory>src/main/protobuf</inputDirectory> + <inputDirectory>${additional-sources.dir}</inputDirectory> + </inputDirectories> + <outputDirectory>${basedir}/target/generated-sources/protoc-jar</outputDirectory> + </configuration> + </execution> + </executions> </plugin> </plugins> </build> diff --git a/statefun-flink/statefun-flink-common/src/main/java/org/apache/flink/statefun/flink/common/types/TypedValueUtil.java b/statefun-flink/statefun-flink-common/src/main/java/org/apache/flink/statefun/flink/common/types/TypedValueUtil.java new file mode 100644 index 0000000..38f9808 --- /dev/null +++ b/statefun-flink/statefun-flink-common/src/main/java/org/apache/flink/statefun/flink/common/types/TypedValueUtil.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.statefun.flink.common.types; + +import com.google.protobuf.Descriptors; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; +import com.google.protobuf.Parser; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; + +public final class TypedValueUtil { + + private TypedValueUtil() {} + + public static boolean isProtobufTypeOf( + TypedValue typedValue, Descriptors.Descriptor messageDescriptor) { + return typedValue.getTypename().equals(protobufTypeUrl(messageDescriptor)); + } + + public static TypedValue packProtobufMessage(Message protobufMessage) { + return TypedValue.newBuilder() + .setTypename(protobufTypeUrl(protobufMessage.getDescriptorForType())) + .setValue(protobufMessage.toByteString()) + .build(); + } + + public static <PB extends Message> PB unpackProtobufMessage( + TypedValue typedValue, Parser<PB> protobufMessageParser) { + try { + return protobufMessageParser.parseFrom(typedValue.getValue()); + } catch (InvalidProtocolBufferException e) { + throw new IllegalStateException(e); + } + } + + private static String protobufTypeUrl(Descriptors.Descriptor messageDescriptor) { + return "type.googleapis.com/" + messageDescriptor.getFullName(); + } +} diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/jsonmodule/EgressJsonEntity.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/jsonmodule/EgressJsonEntity.java index 813b740..d3040b7 100644 --- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/jsonmodule/EgressJsonEntity.java +++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/jsonmodule/EgressJsonEntity.java @@ -18,7 +18,6 @@ package org.apache.flink.statefun.flink.core.jsonmodule; -import com.google.protobuf.Any; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonPointer; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode; import org.apache.flink.statefun.flink.common.json.NamespaceNamePair; @@ -26,6 +25,7 @@ import org.apache.flink.statefun.flink.common.json.Selectors; import org.apache.flink.statefun.flink.io.spi.JsonEgressSpec; import org.apache.flink.statefun.sdk.EgressType; import org.apache.flink.statefun.sdk.io.EgressIdentifier; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; import org.apache.flink.statefun.sdk.spi.StatefulFunctionModule.Binder; final class EgressJsonEntity implements JsonEntity { @@ -55,9 +55,9 @@ final class EgressJsonEntity implements JsonEntity { return new EgressType(nn.namespace(), nn.name()); } - private static EgressIdentifier<Any> egressId(JsonNode spec) { + private static EgressIdentifier<TypedValue> egressId(JsonNode spec) { String egressId = Selectors.textAt(spec, MetaPointers.ID); NamespaceNamePair nn = NamespaceNamePair.from(egressId); - return new EgressIdentifier<>(nn.namespace(), nn.name(), Any.class); + return new EgressIdentifier<>(nn.namespace(), nn.name(), TypedValue.class); } } diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/protorouter/AutoRoutableProtobufRouter.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/protorouter/AutoRoutableProtobufRouter.java index eb37fe8..d5e0347 100644 --- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/protorouter/AutoRoutableProtobufRouter.java +++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/protorouter/AutoRoutableProtobufRouter.java @@ -18,7 +18,6 @@ package org.apache.flink.statefun.flink.core.protorouter; -import com.google.protobuf.Any; import com.google.protobuf.ByteString; import com.google.protobuf.Message; import org.apache.flink.statefun.flink.io.generated.AutoRoutable; @@ -26,15 +25,21 @@ import org.apache.flink.statefun.flink.io.generated.RoutingConfig; import org.apache.flink.statefun.flink.io.generated.TargetFunctionType; import org.apache.flink.statefun.sdk.FunctionType; import org.apache.flink.statefun.sdk.io.Router; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; /** * A {@link Router} that recognizes messages of type {@link AutoRoutable}. * * <p>For each incoming {@code AutoRoutable}, this router forwards the wrapped payload to the - * configured target addresses as a Protobuf {@link Any} message. + * configured target addresses as a {@link TypedValue} message. */ public final class AutoRoutableProtobufRouter implements Router<Message> { + /** + * Note: while the input and type of this method is both {@link Message}, we actually do a + * conversion here. The input {@link Message} is an {@link AutoRoutable}, which gets converted to + * a {@link TypedValue} as the output after slicing the target address and actual payload. + */ @Override public void route(Message message, Downstream<Message> downstream) { final AutoRoutable routable = asAutoRoutable(message); @@ -43,7 +48,7 @@ public final class AutoRoutableProtobufRouter implements Router<Message> { downstream.forward( sdkFunctionType(targetFunction), routable.getId(), - anyPayload(config.getTypeUrl(), routable.getPayloadBytes())); + typedValuePayload(config.getTypeUrl(), routable.getPayloadBytes())); } } @@ -60,7 +65,7 @@ public final class AutoRoutableProtobufRouter implements Router<Message> { return new FunctionType(targetFunctionType.getNamespace(), targetFunctionType.getType()); } - private static Any anyPayload(String typeUrl, ByteString payloadBytes) { - return Any.newBuilder().setTypeUrl(typeUrl).setValue(payloadBytes).build(); + private static TypedValue typedValuePayload(String typeUrl, ByteString payloadBytes) { + return TypedValue.newBuilder().setTypename(typeUrl).setValue(payloadBytes).build(); } } diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValues.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValues.java index 42cffbe..c47c2ac 100644 --- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValues.java +++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValues.java @@ -31,6 +31,7 @@ import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.PersistedVa import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.PersistedValueSpec; import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction; import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction.InvocationBatchRequest; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; import org.apache.flink.statefun.sdk.state.Expiration; import org.apache.flink.statefun.sdk.state.PersistedStateRegistry; import org.apache.flink.statefun.sdk.state.RemotePersistedValue; @@ -48,9 +49,15 @@ public final class PersistedRemoteFunctionValues { final ToFunction.PersistedValue.Builder valueBuilder = ToFunction.PersistedValue.newBuilder().setStateName(managedStateEntry.getKey()); - final byte[] stateValue = managedStateEntry.getValue().get(); - if (stateValue != null) { - valueBuilder.setStateValue(ByteString.copyFrom(stateValue)); + final RemotePersistedValue registeredHandle = managedStateEntry.getValue(); + final byte[] stateBytes = registeredHandle.get(); + if (stateBytes != null) { + final TypedValue stateValue = + TypedValue.newBuilder() + .setValue(ByteString.copyFrom(stateBytes)) + .setTypename(registeredHandle.type().toString()) + .build(); + valueBuilder.setStateValue(stateValue); } batchBuilder.addState(valueBuilder); } @@ -67,7 +74,11 @@ public final class PersistedRemoteFunctionValues { } case MODIFY: { - getStateHandleOrThrow(stateName).set(mutate.getStateValue().toByteArray()); + final RemotePersistedValue registeredHandle = getStateHandleOrThrow(stateName); + final TypedValue newStateValue = mutate.getStateValue(); + + validateType(registeredHandle, newStateValue.getTypename()); + registeredHandle.set(newStateValue.getValue().toByteArray()); break; } case UNRECOGNIZED: @@ -102,7 +113,7 @@ public final class PersistedRemoteFunctionValues { if (stateHandle == null) { registerValueState(protocolPersistedValueSpec); } else { - validateType(stateHandle, protocolPersistedValueSpec); + validateType(stateHandle, protocolPersistedValueSpec.getTypeTypename()); } } @@ -112,7 +123,7 @@ public final class PersistedRemoteFunctionValues { final RemotePersistedValue remoteValueState = RemotePersistedValue.of( stateName, - sdkStateType(protocolPersistedValueSpec), + sdkStateType(protocolPersistedValueSpec.getTypeTypename()), sdkTtlExpiration(protocolPersistedValueSpec.getExpirationSpec())); managedStates.put(stateName, remoteValueState); @@ -125,23 +136,21 @@ public final class PersistedRemoteFunctionValues { } private void validateType( - RemotePersistedValue previousStateHandle, PersistedValueSpec protocolPersistedValueSpec) { - final TypeName newStateType = sdkStateType(protocolPersistedValueSpec); + RemotePersistedValue previousStateHandle, String protocolTypenameString) { + final TypeName newStateType = sdkStateType(protocolTypenameString); if (!newStateType.equals(previousStateHandle.type())) { throw new RemoteFunctionStateException( - protocolPersistedValueSpec.getStateName(), + previousStateHandle.name(), new RemoteValueTypeMismatchException(previousStateHandle.type(), newStateType)); } } - private static TypeName sdkStateType(PersistedValueSpec protocolPersistedValueSpec) { - final String typeStringPair = protocolPersistedValueSpec.getTypeTypename(); - + private static TypeName sdkStateType(String protocolTypenameString) { // TODO type field may be empty in current master only because SDKs are not yet updated; // TODO once SDKs are updated, we should expect that the type is always specified - return protocolPersistedValueSpec.getTypeTypename().isEmpty() + return protocolTypenameString.isEmpty() ? UNSET_STATE_TYPE - : TypeName.parseFrom(typeStringPair); + : TypeName.parseFrom(protocolTypenameString); } private static Expiration sdkTtlExpiration(ExpirationSpec protocolExpirationSpec) { diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java index 51db78c..a577bb0 100644 --- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java +++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java @@ -21,7 +21,6 @@ package org.apache.flink.statefun.flink.core.reqreply; import static org.apache.flink.statefun.flink.core.common.PolyglotUtil.polyglotAddressToSdkAddress; import static org.apache.flink.statefun.flink.core.common.PolyglotUtil.sdkAddressToPolyglotAddress; -import com.google.protobuf.Any; import java.time.Duration; import java.util.Objects; import java.util.concurrent.CompletableFuture; @@ -41,6 +40,7 @@ import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.InvocationR import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction; import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction.Invocation; import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction.InvocationBatchRequest; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; import org.apache.flink.statefun.sdk.state.PersistedAppendingBuffer; import org.apache.flink.statefun.sdk.state.PersistedValue; import org.apache.flink.types.Either; @@ -87,7 +87,7 @@ public final class RequestReplyFunction implements StatefulFunction { public void invoke(Context context, Object input) { InternalContext castedContext = (InternalContext) context; if (!(input instanceof AsyncOperationResult)) { - onRequest(castedContext, (Any) input); + onRequest(castedContext, (TypedValue) input); return; } @SuppressWarnings("unchecked") @@ -96,7 +96,7 @@ public final class RequestReplyFunction implements StatefulFunction { onAsyncResult(castedContext, result); } - private void onRequest(InternalContext context, Any message) { + private void onRequest(InternalContext context, TypedValue message) { Invocation.Builder invocationBuilder = singeInvocationBuilder(context, message); int inflightOrBatched = requestState.getOrDefault(-1); if (inflightOrBatched < 0) { @@ -208,9 +208,9 @@ public final class RequestReplyFunction implements StatefulFunction { private void handleEgressMessages(Context context, InvocationResponse invocationResult) { for (EgressMessage egressMessage : invocationResult.getOutgoingEgressesList()) { - EgressIdentifier<Any> id = + EgressIdentifier<TypedValue> id = new EgressIdentifier<>( - egressMessage.getEgressNamespace(), egressMessage.getEgressType(), Any.class); + egressMessage.getEgressNamespace(), egressMessage.getEgressType(), TypedValue.class); context.send(id, egressMessage.getArgument()); } } @@ -218,7 +218,7 @@ public final class RequestReplyFunction implements StatefulFunction { private void handleOutgoingMessages(Context context, InvocationResponse invocationResult) { for (FromFunction.Invocation invokeCommand : invocationResult.getOutgoingMessagesList()) { final Address to = polyglotAddressToSdkAddress(invokeCommand.getTarget()); - final Any message = invokeCommand.getArgument(); + final TypedValue message = invokeCommand.getArgument(); context.send(to, message); } @@ -228,7 +228,7 @@ public final class RequestReplyFunction implements StatefulFunction { for (FromFunction.DelayedInvocation delayedInvokeCommand : invocationResult.getDelayedInvocationsList()) { final Address to = polyglotAddressToSdkAddress(delayedInvokeCommand.getTarget()); - final Any message = delayedInvokeCommand.getArgument(); + final TypedValue message = delayedInvokeCommand.getArgument(); final long delay = delayedInvokeCommand.getDelayInMs(); context.sendAfter(Duration.ofMillis(delay), to, message); @@ -242,7 +242,7 @@ public final class RequestReplyFunction implements StatefulFunction { * Returns an {@link Invocation.Builder} set with the input {@code message} and the caller * information (is present). */ - private static Invocation.Builder singeInvocationBuilder(Context context, Any message) { + private static Invocation.Builder singeInvocationBuilder(Context context, TypedValue message) { Invocation.Builder invocationBuilder = Invocation.newBuilder(); if (context.caller() != null) { invocationBuilder.setCaller(sdkAddressToPolyglotAddress(context.caller())); diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/jsonmodule/JsonModuleTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/jsonmodule/JsonModuleTest.java index 92014a9..cc928b0 100644 --- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/jsonmodule/JsonModuleTest.java +++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/jsonmodule/JsonModuleTest.java @@ -24,7 +24,6 @@ import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; import static org.junit.Assert.assertThat; -import com.google.protobuf.Any; import com.google.protobuf.Message; import java.net.URL; import java.util.Collections; @@ -35,6 +34,7 @@ import org.apache.flink.statefun.flink.core.message.MessageFactoryType; import org.apache.flink.statefun.sdk.FunctionType; import org.apache.flink.statefun.sdk.io.EgressIdentifier; import org.apache.flink.statefun.sdk.io.IngressIdentifier; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; import org.apache.flink.statefun.sdk.spi.StatefulFunctionModule; import org.junit.Test; @@ -97,7 +97,8 @@ public class JsonModuleTest { module.configure(Collections.emptyMap(), universe); assertThat( - universe.egress(), hasKey(new EgressIdentifier<>("com.mycomp.foo", "bar", Any.class))); + universe.egress(), + hasKey(new EgressIdentifier<>("com.mycomp.foo", "bar", TypedValue.class))); } private static StatefulFunctionModule fromPath(String path) { diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValuesTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValuesTest.java index b5f2927..81ab98a 100644 --- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValuesTest.java +++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValuesTest.java @@ -31,6 +31,7 @@ import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.PersistedVa import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.PersistedValueSpec; import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction.InvocationBatchRequest; import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction.PersistedValue; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; import org.junit.Test; public class PersistedRemoteFunctionValuesTest { @@ -50,8 +51,11 @@ public class PersistedRemoteFunctionValuesTest { // --- update state values values.updateStateValues( Arrays.asList( - protocolPersistedValueModifyMutation("state-1", ByteString.copyFromUtf8("data-1")), - protocolPersistedValueModifyMutation("state-2", ByteString.copyFromUtf8("data-2")))); + protocolPersistedValueModifyMutation( + "state-1", protocolTypedValue(TEST_STATE_TYPE, ByteString.copyFromUtf8("data-1"))), + protocolPersistedValueModifyMutation( + "state-2", + protocolTypedValue(TEST_STATE_TYPE, ByteString.copyFromUtf8("data-2"))))); final InvocationBatchRequest.Builder builder = InvocationBatchRequest.newBuilder(); values.attachStateValues(builder); @@ -61,8 +65,11 @@ public class PersistedRemoteFunctionValuesTest { assertThat( builder.getStateList(), hasItems( - protocolPersistedValue("state-1", ByteString.copyFromUtf8("data-1")), - protocolPersistedValue("state-2", ByteString.copyFromUtf8("data-2")))); + protocolPersistedValue( + "state-1", protocolTypedValue(TEST_STATE_TYPE, ByteString.copyFromUtf8("data-1"))), + protocolPersistedValue( + "state-2", + protocolTypedValue(TEST_STATE_TYPE, ByteString.copyFromUtf8("data-2"))))); } @Test @@ -82,7 +89,8 @@ public class PersistedRemoteFunctionValuesTest { values.updateStateValues( Collections.singletonList( protocolPersistedValueModifyMutation( - "non-registered-state", ByteString.copyFromUtf8("data")))); + "non-registered-state", + protocolTypedValue(TEST_STATE_TYPE, ByteString.copyFromUtf8("data"))))); } @Test @@ -109,7 +117,8 @@ public class PersistedRemoteFunctionValuesTest { // modify and then delete state value values.updateStateValues( Collections.singletonList( - protocolPersistedValueModifyMutation("state", ByteString.copyFromUtf8("data")))); + protocolPersistedValueModifyMutation( + "state", protocolTypedValue(TEST_STATE_TYPE, ByteString.copyFromUtf8("data"))))); values.updateStateValues( Collections.singletonList(protocolPersistedValueDeleteMutation("state"))); @@ -128,7 +137,8 @@ public class PersistedRemoteFunctionValuesTest { Collections.singletonList(protocolPersistedValueSpec("state", TEST_STATE_TYPE))); values.updateStateValues( Collections.singletonList( - protocolPersistedValueModifyMutation("state", ByteString.copyFromUtf8("data")))); + protocolPersistedValueModifyMutation( + "state", protocolTypedValue(TEST_STATE_TYPE, ByteString.copyFromUtf8("data"))))); // duplicate registration under the same state name values.registerStates( @@ -140,7 +150,9 @@ public class PersistedRemoteFunctionValuesTest { assertThat(builder.getStateList().size(), is(1)); assertThat( builder.getStateList(), - hasItems(protocolPersistedValue("state", ByteString.copyFromUtf8("data")))); + hasItems( + protocolPersistedValue( + "state", protocolTypedValue(TEST_STATE_TYPE, ByteString.copyFromUtf8("data"))))); } @Test(expected = RemoteFunctionStateException.class) @@ -155,6 +167,25 @@ public class PersistedRemoteFunctionValuesTest { protocolPersistedValueSpec("state", TypeName.parseFrom("com.foo.bar/type-2")))); } + @Test(expected = RemoteFunctionStateException.class) + public void mutatingStateValueWithMismatchingType() { + final PersistedRemoteFunctionValues values = new PersistedRemoteFunctionValues(); + + values.registerStates( + Collections.singletonList( + protocolPersistedValueSpec("state", TypeName.parseFrom("com.foo.bar/type-1")))); + values.updateStateValues( + Collections.singletonList( + protocolPersistedValueModifyMutation( + "state", + protocolTypedValue( + TypeName.parseFrom("com.foo.bar/type-2"), ByteString.copyFromUtf8("data"))))); + } + + private static TypedValue protocolTypedValue(TypeName typename, ByteString value) { + return TypedValue.newBuilder().setTypename(typename.toString()).setValue(value).build(); + } + private static PersistedValueSpec protocolPersistedValueSpec(String stateName, TypeName type) { return PersistedValueSpec.newBuilder() .setStateName(stateName) @@ -163,7 +194,7 @@ public class PersistedRemoteFunctionValuesTest { } private static PersistedValueMutation protocolPersistedValueModifyMutation( - String stateName, ByteString modifyValue) { + String stateName, TypedValue modifyValue) { return PersistedValueMutation.newBuilder() .setStateName(stateName) .setMutationType(PersistedValueMutation.MutationType.MODIFY) @@ -178,7 +209,7 @@ public class PersistedRemoteFunctionValuesTest { .build(); } - private static PersistedValue protocolPersistedValue(String stateName, ByteString stateValue) { + private static PersistedValue protocolPersistedValue(String stateName, TypedValue stateValue) { final PersistedValue.Builder builder = PersistedValue.newBuilder(); builder.setStateName(stateName); diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java index 9b5d9c9..5b3a053 100644 --- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java +++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java @@ -26,7 +26,6 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; -import com.google.protobuf.Any; import com.google.protobuf.ByteString; import java.time.Duration; import java.util.AbstractMap.SimpleImmutableEntry; @@ -38,7 +37,6 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.function.Supplier; import java.util.stream.Collectors; -import org.apache.flink.statefun.flink.core.TestUtils; import org.apache.flink.statefun.flink.core.backpressure.InternalContext; import org.apache.flink.statefun.flink.core.metrics.FunctionTypeMetrics; import org.apache.flink.statefun.flink.core.metrics.RemoteInvocationMetrics; @@ -58,6 +56,7 @@ import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.PersistedVa import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.PersistedValueSpec; import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction; import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction.Invocation; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; import org.junit.Test; public class RequestReplyFunctionTest { @@ -67,11 +66,12 @@ public class RequestReplyFunctionTest { private final FakeContext context = new FakeContext(); private final RequestReplyFunction functionUnderTest = - new RequestReplyFunction(testInitialRegisteredState("session"), 10, client); + new RequestReplyFunction( + testInitialRegisteredState("session", "com.foo.bar/myType"), 10, client); @Test public void example() { - functionUnderTest.invoke(context, Any.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); assertTrue(client.wasSentToFunction.hasInvocation()); assertThat(client.capturedInvocationBatchSize(), is(1)); @@ -80,7 +80,7 @@ public class RequestReplyFunctionTest { @Test public void callerIsSet() { context.caller = FUNCTION_1_ADDR; - functionUnderTest.invoke(context, Any.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); Invocation anInvocation = client.capturedInvocation(0); Address caller = polyglotAddressToSdkAddress(anInvocation.getCaller()); @@ -90,20 +90,24 @@ public class RequestReplyFunctionTest { @Test public void messageIsSet() { - Any any = Any.pack(TestUtils.DUMMY_PAYLOAD); + TypedValue argument = + TypedValue.newBuilder() + .setTypename("io.statefun.foo/bar") + .setValue(ByteString.copyFromUtf8("Hello!")) + .build(); - functionUnderTest.invoke(context, any); + functionUnderTest.invoke(context, argument); - assertThat(client.capturedInvocation(0).getArgument(), is(any)); + assertThat(client.capturedInvocation(0).getArgument(), is(argument)); } @Test public void batchIsAccumulatedWhileARequestIsInFlight() { // send one message - functionUnderTest.invoke(context, Any.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); // the following invocations should be queued and sent as a batch - functionUnderTest.invoke(context, Any.getDefaultInstance()); - functionUnderTest.invoke(context, Any.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); // simulate a successful completion of the first operation functionUnderTest.invoke(context, successfulAsyncOperation()); @@ -116,13 +120,13 @@ public class RequestReplyFunctionTest { RequestReplyFunction functionUnderTest = new RequestReplyFunction(2, client); // send one message - functionUnderTest.invoke(context, Any.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); // the following invocations should be queued - functionUnderTest.invoke(context, Any.getDefaultInstance()); - functionUnderTest.invoke(context, Any.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); // the following invocations should request backpressure - functionUnderTest.invoke(context, Any.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); assertThat(context.needsWaiting, is(true)); } @@ -132,24 +136,24 @@ public class RequestReplyFunctionTest { RequestReplyFunction functionUnderTest = new RequestReplyFunction(2, client); // the following invocations should cause backpressure - functionUnderTest.invoke(context, Any.getDefaultInstance()); - functionUnderTest.invoke(context, Any.getDefaultInstance()); - functionUnderTest.invoke(context, Any.getDefaultInstance()); - functionUnderTest.invoke(context, Any.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); // complete one message, should send a batch of size 3 context.needsWaiting = false; functionUnderTest.invoke(context, successfulAsyncOperation()); // the next message should not cause backpressure. - functionUnderTest.invoke(context, Any.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); assertThat(context.needsWaiting, is(false)); } @Test public void stateIsModified() { - functionUnderTest.invoke(context, Any.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); // A message returned from the function // that asks to put "hello" into the session state. @@ -159,20 +163,23 @@ public class RequestReplyFunctionTest { InvocationResponse.newBuilder() .addStateMutations( PersistedValueMutation.newBuilder() - .setStateValue(ByteString.copyFromUtf8("hello")) + .setStateValue( + TypedValue.newBuilder() + .setTypename("com.foo.bar/myType") + .setValue(ByteString.copyFromUtf8("hello"))) .setMutationType(MutationType.MODIFY) .setStateName("session"))) .build(); functionUnderTest.invoke(context, successfulAsyncOperation(response)); - functionUnderTest.invoke(context, Any.getDefaultInstance()); - assertThat(client.capturedState(0), is(ByteString.copyFromUtf8("hello"))); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); + assertThat(client.capturedState(0).getValue(), is(ByteString.copyFromUtf8("hello"))); } @Test public void delayedMessages() { - functionUnderTest.invoke(context, Any.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); FromFunction response = FromFunction.newBuilder() @@ -180,7 +187,7 @@ public class RequestReplyFunctionTest { InvocationResponse.newBuilder() .addDelayedInvocations( DelayedInvocation.newBuilder() - .setArgument(Any.getDefaultInstance()) + .setArgument(TypedValue.getDefaultInstance()) .setDelayInMs(1) .build())) .build(); @@ -193,7 +200,7 @@ public class RequestReplyFunctionTest { @Test public void egressIsSent() { - functionUnderTest.invoke(context, Any.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); FromFunction response = FromFunction.newBuilder() @@ -201,7 +208,7 @@ public class RequestReplyFunctionTest { InvocationResponse.newBuilder() .addOutgoingEgresses( EgressMessage.newBuilder() - .setArgument(Any.getDefaultInstance()) + .setArgument(TypedValue.getDefaultInstance()) .setEgressNamespace("org.foo") .setEgressType("bar"))) .build(); @@ -210,13 +217,18 @@ public class RequestReplyFunctionTest { assertFalse(context.egresses.isEmpty()); assertEquals( - new EgressIdentifier<>("org.foo", "bar", Any.class), context.egresses.get(0).getKey()); + new EgressIdentifier<>("org.foo", "bar", TypedValue.class), + context.egresses.get(0).getKey()); } @Test public void retryBatchOnIncompleteInvocationContextResponse() { - Any any = Any.pack(TestUtils.DUMMY_PAYLOAD); - functionUnderTest.invoke(context, any); + TypedValue argument = + TypedValue.newBuilder() + .setTypename("io.statefun.foo/bar") + .setValue(ByteString.copyFromUtf8("Hello!")) + .build(); + functionUnderTest.invoke(context, argument); FromFunction response = FromFunction.newBuilder() @@ -237,7 +249,7 @@ public class RequestReplyFunctionTest { // re-sent batch should have identical invocation input messages assertTrue(client.wasSentToFunction.hasInvocation()); assertThat(client.capturedInvocationBatchSize(), is(1)); - assertThat(client.capturedInvocation(0).getArgument(), is(any)); + assertThat(client.capturedInvocation(0).getArgument(), is(argument)); // re-sent batch should have new state as well as originally registered state assertThat(client.capturedStateNames().size(), is(2)); @@ -246,22 +258,22 @@ public class RequestReplyFunctionTest { @Test public void backlogMetricsIncreasedOnInvoke() { - functionUnderTest.invoke(context, Any.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); // following should be accounted into backlog metrics - functionUnderTest.invoke(context, Any.getDefaultInstance()); - functionUnderTest.invoke(context, Any.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); assertThat(context.functionTypeMetrics().numBacklog, is(2)); } @Test public void backlogMetricsDecreasedOnNextSuccess() { - functionUnderTest.invoke(context, Any.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); // following should be accounted into backlog metrics - functionUnderTest.invoke(context, Any.getDefaultInstance()); - functionUnderTest.invoke(context, Any.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); + functionUnderTest.invoke(context, TypedValue.getDefaultInstance()); // complete one message, should fully consume backlog context.needsWaiting = false; @@ -271,11 +283,14 @@ public class RequestReplyFunctionTest { } private static PersistedRemoteFunctionValues testInitialRegisteredState( - String existingStateName) { + String existingStateName, String typename) { final PersistedRemoteFunctionValues states = new PersistedRemoteFunctionValues(); states.registerStates( Collections.singletonList( - PersistedValueSpec.newBuilder().setStateName(existingStateName).build())); + PersistedValueSpec.newBuilder() + .setTypeTypename(typename) + .setStateName(existingStateName) + .build())); return states; } @@ -318,7 +333,7 @@ public class RequestReplyFunctionTest { return wasSentToFunction.getInvocation().getInvocations(n); } - ByteString capturedState(int n) { + TypedValue capturedState(int n) { return wasSentToFunction.getInvocation().getState(n).getStateValue(); } diff --git a/statefun-flink/statefun-flink-io-bundle/pom.xml b/statefun-flink/statefun-flink-io-bundle/pom.xml index 51acb36..251955a 100644 --- a/statefun-flink/statefun-flink-io-bundle/pom.xml +++ b/statefun-flink/statefun-flink-io-bundle/pom.xml @@ -29,6 +29,10 @@ under the License. <artifactId>statefun-flink-io-bundle</artifactId> + <properties> + <additional-sources.dir>target/additional-sources</additional-sources.dir> + </properties> + <dependencies> <!-- Stateful Functions sdk --> <dependency> @@ -37,6 +41,12 @@ under the License. <version>${project.version}</version> </dependency> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>statefun-sdk-protos</artifactId> + <version>${project.version}</version> + </dependency> + <!-- statefun-flink spi --> <dependency> <groupId>org.apache.flink</groupId> diff --git a/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaEgressSerializer.java b/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaEgressSerializer.java index fb8a484..c232ba3 100644 --- a/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaEgressSerializer.java +++ b/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaEgressSerializer.java @@ -17,11 +17,12 @@ */ package org.apache.flink.statefun.flink.io.kafka; -import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; import java.nio.charset.StandardCharsets; +import org.apache.flink.statefun.flink.common.types.TypedValueUtil; import org.apache.flink.statefun.sdk.egress.generated.KafkaProducerRecord; import org.apache.flink.statefun.sdk.kafka.KafkaEgressSerializer; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; import org.apache.kafka.clients.producer.ProducerRecord; /** @@ -31,24 +32,24 @@ import org.apache.kafka.clients.producer.ProducerRecord; * <p>This serializer expects Protobuf messages of type {@link KafkaProducerRecord}, and simply * transforms those into Kafka's {@link ProducerRecord}. */ -public final class GenericKafkaEgressSerializer implements KafkaEgressSerializer<Any> { +public final class GenericKafkaEgressSerializer implements KafkaEgressSerializer<TypedValue> { private static final long serialVersionUID = 1L; @Override - public ProducerRecord<byte[], byte[]> serialize(Any any) { - KafkaProducerRecord protobufProducerRecord = asKafkaProducerRecord(any); + public ProducerRecord<byte[], byte[]> serialize(TypedValue message) { + KafkaProducerRecord protobufProducerRecord = asKafkaProducerRecord(message); return toProducerRecord(protobufProducerRecord); } - private static KafkaProducerRecord asKafkaProducerRecord(Any message) { - if (!message.is(KafkaProducerRecord.class)) { + private static KafkaProducerRecord asKafkaProducerRecord(TypedValue message) { + if (!TypedValueUtil.isProtobufTypeOf(message, KafkaProducerRecord.getDescriptor())) { throw new IllegalStateException( "The generic Kafka egress expects only messages of type " + KafkaProducerRecord.class.getName()); } try { - return message.unpack(KafkaProducerRecord.class); + return KafkaProducerRecord.parseFrom(message.getValue()); } catch (InvalidProtocolBufferException e) { throw new RuntimeException( "Unable to unpack message as a " + KafkaProducerRecord.class.getName(), e); diff --git a/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaSinkProvider.java b/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaSinkProvider.java index 2590b5f..fd87a69 100644 --- a/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaSinkProvider.java +++ b/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaSinkProvider.java @@ -23,7 +23,6 @@ import static org.apache.flink.statefun.flink.io.kafka.KafkaEgressSpecJsonParser import static org.apache.flink.statefun.flink.io.kafka.KafkaEgressSpecJsonParser.kafkaClientProperties; import static org.apache.flink.statefun.flink.io.kafka.KafkaEgressSpecJsonParser.optionalDeliverySemantic; -import com.google.protobuf.Any; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode; import org.apache.flink.statefun.flink.io.spi.JsonEgressSpec; import org.apache.flink.statefun.flink.io.spi.SinkProvider; @@ -31,6 +30,7 @@ import org.apache.flink.statefun.sdk.io.EgressIdentifier; import org.apache.flink.statefun.sdk.io.EgressSpec; import org.apache.flink.statefun.sdk.kafka.KafkaEgressBuilder; import org.apache.flink.statefun.sdk.kafka.KafkaEgressSpec; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; import org.apache.flink.streaming.api.functions.sink.SinkFunction; final class GenericKafkaSinkProvider implements SinkProvider { @@ -84,10 +84,10 @@ final class GenericKafkaSinkProvider implements SinkProvider { private static void validateConsumedType(EgressIdentifier<?> id) { Class<?> consumedType = id.consumedType(); - if (Any.class != consumedType) { + if (TypedValue.class != consumedType) { throw new IllegalArgumentException( "Generic Kafka egress is only able to consume messages types of " - + Any.class.getName() + + TypedValue.class.getName() + " but " + consumedType.getName() + " is provided."); diff --git a/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kinesis/polyglot/GenericKinesisEgressSerializer.java b/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kinesis/polyglot/GenericKinesisEgressSerializer.java index 4b1c522..1459b15 100644 --- a/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kinesis/polyglot/GenericKinesisEgressSerializer.java +++ b/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kinesis/polyglot/GenericKinesisEgressSerializer.java @@ -18,18 +18,19 @@ package org.apache.flink.statefun.flink.io.kinesis.polyglot; -import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; +import org.apache.flink.statefun.flink.common.types.TypedValueUtil; import org.apache.flink.statefun.sdk.egress.generated.KinesisEgressRecord; import org.apache.flink.statefun.sdk.kinesis.egress.EgressRecord; import org.apache.flink.statefun.sdk.kinesis.egress.KinesisEgressSerializer; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; -public final class GenericKinesisEgressSerializer implements KinesisEgressSerializer<Any> { +public final class GenericKinesisEgressSerializer implements KinesisEgressSerializer<TypedValue> { private static final long serialVersionUID = 1L; @Override - public EgressRecord serialize(Any value) { + public EgressRecord serialize(TypedValue value) { final KinesisEgressRecord kinesisEgressRecord = asKinesisEgressRecord(value); final EgressRecord.Builder builder = @@ -46,14 +47,14 @@ public final class GenericKinesisEgressSerializer implements KinesisEgressSerial return builder.build(); } - private static KinesisEgressRecord asKinesisEgressRecord(Any message) { - if (!message.is(KinesisEgressRecord.class)) { + private static KinesisEgressRecord asKinesisEgressRecord(TypedValue message) { + if (!TypedValueUtil.isProtobufTypeOf(message, KinesisEgressRecord.getDescriptor())) { throw new IllegalStateException( "The generic Kinesis egress expects only messages of type " + KinesisEgressRecord.class.getName()); } try { - return message.unpack(KinesisEgressRecord.class); + return KinesisEgressRecord.parseFrom(message.getValue()); } catch (InvalidProtocolBufferException e) { throw new RuntimeException( "Unable to unpack message as a " + KinesisEgressRecord.class.getName(), e); diff --git a/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kinesis/polyglot/GenericKinesisSinkProvider.java b/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kinesis/polyglot/GenericKinesisSinkProvider.java index ad8fc1f..d5f5f29 100644 --- a/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kinesis/polyglot/GenericKinesisSinkProvider.java +++ b/statefun-flink/statefun-flink-io-bundle/src/main/java/org/apache/flink/statefun/flink/io/kinesis/polyglot/GenericKinesisSinkProvider.java @@ -22,7 +22,6 @@ import static org.apache.flink.statefun.flink.io.kinesis.polyglot.AwsAuthSpecJso import static org.apache.flink.statefun.flink.io.kinesis.polyglot.KinesisEgressSpecJsonParser.clientConfigProperties; import static org.apache.flink.statefun.flink.io.kinesis.polyglot.KinesisEgressSpecJsonParser.optionalMaxOutstandingRecords; -import com.google.protobuf.Any; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode; import org.apache.flink.statefun.flink.io.kinesis.KinesisSinkProvider; import org.apache.flink.statefun.flink.io.spi.JsonEgressSpec; @@ -31,6 +30,7 @@ import org.apache.flink.statefun.sdk.io.EgressIdentifier; import org.apache.flink.statefun.sdk.io.EgressSpec; import org.apache.flink.statefun.sdk.kinesis.egress.KinesisEgressBuilder; import org.apache.flink.statefun.sdk.kinesis.egress.KinesisEgressSpec; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; import org.apache.flink.streaming.api.functions.sink.SinkFunction; public final class GenericKinesisSinkProvider implements SinkProvider { @@ -74,10 +74,10 @@ public final class GenericKinesisSinkProvider implements SinkProvider { private static void validateConsumedType(EgressIdentifier<?> id) { Class<?> consumedType = id.consumedType(); - if (Any.class != consumedType) { + if (TypedValue.class != consumedType) { throw new IllegalArgumentException( "Generic Kinesis egress is only able to consume messages types of " - + Any.class.getName() + + TypedValue.class.getName() + " but " + consumedType.getName() + " is provided."); diff --git a/statefun-flink/statefun-flink-io-bundle/src/test/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaSinkProviderTest.java b/statefun-flink/statefun-flink-io-bundle/src/test/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaSinkProviderTest.java index 151574d..d0dcc50 100644 --- a/statefun-flink/statefun-flink-io-bundle/src/test/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaSinkProviderTest.java +++ b/statefun-flink/statefun-flink-io-bundle/src/test/java/org/apache/flink/statefun/flink/io/kafka/GenericKafkaSinkProviderTest.java @@ -21,10 +21,10 @@ import static org.apache.flink.statefun.flink.io.testutils.YamlUtils.loadAsJsonF import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.MatcherAssert.assertThat; -import com.google.protobuf.Any; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode; import org.apache.flink.statefun.flink.io.spi.JsonEgressSpec; import org.apache.flink.statefun.sdk.io.EgressIdentifier; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; import org.apache.flink.streaming.api.functions.sink.SinkFunction; import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer; import org.junit.Test; @@ -38,7 +38,7 @@ public class GenericKafkaSinkProviderTest { JsonEgressSpec<?> spec = new JsonEgressSpec<>( KafkaEgressTypes.GENERIC_KAFKA_EGRESS_TYPE, - new EgressIdentifier<>("foo", "bar", Any.class), + new EgressIdentifier<>("foo", "bar", TypedValue.class), egressDefinition); GenericKafkaSinkProvider provider = new GenericKafkaSinkProvider(); diff --git a/statefun-flink/statefun-flink-io-bundle/src/test/java/org/apache/flink/statefun/flink/io/kinesis/GenericKinesisSinkProviderTest.java b/statefun-flink/statefun-flink-io-bundle/src/test/java/org/apache/flink/statefun/flink/io/kinesis/GenericKinesisSinkProviderTest.java index 2a6b19b..adfc8f6 100644 --- a/statefun-flink/statefun-flink-io-bundle/src/test/java/org/apache/flink/statefun/flink/io/kinesis/GenericKinesisSinkProviderTest.java +++ b/statefun-flink/statefun-flink-io-bundle/src/test/java/org/apache/flink/statefun/flink/io/kinesis/GenericKinesisSinkProviderTest.java @@ -21,11 +21,11 @@ import static org.apache.flink.statefun.flink.io.testutils.YamlUtils.loadAsJsonF import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.MatcherAssert.assertThat; -import com.google.protobuf.Any; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode; import org.apache.flink.statefun.flink.io.kinesis.polyglot.GenericKinesisSinkProvider; import org.apache.flink.statefun.flink.io.spi.JsonEgressSpec; import org.apache.flink.statefun.sdk.io.EgressIdentifier; +import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue; import org.apache.flink.streaming.api.functions.sink.SinkFunction; import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisProducer; import org.junit.Test; @@ -39,7 +39,7 @@ public class GenericKinesisSinkProviderTest { JsonEgressSpec<?> spec = new JsonEgressSpec<>( PolyglotKinesisIOTypes.GENERIC_KINESIS_EGRESS_TYPE, - new EgressIdentifier<>("foo", "bar", Any.class), + new EgressIdentifier<>("foo", "bar", TypedValue.class), egressDefinition); GenericKinesisSinkProvider provider = new GenericKinesisSinkProvider(); diff --git a/statefun-python-sdk/statefun/core.py b/statefun-python-sdk/statefun/core.py index 8499a71..8e342d0 100644 --- a/statefun-python-sdk/statefun/core.py +++ b/statefun-python-sdk/statefun/core.py @@ -46,6 +46,13 @@ class AnyStateHandle(object): self.modified = False self.deleted = False + # + # TODO This should reflect the actual type URL. + # TODO we can support that only after reworking the SDK. + # + def typename(self): + return "type.googleapis.com/google.protobuf.Any" + def bytes(self): if self.deleted: raise AssertionError("can not obtain the bytes of a delete handle") diff --git a/statefun-python-sdk/statefun/request_reply.py b/statefun-python-sdk/statefun/request_reply.py index 6be41d0..f58e6d7 100644 --- a/statefun-python-sdk/statefun/request_reply.py +++ b/statefun-python-sdk/statefun/request_reply.py @@ -21,14 +21,13 @@ from google.protobuf.any_pb2 import Any from statefun.core import SdkAddress from statefun.core import Expiration -from statefun.core import AnyStateHandle from statefun.core import parse_typename from statefun.core import StateRegistrationError # generated function protocol from statefun.request_reply_pb2 import FromFunction from statefun.request_reply_pb2 import ToFunction - +from statefun.typed_value_utils import to_proto_any, from_proto_any, to_proto_any_state, from_proto_any_state class InvocationContext: def __init__(self, functions): @@ -88,7 +87,7 @@ class InvocationContext: @staticmethod def provided_state_values(to_function): - return {s.state_name: AnyStateHandle(s.state_value) for s in to_function.invocation.state} + return {s.state_name: to_proto_any_state(s.state_value) for s in to_function.invocation.state} @staticmethod def add_outgoing_messages(context, invocation_result): @@ -100,7 +99,7 @@ class InvocationContext: outgoing.target.namespace = namespace outgoing.target.type = type outgoing.target.id = id - outgoing.argument.CopyFrom(message) + outgoing.argument.CopyFrom(from_proto_any(message)) @staticmethod def add_mutations(context, invocation_result): @@ -114,7 +113,7 @@ class InvocationContext: mutation.mutation_type = FromFunction.PersistedValueMutation.MutationType.Value('DELETE') else: mutation.mutation_type = FromFunction.PersistedValueMutation.MutationType.Value('MODIFY') - mutation.state_value = handle.bytes() + mutation.state_value.CopyFrom(from_proto_any_state(handle)) @staticmethod def add_delayed_messages(context, invocation_result): @@ -127,7 +126,7 @@ class InvocationContext: outgoing.target.type = type outgoing.target.id = id outgoing.delay_in_ms = delay - outgoing.argument.CopyFrom(message) + outgoing.argument.CopyFrom(from_proto_any(message)) @staticmethod def add_egress(context, invocation_result): @@ -138,7 +137,7 @@ class InvocationContext: namespace, type = parse_typename(typename) outgoing.egress_namespace = namespace outgoing.egress_type = type - outgoing.argument.CopyFrom(message) + outgoing.argument.CopyFrom(from_proto_any(message)) @staticmethod def add_missing_state_specs(missing_state_specs, incomplete_context_response): @@ -147,6 +146,10 @@ class InvocationContext: missing_value = missing_values.add() missing_value.state_name = state_spec.name + # TODO see the comment in typed_value_utils.from_proto_any_state on + # TODO the reason to use this specific typename + missing_value.type_typename = "type.googleapis.com/google.protobuf.Any" + protocol_expiration_spec = FromFunction.ExpirationSpec() sdk_expiration_spec = state_spec.expiration if not sdk_expiration_spec: @@ -181,9 +184,10 @@ class RequestReplyHandler: fun = target_function.func for invocation in batch: context.prepare(invocation) - unpacked = target_function.unpack_any(invocation.argument) + any_arg = to_proto_any(invocation.argument) + unpacked = target_function.unpack_any(any_arg) if not unpacked: - fun(context, invocation.argument) + fun(context, any_arg) else: fun(context, unpacked) @@ -207,9 +211,10 @@ class AsyncRequestReplyHandler: fun = target_function.func for invocation in batch: context.prepare(invocation) - unpacked = target_function.unpack_any(invocation.argument) + any_arg = to_proto_any(invocation.argument) + unpacked = target_function.unpack_any(any_arg) if not unpacked: - await fun(context, invocation.argument) + await fun(context, any_arg) else: await fun(context, unpacked) diff --git a/statefun-python-sdk/statefun/typed_value_utils.py b/statefun-python-sdk/statefun/typed_value_utils.py new file mode 100644 index 0000000..8706800 --- /dev/null +++ b/statefun-python-sdk/statefun/typed_value_utils.py @@ -0,0 +1,49 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from google.protobuf.any_pb2 import Any + +from statefun.core import AnyStateHandle +from statefun.request_reply_pb2 import TypedValue + +# +# Utility methods to covert back and forth from Protobuf Any to our TypedValue. +# TODO this conversion needs to take place only because the Python SDK still works with Protobuf Any's +# TODO this would soon go away by letting the SDK work directly with TypedValues. +# + +def to_proto_any(typed_value: TypedValue): + proto_any = Any() + proto_any.type_url = typed_value.typename + proto_any.value = typed_value.value + return proto_any + +def from_proto_any(proto_any: Any): + typed_value = TypedValue() + typed_value.typename = proto_any.type_url + typed_value.value = proto_any.value + return typed_value + +def from_proto_any_state(any_state_handle: AnyStateHandle): + typed_value = TypedValue() + typed_value.typename = any_state_handle.typename() + typed_value.value = any_state_handle.bytes() + return typed_value + +def to_proto_any_state(typed_value: TypedValue) -> AnyStateHandle: + return AnyStateHandle(typed_value.value) diff --git a/statefun-python-sdk/tests/request_reply_test.py b/statefun-python-sdk/tests/request_reply_test.py index 80691f9..157bba6 100644 --- a/statefun-python-sdk/tests/request_reply_test.py +++ b/statefun-python-sdk/tests/request_reply_test.py @@ -23,7 +23,7 @@ from google.protobuf.json_format import MessageToDict from google.protobuf.any_pb2 import Any from tests.examples_pb2 import LoginEvent, SeenCount -from statefun.request_reply_pb2 import ToFunction, FromFunction +from statefun.request_reply_pb2 import ToFunction, FromFunction, TypedValue from statefun import RequestReplyHandler, AsyncRequestReplyHandler from statefun import StatefulFunctions, StateSpec, AfterWrite, StateRegistrationError from statefun import kafka_egress_record, kinesis_egress_record @@ -43,9 +43,7 @@ class InvocationBuilder(object): state = self.to_function.invocation.state.add() state.state_name = name if value: - any = Any() - any.Pack(value) - state.state_value = any.SerializeToString() + state.state_value.CopyFrom(self.to_typed_value_any_state(value)) return self def with_invocation(self, arg, caller=None): @@ -53,13 +51,31 @@ class InvocationBuilder(object): if caller: (ns, type, id) = caller InvocationBuilder.set_address(ns, type, id, invocation.caller) - invocation.argument.Pack(arg) + invocation.argument.CopyFrom(self.to_typed_value(arg)) return self def SerializeToString(self): return self.to_function.SerializeToString() @staticmethod + def to_typed_value(proto_msg): + any = Any() + any.Pack(proto_msg) + typed_value = TypedValue() + typed_value.typename = any.type_url + typed_value.value = any.value + return typed_value + + @staticmethod + def to_typed_value_any_state(proto_msg): + any = Any() + any.Pack(proto_msg) + typed_value = TypedValue() + typed_value.typename = "type.googleapis.com/google.protobuf.Any" + typed_value.value = any.SerializeToString() + return typed_value + + @staticmethod def set_address(namespace, type, id, address): address.namespace = namespace address.type = type @@ -184,14 +200,14 @@ class RequestReplyTestCase(unittest.TestCase): self.assertEqual(first_out_message['target']['namespace'], 'org.foo') self.assertEqual(first_out_message['target']['type'], 'greeter-java') self.assertEqual(first_out_message['target']['id'], '0') - self.assertEqual(first_out_message['argument']['@type'], 'type.googleapis.com/k8s.demo.SeenCount') + self.assertEqual(first_out_message['argument']['typename'], 'type.googleapis.com/k8s.demo.SeenCount') # assert second outgoing message second_out_message = json_at(result_json, NTH_OUTGOING_MESSAGE(1)) self.assertEqual(second_out_message['target']['namespace'], 'bar.baz') self.assertEqual(second_out_message['target']['type'], 'foo') self.assertEqual(second_out_message['target']['id'], '12345') - self.assertEqual(second_out_message['argument']['@type'], 'type.googleapis.com/k8s.demo.SeenCount') + self.assertEqual(second_out_message['argument']['typename'], 'type.googleapis.com/k8s.demo.SeenCount') # assert state mutations first_mutation = json_at(result_json, NTH_STATE_MUTATION(0)) @@ -207,7 +223,7 @@ class RequestReplyTestCase(unittest.TestCase): first_egress = json_at(result_json, NTH_EGRESS(0)) self.assertEqual(first_egress['egress_namespace'], 'foo.bar.baz') self.assertEqual(first_egress['egress_type'], 'my-egress') - self.assertEqual(first_egress['argument']['@type'], 'type.googleapis.com/k8s.demo.SeenCount') + self.assertEqual(first_egress['argument']['typename'], 'type.googleapis.com/k8s.demo.SeenCount') def test_integration_incomplete_context(self): functions = StatefulFunctions() @@ -309,7 +325,7 @@ class AsyncRequestReplyTestCase(unittest.TestCase): self.assertEqual(second_out_message['target']['namespace'], 'bar.baz') self.assertEqual(second_out_message['target']['type'], 'foo') self.assertEqual(second_out_message['target']['id'], '12345') - self.assertEqual(second_out_message['argument']['@type'], 'type.googleapis.com/k8s.demo.SeenCount') + self.assertEqual(second_out_message['argument']['typename'], 'type.googleapis.com/k8s.demo.SeenCount') def test_integration_incomplete_context(self): functions = StatefulFunctions() diff --git a/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto b/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto index 2ebd8f9..e0895a4 100644 --- a/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto +++ b/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto @@ -23,8 +23,6 @@ package io.statefun.sdk.reqreply; option java_package = "org.apache.flink.statefun.sdk.reqreply.generated"; option java_multiple_files = true; -import "google/protobuf/any.proto"; - // ------------------------------------------------------------------------------------------------------------------- // Common message definitions // ------------------------------------------------------------------------------------------------------------------- @@ -39,6 +37,11 @@ message Address { string id = 3; } +message TypedValue { + string typename = 1; + bytes value = 2; +} + // ------------------------------------------------------------------------------------------------------------------- // Messages sent to a Remote Function // ------------------------------------------------------------------------------------------------------------------- @@ -51,7 +54,7 @@ message ToFunction { // The unique name of the persisted state. string state_name = 1; // The serialized state value - bytes state_value = 2; + TypedValue state_value = 2; } // Invocation represents a remote function call, it associated with an (optional) return address, @@ -60,7 +63,7 @@ message ToFunction { // The address of the function that requested the invocation (possibly absent) Address caller = 1; // The invocation argument (aka the message sent to the target function) - google.protobuf.Any argument = 2; + TypedValue argument = 2; } // InvocationBatchRequest represents a request to invoke a remote function. It is always associated with a target @@ -94,7 +97,7 @@ message FromFunction { } MutationType mutation_type = 1; string state_name = 2; - bytes state_value = 3; + TypedValue state_value = 3; } // Invocation represents a remote function call, it associated with a (mandatory) target address, @@ -103,7 +106,7 @@ message FromFunction { // The target function to invoke Address target = 1; // The invocation argument (aka the message sent to the target function) - google.protobuf.Any argument = 2; + TypedValue argument = 2; } // DelayedInvocation represents a delayed remote function call with a target address, an argument @@ -114,19 +117,19 @@ message FromFunction { // the target address to send this message to Address target = 2; // the invocation argument - google.protobuf.Any argument = 3; + TypedValue argument = 3; } // EgressMessage an argument to forward to an egress. // An egress is identified by a namespace and type (see EgressIdentifier SDK class). - // The argument is a google.protobuf.Any + // The argument is an io.statefun.sdk.reqreply.TypedValue. message EgressMessage { // The target egress namespace string egress_namespace = 1; // The target egress type string egress_type = 2; // egress argument - google.protobuf.Any argument = 3; + TypedValue argument = 3; } // InvocationResponse represents a result of an io.statefun.sdk.reqreply.ToFunction.InvocationBatchRequest
