This is an automated email from the ASF dual-hosted git repository. echauchot pushed a commit to branch spark-runner_structured-streaming in repository https://gitbox.apache.org/repos/asf/beam.git
commit 29f7e93c954cc26425a052c0f1c19ec6e6c9fe66 Author: Etienne Chauchot <[email protected]> AuthorDate: Fri Sep 27 11:55:20 2019 +0200 Apply new Encoders to AggregatorCombiner --- .../translation/batch/AggregatorCombiner.java | 22 +++++++++++++++++----- .../batch/CombinePerKeyTranslatorBatch.java | 20 ++++++++++++++++---- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner.java index 0e3229e..d14569a 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner.java @@ -27,6 +27,8 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; @@ -52,13 +54,25 @@ class AggregatorCombiner<K, InputT, AccumT, OutputT, W extends BoundedWindow> private final Combine.CombineFn<InputT, AccumT, OutputT> combineFn; private WindowingStrategy<InputT, W> windowingStrategy; private TimestampCombiner timestampCombiner; + private IterableCoder<WindowedValue<AccumT>> accumulatorCoder; + private IterableCoder<WindowedValue<OutputT>> outputCoder; public AggregatorCombiner( Combine.CombineFn<InputT, AccumT, OutputT> combineFn, - WindowingStrategy<?, ?> windowingStrategy) { + WindowingStrategy<?, ?> windowingStrategy, + Coder<AccumT> accumulatorCoder, + Coder<OutputT> outputCoder) { this.combineFn = combineFn; this.windowingStrategy = (WindowingStrategy<InputT, W>) windowingStrategy; this.timestampCombiner = windowingStrategy.getTimestampCombiner(); + this.accumulatorCoder = + IterableCoder.of( + WindowedValue.FullWindowedValueCoder.of( + accumulatorCoder, windowingStrategy.getWindowFn().windowCoder())); + this.outputCoder = + IterableCoder.of( + WindowedValue.FullWindowedValueCoder.of( + outputCoder, windowingStrategy.getWindowFn().windowCoder())); } @Override @@ -142,14 +156,12 @@ class AggregatorCombiner<K, InputT, AccumT, OutputT, W extends BoundedWindow> @Override public Encoder<Iterable<WindowedValue<AccumT>>> bufferEncoder() { - // TODO replace with accumulatorCoder if possible - return EncoderHelpers.genericEncoder(); + return EncoderHelpers.fromBeamCoder(accumulatorCoder); } @Override public Encoder<Iterable<WindowedValue<OutputT>>> outputEncoder() { - // TODO replace with outputCoder if possible - return EncoderHelpers.genericEncoder(); + return EncoderHelpers.fromBeamCoder(outputCoder); } private Set<W> collectAccumulatorsWindows(Iterable<WindowedValue<AccumT>> accumulators) { diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java index 33b037a..be238b5 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java @@ -23,6 +23,7 @@ import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTr import org.apache.beam.runners.spark.structuredstreaming.translation.TranslationContext; import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.KVHelpers; +import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.transforms.Combine; @@ -58,20 +59,31 @@ class CombinePerKeyTranslatorBatch<K, InputT, AccumT, OutputT> Dataset<WindowedValue<KV<K, InputT>>> inputDataset = context.getDataset(input); - Coder<K> keyCoder = (Coder<K>) input.getCoder().getCoderArguments().get(0); - Coder<OutputT> outputTCoder = (Coder<OutputT>) output.getCoder().getCoderArguments().get(1); + KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) input.getCoder(); + Coder<K> keyCoder = inputCoder.getKeyCoder(); + KvCoder<K, OutputT> outputKVCoder = (KvCoder<K, OutputT>) output.getCoder(); + Coder<OutputT> outputCoder = outputKVCoder.getValueCoder(); KeyValueGroupedDataset<K, WindowedValue<KV<K, InputT>>> groupedDataset = inputDataset.groupByKey(KVHelpers.extractKey(), EncoderHelpers.fromBeamCoder(keyCoder)); + Coder<AccumT> accumulatorCoder = null; + try { + accumulatorCoder = + combineFn.getAccumulatorCoder( + input.getPipeline().getCoderRegistry(), inputCoder.getValueCoder()); + } catch (CannotProvideCoderException e) { + throw new RuntimeException(e); + } + Dataset<Tuple2<K, Iterable<WindowedValue<OutputT>>>> combinedDataset = groupedDataset.agg( new AggregatorCombiner<K, InputT, AccumT, OutputT, BoundedWindow>( - combineFn, windowingStrategy) + combineFn, windowingStrategy, accumulatorCoder, outputCoder) .toColumn()); // expand the list into separate elements and put the key back into the elements - Coder<KV<K, OutputT>> kvCoder = KvCoder.of(keyCoder, outputTCoder); + Coder<KV<K, OutputT>> kvCoder = KvCoder.of(keyCoder, outputCoder); WindowedValue.WindowedValueCoder<KV<K, OutputT>> wvCoder = WindowedValue.FullWindowedValueCoder.of( kvCoder, input.getWindowingStrategy().getWindowFn().windowCoder());
