Repository: beam Updated Branches: refs/heads/master 13db84bb0 -> 2a2337460
[BEAM-1036] Support for new State API in Flink Batch Runner Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/ec5a8262 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/ec5a8262 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/ec5a8262 Branch: refs/heads/master Commit: ec5a82620916b2297a8b349a605af8fadeb2ceb7 Parents: 13db84b Author: JingsongLi <[email protected]> Authored: Tue Feb 28 02:07:33 2017 +0800 Committer: Aljoscha Krettek <[email protected]> Committed: Tue Feb 28 11:02:48 2017 +0100 ---------------------------------------------------------------------- runners/flink/runner/pom.xml | 1 - .../flink/FlinkBatchTransformTranslators.java | 130 +++++++++++------- .../functions/FlinkDoFnFunction.java | 52 ++++++- .../functions/FlinkMultiOutputDoFnFunction.java | 131 ------------------ .../FlinkMultiOutputPruningFunction.java | 2 +- .../functions/FlinkStatefulDoFnFunction.java | 134 +++++++++++++++++++ 6 files changed, 265 insertions(+), 185 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/ec5a8262/runners/flink/runner/pom.xml ---------------------------------------------------------------------- diff --git a/runners/flink/runner/pom.xml b/runners/flink/runner/pom.xml index c00b328..8cc65b0 100644 --- a/runners/flink/runner/pom.xml +++ b/runners/flink/runner/pom.xml @@ -55,7 +55,6 @@ <groups>org.apache.beam.sdk.testing.RunnableOnService</groups> <excludedGroups> org.apache.beam.sdk.testing.FlattenWithHeterogeneousCoders, - org.apache.beam.sdk.testing.UsesStatefulParDo, org.apache.beam.sdk.testing.UsesTimersInParDo, org.apache.beam.sdk.testing.UsesSplittableParDo, org.apache.beam.sdk.testing.UsesAttemptedMetrics, http://git-wip-us.apache.org/repos/asf/beam/blob/ec5a8262/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java index 99651c3..ed2f4aa 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java @@ -32,10 +32,10 @@ import org.apache.beam.runners.flink.translation.functions.FlinkDoFnFunction; import org.apache.beam.runners.flink.translation.functions.FlinkMergingNonShuffleReduceFunction; import org.apache.beam.runners.flink.translation.functions.FlinkMergingPartialReduceFunction; import org.apache.beam.runners.flink.translation.functions.FlinkMergingReduceFunction; -import org.apache.beam.runners.flink.translation.functions.FlinkMultiOutputDoFnFunction; import org.apache.beam.runners.flink.translation.functions.FlinkMultiOutputPruningFunction; import org.apache.beam.runners.flink.translation.functions.FlinkPartialReduceFunction; import org.apache.beam.runners.flink.translation.functions.FlinkReduceFunction; +import org.apache.beam.runners.flink.translation.functions.FlinkStatefulDoFnFunction; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.runners.flink.translation.types.KvKeySelector; import org.apache.beam.runners.flink.translation.wrappers.SourceInputFormat; @@ -498,19 +498,9 @@ class FlinkBatchTransformTranslators { } } - private static void rejectStateAndTimers(DoFn<?, ?> doFn) { + private static void rejectTimers(DoFn<?, ?> doFn) { DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); - if (signature.stateDeclarations().size() > 0) { - throw new UnsupportedOperationException( - String.format( - "Found %s annotations on %s, but %s cannot yet be used with state in the %s.", - DoFn.StateId.class.getSimpleName(), - doFn.getClass().getName(), - DoFn.class.getSimpleName(), - FlinkRunner.class.getSimpleName())); - } - if (signature.timerDeclarations().size() > 0) { throw new UnsupportedOperationException( String.format( @@ -527,13 +517,14 @@ class FlinkBatchTransformTranslators { ParDo.Bound<InputT, OutputT>> { @Override + @SuppressWarnings("unchecked") public void translateNode( ParDo.Bound<InputT, OutputT> transform, FlinkBatchTranslationContext context) { DoFn<InputT, OutputT> doFn = transform.getFn(); rejectSplittable(doFn); - rejectStateAndTimers(doFn); + rejectTimers(doFn); DataSet<WindowedValue<InputT>> inputDataSet = context.getInputDataSet(context.getInput(transform)); @@ -550,23 +541,48 @@ class FlinkBatchTransformTranslators { sideInputStrategies.put(sideInput, sideInput.getWindowingStrategyInternal()); } - FlinkDoFnFunction<InputT, OutputT> doFnWrapper = - new FlinkDoFnFunction<>( - doFn, - context.getOutput(transform).getWindowingStrategy(), - sideInputStrategies, - context.getPipelineOptions()); + WindowingStrategy<?, ?> windowingStrategy = + context.getOutput(transform).getWindowingStrategy(); + + SingleInputUdfOperator<WindowedValue<InputT>, WindowedValue<OutputT>, ?> outputDataSet; + DoFnSignature signature = DoFnSignatures.getSignature(transform.getFn().getClass()); + if (signature.stateDeclarations().size() > 0 + || signature.timerDeclarations().size() > 0) { + + // Based on the fact that the signature is stateful, DoFnSignatures ensures + // that it is also keyed + KvCoder<?, InputT> inputCoder = + (KvCoder<?, InputT>) context.getInput(transform).getCoder(); + + FlinkStatefulDoFnFunction<?, ?, OutputT> doFnWrapper = new FlinkStatefulDoFnFunction<>( + (DoFn) doFn, windowingStrategy, sideInputStrategies, context.getPipelineOptions(), + null, new TupleTag<OutputT>() + ); - MapPartitionOperator<WindowedValue<InputT>, WindowedValue<OutputT>> outputDataSet = - new MapPartitionOperator<>( - inputDataSet, - typeInformation, - doFnWrapper, - transform.getName()); + Grouping<WindowedValue<InputT>> grouping = + inputDataSet.groupBy(new KvKeySelector(inputCoder.getKeyCoder())); + + outputDataSet = new GroupReduceOperator( + grouping, typeInformation, doFnWrapper, transform.getName()); + + } else { + FlinkDoFnFunction<InputT, OutputT> doFnWrapper = + new FlinkDoFnFunction<>( + doFn, + windowingStrategy, + sideInputStrategies, + context.getPipelineOptions(), + null, new TupleTag<OutputT>()); + + outputDataSet = new MapPartitionOperator<>(inputDataSet, typeInformation, doFnWrapper, + transform.getName()); + + } transformSideInputs(sideInputs, outputDataSet, context); context.setOutputDataSet(context.getOutput(transform), outputDataSet); + } } @@ -575,12 +591,13 @@ class FlinkBatchTransformTranslators { ParDo.BoundMulti<InputT, OutputT>> { @Override + @SuppressWarnings("unchecked") public void translateNode( ParDo.BoundMulti<InputT, OutputT> transform, FlinkBatchTranslationContext context) { DoFn<InputT, OutputT> doFn = transform.getFn(); rejectSplittable(doFn); - rejectStateAndTimers(doFn); + rejectTimers(doFn); DataSet<WindowedValue<InputT>> inputDataSet = context.getInputDataSet(context.getInput(transform)); @@ -633,36 +650,57 @@ class FlinkBatchTransformTranslators { sideInputStrategies.put(sideInput, sideInput.getWindowingStrategyInternal()); } - @SuppressWarnings("unchecked") - FlinkMultiOutputDoFnFunction<InputT, OutputT> doFnWrapper = - new FlinkMultiOutputDoFnFunction( - doFn, - windowingStrategy, - sideInputStrategies, - context.getPipelineOptions(), - outputMap, - transform.getMainOutputTag()); - - MapPartitionOperator<WindowedValue<InputT>, WindowedValue<RawUnionValue>> taggedDataSet = - new MapPartitionOperator<>( - inputDataSet, - typeInformation, - doFnWrapper, - transform.getName()); - - transformSideInputs(sideInputs, taggedDataSet, context); + SingleInputUdfOperator<WindowedValue<InputT>, WindowedValue<RawUnionValue>, ?> outputDataSet; + DoFnSignature signature = DoFnSignatures.getSignature(transform.getFn().getClass()); + if (signature.stateDeclarations().size() > 0 + || signature.timerDeclarations().size() > 0) { + + // Based on the fact that the signature is stateful, DoFnSignatures ensures + // that it is also keyed + KvCoder<?, InputT> inputCoder = + (KvCoder<?, InputT>) context.getInput(transform).getCoder(); + + FlinkStatefulDoFnFunction<?, ?, OutputT> doFnWrapper = new FlinkStatefulDoFnFunction<>( + (DoFn) doFn, windowingStrategy, sideInputStrategies, context.getPipelineOptions(), + outputMap, transform.getMainOutputTag() + ); + + Grouping<WindowedValue<InputT>> grouping = + inputDataSet.groupBy(new KvKeySelector(inputCoder.getKeyCoder())); + + outputDataSet = + new GroupReduceOperator(grouping, typeInformation, doFnWrapper, transform.getName()); + + } else { + FlinkDoFnFunction<InputT, RawUnionValue> doFnWrapper = + new FlinkDoFnFunction( + doFn, + windowingStrategy, + sideInputStrategies, + context.getPipelineOptions(), + outputMap, + transform.getMainOutputTag()); + + outputDataSet = new MapPartitionOperator<>( + inputDataSet, typeInformation, + doFnWrapper, transform.getName()); + + } + + transformSideInputs(sideInputs, outputDataSet, context); for (TaggedPValue output : outputs) { pruneOutput( - taggedDataSet, + outputDataSet, context, outputMap.get(output.getTag()), (PCollection) output.getValue()); } + } private <T> void pruneOutput( - MapPartitionOperator<WindowedValue<InputT>, WindowedValue<RawUnionValue>> taggedDataSet, + DataSet<WindowedValue<RawUnionValue>> taggedDataSet, FlinkBatchTranslationContext context, int integerTag, PCollection<T> collection) { http://git-wip-us.apache.org/repos/asf/beam/blob/ec5a8262/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java index 7081aad..9687478 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java @@ -24,6 +24,7 @@ import org.apache.beam.runners.core.DoFnRunners; import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.join.RawUnionValue; import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; import org.apache.beam.sdk.util.WindowedValue; @@ -38,6 +39,10 @@ import org.apache.flink.util.Collector; /** * Encapsulates a {@link DoFn} * inside a Flink {@link org.apache.flink.api.common.functions.RichMapPartitionFunction}. + * + * <p>We get a mapping from {@link org.apache.beam.sdk.values.TupleTag} to output index + * and must tag all outputs with the output number. Afterwards a filter will filter out + * those elements that are not to be in a specific output. */ public class FlinkDoFnFunction<InputT, OutputT> extends RichMapPartitionFunction<WindowedValue<InputT>, WindowedValue<OutputT>> { @@ -49,18 +54,25 @@ public class FlinkDoFnFunction<InputT, OutputT> private final WindowingStrategy<?, ?> windowingStrategy; + private final Map<TupleTag<?>, Integer> outputMap; + private final TupleTag<OutputT> mainOutputTag; + private transient DoFnInvoker<InputT, OutputT> doFnInvoker; public FlinkDoFnFunction( DoFn<InputT, OutputT> doFn, WindowingStrategy<?, ?> windowingStrategy, Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs, - PipelineOptions options) { + PipelineOptions options, + Map<TupleTag<?>, Integer> outputMap, + TupleTag<OutputT> mainOutputTag) { this.doFn = doFn; this.sideInputs = sideInputs; this.serializedOptions = new SerializedPipelineOptions(options); this.windowingStrategy = windowingStrategy; + this.outputMap = outputMap; + this.mainOutputTag = mainOutputTag; } @@ -71,12 +83,21 @@ public class FlinkDoFnFunction<InputT, OutputT> RuntimeContext runtimeContext = getRuntimeContext(); + DoFnRunners.OutputManager outputManager; + if (outputMap == null) { + outputManager = new FlinkDoFnFunction.DoFnOutputManager(out); + } else { + // it has some sideOutputs + outputManager = + new FlinkDoFnFunction.MultiDoFnOutputManager((Collector) out, outputMap); + } + DoFnRunner<InputT, OutputT> doFnRunner = DoFnRunners.simpleRunner( serializedOptions.getPipelineOptions(), doFn, new FlinkSideInputReader(sideInputs, runtimeContext), - new DoFnOutputManager(out), - new TupleTag<OutputT>() { - }, + outputManager, + mainOutputTag, + // see SimpleDoFnRunner, just use it to limit number of side outputs Collections.<TupleTag<?>>emptyList(), new FlinkNoOpStepContext(), new FlinkAggregatorFactory(runtimeContext), @@ -102,12 +123,12 @@ public class FlinkDoFnFunction<InputT, OutputT> doFnInvoker.invokeTeardown(); } - private class DoFnOutputManager + static class DoFnOutputManager implements DoFnRunners.OutputManager { private Collector collector; - DoFnOutputManager(Collector<WindowedValue<OutputT>> collector) { + DoFnOutputManager(Collector collector) { this.collector = collector; } @@ -118,4 +139,23 @@ public class FlinkDoFnFunction<InputT, OutputT> } } + static class MultiDoFnOutputManager + implements DoFnRunners.OutputManager { + + private Collector<WindowedValue<RawUnionValue>> collector; + private Map<TupleTag<?>, Integer> outputMap; + + MultiDoFnOutputManager(Collector<WindowedValue<RawUnionValue>> collector, + Map<TupleTag<?>, Integer> outputMap) { + this.collector = collector; + this.outputMap = outputMap; + } + + @Override + public <T> void output(TupleTag<T> tag, WindowedValue<T> output) { + collector.collect(WindowedValue.of(new RawUnionValue(outputMap.get(tag), output.getValue()), + output.getTimestamp(), output.getWindows(), output.getPane())); + } + } + } http://git-wip-us.apache.org/repos/asf/beam/blob/ec5a8262/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java deleted file mode 100644 index 27ba5ac..0000000 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink.translation.functions; - -import java.util.Collections; -import java.util.Map; -import org.apache.beam.runners.core.DoFnRunner; -import org.apache.beam.runners.core.DoFnRunners; -import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.join.RawUnionValue; -import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; -import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.util.WindowingStrategy; -import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.flink.api.common.functions.RichMapPartitionFunction; -import org.apache.flink.api.common.functions.RuntimeContext; -import org.apache.flink.configuration.Configuration; -import org.apache.flink.util.Collector; - -/** - * Encapsulates a {@link DoFn} that can emit to multiple - * outputs inside a Flink {@link org.apache.flink.api.common.functions.RichMapPartitionFunction}. - * - * <p>We get a mapping from {@link org.apache.beam.sdk.values.TupleTag} to output index - * and must tag all outputs with the output number. Afterwards a filter will filter out - * those elements that are not to be in a specific output. - */ -public class FlinkMultiOutputDoFnFunction<InputT, OutputT> - extends RichMapPartitionFunction<WindowedValue<InputT>, WindowedValue<RawUnionValue>> { - - private final DoFn<InputT, OutputT> doFn; - private final SerializedPipelineOptions serializedOptions; - - private final Map<TupleTag<?>, Integer> outputMap; - - private final Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs; - private final WindowingStrategy<?, ?> windowingStrategy; - private TupleTag<OutputT> mainOutputTag; - private transient DoFnInvoker<InputT, OutputT> doFnInvoker; - - public FlinkMultiOutputDoFnFunction( - DoFn<InputT, OutputT> doFn, - WindowingStrategy<?, ?> windowingStrategy, - Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs, - PipelineOptions options, - Map<TupleTag<?>, Integer> outputMap, - TupleTag<OutputT> mainOutputTag) { - this.doFn = doFn; - this.serializedOptions = new SerializedPipelineOptions(options); - this.outputMap = outputMap; - - this.windowingStrategy = windowingStrategy; - this.sideInputs = sideInputs; - this.mainOutputTag = mainOutputTag; - } - - @Override - public void mapPartition( - Iterable<WindowedValue<InputT>> values, - Collector<WindowedValue<RawUnionValue>> out) throws Exception { - - RuntimeContext runtimeContext = getRuntimeContext(); - - DoFnRunner<InputT, OutputT> doFnRunner = DoFnRunners.simpleRunner( - serializedOptions.getPipelineOptions(), doFn, - new FlinkSideInputReader(sideInputs, runtimeContext), - new DoFnOutputManager(out), - mainOutputTag, - // see SimpleDoFnRunner, just use it to limit number of side outputs - Collections.<TupleTag<?>>emptyList(), - new FlinkNoOpStepContext(), - new FlinkAggregatorFactory(runtimeContext), - windowingStrategy); - - doFnRunner.startBundle(); - - for (WindowedValue<InputT> value : values) { - doFnRunner.processElement(value); - } - - doFnRunner.finishBundle(); - - } - - @Override - public void open(Configuration parameters) throws Exception { - doFnInvoker = DoFnInvokers.invokerFor(doFn); - doFnInvoker.invokeSetup(); - } - - @Override - public void close() throws Exception { - doFnInvoker.invokeTeardown(); - } - - private class DoFnOutputManager - implements DoFnRunners.OutputManager { - - private Collector<WindowedValue<RawUnionValue>> collector; - - DoFnOutputManager(Collector<WindowedValue<RawUnionValue>> collector) { - this.collector = collector; - } - - @Override - public <T> void output(TupleTag<T> tag, WindowedValue<T> output) { - collector.collect(WindowedValue.of(new RawUnionValue(outputMap.get(tag), output.getValue()), - output.getTimestamp(), output.getWindows(), output.getPane())); - } - } - -} http://git-wip-us.apache.org/repos/asf/beam/blob/ec5a8262/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java index b72750a..9071cc5 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java @@ -25,7 +25,7 @@ import org.apache.flink.util.Collector; /** * A {@link FlatMapFunction} function that filters out those elements that don't belong in this * output. We need this to implement MultiOutput ParDo functions in combination with - * {@link FlinkMultiOutputDoFnFunction}. + * {@link FlinkDoFnFunction}. */ public class FlinkMultiOutputPruningFunction<T> implements FlatMapFunction<WindowedValue<RawUnionValue>, WindowedValue<T>> { http://git-wip-us.apache.org/repos/asf/beam/blob/ec5a8262/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java new file mode 100644 index 0000000..fca7691 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import org.apache.beam.runners.core.DoFnRunner; +import org.apache.beam.runners.core.DoFnRunners; +import org.apache.beam.runners.core.InMemoryStateInternals; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowingStrategy; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.flink.api.common.functions.RichGroupReduceFunction; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.util.Collector; + +/** + * A {@link RichGroupReduceFunction} for stateful {@link ParDo} in Flink Batch Runner. + */ +public class FlinkStatefulDoFnFunction<K, V, OutputT> + extends RichGroupReduceFunction<WindowedValue<KV<K, V>>, WindowedValue<OutputT>> { + + private final DoFn<KV<K, V>, OutputT> dofn; + private final WindowingStrategy<?, ?> windowingStrategy; + private final Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs; + private final SerializedPipelineOptions serializedOptions; + private final Map<TupleTag<?>, Integer> outputMap; + private final TupleTag<OutputT> mainOutputTag; + private transient DoFnInvoker doFnInvoker; + + public FlinkStatefulDoFnFunction( + DoFn<KV<K, V>, OutputT> dofn, + WindowingStrategy<?, ?> windowingStrategy, + Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs, + PipelineOptions pipelineOptions, + Map<TupleTag<?>, Integer> outputMap, + TupleTag<OutputT> mainOutputTag) { + + this.dofn = dofn; + this.windowingStrategy = windowingStrategy; + this.sideInputs = sideInputs; + this.serializedOptions = new SerializedPipelineOptions(pipelineOptions); + this.outputMap = outputMap; + this.mainOutputTag = mainOutputTag; + } + + @Override + public void reduce( + Iterable<WindowedValue<KV<K, V>>> values, + Collector<WindowedValue<OutputT>> out) throws Exception { + RuntimeContext runtimeContext = getRuntimeContext(); + + DoFnRunners.OutputManager outputManager; + if (outputMap == null) { + outputManager = new FlinkDoFnFunction.DoFnOutputManager(out); + } else { + // it has some sideOutputs + outputManager = + new FlinkDoFnFunction.MultiDoFnOutputManager((Collector) out, outputMap); + } + + final Iterator<WindowedValue<KV<K, V>>> iterator = values.iterator(); + + // get the first value, we need this for initializing the state internals with the key. + // we are guaranteed to have a first value, otherwise reduce() would not have been called. + WindowedValue<KV<K, V>> currentValue = iterator.next(); + final K key = currentValue.getValue().getKey(); + + final InMemoryStateInternals<K> stateInternals = InMemoryStateInternals.forKey(key); + DoFnRunner<KV<K, V>, OutputT> doFnRunner = DoFnRunners.simpleRunner( + serializedOptions.getPipelineOptions(), dofn, + new FlinkSideInputReader(sideInputs, runtimeContext), + outputManager, + mainOutputTag, + // see SimpleDoFnRunner, just use it to limit number of side outputs + Collections.<TupleTag<?>>emptyList(), + new FlinkNoOpStepContext() { + @Override + public StateInternals<?> stateInternals() { + return stateInternals; + } + }, + new FlinkAggregatorFactory(runtimeContext), + windowingStrategy); + + doFnRunner.startBundle(); + + doFnRunner.processElement(currentValue); + while (iterator.hasNext()) { + currentValue = iterator.next(); + doFnRunner.processElement(currentValue); + } + + doFnRunner.finishBundle(); + } + + @Override + public void open(Configuration parameters) throws Exception { + doFnInvoker = DoFnInvokers.invokerFor(dofn); + doFnInvoker.invokeSetup(); + } + + @Override + public void close() throws Exception { + doFnInvoker.invokeTeardown(); + } + +}
