[
https://issues.apache.org/jira/browse/BEAM-2918?focusedWorklogId=158759&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-158759
]
ASF GitHub Bot logged work on BEAM-2918:
----------------------------------------
Author: ASF GitHub Bot
Created on: 25/Oct/18 16:07
Start Date: 25/Oct/18 16:07
Worklog Time Spent: 10m
Work Description: mxm closed pull request #6740: [BEAM-2918] Add state
support for batch in portable FlinkRunner
URL: https://github.com/apache/beam/pull/6740
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
index 4a93e1b54cf..201b4a50adb 100644
--- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
+++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy
@@ -219,7 +219,8 @@ class BeamModulePlugin implements Plugin<Project> {
excludeCategories 'org.apache.beam.sdk.testing.UsesFailureMessage'
excludeCategories 'org.apache.beam.sdk.testing.UsesGaugeMetrics'
excludeCategories 'org.apache.beam.sdk.testing.UsesParDoLifecycle'
- excludeCategories 'org.apache.beam.sdk.testing.UsesStatefulParDo'
+ excludeCategories 'org.apache.beam.sdk.testing.UsesMapState'
+ excludeCategories 'org.apache.beam.sdk.testing.UsesSetState'
excludeCategories 'org.apache.beam.sdk.testing.UsesTestStream'
excludeCategories 'org.apache.beam.sdk.testing.UsesTimersInParDo'
//SplitableDoFnTests
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
index bb3a8903100..252a624156b 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
@@ -18,6 +18,7 @@
package org.apache.beam.runners.flink;
import static com.google.common.base.Preconditions.checkArgument;
+import static
org.apache.beam.runners.flink.translation.utils.FlinkPipelineTranslatorUtils.instantiateCoder;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableMap;
@@ -31,6 +32,7 @@
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
+import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
@@ -86,6 +88,7 @@
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.operators.Grouping;
import org.apache.flink.api.java.operators.MapPartitionOperator;
+import org.apache.flink.api.java.operators.SingleInputUdfOperator;
/**
* A translator that translates bounded portable pipelines into executable
Flink pipelines.
@@ -333,19 +336,46 @@ public void translate(BatchTranslationContext context,
RunnerApi.Pipeline pipeli
} catch (IOException e) {
throw new RuntimeException(e);
}
- FlinkExecutableStageFunction<InputT> function =
+
+ String inputPCollectionId = stagePayload.getInput();
+ DataSet<WindowedValue<InputT>> inputDataSet =
context.getDataSetOrThrow(inputPCollectionId);
+
+ final boolean stateful = stagePayload.getUserStatesCount() > 0;
+ final FlinkExecutableStageFunction<InputT> function =
new FlinkExecutableStageFunction<>(
stagePayload,
context.getJobInfo(),
outputMap,
- FlinkExecutableStageContext.factory(context.getPipelineOptions()));
-
- DataSet<WindowedValue<InputT>> inputDataSet =
- context.getDataSetOrThrow(stagePayload.getInput());
+ FlinkExecutableStageContext.factory(context.getPipelineOptions()),
+ stateful);
+
+ final SingleInputUdfOperator taggedDataset;
+ if (stateful) {
+ Coder<WindowedValue<InputT>> windowedInputCoder =
+ instantiateCoder(inputPCollectionId, components);
+ Coder valueCoder =
+ ((WindowedValue.FullWindowedValueCoder)
windowedInputCoder).getValueCoder();
+ // Stateful stages are only allowed of KV input to be able to group on
the key
+ if (!(valueCoder instanceof KvCoder)) {
+ throw new IllegalStateException(
+ String.format(
+ Locale.ENGLISH,
+ "The element coder for stateful DoFn '%s' must be KvCoder but
is: %s",
+ inputPCollectionId,
+ valueCoder.getClass().getSimpleName()));
+ }
+ Coder keyCoder = ((KvCoder) valueCoder).getKeyCoder();
- MapPartitionOperator<WindowedValue<InputT>, RawUnionValue> taggedDataset =
- new MapPartitionOperator<>(
- inputDataSet, typeInformation, function,
transform.getTransform().getUniqueName());
+ Grouping<WindowedValue<InputT>> groupedInput =
+ inputDataSet.groupBy(new KvKeySelector<>(keyCoder));
+ taggedDataset =
+ new GroupReduceOperator<>(
+ groupedInput, typeInformation, function,
transform.getTransform().getUniqueName());
+ } else {
+ taggedDataset =
+ new MapPartitionOperator<>(
+ inputDataSet, typeInformation, function,
transform.getTransform().getUniqueName());
+ }
for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
String collectionId =
@@ -372,7 +402,7 @@ public void translate(BatchTranslationContext context,
RunnerApi.Pipeline pipeli
// no-op sink to each to make sure they are materialized by Flink.
However, some SDK-executed
// stages have no runner-visible output after fusion. We handle this
case by adding a sink
// here.
- taggedDataset.output(new DiscardingOutputFormat());
+ taggedDataset.output(new DiscardingOutputFormat<>());
}
}
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java
index 08e520c97b2..942cc114d90 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java
@@ -584,7 +584,7 @@ public void translateNode(
sideInputStrategies,
context.getPipelineOptions(),
outputMap,
- (TupleTag<OutputT>) mainOutputTag,
+ mainOutputTag,
inputCoder,
outputCoderMap);
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java
index f087223eb9d..dc497744bbf 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java
@@ -17,6 +17,8 @@
*/
package org.apache.beam.runners.flink;
+import static
org.apache.beam.runners.flink.translation.utils.FlinkPipelineTranslatorUtils.instantiateCoder;
+
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.auto.service.AutoService;
@@ -60,7 +62,6 @@
import
org.apache.beam.runners.flink.translation.wrappers.streaming.WorkItemKeySelector;
import
org.apache.beam.runners.flink.translation.wrappers.streaming.io.StreamingImpulseSource;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
-import org.apache.beam.runners.fnexecution.wire.WireCoders;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
@@ -787,15 +788,4 @@ private TransformedSideInputs transformSideInputs(
return value.withValue(KV.of(null, value.getValue()));
}
}
-
- static <T> Coder<WindowedValue<T>> instantiateCoder(
- String collectionId, RunnerApi.Components components) {
- PipelineNode.PCollectionNode collectionNode =
- PipelineNode.pCollection(collectionId,
components.getPcollectionsOrThrow(collectionId));
- try {
- return WireCoders.instantiateRunnerWireCoder(collectionNode, components);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
}
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactory.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactory.java
index b912b2a7c4f..1e8a4025012 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactory.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkBatchSideInputHandlerFactory.java
@@ -158,7 +158,7 @@ private FlinkBatchSideInputHandlerFactory(
}
}
- return new MultimapSideInputHandler(multimap.build(), keyCoder,
valueCoder, windowCoder);
+ return new MultimapSideInputHandler<>(multimap.build(), keyCoder,
valueCoder, windowCoder);
}
private static class MultimapSideInputHandler<K, V, W extends BoundedWindow>
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
index 09c45fa719a..429c00d27c2 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
@@ -22,9 +22,20 @@
import com.google.common.collect.Iterables;
import java.io.IOException;
+import java.util.ArrayList;
+import java.util.EnumMap;
+import java.util.Iterator;
+import java.util.List;
import java.util.Map;
import javax.annotation.concurrent.GuardedBy;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.runners.core.InMemoryStateInternals;
+import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateNamespaces;
+import org.apache.beam.runners.core.StateTag;
+import org.apache.beam.runners.core.StateTags;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
import org.apache.beam.runners.fnexecution.control.OutputReceiverFactory;
@@ -34,12 +45,17 @@
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
import org.apache.beam.runners.fnexecution.state.StateRequestHandlers;
+import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.transforms.join.RawUnionValue;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.common.functions.GroupReduceFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.Collector;
@@ -54,8 +70,9 @@
* coder. The coder's tags are determined by the output coder map. The
resulting data set should be
* further processed by a {@link FlinkExecutableStagePruningFunction}.
*/
-public class FlinkExecutableStageFunction<InputT>
- extends RichMapPartitionFunction<WindowedValue<InputT>, RawUnionValue> {
+public class FlinkExecutableStageFunction<InputT> extends AbstractRichFunction
+ implements MapPartitionFunction<WindowedValue<InputT>, RawUnionValue>,
+ GroupReduceFunction<WindowedValue<InputT>, RawUnionValue> {
private static final Logger LOG =
LoggerFactory.getLogger(FlinkExecutableStageFunction.class);
// Main constructor fields. All must be Serializable because Flink
distributes Functions to
@@ -68,6 +85,7 @@
// Map from PCollection id to the union tag used to represent this
PCollection in the output.
private final Map<String, Integer> outputMap;
private final FlinkExecutableStageContext.Factory contextFactory;
+ private final boolean stateful;
// Worker-local fields. These should only be constructed and consumed on
Flink TaskManagers.
private transient RuntimeContext runtimeContext;
@@ -75,16 +93,20 @@
private transient FlinkExecutableStageContext stageContext;
private transient StageBundleFactory stageBundleFactory;
private transient BundleProgressHandler progressHandler;
+ // Only initialized when the ExecutableStage is stateful
+ private transient InMemoryBagUserStateFactory bagUserStateHandlerFactory;
public FlinkExecutableStageFunction(
RunnerApi.ExecutableStagePayload stagePayload,
JobInfo jobInfo,
Map<String, Integer> outputMap,
- FlinkExecutableStageContext.Factory contextFactory) {
+ FlinkExecutableStageContext.Factory contextFactory,
+ boolean stateful) {
this.stagePayload = stagePayload;
this.jobInfo = jobInfo;
this.outputMap = outputMap;
this.contextFactory = contextFactory;
+ this.stateful = stateful;
}
@Override
@@ -96,30 +118,68 @@ public void open(Configuration parameters) throws
Exception {
runtimeContext = getRuntimeContext();
// TODO: Wire this into the distributed cache and make it pluggable.
stageContext = contextFactory.get(jobInfo);
+ stageBundleFactory = stageContext.getStageBundleFactory(executableStage);
// NOTE: It's safe to reuse the state handler between partitions because
each partition uses the
// same backing runtime context and broadcast variables. We use checkState
below to catch errors
// in backward-incompatible Flink changes.
- stateRequestHandler = getStateRequestHandler(executableStage,
runtimeContext);
- stageBundleFactory = stageContext.getStageBundleFactory(executableStage);
+ stateRequestHandler =
+ getStateRequestHandler(
+ executableStage, stageBundleFactory.getProcessBundleDescriptor(),
runtimeContext);
progressHandler = BundleProgressHandler.unsupported();
}
- private static StateRequestHandler getStateRequestHandler(
- ExecutableStage executableStage, RuntimeContext runtimeContext) {
+ private StateRequestHandler getStateRequestHandler(
+ ExecutableStage executableStage,
+ ProcessBundleDescriptors.ExecutableProcessBundleDescriptor
processBundleDescriptor,
+ RuntimeContext runtimeContext) {
+ final StateRequestHandler sideInputHandler;
StateRequestHandlers.SideInputHandlerFactory sideInputHandlerFactory =
FlinkBatchSideInputHandlerFactory.forStage(executableStage,
runtimeContext);
try {
- return StateRequestHandlers.forSideInputHandlerFactory(
- ProcessBundleDescriptors.getSideInputs(executableStage),
sideInputHandlerFactory);
+ sideInputHandler =
+ StateRequestHandlers.forSideInputHandlerFactory(
+ ProcessBundleDescriptors.getSideInputs(executableStage),
sideInputHandlerFactory);
} catch (IOException e) {
- throw new RuntimeException(e);
+ throw new RuntimeException("Failed to setup state handler", e);
+ }
+
+ final StateRequestHandler userStateHandler;
+ if (stateful) {
+ bagUserStateHandlerFactory = new InMemoryBagUserStateFactory();
+ userStateHandler =
+ StateRequestHandlers.forBagUserStateHandlerFactory(
+ processBundleDescriptor, bagUserStateHandlerFactory);
+ } else {
+ userStateHandler = StateRequestHandler.unsupported();
}
+
+ EnumMap<StateKey.TypeCase, StateRequestHandler> handlerMap =
+ new EnumMap<>(StateKey.TypeCase.class);
+ handlerMap.put(StateKey.TypeCase.MULTIMAP_SIDE_INPUT, sideInputHandler);
+ handlerMap.put(StateKey.TypeCase.BAG_USER_STATE, userStateHandler);
+
+ return StateRequestHandlers.delegateBasedUponType(handlerMap);
}
+ /** For non-stateful processing via a simple MapPartitionFunction. */
@Override
public void mapPartition(
Iterable<WindowedValue<InputT>> iterable, Collector<RawUnionValue>
collector)
throws Exception {
+ processElements(iterable, collector);
+ }
+
+ /** For stateful processing via a GroupReduceFunction. */
+ @Override
+ public void reduce(Iterable<WindowedValue<InputT>> iterable,
Collector<RawUnionValue> collector)
+ throws Exception {
+ bagUserStateHandlerFactory.resetForNewKey();
+ processElements(iterable, collector);
+ }
+
+ private void processElements(
+ Iterable<WindowedValue<InputT>> iterable, Collector<RawUnionValue>
collector)
+ throws Exception {
checkState(
runtimeContext == getRuntimeContext(),
"RuntimeContext changed from under us. State handler invalid.");
@@ -188,4 +248,93 @@ public void close() throws Exception {
};
}
}
+
+ /**
+ * Holds user state in memory if the ExecutableStage is stateful. Only one
key is active at a time
+ * due to the GroupReduceFunction being called once per key. Needs to be
reset via {@code
+ * resetForNewKey()} before processing a new key.
+ */
+ private static class InMemoryBagUserStateFactory
+ implements StateRequestHandlers.BagUserStateHandlerFactory {
+
+ private List<InMemorySingleKeyBagState> handlers;
+
+ private InMemoryBagUserStateFactory() {
+ handlers = new ArrayList<>();
+ }
+
+ @Override
+ public <K, V, W extends BoundedWindow>
+ StateRequestHandlers.BagUserStateHandler<K, V, W> forUserState(
+ String pTransformId,
+ String userStateId,
+ Coder<K> keyCoder,
+ Coder<V> valueCoder,
+ Coder<W> windowCoder) {
+
+ InMemorySingleKeyBagState<K, V, W> bagUserStateHandler =
+ new InMemorySingleKeyBagState<>(userStateId, valueCoder,
windowCoder);
+ handlers.add(bagUserStateHandler);
+
+ return bagUserStateHandler;
+ }
+
+ /** Prepares previous emitted state handlers for processing a new key. */
+ void resetForNewKey() {
+ for (InMemorySingleKeyBagState stateBags : handlers) {
+ stateBags.reset();
+ }
+ }
+
+ static class InMemorySingleKeyBagState<K, V, W extends BoundedWindow>
+ implements StateRequestHandlers.BagUserStateHandler<K, V, W> {
+
+ private final StateTag<BagState<V>> stateTag;
+ private final Coder<W> windowCoder;
+
+ /* Lazily initialized state internals upon first access */
+ private volatile StateInternals stateInternals;
+
+ InMemorySingleKeyBagState(String userStateId, Coder<V> valueCoder,
Coder<W> windowCoder) {
+ this.windowCoder = windowCoder;
+ this.stateTag = StateTags.bag(userStateId, valueCoder);
+ }
+
+ @Override
+ public Iterable<V> get(K key, W window) {
+ initStateInternals(key);
+ StateNamespace namespace = StateNamespaces.window(windowCoder, window);
+ BagState<V> bagState = stateInternals.state(namespace, stateTag);
+ return bagState.read();
+ }
+
+ @Override
+ public void append(K key, W window, Iterator<V> values) {
+ initStateInternals(key);
+ StateNamespace namespace = StateNamespaces.window(windowCoder, window);
+ BagState<V> bagState = stateInternals.state(namespace, stateTag);
+ while (values.hasNext()) {
+ bagState.add(values.next());
+ }
+ }
+
+ @Override
+ public void clear(K key, W window) {
+ initStateInternals(key);
+ StateNamespace namespace = StateNamespaces.window(windowCoder, window);
+ BagState<V> bagState = stateInternals.state(namespace, stateTag);
+ bagState.clear();
+ }
+
+ private void initStateInternals(K key) {
+ if (stateInternals == null) {
+ stateInternals = InMemoryStateInternals.forKey(key);
+ }
+ }
+
+ void reset() {
+ stateInternals = null;
+ }
+ }
+ }
}
diff --git
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/FlinkPipelineTranslatorUtils.java
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/FlinkPipelineTranslatorUtils.java
index 05cbffb7d3e..1710a27da7c 100644
---
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/FlinkPipelineTranslatorUtils.java
+++
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/utils/FlinkPipelineTranslatorUtils.java
@@ -20,6 +20,12 @@
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.Sets;
+import java.io.IOException;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.runners.core.construction.graph.PipelineNode;
+import org.apache.beam.runners.fnexecution.wire.WireCoders;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.util.WindowedValue;
/** Utilities for pipeline translation. */
public final class FlinkPipelineTranslatorUtils {
@@ -36,4 +42,16 @@ private FlinkPipelineTranslatorUtils() {}
}
return builder.build();
}
+
+ /** Creates a coder for a given PCollection id from the Proto definition. */
+ public static <T> Coder<WindowedValue<T>> instantiateCoder(
+ String collectionId, RunnerApi.Components components) {
+ PipelineNode.PCollectionNode collectionNode =
+ PipelineNode.pCollection(collectionId,
components.getPcollectionsOrThrow(collectionId));
+ try {
+ return WireCoders.instantiateRunnerWireCoder(collectionNode, components);
+ } catch (IOException e) {
+ throw new RuntimeException("Could not instantiate Coder", e);
+ }
+ }
}
diff --git
a/runners/flink/src/test/java/org/apache/beam/runners/flink/PortableStateExecutionTest.java
b/runners/flink/src/test/java/org/apache/beam/runners/flink/PortableStateExecutionTest.java
index bd133d3c73e..58aeb7d3fa4 100644
---
a/runners/flink/src/test/java/org/apache/beam/runners/flink/PortableStateExecutionTest.java
+++
b/runners/flink/src/test/java/org/apache/beam/runners/flink/PortableStateExecutionTest.java
@@ -63,7 +63,7 @@
@Parameters
public static Object[] data() {
- return new Object[] {true};
+ return new Object[] {true, false};
}
@Parameter public boolean isStreaming;
@@ -80,11 +80,14 @@ public void tearDown() {
flinkJobExecutor.shutdown();
}
- private static final Map<String, Integer> stateValues = new HashMap<>();
+ // State -> Key -> Value
+ private static final Map<String, Map<String, Integer>> stateValuesMap = new
HashMap<>();
@Before
public void before() {
- stateValues.clear();
+ stateValuesMap.clear();
+ stateValuesMap.put("valueState", new HashMap<>());
+ stateValuesMap.put("valueState2", new HashMap<>());
}
// Special values which clear / write out state
@@ -131,9 +134,22 @@ public void process(ProcessContext ctx) {
private final StateSpec<ValueState<Integer>> valueStateSpec =
StateSpecs.value(VarIntCoder.of());
+ @StateId("valueState2")
+ private final StateSpec<ValueState<Integer>> valueStateSpec2
=
+ StateSpecs.value(VarIntCoder.of());
+
@ProcessElement
public void process(
- ProcessContext ctx, @StateId("valueState")
ValueState<Integer> valueState) {
+ ProcessContext ctx,
+ @StateId("valueState") ValueState<Integer> valueState,
+ @StateId("valueState2") ValueState<Integer> valueState2)
{
+ performStateUpdates("valueState", ctx, valueState);
+ performStateUpdates("valueState2", ctx, valueState2);
+ }
+
+ private void performStateUpdates(
+ String stateId, ProcessContext ctx, ValueState<Integer>
valueState) {
+ Map<String, Integer> stateValues =
stateValuesMap.get(stateId);
Integer value = ctx.element().getValue();
if (value == null) {
throw new IllegalStateException();
@@ -181,6 +197,8 @@ public void process(
expected.put("bla2", 64);
expected.put("clearedState", null);
- assertThat(stateValues, equalTo(expected));
+ for (Map<String, Integer> statesValues : stateValuesMap.values()) {
+ assertThat(statesValues, equalTo(expected));
+ }
}
}
diff --git
a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java
b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java
index bbb5302ae3f..e357f5e8101 100644
---
a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java
+++
b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java
@@ -52,15 +52,23 @@
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
+import org.junit.runners.Parameterized;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.mockito.internal.util.reflection.Whitebox;
/** Tests for {@link FlinkExecutableStageFunction}. */
-@RunWith(JUnit4.class)
+@RunWith(Parameterized.class)
public class FlinkExecutableStageFunctionTest {
+
+ @Parameterized.Parameters
+ public static Object[] data() {
+ return new Object[] {true, false};
+ }
+
+ @Parameterized.Parameter public boolean isStateful;
+
@Rule public ExpectedException thrown = ExpectedException.none();
@Mock private RuntimeContext runtimeContext;
@@ -202,7 +210,11 @@ public void close() throws Exception {}
FlinkExecutableStageFunction<Integer> function = getFunction(outputTagMap);
function.open(new Configuration());
- function.mapPartition(Collections.emptyList(), collector);
+ if (isStateful) {
+ function.reduce(Collections.emptyList(), collector);
+ } else {
+ function.mapPartition(Collections.emptyList(), collector);
+ }
// Ensure that the tagged values sent to the collector have the correct
union tags as specified
// in the output map.
verify(collector).collect(new RawUnionValue(1, three));
@@ -216,6 +228,7 @@ public void testStageBundleClosed() throws Exception {
FlinkExecutableStageFunction<Integer> function =
getFunction(Collections.emptyMap());
function.open(new Configuration());
function.close();
+ verify(stageBundleFactory).getProcessBundleDescriptor();
verify(stageBundleFactory).close();
verifyNoMoreInteractions(stageBundleFactory);
}
@@ -230,7 +243,8 @@ public void testStageBundleClosed() throws Exception {
Mockito.mock(FlinkExecutableStageContext.Factory.class);
when(contextFactory.get(any())).thenReturn(stageContext);
FlinkExecutableStageFunction<Integer> function =
- new FlinkExecutableStageFunction<>(stagePayload, jobInfo, outputMap,
contextFactory);
+ new FlinkExecutableStageFunction<>(
+ stagePayload, jobInfo, outputMap, contextFactory, isStateful);
function.setRuntimeContext(runtimeContext);
Whitebox.setInternalState(function, "stateRequestHandler",
stateRequestHandler);
return function;
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
index 06559bc198c..a99dfb3b989 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
@@ -255,7 +255,18 @@ public void add(T value) {
@Override
public ReadableState<Boolean> isEmpty() {
- return
ReadableStates.immediate(!impl.get().iterator().hasNext());
+ return new ReadableState<Boolean>() {
+ @Nullable
+ @Override
+ public Boolean read() {
+ return !impl.get().iterator().hasNext();
+ }
+
+ @Override
+ public ReadableState<Boolean> readLater() {
+ return this;
+ }
+ };
}
@Override
diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
index 2d3ae54d781..28f6aa2a287 100644
--- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
@@ -106,12 +106,6 @@ def test_read(self):
def test_no_subtransform_composite(self):
raise unittest.SkipTest("BEAM-4781")
- def test_pardo_state_only(self):
- if streaming:
- super(FlinkRunnerTest, self).test_pardo_state_only()
- else:
- raise unittest.SkipTest("BEAM-2918 - User state not yet supported.")
-
def test_pardo_timers(self):
raise unittest.SkipTest("BEAM-4681 - User timers not yet supported.")
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
Issue Time Tracking
-------------------
Worklog Id: (was: 158759)
Time Spent: 13h 10m (was: 13h)
> Flink support for portable user state
> -------------------------------------
>
> Key: BEAM-2918
> URL: https://issues.apache.org/jira/browse/BEAM-2918
> Project: Beam
> Issue Type: Sub-task
> Components: runner-flink
> Reporter: Henning Rohde
> Assignee: Maximilian Michels
> Priority: Minor
> Labels: portability
> Fix For: 2.9.0
>
> Time Spent: 13h 10m
> Remaining Estimate: 0h
>
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)