[ https://issues.apache.org/jira/browse/BEAM-2915?focusedWorklogId=121943&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-121943 ]
ASF GitHub Bot logged work on BEAM-2915: ---------------------------------------- Author: ASF GitHub Bot Created on: 11/Jul/18 16:37 Start Date: 11/Jul/18 16:37 Worklog Time Spent: 10m Work Description: lukecwik closed pull request #5445: [BEAM-2915] Add support for handling bag user state to the java-fn-execution library to support runner integration. URL: https://github.com/apache/beam/pull/5445 This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/.gitignore b/.gitignore index b37254a1f5e..204f22fde87 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,7 @@ **/.gogradle/**/* **/gogradle.lock **/build/**/* -**/vendor/**/* +sdks/go/**/vendor/**/* **/.gradletasknamecache # Ignore files generated by the Maven build process. diff --git a/model/pipeline/src/main/proto/beam_runner_api.proto b/model/pipeline/src/main/proto/beam_runner_api.proto index 16cb4e75dbb..3cec68d5842 100644 --- a/model/pipeline/src/main/proto/beam_runner_api.proto +++ b/model/pipeline/src/main/proto/beam_runner_api.proto @@ -339,7 +339,7 @@ message ExecutableStagePayload { // PTransform the ExecutableStagePayload is the payload of. string input = 2; - // The side inputs required for this executable stage. Each Side Input of each PTransform within + // The side inputs required for this executable stage. Each side input of each PTransform within // this ExecutableStagePayload must be represented within this field. repeated SideInputId side_inputs = 3; @@ -355,6 +355,10 @@ message ExecutableStagePayload { // in transforms, and the closure of all of the components they recognize. Components components = 6; + // The user states required for this executable stage. Each user state of each PTransform within + // this ExecutableStagePayload must be represented within this field. + repeated UserStateId user_states = 7; + // A reference to a side input. Side inputs are uniquely identified by PTransform id and // local name. message SideInputId { @@ -364,6 +368,16 @@ message ExecutableStagePayload { // (Required) The local name of this side input from the PTransform that references it. string local_name = 2; } + + // A reference to user state. User states are uniquely identified by PTransform id and + // local name. + message UserStateId { + // (Required) The id of the PTransform that references this user state. + string transform_id = 1; + + // (Required) The local name of this user state for the PTransform that references it. + string local_name = 2; + } } // The payload for the primitive ParDo transform. diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/ExecutableStage.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/ExecutableStage.java index c08b841fd86..486b3b7a98a 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/ExecutableStage.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/ExecutableStage.java @@ -26,6 +26,7 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.Environment; import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload; import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId; +import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.UserStateId; import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec; import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection; import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform; @@ -80,6 +81,9 @@ */ Collection<SideInputReference> getSideInputs(); + /** Returns the set of {@link PTransformNode PTransforms} that contain user state. */ + Collection<UserStateReference> getUserStates(); + /** * Returns the leaf {@link PCollectionNode PCollections} of this {@link ExecutableStage}. * @@ -132,8 +136,14 @@ default PTransform toPTransform(String uniqueName) { payload.addSideInputs( SideInputId.newBuilder() .setTransformId(sideInput.transform().getId()) - .setLocalName(sideInput.localName()) - .build()); + .setLocalName(sideInput.localName())); + } + + for (UserStateReference userState : getUserStates()) { + payload.addUserStates( + UserStateId.newBuilder() + .setTransformId(userState.transform().getId()) + .setLocalName(userState.localName())); } int outputIndex = 0; @@ -179,6 +189,7 @@ default PTransform toPTransform(String uniqueName) { static ExecutableStage fromPayload(ExecutableStagePayload payload) { Components components = payload.getComponents(); Environment environment = payload.getEnvironment(); + PCollectionNode input = PipelineNode.pCollection( payload.getInput(), components.getPcollectionsOrThrow(payload.getInput())); @@ -188,6 +199,12 @@ static ExecutableStage fromPayload(ExecutableStagePayload payload) { .stream() .map(sideInputId -> SideInputReference.fromSideInputId(sideInputId, components)) .collect(Collectors.toList()); + List<UserStateReference> userStates = + payload + .getUserStatesList() + .stream() + .map(userStateId -> UserStateReference.fromUserStateId(userStateId, components)) + .collect(Collectors.toList()); List<PTransformNode> transforms = payload .getTransformsList() @@ -201,6 +218,6 @@ static ExecutableStage fromPayload(ExecutableStagePayload payload) { .map(id -> PipelineNode.pCollection(id, components.getPcollectionsOrThrow(id))) .collect(Collectors.toList()); return ImmutableExecutableStage.of( - components, environment, input, sideInputs, transforms, outputs); + components, environment, input, sideInputs, userStates, transforms, outputs); } } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPCollectionFusers.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPCollectionFusers.java index 07fb659769f..64a92f5190c 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPCollectionFusers.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPCollectionFusers.java @@ -180,6 +180,10 @@ private static boolean parDoCompatibility( // side inputs can be fused with other transforms in the same environment which are not // upstream of any of the side inputs. return pipeline.getSideInputs(parDo).isEmpty() + // Since we lack the ability to mark upstream transforms as key preserving, we + // purposefully break fusion here to provide runners the opportunity to insert a + // grouping operation + && pipeline.getUserStates(parDo).isEmpty() && compatibleEnvironments(parDo, other, pipeline); } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyStageFuser.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyStageFuser.java index 587d1959baa..89da434c83e 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyStageFuser.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyStageFuser.java @@ -38,8 +38,9 @@ * * <p>A {@link PCollectionNode} is fused into a stage if all of its consumers can be fused into the * stage. A consumer can be fused into a stage if it is executed within the environment of that - * {@link ExecutableStage}, and receives only per-element inputs. PTransforms which consume side - * inputs are always at the root of a stage. + * {@link ExecutableStage}, and receives only per-element inputs. To simplify integration for + * runners, this fuser specifically does not fuse PTransforms which consume side inputs or have user + * state, always making them the root of {@link ExecutableStage}. * * <p>A {@link PCollectionNode} with consumers that execute in an environment other than a stage is * materialized, and its consumers execute in independent stages. @@ -80,6 +81,7 @@ public static ExecutableStage forGrpcPortRead( fusedTransforms.addAll(initialNodes); Set<SideInputReference> sideInputs = new LinkedHashSet<>(); + Set<UserStateReference> userStates = new LinkedHashSet<>(); Set<PCollectionNode> fusedCollections = new LinkedHashSet<>(); Set<PCollectionNode> materializedPCollections = new LinkedHashSet<>(); @@ -87,6 +89,7 @@ public static ExecutableStage forGrpcPortRead( for (PTransformNode initialConsumer : initialNodes) { fusionCandidates.addAll(pipeline.getOutputPCollections(initialConsumer)); sideInputs.addAll(pipeline.getSideInputs(initialConsumer)); + userStates.addAll(pipeline.getUserStates(initialConsumer)); } while (!fusionCandidates.isEmpty()) { PCollectionNode candidate = fusionCandidates.poll(); @@ -130,6 +133,7 @@ public static ExecutableStage forGrpcPortRead( environment, inputPCollection, sideInputs, + userStates, fusedTransforms.build(), materializedPCollections); } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/ImmutableExecutableStage.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/ImmutableExecutableStage.java index ee77a873e74..5b7dedd15d1 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/ImmutableExecutableStage.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/ImmutableExecutableStage.java @@ -35,6 +35,7 @@ public static ImmutableExecutableStage ofFullComponents( Environment environment, PCollectionNode input, Collection<SideInputReference> sideInputs, + Collection<UserStateReference> userStates, Collection<PTransformNode> transforms, Collection<PCollectionNode> outputs) { Components prunedComponents = @@ -46,7 +47,7 @@ public static ImmutableExecutableStage ofFullComponents( .stream() .collect(Collectors.toMap(PTransformNode::getId, PTransformNode::getTransform))) .build(); - return of(prunedComponents, environment, input, sideInputs, transforms, outputs); + return of(prunedComponents, environment, input, sideInputs, userStates, transforms, outputs); } public static ImmutableExecutableStage of( @@ -54,6 +55,7 @@ public static ImmutableExecutableStage of( Environment environment, PCollectionNode input, Collection<SideInputReference> sideInputs, + Collection<UserStateReference> userStates, Collection<PTransformNode> transforms, Collection<PCollectionNode> outputs) { return new AutoValue_ImmutableExecutableStage( @@ -61,6 +63,7 @@ public static ImmutableExecutableStage of( environment, input, ImmutableSet.copyOf(sideInputs), + ImmutableSet.copyOf(userStates), ImmutableSet.copyOf(transforms), ImmutableSet.copyOf(outputs)); } @@ -78,6 +81,9 @@ public static ImmutableExecutableStage of( @Override public abstract Collection<SideInputReference> getSideInputs(); + @Override + public abstract Collection<UserStateReference> getUserStates(); + @Override public abstract Collection<PTransformNode> getTransforms(); diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/OutputDeduplicator.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/OutputDeduplicator.java index b4f08a887ce..4b06ab61228 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/OutputDeduplicator.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/OutputDeduplicator.java @@ -309,6 +309,7 @@ private static ExecutableStage deduplicateStageOutput( stage.getEnvironment(), stage.getInputPCollection(), stage.getSideInputs(), + stage.getUserStates(), updatedTransforms, updatedOutputs); } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/QueryablePipeline.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/QueryablePipeline.java index be624262bf3..12057a7f27c 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/QueryablePipeline.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/QueryablePipeline.java @@ -22,6 +22,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Iterables; +import com.google.common.collect.Sets; import com.google.common.graph.MutableNetwork; import com.google.common.graph.Network; import com.google.common.graph.NetworkBuilder; @@ -298,6 +299,31 @@ public Components getComponents() { .collect(Collectors.toSet()); } + public Collection<UserStateReference> getUserStates(PTransformNode transform) { + return getLocalUserStateNames(transform.getTransform()) + .stream() + .map( + localName -> { + String transformId = transform.getId(); + PTransform transformProto = components.getTransformsOrThrow(transformId); + // Get the main input PCollection id. + String collectionId = + transform + .getTransform() + .getInputsOrThrow( + Iterables.getOnlyElement( + Sets.difference( + transform.getTransform().getInputsMap().keySet(), + getLocalSideInputNames(transformProto)))); + PCollection collection = components.getPcollectionsOrThrow(collectionId); + return UserStateReference.of( + PipelineNode.pTransform(transformId, transformProto), + localName, + PipelineNode.pCollection(collectionId, collection)); + }) + .collect(Collectors.toSet()); + } + private Set<String> getLocalSideInputNames(PTransform transform) { if (PTransformTranslation.PAR_DO_TRANSFORM_URN.equals(transform.getSpec().getUrn())) { try { @@ -310,6 +336,18 @@ public Components getComponents() { } } + private Set<String> getLocalUserStateNames(PTransform transform) { + if (PTransformTranslation.PAR_DO_TRANSFORM_URN.equals(transform.getSpec().getUrn())) { + try { + return ParDoPayload.parseFrom(transform.getSpec().getPayload()).getStateSpecsMap().keySet(); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + } else { + return Collections.emptySet(); + } + } + public Optional<Environment> getEnvironment(PTransformNode parDo) { return Environments.getEnvironment(parDo.getId(), components); } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/UserStateReference.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/UserStateReference.java new file mode 100644 index 00000000000..08c6a6ac7f3 --- /dev/null +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/UserStateReference.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core.construction.graph; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.Iterables; +import com.google.common.collect.Sets; +import java.util.Collections; +import java.util.Set; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.UserStateId; +import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection; +import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform; +import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload; +import org.apache.beam.runners.core.construction.PTransformTranslation; +import org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode; +import org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode; +import org.apache.beam.vendor.protobuf.v3.com.google.protobuf.InvalidProtocolBufferException; + +/** + * A reference to user state. This includes the PTransform that references the user state as well as + * the local name. Both are necessary in order to fully resolve user state. + */ +@AutoValue +public abstract class UserStateReference { + + /** Create a user state reference. */ + public static UserStateReference of( + PTransformNode transform, String localName, PCollectionNode collection) { + return new AutoValue_UserStateReference(transform, localName, collection); + } + + /** Create a user state reference from a UserStateId proto and components. */ + public static UserStateReference fromUserStateId( + UserStateId userStateId, RunnerApi.Components components) { + String transformId = userStateId.getTransformId(); + String localName = userStateId.getLocalName(); + + PTransform transform = components.getTransformsOrThrow(transformId); + + Set<String> sideInputNames = Collections.emptySet(); + if (PTransformTranslation.PAR_DO_TRANSFORM_URN.equals(transform.getSpec().getUrn())) { + try { + sideInputNames = + ParDoPayload.parseFrom(transform.getSpec().getPayload()).getSideInputsMap().keySet(); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + } + + // Get the main input PCollection id. + String collectionId = + transform.getInputsOrThrow( + Iterables.getOnlyElement( + Sets.difference(transform.getInputsMap().keySet(), sideInputNames))); + PCollection collection = components.getPcollectionsOrThrow(collectionId); + return UserStateReference.of( + PipelineNode.pTransform(transformId, transform), + localName, + PipelineNode.pCollection(collectionId, collection)); + } + + /** The id of the PTransform that uses this user state. */ + public abstract PTransformNode transform(); + /** The local name the referencing PTransform uses to refer to this user state. */ + public abstract String localName(); + /** The PCollection that represents the input to the PTransform. */ + public abstract PCollectionNode collection(); +} diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/ExecutableStageTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/ExecutableStageTest.java index 2ffbee2211c..e2c731f06bf 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/ExecutableStageTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/ExecutableStageTest.java @@ -37,6 +37,7 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload; import org.apache.beam.model.pipeline.v1.RunnerApi.SdkFunctionSpec; import org.apache.beam.model.pipeline.v1.RunnerApi.SideInput; +import org.apache.beam.model.pipeline.v1.RunnerApi.StateSpec; import org.apache.beam.model.pipeline.v1.RunnerApi.WindowIntoPayload; import org.apache.beam.runners.core.construction.PTransformTranslation; import org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode; @@ -62,6 +63,7 @@ public void testRoundTripToFromTransform() throws Exception { ParDoPayload.newBuilder() .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("foo")) .putSideInputs("side_input", SideInput.getDefaultInstance()) + .putStateSpecs("user_state", StateSpec.getDefaultInstance()) .build() .toByteString())) .build(); @@ -82,12 +84,16 @@ public void testRoundTripToFromTransform() throws Exception { SideInputReference sideInputRef = SideInputReference.of( transformNode, "side_input", PipelineNode.pCollection("sideInput.in", sideInput)); + UserStateReference userStateRef = + UserStateReference.of( + transformNode, "user_state", PipelineNode.pCollection("input.out", input)); ImmutableExecutableStage stage = ImmutableExecutableStage.of( components, env, PipelineNode.pCollection("input.out", input), Collections.singleton(sideInputRef), + Collections.singleton(userStateRef), Collections.singleton(PipelineNode.pTransform("pt", pt)), Collections.singleton(PipelineNode.pCollection("output.out", output))); diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/GreedyStageFuserTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/GreedyStageFuserTest.java index b0054f7920b..a1175afcffb 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/GreedyStageFuserTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/GreedyStageFuserTest.java @@ -1029,6 +1029,75 @@ public void executableStageProducingSideInputMaterializesIt() { subgraph.getOutputPCollections(), contains(PipelineNode.pCollection("sidePC", sidePC))); } + @Test + public void userStateIncludedInStage() { + Environment env = Environment.newBuilder().setUrl("common").build(); + PTransform readTransform = + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "read.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .build() + .toByteString())) + .build(); + PTransform parDoTransform = + PTransform.newBuilder() + .putInputs("input", "read.out") + .putOutputs("output", "parDo.out") + .setSpec( + FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload( + ParDoPayload.newBuilder() + .setDoFn(SdkFunctionSpec.newBuilder().setEnvironmentId("common")) + .putStateSpecs("state_spec", StateSpec.getDefaultInstance()) + .build() + .toByteString())) + .build(); + PCollection userStateMainInputPCollection = + PCollection.newBuilder().setUniqueName("read.out").build(); + + QueryablePipeline p = + QueryablePipeline.forPrimitivesIn( + partialComponents + .toBuilder() + .putTransforms("read", readTransform) + .putPcollections("read.out", userStateMainInputPCollection) + .putTransforms( + "user_state", + PTransform.newBuilder() + .putInputs("input", "impulse.out") + .putOutputs("output", "user_state.out") + .build()) + .putPcollections( + "user_state.out", + PCollection.newBuilder().setUniqueName("user_state.out").build()) + .putTransforms("parDo", parDoTransform) + .putPcollections( + "parDo.out", PCollection.newBuilder().setUniqueName("parDo.out").build()) + .putEnvironments("common", env) + .build()); + + PCollectionNode readOutput = + getOnlyElement(p.getOutputPCollections(PipelineNode.pTransform("read", readTransform))); + ExecutableStage subgraph = + GreedyStageFuser.forGrpcPortRead( + p, readOutput, ImmutableSet.of(PipelineNode.pTransform("parDo", parDoTransform))); + PTransformNode parDoNode = PipelineNode.pTransform("parDo", parDoTransform); + UserStateReference userStateRef = + UserStateReference.of( + parDoNode, + "state_spec", + PipelineNode.pCollection("read.out", userStateMainInputPCollection)); + assertThat(subgraph.getUserStates(), contains(userStateRef)); + assertThat(subgraph.getOutputPCollections(), emptyIterable()); + } + @Test public void materializesWithGroupByKeyConsumer() { // (impulse.out) -> read -> read.out -> gbk -> gbk.out diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/ImmutableExecutableStageTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/ImmutableExecutableStageTest.java index 7a9ccee9600..61b3cea54c9 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/ImmutableExecutableStageTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/ImmutableExecutableStageTest.java @@ -57,6 +57,7 @@ public void ofFullComponentsOnlyHasStagePTransforms() throws Exception { ParDoPayload.newBuilder() .setDoFn(RunnerApi.SdkFunctionSpec.newBuilder().setEnvironmentId("foo")) .putSideInputs("side_input", RunnerApi.SideInput.getDefaultInstance()) + .putStateSpecs("user_state", RunnerApi.StateSpec.getDefaultInstance()) .build() .toByteString())) .build(); @@ -78,12 +79,16 @@ public void ofFullComponentsOnlyHasStagePTransforms() throws Exception { SideInputReference sideInputRef = SideInputReference.of( transformNode, "side_input", PipelineNode.pCollection("sideInput.in", sideInput)); + UserStateReference userStateRef = + UserStateReference.of( + transformNode, "user_state", PipelineNode.pCollection("input.out", input)); ImmutableExecutableStage stage = ImmutableExecutableStage.ofFullComponents( components, env, PipelineNode.pCollection("input.out", input), Collections.singleton(sideInputRef), + Collections.singleton(userStateRef), Collections.singleton(PipelineNode.pTransform("pt", pt)), Collections.singleton(PipelineNode.pCollection("output.out", output))); diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/OutputDeduplicatorTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/OutputDeduplicatorTest.java index ea5b3ae9969..69a625b7f09 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/OutputDeduplicatorTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/OutputDeduplicatorTest.java @@ -96,6 +96,7 @@ public void unchangedWithNoDuplicates() { Environment.getDefaultInstance(), PipelineNode.pCollection(redOut.getUniqueName(), redOut), ImmutableList.of(), + ImmutableList.of(), ImmutableList.of(PipelineNode.pTransform("one", one)), ImmutableList.of(PipelineNode.pCollection(oneOut.getUniqueName(), oneOut))); ExecutableStage twoStage = @@ -104,6 +105,7 @@ public void unchangedWithNoDuplicates() { Environment.getDefaultInstance(), PipelineNode.pCollection(redOut.getUniqueName(), redOut), ImmutableList.of(), + ImmutableList.of(), ImmutableList.of(PipelineNode.pTransform("two", two)), ImmutableList.of(PipelineNode.pCollection(twoOut.getUniqueName(), twoOut))); PTransformNode redTransform = PipelineNode.pTransform("red", red); @@ -186,6 +188,7 @@ public void duplicateOverStages() { Environment.getDefaultInstance(), PipelineNode.pCollection(redOut.getUniqueName(), redOut), ImmutableList.of(), + ImmutableList.of(), ImmutableList.of( PipelineNode.pTransform("one", one), PipelineNode.pTransform("shared", shared)), ImmutableList.of(PipelineNode.pCollection(sharedOut.getUniqueName(), sharedOut))); @@ -195,6 +198,7 @@ public void duplicateOverStages() { Environment.getDefaultInstance(), PipelineNode.pCollection(redOut.getUniqueName(), redOut), ImmutableList.of(), + ImmutableList.of(), ImmutableList.of( PipelineNode.pTransform("two", two), PipelineNode.pTransform("shared", shared)), ImmutableList.of(PipelineNode.pCollection(sharedOut.getUniqueName(), sharedOut))); @@ -297,6 +301,7 @@ public void duplicateOverStagesAndTransforms() { Environment.getDefaultInstance(), PipelineNode.pCollection(redOut.getUniqueName(), redOut), ImmutableList.of(), + ImmutableList.of(), ImmutableList.of(PipelineNode.pTransform("one", one), sharedTransform), ImmutableList.of(PipelineNode.pCollection(sharedOut.getUniqueName(), sharedOut))); PTransformNode redTransform = PipelineNode.pTransform("red", red); @@ -436,6 +441,7 @@ public void multipleDuplicatesInStages() { Environment.getDefaultInstance(), PipelineNode.pCollection(redOut.getUniqueName(), redOut), ImmutableList.of(), + ImmutableList.of(), ImmutableList.of( PipelineNode.pTransform("multi", three), PipelineNode.pTransform("shared", shared), @@ -449,6 +455,7 @@ public void multipleDuplicatesInStages() { Environment.getDefaultInstance(), PipelineNode.pCollection(redOut.getUniqueName(), redOut), ImmutableList.of(), + ImmutableList.of(), ImmutableList.of( PipelineNode.pTransform("one", one), PipelineNode.pTransform("shared", shared)), ImmutableList.of(PipelineNode.pCollection(sharedOut.getUniqueName(), sharedOut))); @@ -458,6 +465,7 @@ public void multipleDuplicatesInStages() { Environment.getDefaultInstance(), PipelineNode.pCollection(redOut.getUniqueName(), redOut), ImmutableList.of(), + ImmutableList.of(), ImmutableList.of( PipelineNode.pTransform("two", two), PipelineNode.pTransform("otherShared", otherShared)), diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactoryTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactoryTest.java index b5d616a6b30..c36409ce32a 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactoryTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactoryTest.java @@ -239,6 +239,7 @@ private static ExecutableStage createExecutableStage(Collection<SideInputReferen inputCollection, sideInputs, Collections.emptyList(), + Collections.emptyList(), Collections.emptyList()); } diff --git a/runners/java-fn-execution/build.gradle b/runners/java-fn-execution/build.gradle index 432b136c2dd..1979d3351b0 100644 --- a/runners/java-fn-execution/build.gradle +++ b/runners/java-fn-execution/build.gradle @@ -29,6 +29,7 @@ dependencies { shadow project(path: ":beam-sdks-java-core", configuration: "shadow") shadow project(path: ":beam-sdks-java-fn-execution", configuration: "shadow") shadow project(path: ":beam-runners-core-construction-java", configuration: "shadow") + shadow project(path: ":beam-vendor-sdks-java-extensions-protobuf", configuration: "shadow") shadow library.java.slf4j_api testCompile project(":beam-sdks-java-harness") testCompile project(path: ":beam-runners-core-construction-java", configuration: "shadow") diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java index 280e3647761..0338d3e0347 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java @@ -44,6 +44,7 @@ import org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode; import org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode; import org.apache.beam.runners.core.construction.graph.SideInputReference; +import org.apache.beam.runners.core.construction.graph.UserStateReference; import org.apache.beam.runners.fnexecution.data.RemoteInputDestination; import org.apache.beam.runners.fnexecution.wire.LengthPrefixUnknownCoders; import org.apache.beam.runners.fnexecution.wire.WireCoders; @@ -55,6 +56,7 @@ import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; import org.apache.beam.sdk.values.KV; import org.apache.beam.vendor.protobuf.v3.com.google.protobuf.InvalidProtocolBufferException; +import org.apache.beam.vendor.sdk.v2.sdk.extensions.protobuf.ByteStringCoder; /** Utility methods for creating {@link ProcessBundleDescriptor} instances. */ // TODO: Rename to ExecutableStages? @@ -131,8 +133,15 @@ private static ExecutableProcessBundleDescriptor fromExecutableStageInternal( .putAllWindowingStrategies(components.getWindowingStrategiesMap()) .putAllTransforms(components.getTransformsMap()); + Map<String, Map<String, BagUserStateSpec>> bagUserStateSpecs = + forBagUserStates(stage, components.build()); + return ExecutableProcessBundleDescriptor.of( - bundleDescriptorBuilder.build(), inputDestination, outputTargetCoders, sideInputSpecs); + bundleDescriptorBuilder.build(), + inputDestination, + outputTargetCoders, + sideInputSpecs, + bagUserStateSpecs); } private static Map<Target, Coder<WindowedValue<?>>> addStageOutputs( @@ -254,6 +263,30 @@ private static TargetEncoding addStageOutput( } } + private static Map<String, Map<String, BagUserStateSpec>> forBagUserStates( + ExecutableStage stage, Components components) throws IOException { + ImmutableTable.Builder<String, String, BagUserStateSpec> idsToSpec = ImmutableTable.builder(); + for (UserStateReference userStateReference : stage.getUserStates()) { + FullWindowedValueCoder<KV<?, ?>> coder = + (FullWindowedValueCoder) + WireCoders.instantiateRunnerWireCoder(userStateReference.collection(), components); + idsToSpec.put( + userStateReference.transform().getId(), + userStateReference.localName(), + BagUserStateSpec.of( + userStateReference.transform().getId(), + userStateReference.localName(), + // We use the ByteString coder to save on encoding and decoding the actual key. + ByteStringCoder.of(), + // Usage of the ByteStringCoder provides a significant simplification for handling + // a logical stream of values by not needing to know where the element boundaries + // actually are. See StateRequestHandlers.java for further details. + ByteStringCoder.of(), + coder.getWindowCoder())); + } + return idsToSpec.build().rowMap(); + } + @AutoValue abstract static class TargetEncoding { abstract BeamFnApi.Target getTarget(); @@ -288,6 +321,33 @@ private static TargetEncoding addStageOutput( public abstract Coder<W> windowCoder(); } + /** + * A container type storing references to the key, value, and window {@link Coder} used when + * handling bag user state requests. + */ + @AutoValue + public abstract static class BagUserStateSpec<K, V, W extends BoundedWindow> { + static <K, V, W extends BoundedWindow> BagUserStateSpec<K, V, W> of( + String transformId, + String userStateId, + Coder<K> keyCoder, + Coder<V> valueCoder, + Coder<W> windowCoder) { + return new AutoValue_ProcessBundleDescriptors_BagUserStateSpec( + transformId, userStateId, keyCoder, valueCoder, windowCoder); + } + + public abstract String transformId(); + + public abstract String userStateId(); + + public abstract Coder<K> keyCoder(); + + public abstract Coder<V> valueCoder(); + + public abstract Coder<W> windowCoder(); + } + /** */ @AutoValue public abstract static class ExecutableProcessBundleDescriptor { @@ -295,18 +355,26 @@ public static ExecutableProcessBundleDescriptor of( ProcessBundleDescriptor descriptor, RemoteInputDestination<WindowedValue<?>> inputDestination, Map<BeamFnApi.Target, Coder<WindowedValue<?>>> outputTargetCoders, - Map<String, Map<String, SideInputSpec>> sideInputSpecs) { + Map<String, Map<String, SideInputSpec>> sideInputSpecs, + Map<String, Map<String, BagUserStateSpec>> bagUserStateSpecs) { ImmutableTable.Builder copyOfSideInputSpecs = ImmutableTable.builder(); for (Map.Entry<String, Map<String, SideInputSpec>> outer : sideInputSpecs.entrySet()) { for (Map.Entry<String, SideInputSpec> inner : outer.getValue().entrySet()) { copyOfSideInputSpecs.put(outer.getKey(), inner.getKey(), inner.getValue()); } } + ImmutableTable.Builder copyOfBagUserStateSpecs = ImmutableTable.builder(); + for (Map.Entry<String, Map<String, BagUserStateSpec>> outer : bagUserStateSpecs.entrySet()) { + for (Map.Entry<String, BagUserStateSpec> inner : outer.getValue().entrySet()) { + copyOfBagUserStateSpecs.put(outer.getKey(), inner.getKey(), inner.getValue()); + } + } return new AutoValue_ProcessBundleDescriptors_ExecutableProcessBundleDescriptor( descriptor, inputDestination, Collections.unmodifiableMap(outputTargetCoders), - copyOfSideInputSpecs.build().rowMap()); + copyOfSideInputSpecs.build().rowMap(), + copyOfBagUserStateSpecs.build().rowMap()); } public abstract ProcessBundleDescriptor getProcessBundleDescriptor(); @@ -328,5 +396,11 @@ public static ExecutableProcessBundleDescriptor of( * are used during execution. */ public abstract Map<String, Map<String, SideInputSpec>> getSideInputSpecs(); + + /** + * Get a mapping from PTransform id to user state input id to {@link BagUserStateSpec bag user + * states} that are used during execution. + */ + public abstract Map<String, Map<String, BagUserStateSpec>> getBagUserStateSpecs(); } } diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/state/StateRequestHandlers.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/state/StateRequestHandlers.java index 8b1f8b8f56d..4b36c25e356 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/state/StateRequestHandlers.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/state/StateRequestHandlers.java @@ -20,7 +20,9 @@ import static com.google.common.base.Preconditions.checkState; +import com.google.common.collect.ImmutableList; import java.util.ArrayList; +import java.util.EnumMap; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -28,12 +30,15 @@ import java.util.concurrent.CompletionStage; import java.util.concurrent.ConcurrentHashMap; import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateAppendResponse; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearResponse; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateGetResponse; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey.TypeCase; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse; import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.BagUserStateSpec; import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.ExecutableProcessBundleDescriptor; import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.SideInputSpec; import org.apache.beam.sdk.coders.Coder; @@ -42,6 +47,7 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.common.Reiterable; import org.apache.beam.vendor.protobuf.v3.com.google.protobuf.ByteString; +import org.apache.beam.vendor.sdk.v2.sdk.extensions.protobuf.ByteStringCoder; /** * A set of utility methods which construct {@link StateRequestHandler}s. @@ -168,6 +174,45 @@ static BagUserStateHandlerFactory unsupported() { } } + /** + * Returns a {@link StateRequestHandler} which delegates to the supplied handler depending on the + * {@link StateRequest}s {@link StateKey.TypeCase type}. + * + * <p>An exception is thrown if a corresponding handler is not found. + */ + public static StateRequestHandler delegateBasedUponType( + EnumMap<StateKey.TypeCase, StateRequestHandler> handlers) { + return new StateKeyTypeDelegatingStateRequestHandler(handlers); + } + + /** + * A {@link StateRequestHandler} which delegates to the supplied handler depending on the {@link + * StateRequest}s {@link StateKey.TypeCase type}. + * + * <p>An exception is thrown if a corresponding handler is not found. + */ + static class StateKeyTypeDelegatingStateRequestHandler implements StateRequestHandler { + private final EnumMap<TypeCase, StateRequestHandler> handlers; + + StateKeyTypeDelegatingStateRequestHandler( + EnumMap<StateKey.TypeCase, StateRequestHandler> handlers) { + this.handlers = handlers; + } + + @Override + public CompletionStage<StateResponse.Builder> handle(StateRequest request) throws Exception { + return handlers + .getOrDefault(request.getStateKey().getTypeCase(), this::handlerNotFound) + .handle(request); + } + + private CompletionStage<StateResponse.Builder> handlerNotFound(StateRequest request) { + CompletableFuture<StateResponse.Builder> rval = new CompletableFuture<>(); + rval.completeExceptionally(new IllegalStateException()); + return rval; + } + } + /** * Returns an adapter which converts a {@link SideInputHandlerFactory} to a {@link * StateRequestHandler}. @@ -213,10 +258,9 @@ public static StateRequestHandler forSideInputHandlerFactory( TypeCase.MULTIMAP_SIDE_INPUT); StateKey.MultimapSideInput stateKey = request.getStateKey().getMultimapSideInput(); - SideInputSpec<?, ?, ?> sideInputReferenceSpec = + SideInputSpec<?, ?, ?> referenceSpec = sideInputSpecs.get(stateKey.getPtransformId()).get(stateKey.getSideInputId()); - SideInputHandler<?, ?> handler = - cache.computeIfAbsent(sideInputReferenceSpec, this::createHandler); + SideInputHandler<?, ?> handler = cache.computeIfAbsent(referenceSpec, this::createHandler); switch (request.getRequestCase()) { case GET: @@ -276,4 +320,156 @@ public static StateRequestHandler forSideInputHandlerFactory( cacheKey.windowCoder()); } } + + /** + * Returns an adapter which converts a {@link BagUserStateHandlerFactory} to a {@link + * StateRequestHandler}. + * + * <p>The {@link MultimapSideInputHandlerFactory} is required to handle all multimap side inputs + * contained within the {@link ExecutableProcessBundleDescriptor}. See {@link + * ExecutableProcessBundleDescriptor#getMultimapSideInputSpecs} for the set of multimap side + * inputs that are contained. + * + * <p>Instances of {@link MultimapSideInputHandler}s returned by the {@link + * MultimapSideInputHandlerFactory} are cached. + */ + public static StateRequestHandler forBagUserStateHandlerFactory( + ExecutableProcessBundleDescriptor processBundleDescriptor, + BagUserStateHandlerFactory bagUserStateHandlerFactory) { + return new ByteStringStateRequestHandlerToBagUserStateHandlerFactoryAdapter( + processBundleDescriptor, bagUserStateHandlerFactory); + } + + /** + * An adapter which converts {@link BagUserStateHandlerFactory} to {@link StateRequestHandler}. + */ + static class ByteStringStateRequestHandlerToBagUserStateHandlerFactoryAdapter + implements StateRequestHandler { + + private final ExecutableProcessBundleDescriptor processBundleDescriptor; + private final BagUserStateHandlerFactory handlerFactory; + private final ConcurrentHashMap<BagUserStateSpec, BagUserStateHandler> cache; + + ByteStringStateRequestHandlerToBagUserStateHandlerFactoryAdapter( + ExecutableProcessBundleDescriptor processBundleDescriptor, + BagUserStateHandlerFactory handlerFactory) { + this.processBundleDescriptor = processBundleDescriptor; + this.handlerFactory = handlerFactory; + this.cache = new ConcurrentHashMap<>(); + } + + @Override + public CompletionStage<StateResponse.Builder> handle(StateRequest request) throws Exception { + try { + checkState( + TypeCase.BAG_USER_STATE.equals(request.getStateKey().getTypeCase()), + "Unsupported %s type %s, expected %s", + StateRequest.class.getSimpleName(), + request.getStateKey().getTypeCase(), + TypeCase.BAG_USER_STATE); + + StateKey.BagUserState stateKey = request.getStateKey().getBagUserState(); + BagUserStateSpec<Object, Object, BoundedWindow> referenceSpec = + processBundleDescriptor + .getBagUserStateSpecs() + .get(stateKey.getPtransformId()) + .get(stateKey.getUserStateId()); + + // Note that by using the ByteStringCoder, we simplify the issue of encoding/decoding the + // logical stream because we do not need to maintain knowledge of element boundaries and + // instead we rely on the client to be internally consistent. This allows us to just + // take the append requests and also to serve them back without internal knowledge. + checkState( + ((Coder) referenceSpec.keyCoder()) instanceof ByteStringCoder, + "This %s only supports the %s as the key coder.", + BagUserStateHandlerFactory.class.getSimpleName(), + ByteStringCoder.class.getSimpleName()); + checkState( + ((Coder) referenceSpec.valueCoder()) instanceof ByteStringCoder, + "This %s only supports the %s as the value coder.", + BagUserStateHandlerFactory.class.getSimpleName(), + ByteStringCoder.class.getSimpleName()); + + BagUserStateHandler<ByteString, ByteString, BoundedWindow> handler = + cache.computeIfAbsent(referenceSpec, this::createHandler); + + ByteString key = stateKey.getKey(); + BoundedWindow window = referenceSpec.windowCoder().decode(stateKey.getWindow().newInput()); + + switch (request.getRequestCase()) { + case GET: + return handleGetRequest(request, key, window, handler); + case APPEND: + return handleAppendRequest(request, key, window, handler); + case CLEAR: + return handleClearRequest(request, key, window, handler); + default: + throw new Exception( + String.format( + "Unsupported request type %s for user state.", request.getRequestCase())); + } + } catch (Exception e) { + CompletableFuture f = new CompletableFuture(); + f.completeExceptionally(e); + return f; + } + } + + private static <W extends BoundedWindow> + CompletionStage<StateResponse.Builder> handleGetRequest( + StateRequest request, + ByteString key, + W window, + BagUserStateHandler<ByteString, ByteString, W> handler) { + // TODO: Add support for continuation tokens when handling state if the handler + // returned a {@link Reiterable}. + checkState( + request.getGet().getContinuationToken().isEmpty(), + "Continuation tokens are unsupported."); + + return CompletableFuture.completedFuture( + StateResponse.newBuilder() + .setId(request.getId()) + .setGet( + StateGetResponse.newBuilder() + // Note that this doesn't copy the actual bytes, just the references. + .setData(ByteString.copyFrom(handler.get(key, window))))); + } + + private static <W extends BoundedWindow> + CompletionStage<StateResponse.Builder> handleAppendRequest( + StateRequest request, + ByteString key, + W window, + BagUserStateHandler<ByteString, ByteString, W> handler) { + handler.append(key, window, ImmutableList.of(request.getAppend().getData()).iterator()); + return CompletableFuture.completedFuture( + StateResponse.newBuilder() + .setId(request.getId()) + .setAppend(StateAppendResponse.getDefaultInstance())); + } + + private static <W extends BoundedWindow> + CompletionStage<StateResponse.Builder> handleClearRequest( + StateRequest request, + ByteString key, + W window, + BagUserStateHandler<ByteString, ByteString, W> handler) { + handler.clear(key, window); + return CompletableFuture.completedFuture( + StateResponse.newBuilder() + .setId(request.getId()) + .setClear(StateClearResponse.getDefaultInstance())); + } + + private <K, V, W extends BoundedWindow> BagUserStateHandler<K, V, W> createHandler( + BagUserStateSpec cacheKey) { + return handlerFactory.forUserState( + cacheKey.transformId(), + cacheKey.userStateId(), + cacheKey.keyCoder(), + cacheKey.valueCoder(), + cacheKey.windowCoder()); + } + } } diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java index 7f16631f420..13a533ae2cd 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java @@ -23,7 +23,9 @@ import static org.junit.Assert.assertThat; import com.google.common.base.Optional; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; import com.google.common.util.concurrent.ThreadFactoryBuilder; import java.io.Serializable; import java.time.Duration; @@ -32,6 +34,7 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -60,6 +63,8 @@ import org.apache.beam.runners.fnexecution.state.GrpcStateService; import org.apache.beam.runners.fnexecution.state.StateRequestHandler; import org.apache.beam.runners.fnexecution.state.StateRequestHandlers; +import org.apache.beam.runners.fnexecution.state.StateRequestHandlers.BagUserStateHandler; +import org.apache.beam.runners.fnexecution.state.StateRequestHandlers.BagUserStateHandlerFactory; import org.apache.beam.runners.fnexecution.state.StateRequestHandlers.SideInputHandler; import org.apache.beam.runners.fnexecution.state.StateRequestHandlers.SideInputHandlerFactory; import org.apache.beam.sdk.Pipeline; @@ -72,6 +77,10 @@ import org.apache.beam.sdk.fn.stream.OutboundObserverFactory; import org.apache.beam.sdk.fn.test.InProcessManagedChannelFactory; import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.ReadableState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.Impulse; @@ -84,6 +93,9 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.vendor.protobuf.v3.com.google.protobuf.ByteString; +import org.hamcrest.collection.IsEmptyIterable; +import org.hamcrest.collection.IsIterableContainingInOrder; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -212,10 +224,10 @@ public void process(ProcessContext ctxt) { ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage( "my_stage", stage, dataServer.getApiServiceDescriptor()); - // TODO: This cast is nonsense + + @SuppressWarnings({"unchecked", "rawtypes"}) RemoteInputDestination<WindowedValue<byte[]>> remoteDestination = - (RemoteInputDestination<WindowedValue<byte[]>>) - (RemoteInputDestination) descriptor.getRemoteInputDestination(); + (RemoteInputDestination) descriptor.getRemoteInputDestination(); BundleProcessor<byte[]> processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), remoteDestination); @@ -299,10 +311,10 @@ public void processElement(ProcessContext context) { stage, dataServer.getApiServiceDescriptor(), stateServer.getApiServiceDescriptor()); - // TODO: This cast is nonsense + + @SuppressWarnings({"unchecked", "rawtypes"}) RemoteInputDestination<WindowedValue<byte[]>> remoteDestination = - (RemoteInputDestination<WindowedValue<byte[]>>) - (RemoteInputDestination) descriptor.getRemoteInputDestination(); + (RemoteInputDestination) descriptor.getRemoteInputDestination(); BundleProcessor<byte[]> processor = controlClient.getProcessor( @@ -375,6 +387,161 @@ public void processElement(ProcessContext context) { } } + @Test + public void testExecutionWithUserState() throws Exception { + Pipeline p = Pipeline.create(); + final String stateId = "foo"; + final String stateId2 = "foo2"; + + p.apply("impulse", Impulse.create()) + .apply( + "create", + ParDo.of( + new DoFn<byte[], KV<String, String>>() { + @ProcessElement + public void process(ProcessContext ctxt) {} + })) + .setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())) + .apply( + "userState", + ParDo.of( + new DoFn<KV<String, String>, KV<String, String>>() { + + @StateId(stateId) + private final StateSpec<BagState<String>> bufferState = + StateSpecs.bag(StringUtf8Coder.of()); + + @StateId(stateId2) + private final StateSpec<BagState<String>> bufferState2 = + StateSpecs.bag(StringUtf8Coder.of()); + + @ProcessElement + public void processElement( + @Element KV<String, String> element, + @StateId(stateId) BagState<String> state, + @StateId(stateId2) BagState<String> state2, + OutputReceiver<KV<String, String>> r) { + ReadableState<Boolean> isEmpty = state.isEmpty(); + for (String value : state.read()) { + r.output(KV.of(element.getKey(), value)); + } + state.add(element.getValue()); + state2.clear(); + } + })) + // Force the output to be materialized + .apply("gbk", GroupByKey.create()); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto); + Optional<ExecutableStage> optionalStage = + Iterables.tryFind( + fused.getFusedStages(), (ExecutableStage stage) -> !stage.getUserStates().isEmpty()); + checkState(optionalStage.isPresent(), "Expected a stage with user state."); + ExecutableStage stage = optionalStage.get(); + + ExecutableProcessBundleDescriptor descriptor = + ProcessBundleDescriptors.fromExecutableStage( + "test_stage", + stage, + dataServer.getApiServiceDescriptor(), + stateServer.getApiServiceDescriptor()); + + @SuppressWarnings({"unchecked", "rawtypes"}) + RemoteInputDestination<WindowedValue<KV<byte[], byte[]>>> remoteDestination = + (RemoteInputDestination) descriptor.getRemoteInputDestination(); + + BundleProcessor<KV<byte[], byte[]>> processor = + controlClient.getProcessor( + descriptor.getProcessBundleDescriptor(), remoteDestination, stateDelegator); + Map<Target, Coder<WindowedValue<?>>> outputTargets = descriptor.getOutputTargetCoders(); + Map<Target, Collection<WindowedValue<?>>> outputValues = new HashMap<>(); + Map<Target, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>(); + for (Entry<Target, Coder<WindowedValue<?>>> targetCoder : outputTargets.entrySet()) { + List<WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>()); + outputValues.put(targetCoder.getKey(), outputContents); + outputReceivers.put( + targetCoder.getKey(), + RemoteOutputReceiver.of(targetCoder.getValue(), outputContents::add)); + } + + Map<String, List<ByteString>> userStateData = + ImmutableMap.of( + stateId, + new ArrayList( + Arrays.asList( + ByteString.copyFrom( + CoderUtils.encodeToByteArray( + StringUtf8Coder.of(), "A", Coder.Context.NESTED)), + ByteString.copyFrom( + CoderUtils.encodeToByteArray( + StringUtf8Coder.of(), "B", Coder.Context.NESTED)), + ByteString.copyFrom( + CoderUtils.encodeToByteArray( + StringUtf8Coder.of(), "C", Coder.Context.NESTED)))), + stateId2, + new ArrayList( + Arrays.asList( + ByteString.copyFrom( + CoderUtils.encodeToByteArray( + StringUtf8Coder.of(), "D", Coder.Context.NESTED))))); + StateRequestHandler stateRequestHandler = + StateRequestHandlers.forBagUserStateHandlerFactory( + descriptor, + new BagUserStateHandlerFactory() { + @Override + public <K, V, W extends BoundedWindow> BagUserStateHandler<K, V, W> forUserState( + String pTransformId, + String userStateId, + Coder<K> keyCoder, + Coder<V> valueCoder, + Coder<W> windowCoder) { + return new BagUserStateHandler<K, V, W>() { + @Override + public Iterable<V> get(K key, W window) { + return (Iterable) userStateData.get(userStateId); + } + + @Override + public void append(K key, W window, Iterator<V> values) { + Iterators.addAll(userStateData.get(userStateId), (Iterator) values); + } + + @Override + public void clear(K key, W window) { + userStateData.get(userStateId).clear(); + } + }; + } + }); + + try (ActiveBundle<KV<byte[], byte[]>> bundle = + processor.newBundle( + outputReceivers, stateRequestHandler, BundleProgressHandler.unsupported())) { + bundle.getInputReceiver().accept(WindowedValue.valueInGlobalWindow(kvBytes("X", "Y"))); + } + for (Collection<WindowedValue<?>> windowedValues : outputValues.values()) { + assertThat( + windowedValues, + containsInAnyOrder( + WindowedValue.valueInGlobalWindow(kvBytes("X", "A")), + WindowedValue.valueInGlobalWindow(kvBytes("X", "B")), + WindowedValue.valueInGlobalWindow(kvBytes("X", "C")))); + } + assertThat( + userStateData.get(stateId), + IsIterableContainingInOrder.contains( + ByteString.copyFrom( + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), + ByteString.copyFrom( + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), + ByteString.copyFrom( + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED)), + ByteString.copyFrom( + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "Y", Coder.Context.NESTED)))); + assertThat(userStateData.get(stateId2), IsEmptyIterable.emptyIterable()); + } + private KV<byte[], byte[]> kvBytes(String key, long value) throws CoderException { return KV.of( CoderUtils.encodeToByteArray(StringUtf8Coder.of(), key), diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/state/StateRequestHandlersTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/state/StateRequestHandlersTest.java new file mode 100644 index 00000000000..7bd4d4e9f17 --- /dev/null +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/state/StateRequestHandlersTest.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.fnexecution.state; + +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import java.util.EnumMap; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey.MultimapSideInput; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey.TypeCase; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +/** Tests for {@link StateRequestHandlers}. */ +@RunWith(JUnit4.class) +public class StateRequestHandlersTest { + @Test + public void testDelegatingStateHandlerDelegates() throws Exception { + StateRequestHandler mockHandler = Mockito.mock(StateRequestHandler.class); + StateRequestHandler mockHandler2 = Mockito.mock(StateRequestHandler.class); + EnumMap<StateKey.TypeCase, StateRequestHandler> handlers = + new EnumMap<>(StateKey.TypeCase.class); + handlers.put(StateKey.TypeCase.TYPE_NOT_SET, mockHandler); + handlers.put(TypeCase.MULTIMAP_SIDE_INPUT, mockHandler2); + StateRequest request = StateRequest.getDefaultInstance(); + StateRequest request2 = + StateRequest.newBuilder() + .setStateKey( + StateKey.newBuilder().setMultimapSideInput(MultimapSideInput.getDefaultInstance())) + .build(); + StateRequestHandlers.delegateBasedUponType(handlers).handle(request); + StateRequestHandlers.delegateBasedUponType(handlers).handle(request2); + verify(mockHandler).handle(request); + verify(mockHandler2).handle(request2); + verifyNoMoreInteractions(mockHandler, mockHandler2); + } + + @Test + public void testDelegatingStateHandlerThrowsWhenNotFound() throws Exception { + StateRequestHandlers.delegateBasedUponType(new EnumMap<>(StateKey.TypeCase.class)) + .handle(StateRequest.getDefaultInstance()); + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/TimeDomain.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/TimeDomain.java index 1778047a055..685f3e956c0 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/TimeDomain.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/TimeDomain.java @@ -33,8 +33,8 @@ EVENT_TIME, /** - * The {@link #PROCESSING_TIME} domain corresponds to the current to the current (system) time. - * This is advanced during execution of the pipeline. + * The {@link #PROCESSING_TIME} domain corresponds to the current (system) time. This is advanced + * during execution of the pipeline. */ PROCESSING_TIME, diff --git a/settings.gradle b/settings.gradle index 07ae1adead7..b44236ee516 100644 --- a/settings.gradle +++ b/settings.gradle @@ -164,3 +164,5 @@ include "beam-sdks-python" project(":beam-sdks-python").dir = file("sdks/python") include "beam-sdks-python-container" project(":beam-sdks-python-container").dir = file("sdks/python/container") +include "beam-vendor-sdks-java-extensions-protobuf" +project(":beam-vendor-sdks-java-extensions-protobuf").dir = file("vendor/sdks-java-extensions-protobuf") diff --git a/vendor/sdks-java-extensions-protobuf/build.gradle b/vendor/sdks-java-extensions-protobuf/build.gradle new file mode 100644 index 00000000000..437a2264c06 --- /dev/null +++ b/vendor/sdks-java-extensions-protobuf/build.gradle @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +apply plugin: org.apache.beam.gradle.BeamModulePlugin +applyJavaNature(shadowClosure: DEFAULT_SHADOW_CLOSURE << { + dependencies { + include(dependency('com.google.guava:guava:20.0')) + include(dependency('com.google.protobuf:protobuf-java:3.5.1')) + } + // We specifically relocate beam-sdks-extensions-protobuf under a vendored namespace + // but also vendor guava and protobuf to the same vendored namespace as the model/* + // implementations allowing the artifacts to encode/decode vendored byte strings and + // vendored protobuf messages + relocate "org.apache.beam.sdk.extensions.protobuf", "org.apache.beam.vendor.sdk.v2.sdk.extensions.protobuf" + + // guava uses the com.google.common and com.google.thirdparty package namespaces + relocate "com.google.common", "org.apache.beam.vendor.guava.v20.com.google.common" + relocate "com.google.thirdparty", "org.apache.beam.vendor.guava.v20.com.google.thirdparty" + + relocate "com.google.protobuf", "org.apache.beam.vendor.protobuf.v3.com.google.protobuf" +}) + +description = "Apache Beam :: Vendored Dependencies :: SDKs :: Java :: Extensions :: Protobuf" +ext.summary = "Add support to Apache Beam for Vendored Google Protobuf." + +/* + * We need to rely on manually specifying these evaluationDependsOn to ensure that + * the following projects are evaluated before we evaluate this project. This is because + * we are attempting to reference the "sourceSets.main.java.srcDirs" directly. + */ +evaluationDependsOn(":beam-sdks-java-extensions-protobuf") + +sourceSets { + main { + java { + srcDirs project(":beam-sdks-java-extensions-protobuf").sourceSets.main.java.srcDirs + } + } +} + +dependencies { + compile 'com.google.guava:guava:20.0' + compile 'com.google.protobuf:protobuf-java:3.5.1' + shadow project(path: ":beam-sdks-java-core", configuration: "shadow") +} + +task('validateShadedJarDoesntLeakNonOrgApacheBeamClasses', dependsOn: ['shadowJar', 'shadowTestJar']) { + inputs.files configurations.shadow.artifacts.files + inputs.files configurations.shadowTest.artifacts.files + doLast { + (configurations.shadow.artifacts.files + configurations.shadowTest.artifacts.files).each { + FileTree exposedClasses = zipTree(it).matching { + include "**/*.class" + exclude "org/apache/beam/**" + } + if (exposedClasses.files) { + throw new GradleException("$it exposed classes outside of org.apache.beam namespace: ${exposedClasses.files}") + } + } + } +} +tasks.check.dependsOn tasks.validateShadedJarDoesntLeakNonOrgApacheBeamClasses ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org Issue Time Tracking ------------------- Worklog Id: (was: 121943) Time Spent: 3h 40m (was: 3.5h) > Java SDK support for portable user state > ---------------------------------------- > > Key: BEAM-2915 > URL: https://issues.apache.org/jira/browse/BEAM-2915 > Project: Beam > Issue Type: Sub-task > Components: sdk-java-core > Reporter: Henning Rohde > Assignee: Luke Cwik > Priority: Minor > Labels: portability > Time Spent: 3h 40m > Remaining Estimate: 0h > -- This message was sent by Atlassian JIRA (v7.6.3#76005)