Repository: incubator-beam Updated Branches: refs/heads/master ac63fd6d4 -> c30326007
[BEAM-96] Add composed `CombineFn` builders in `CombineFns` * `compose()` or `composeKeyed()` are used to start composition * `with()` is used to add an input-transformation, a `CombineFn` and an output `TupleTag`. * A non-`CombineFn` initial builder is used to ensure that every composition includes at least one item * Duplicate output tags are not allowed in the same composition Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/23b43780 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/23b43780 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/23b43780 Branch: refs/heads/master Commit: 23b437802546f32a167b38f8d0bc7a566abde224 Parents: ac63fd6 Author: Pei He <pe...@google.com> Authored: Fri Mar 4 13:54:34 2016 -0800 Committer: bchambers <bchamb...@google.com> Committed: Thu Mar 17 13:54:40 2016 -0700 ---------------------------------------------------------------------- .../dataflow/sdk/transforms/CombineFns.java | 1100 ++++++++++++++++++ .../dataflow/sdk/transforms/CombineFnsTest.java | 413 +++++++ 2 files changed, 1513 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/23b43780/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java new file mode 100644 index 0000000..656c010 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java @@ -0,0 +1,1100 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; +import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.GlobalCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.PerKeyCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.CombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.Context; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +/** + * Static utility methods that create combine function instances. + */ +public class CombineFns { + + /** + * Returns a {@link ComposeKeyedCombineFnBuilder} to construct a composed + * {@link PerKeyCombineFn}. + * + * <p>The same {@link TupleTag} cannot be used in a composition multiple times. + * + * <p>Example: + * <pre>{ @code + * PCollection<KV<K, Integer>> latencies = ...; + * + * TupleTag<Integer> maxLatencyTag = new TupleTag<Integer>(); + * TupleTag<Double> meanLatencyTag = new TupleTag<Double>(); + * + * SimpleFunction<Integer, Integer> identityFn = + * new SimpleFunction<Integer, Integer>() { + * @Override + * public Integer apply(Integer input) { + * return input; + * }}; + * PCollection<KV<K, CoCombineResult>> maxAndMean = latencies.apply( + * Combine.perKey( + * CombineFns.composeKeyed() + * .with(identityFn, new MaxIntegerFn(), maxLatencyTag) + * .with(identityFn, new MeanFn<Integer>(), meanLatencyTag))); + * + * PCollection<T> finalResultCollection = maxAndMean + * .apply(ParDo.of( + * new DoFn<KV<K, CoCombineResult>, T>() { + * @Override + * public void processElement(ProcessContext c) throws Exception { + * KV<K, CoCombineResult> e = c.element(); + * Integer maxLatency = e.getValue().get(maxLatencyTag); + * Double meanLatency = e.getValue().get(meanLatencyTag); + * .... Do Something .... + * c.output(...some T...); + * } + * })); + * } </pre> + */ + public static ComposeKeyedCombineFnBuilder composeKeyed() { + return new ComposeKeyedCombineFnBuilder(); + } + + /** + * Returns a {@link ComposeCombineFnBuilder} to construct a composed + * {@link GlobalCombineFn}. + * + * <p>The same {@link TupleTag} cannot be used in a composition multiple times. + * + * <p>Example: + * <pre>{ @code + * PCollection<Integer> globalLatencies = ...; + * + * TupleTag<Integer> maxLatencyTag = new TupleTag<Integer>(); + * TupleTag<Double> meanLatencyTag = new TupleTag<Double>(); + * + * SimpleFunction<Integer, Integer> identityFn = + * new SimpleFunction<Integer, Integer>() { + * @Override + * public Integer apply(Integer input) { + * return input; + * }}; + * PCollection<CoCombineResult> maxAndMean = globalLatencies.apply( + * Combine.globally( + * CombineFns.compose() + * .with(identityFn, new MaxIntegerFn(), maxLatencyTag) + * .with(identityFn, new MeanFn<Integer>(), meanLatencyTag))); + * + * PCollection<T> finalResultCollection = maxAndMean + * .apply(ParDo.of( + * new DoFn<CoCombineResult, T>() { + * @Override + * public void processElement(ProcessContext c) throws Exception { + * CoCombineResult e = c.element(); + * Integer maxLatency = e.get(maxLatencyTag); + * Double meanLatency = e.get(meanLatencyTag); + * .... Do Something .... + * c.output(...some T...); + * } + * })); + * } </pre> + */ + public static ComposeCombineFnBuilder compose() { + return new ComposeCombineFnBuilder(); + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A builder class to construct a composed {@link PerKeyCombineFn}. + */ + public static class ComposeKeyedCombineFnBuilder { + /** + * Returns a {@link ComposedKeyedCombineFn} that can take additional + * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function. + * + * <p>The {@link ComposedKeyedCombineFn} extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them with the {@code keyedCombineFn}, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public <K, DataT, InputT, OutputT> ComposedKeyedCombineFn<DataT, K> with( + SimpleFunction<DataT, InputT> extractInputFn, + KeyedCombineFn<K, InputT, ?, OutputT> keyedCombineFn, + TupleTag<OutputT> outputTag) { + return new ComposedKeyedCombineFn<DataT, K>() + .with(extractInputFn, keyedCombineFn, outputTag); + } + + /** + * Returns a {@link ComposedKeyedCombineFnWithContext} that can take additional + * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function. + * + * <p>The {@link ComposedKeyedCombineFnWithContext} extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them with the {@code keyedCombineFnWithContext}, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public <K, DataT, InputT, OutputT> ComposedKeyedCombineFnWithContext<DataT, K> with( + SimpleFunction<DataT, InputT> extractInputFn, + KeyedCombineFnWithContext<K, InputT, ?, OutputT> keyedCombineFnWithContext, + TupleTag<OutputT> outputTag) { + return new ComposedKeyedCombineFnWithContext<DataT, K>() + .with(extractInputFn, keyedCombineFnWithContext, outputTag); + } + + /** + * Returns a {@link ComposedKeyedCombineFn} that can take additional + * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function. + */ + public <K, DataT, InputT, OutputT> ComposedKeyedCombineFn<DataT, K> with( + SimpleFunction<DataT, InputT> extractInputFn, + CombineFn<InputT, ?, OutputT> combineFn, + TupleTag<OutputT> outputTag) { + return with(extractInputFn, combineFn.<K>asKeyedFn(), outputTag); + } + + /** + * Returns a {@link ComposedKeyedCombineFnWithContext} that can take additional + * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function. + */ + public <K, DataT, InputT, OutputT> ComposedKeyedCombineFnWithContext<DataT, K> with( + SimpleFunction<DataT, InputT> extractInputFn, + CombineFnWithContext<InputT, ?, OutputT> combineFnWithContext, + TupleTag<OutputT> outputTag) { + return with(extractInputFn, combineFnWithContext.<K>asKeyedFn(), outputTag); + } + } + + /** + * A builder class to construct a composed {@link GlobalCombineFn}. + */ + public static class ComposeCombineFnBuilder { + /** + * Returns a {@link ComposedCombineFn} that can take additional + * {@link GlobalCombineFn GlobalCombineFns} and apply them as a single combine function. + * + * <p>The {@link ComposedCombineFn} extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them with the {@code combineFn}, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public <DataT, InputT, OutputT> ComposedCombineFn<DataT> with( + SimpleFunction<DataT, InputT> extractInputFn, + CombineFn<InputT, ?, OutputT> combineFn, + TupleTag<OutputT> outputTag) { + return new ComposedCombineFn<DataT>() + .with(extractInputFn, combineFn, outputTag); + } + + /** + * Returns a {@link ComposedCombineFnWithContext} that can take additional + * {@link GlobalCombineFn GlobalCombineFns} and apply them as a single combine function. + * + * <p>The {@link ComposedCombineFnWithContext} extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them with the {@code combineFnWithContext}, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public <DataT, InputT, OutputT> ComposedCombineFnWithContext<DataT> with( + SimpleFunction<DataT, InputT> extractInputFn, + CombineFnWithContext<InputT, ?, OutputT> combineFnWithContext, + TupleTag<OutputT> outputTag) { + return new ComposedCombineFnWithContext<DataT>() + .with(extractInputFn, combineFnWithContext, outputTag); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A tuple of outputs produced by a composed combine functions. + * + * <p>See {@link #compose()} or {@link #composeKeyed()}) for details. + */ + public static class CoCombineResult implements Serializable { + + private enum NullValue { + INSTANCE; + } + + private final Map<TupleTag<?>, Object> valuesMap; + + /** + * The constructor of {@link CoCombineResult}. + * + * <p>Null values should have been filtered out from the {@code valuesMap}. + * {@link TupleTag TupleTags} that associate with null values doesn't exist in the key set of + * {@code valuesMap}. + * + * @throws NullPointerException if any key or value in {@code valuesMap} is null + */ + CoCombineResult(Map<TupleTag<?>, Object> valuesMap) { + ImmutableMap.Builder<TupleTag<?>, Object> builder = ImmutableMap.builder(); + for (Entry<TupleTag<?>, Object> entry : valuesMap.entrySet()) { + if (entry.getValue() != null) { + builder.put(entry); + } else { + builder.put(entry.getKey(), NullValue.INSTANCE); + } + } + this.valuesMap = builder.build(); + } + + /** + * Returns the value represented by the given {@link TupleTag}. + * + * <p>It is an error to request a non-exist tuple tag from the {@link CoCombineResult}. + */ + @SuppressWarnings("unchecked") + public <V> V get(TupleTag<V> tag) { + checkArgument( + valuesMap.keySet().contains(tag), "TupleTag " + tag + " is not in the CoCombineResult"); + Object value = valuesMap.get(tag); + if (value == NullValue.INSTANCE) { + return null; + } else { + return (V) value; + } + } + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * A composed {@link CombineFn} that applies multiple {@link CombineFn CombineFns}. + * + * <p>For each {@link CombineFn} it extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public static class ComposedCombineFn<DataT> extends CombineFn<DataT, Object[], CoCombineResult> { + + private final List<CombineFn<Object, Object, Object>> combineFns; + private final List<SerializableFunction<DataT, Object>> extractInputFns; + private final List<TupleTag<?>> outputTags; + private final int combineFnCount; + + private ComposedCombineFn() { + this.extractInputFns = ImmutableList.of(); + this.combineFns = ImmutableList.of(); + this.outputTags = ImmutableList.of(); + this.combineFnCount = 0; + } + + private ComposedCombineFn( + ImmutableList<SerializableFunction<DataT, ?>> extractInputFns, + ImmutableList<CombineFn<?, ?, ?>> combineFns, + ImmutableList<TupleTag<?>> outputTags) { + @SuppressWarnings({"unchecked", "rawtypes"}) + List<SerializableFunction<DataT, Object>> castedExtractInputFns = (List) extractInputFns; + this.extractInputFns = castedExtractInputFns; + + @SuppressWarnings({"unchecked", "rawtypes"}) + List<CombineFn<Object, Object, Object>> castedCombineFns = (List) combineFns; + this.combineFns = castedCombineFns; + + this.outputTags = outputTags; + this.combineFnCount = this.combineFns.size(); + } + + /** + * Returns a {@link ComposedCombineFn} with an additional {@link CombineFn}. + */ + public <InputT, OutputT> ComposedCombineFn<DataT> with( + SimpleFunction<DataT, InputT> extractInputFn, + CombineFn<InputT, ?, OutputT> combineFn, + TupleTag<OutputT> outputTag) { + checkUniqueness(outputTags, outputTag); + return new ComposedCombineFn<>( + ImmutableList.<SerializableFunction<DataT, ?>>builder() + .addAll(extractInputFns) + .add(extractInputFn) + .build(), + ImmutableList.<CombineFn<?, ?, ?>>builder() + .addAll(combineFns) + .add(combineFn) + .build(), + ImmutableList.<TupleTag<?>>builder() + .addAll(outputTags) + .add(outputTag) + .build()); + } + + /** + * Returns a {@link ComposedCombineFnWithContext} with an additional + * {@link CombineFnWithContext}. + */ + public <InputT, OutputT> ComposedCombineFnWithContext<DataT> with( + SimpleFunction<DataT, InputT> extractInputFn, + CombineFnWithContext<InputT, ?, OutputT> combineFn, + TupleTag<OutputT> outputTag) { + checkUniqueness(outputTags, outputTag); + List<CombineFnWithContext<Object, Object, Object>> fnsWithContext = Lists.newArrayList(); + for (CombineFn<Object, Object, Object> fn : combineFns) { + fnsWithContext.add(toFnWithContext(fn)); + } + return new ComposedCombineFnWithContext<>( + ImmutableList.<SerializableFunction<DataT, ?>>builder() + .addAll(extractInputFns) + .add(extractInputFn) + .build(), + ImmutableList.<CombineFnWithContext<?, ?, ?>>builder() + .addAll(fnsWithContext) + .add(combineFn) + .build(), + ImmutableList.<TupleTag<?>>builder() + .addAll(outputTags) + .add(outputTag) + .build()); + } + + @Override + public Object[] createAccumulator() { + Object[] accumsArray = new Object[combineFnCount]; + for (int i = 0; i < combineFnCount; ++i) { + accumsArray[i] = combineFns.get(i).createAccumulator(); + } + return accumsArray; + } + + @Override + public Object[] addInput(Object[] accumulator, DataT value) { + for (int i = 0; i < combineFnCount; ++i) { + Object input = extractInputFns.get(i).apply(value); + accumulator[i] = combineFns.get(i).addInput(accumulator[i], input); + } + return accumulator; + } + + @Override + public Object[] mergeAccumulators(Iterable<Object[]> accumulators) { + Iterator<Object[]> iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(); + } else { + // Reuses the first accumulator, and overwrites its values. + // It is safe because {@code accum[i]} only depends on + // the i-th component of each accumulator. + Object[] accum = iter.next(); + for (int i = 0; i < combineFnCount; ++i) { + accum[i] = combineFns.get(i).mergeAccumulators(new ProjectionIterable(accumulators, i)); + } + return accum; + } + } + + @Override + public CoCombineResult extractOutput(Object[] accumulator) { + Map<TupleTag<?>, Object> valuesMap = Maps.newHashMap(); + for (int i = 0; i < combineFnCount; ++i) { + valuesMap.put( + outputTags.get(i), + combineFns.get(i).extractOutput(accumulator[i])); + } + return new CoCombineResult(valuesMap); + } + + @Override + public Object[] compact(Object[] accumulator) { + for (int i = 0; i < combineFnCount; ++i) { + accumulator[i] = combineFns.get(i).compact(accumulator[i]); + } + return accumulator; + } + + @Override + public Coder<Object[]> getAccumulatorCoder(CoderRegistry registry, Coder<DataT> dataCoder) + throws CannotProvideCoderException { + List<Coder<Object>> coders = Lists.newArrayList(); + for (int i = 0; i < combineFnCount; ++i) { + Coder<Object> inputCoder = + registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder); + coders.add(combineFns.get(i).getAccumulatorCoder(registry, inputCoder)); + } + return new ComposedAccumulatorCoder(coders); + } + } + + /** + * A composed {@link CombineFnWithContext} that applies multiple + * {@link CombineFnWithContext CombineFnWithContexts}. + * + * <p>For each {@link CombineFnWithContext} it extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public static class ComposedCombineFnWithContext<DataT> + extends CombineFnWithContext<DataT, Object[], CoCombineResult> { + + private final List<SerializableFunction<DataT, Object>> extractInputFns; + private final List<CombineFnWithContext<Object, Object, Object>> combineFnWithContexts; + private final List<TupleTag<?>> outputTags; + private final int combineFnCount; + + private ComposedCombineFnWithContext() { + this.extractInputFns = ImmutableList.of(); + this.combineFnWithContexts = ImmutableList.of(); + this.outputTags = ImmutableList.of(); + this.combineFnCount = 0; + } + + private ComposedCombineFnWithContext( + ImmutableList<SerializableFunction<DataT, ?>> extractInputFns, + ImmutableList<CombineFnWithContext<?, ?, ?>> combineFnWithContexts, + ImmutableList<TupleTag<?>> outputTags) { + @SuppressWarnings({"unchecked", "rawtypes"}) + List<SerializableFunction<DataT, Object>> castedExtractInputFns = + (List) extractInputFns; + this.extractInputFns = castedExtractInputFns; + + @SuppressWarnings({"rawtypes", "unchecked"}) + List<CombineFnWithContext<Object, Object, Object>> castedCombineFnWithContexts + = (List) combineFnWithContexts; + this.combineFnWithContexts = castedCombineFnWithContexts; + + this.outputTags = outputTags; + this.combineFnCount = this.combineFnWithContexts.size(); + } + + /** + * Returns a {@link ComposedCombineFnWithContext} with an additional {@link GlobalCombineFn}. + */ + public <InputT, OutputT> ComposedCombineFnWithContext<DataT> with( + SimpleFunction<DataT, InputT> extractInputFn, + GlobalCombineFn<InputT, ?, OutputT> globalCombineFn, + TupleTag<OutputT> outputTag) { + checkUniqueness(outputTags, outputTag); + return new ComposedCombineFnWithContext<>( + ImmutableList.<SerializableFunction<DataT, ?>>builder() + .addAll(extractInputFns) + .add(extractInputFn) + .build(), + ImmutableList.<CombineFnWithContext<?, ?, ?>>builder() + .addAll(combineFnWithContexts) + .add(toFnWithContext(globalCombineFn)) + .build(), + ImmutableList.<TupleTag<?>>builder() + .addAll(outputTags) + .add(outputTag) + .build()); + } + + @Override + public Object[] createAccumulator(Context c) { + Object[] accumsArray = new Object[combineFnCount]; + for (int i = 0; i < combineFnCount; ++i) { + accumsArray[i] = combineFnWithContexts.get(i).createAccumulator(c); + } + return accumsArray; + } + + @Override + public Object[] addInput(Object[] accumulator, DataT value, Context c) { + for (int i = 0; i < combineFnCount; ++i) { + Object input = extractInputFns.get(i).apply(value); + accumulator[i] = combineFnWithContexts.get(i).addInput(accumulator[i], input, c); + } + return accumulator; + } + + @Override + public Object[] mergeAccumulators(Iterable<Object[]> accumulators, Context c) { + Iterator<Object[]> iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(c); + } else { + // Reuses the first accumulator, and overwrites its values. + // It is safe because {@code accum[i]} only depends on + // the i-th component of each accumulator. + Object[] accum = iter.next(); + for (int i = 0; i < combineFnCount; ++i) { + accum[i] = combineFnWithContexts.get(i).mergeAccumulators( + new ProjectionIterable(accumulators, i), c); + } + return accum; + } + } + + @Override + public CoCombineResult extractOutput(Object[] accumulator, Context c) { + Map<TupleTag<?>, Object> valuesMap = Maps.newHashMap(); + for (int i = 0; i < combineFnCount; ++i) { + valuesMap.put( + outputTags.get(i), + combineFnWithContexts.get(i).extractOutput(accumulator[i], c)); + } + return new CoCombineResult(valuesMap); + } + + @Override + public Object[] compact(Object[] accumulator, Context c) { + for (int i = 0; i < combineFnCount; ++i) { + accumulator[i] = combineFnWithContexts.get(i).compact(accumulator[i], c); + } + return accumulator; + } + + @Override + public Coder<Object[]> getAccumulatorCoder(CoderRegistry registry, Coder<DataT> dataCoder) + throws CannotProvideCoderException { + List<Coder<Object>> coders = Lists.newArrayList(); + for (int i = 0; i < combineFnCount; ++i) { + Coder<Object> inputCoder = + registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder); + coders.add(combineFnWithContexts.get(i).getAccumulatorCoder(registry, inputCoder)); + } + return new ComposedAccumulatorCoder(coders); + } + } + + /** + * A composed {@link KeyedCombineFn} that applies multiple {@link KeyedCombineFn KeyedCombineFns}. + * + * <p>For each {@link KeyedCombineFn} it extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public static class ComposedKeyedCombineFn<DataT, K> + extends KeyedCombineFn<K, DataT, Object[], CoCombineResult> { + + private final List<SerializableFunction<DataT, Object>> extractInputFns; + private final List<KeyedCombineFn<K, Object, Object, Object>> keyedCombineFns; + private final List<TupleTag<?>> outputTags; + private final int combineFnCount; + + private ComposedKeyedCombineFn() { + this.extractInputFns = ImmutableList.of(); + this.keyedCombineFns = ImmutableList.of(); + this.outputTags = ImmutableList.of(); + this.combineFnCount = 0; + } + + private ComposedKeyedCombineFn( + ImmutableList<SerializableFunction<DataT, ?>> extractInputFns, + ImmutableList<KeyedCombineFn<K, ?, ?, ?>> keyedCombineFns, + ImmutableList<TupleTag<?>> outputTags) { + @SuppressWarnings({"unchecked", "rawtypes"}) + List<SerializableFunction<DataT, Object>> castedExtractInputFns = (List) extractInputFns; + this.extractInputFns = castedExtractInputFns; + + @SuppressWarnings({"unchecked", "rawtypes"}) + List<KeyedCombineFn<K, Object, Object, Object>> castedKeyedCombineFns = + (List) keyedCombineFns; + this.keyedCombineFns = castedKeyedCombineFns; + this.outputTags = outputTags; + this.combineFnCount = this.keyedCombineFns.size(); + } + + /** + * Returns a {@link ComposedKeyedCombineFn} with an additional {@link KeyedCombineFn}. + */ + public <InputT, OutputT> ComposedKeyedCombineFn<DataT, K> with( + SimpleFunction<DataT, InputT> extractInputFn, + KeyedCombineFn<K, InputT, ?, OutputT> keyedCombineFn, + TupleTag<OutputT> outputTag) { + checkUniqueness(outputTags, outputTag); + return new ComposedKeyedCombineFn<>( + ImmutableList.<SerializableFunction<DataT, ?>>builder() + .addAll(extractInputFns) + .add(extractInputFn) + .build(), + ImmutableList.<KeyedCombineFn<K, ?, ?, ?>>builder() + .addAll(keyedCombineFns) + .add(keyedCombineFn) + .build(), + ImmutableList.<TupleTag<?>>builder() + .addAll(outputTags) + .add(outputTag) + .build()); + } + + /** + * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional + * {@link KeyedCombineFnWithContext}. + */ + public <InputT, OutputT> ComposedKeyedCombineFnWithContext<DataT, K> with( + SimpleFunction<DataT, InputT> extractInputFn, + KeyedCombineFnWithContext<K, InputT, ?, OutputT> keyedCombineFn, + TupleTag<OutputT> outputTag) { + checkUniqueness(outputTags, outputTag); + List<KeyedCombineFnWithContext<K, Object, Object, Object>> fnsWithContext = + Lists.newArrayList(); + for (KeyedCombineFn<K, Object, Object, Object> fn : keyedCombineFns) { + fnsWithContext.add(toFnWithContext(fn)); + } + return new ComposedKeyedCombineFnWithContext<>( + ImmutableList.<SerializableFunction<DataT, ?>>builder() + .addAll(extractInputFns) + .add(extractInputFn) + .build(), + ImmutableList.<KeyedCombineFnWithContext<K, ?, ?, ?>>builder() + .addAll(fnsWithContext) + .add(keyedCombineFn) + .build(), + ImmutableList.<TupleTag<?>>builder() + .addAll(outputTags) + .add(outputTag) + .build()); + } + + /** + * Returns a {@link ComposedKeyedCombineFn} with an additional {@link CombineFn}. + */ + public <InputT, OutputT> ComposedKeyedCombineFn<DataT, K> with( + SimpleFunction<DataT, InputT> extractInputFn, + CombineFn<InputT, ?, OutputT> keyedCombineFn, + TupleTag<OutputT> outputTag) { + return with(extractInputFn, keyedCombineFn.<K>asKeyedFn(), outputTag); + } + + /** + * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional + * {@link CombineFnWithContext}. + */ + public <InputT, OutputT> ComposedKeyedCombineFnWithContext<DataT, K> with( + SimpleFunction<DataT, InputT> extractInputFn, + CombineFnWithContext<InputT, ?, OutputT> keyedCombineFn, + TupleTag<OutputT> outputTag) { + return with(extractInputFn, keyedCombineFn.<K>asKeyedFn(), outputTag); + } + + @Override + public Object[] createAccumulator(K key) { + Object[] accumsArray = new Object[combineFnCount]; + for (int i = 0; i < combineFnCount; ++i) { + accumsArray[i] = keyedCombineFns.get(i).createAccumulator(key); + } + return accumsArray; + } + + @Override + public Object[] addInput(K key, Object[] accumulator, DataT value) { + for (int i = 0; i < combineFnCount; ++i) { + Object input = extractInputFns.get(i).apply(value); + accumulator[i] = keyedCombineFns.get(i).addInput(key, accumulator[i], input); + } + return accumulator; + } + + @Override + public Object[] mergeAccumulators(K key, final Iterable<Object[]> accumulators) { + Iterator<Object[]> iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(key); + } else { + // Reuses the first accumulator, and overwrites its values. + // It is safe because {@code accum[i]} only depends on + // the i-th component of each accumulator. + Object[] accum = iter.next(); + for (int i = 0; i < combineFnCount; ++i) { + accum[i] = keyedCombineFns.get(i).mergeAccumulators( + key, new ProjectionIterable(accumulators, i)); + } + return accum; + } + } + + @Override + public CoCombineResult extractOutput(K key, Object[] accumulator) { + Map<TupleTag<?>, Object> valuesMap = Maps.newHashMap(); + for (int i = 0; i < combineFnCount; ++i) { + valuesMap.put( + outputTags.get(i), + keyedCombineFns.get(i).extractOutput(key, accumulator[i])); + } + return new CoCombineResult(valuesMap); + } + + @Override + public Object[] compact(K key, Object[] accumulator) { + for (int i = 0; i < combineFnCount; ++i) { + accumulator[i] = keyedCombineFns.get(i).compact(key, accumulator[i]); + } + return accumulator; + } + + @Override + public Coder<Object[]> getAccumulatorCoder( + CoderRegistry registry, Coder<K> keyCoder, Coder<DataT> dataCoder) + throws CannotProvideCoderException { + List<Coder<Object>> coders = Lists.newArrayList(); + for (int i = 0; i < combineFnCount; ++i) { + Coder<Object> inputCoder = + registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder); + coders.add(keyedCombineFns.get(i).getAccumulatorCoder(registry, keyCoder, inputCoder)); + } + return new ComposedAccumulatorCoder(coders); + } + } + + /** + * A composed {@link KeyedCombineFnWithContext} that applies multiple + * {@link KeyedCombineFnWithContext KeyedCombineFnWithContexts}. + * + * <p>For each {@link KeyedCombineFnWithContext} it extracts inputs from {@code DataT} with + * the {@code extractInputFn} and combines them, + * and then it outputs each combined value with a {@link TupleTag} to a + * {@link CoCombineResult}. + */ + public static class ComposedKeyedCombineFnWithContext<DataT, K> + extends KeyedCombineFnWithContext<K, DataT, Object[], CoCombineResult> { + + private final List<SerializableFunction<DataT, Object>> extractInputFns; + private final List<KeyedCombineFnWithContext<K, Object, Object, Object>> keyedCombineFns; + private final List<TupleTag<?>> outputTags; + private final int combineFnCount; + + private ComposedKeyedCombineFnWithContext() { + this.extractInputFns = ImmutableList.of(); + this.keyedCombineFns = ImmutableList.of(); + this.outputTags = ImmutableList.of(); + this.combineFnCount = 0; + } + + private ComposedKeyedCombineFnWithContext( + ImmutableList<SerializableFunction<DataT, ?>> extractInputFns, + ImmutableList<KeyedCombineFnWithContext<K, ?, ?, ?>> keyedCombineFns, + ImmutableList<TupleTag<?>> outputTags) { + @SuppressWarnings({"unchecked", "rawtypes"}) + List<SerializableFunction<DataT, Object>> castedExtractInputFns = + (List) extractInputFns; + this.extractInputFns = castedExtractInputFns; + + @SuppressWarnings({"unchecked", "rawtypes"}) + List<KeyedCombineFnWithContext<K, Object, Object, Object>> castedKeyedCombineFns = + (List) keyedCombineFns; + this.keyedCombineFns = castedKeyedCombineFns; + this.outputTags = outputTags; + this.combineFnCount = this.keyedCombineFns.size(); + } + + /** + * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional + * {@link PerKeyCombineFn}. + */ + public <InputT, OutputT> ComposedKeyedCombineFnWithContext<DataT, K> with( + SimpleFunction<DataT, InputT> extractInputFn, + PerKeyCombineFn<K, InputT, ?, OutputT> perKeyCombineFn, + TupleTag<OutputT> outputTag) { + checkUniqueness(outputTags, outputTag); + return new ComposedKeyedCombineFnWithContext<>( + ImmutableList.<SerializableFunction<DataT, ?>>builder() + .addAll(extractInputFns) + .add(extractInputFn) + .build(), + ImmutableList.<KeyedCombineFnWithContext<K, ?, ?, ?>>builder() + .addAll(keyedCombineFns) + .add(toFnWithContext(perKeyCombineFn)) + .build(), + ImmutableList.<TupleTag<?>>builder() + .addAll(outputTags) + .add(outputTag) + .build()); + } + + /** + * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional + * {@link GlobalCombineFn}. + */ + public <InputT, OutputT> ComposedKeyedCombineFnWithContext<DataT, K> with( + SimpleFunction<DataT, InputT> extractInputFn, + GlobalCombineFn<InputT, ?, OutputT> perKeyCombineFn, + TupleTag<OutputT> outputTag) { + return with(extractInputFn, perKeyCombineFn.<K>asKeyedFn(), outputTag); + } + + @Override + public Object[] createAccumulator(K key, Context c) { + Object[] accumsArray = new Object[combineFnCount]; + for (int i = 0; i < combineFnCount; ++i) { + accumsArray[i] = keyedCombineFns.get(i).createAccumulator(key, c); + } + return accumsArray; + } + + @Override + public Object[] addInput(K key, Object[] accumulator, DataT value, Context c) { + for (int i = 0; i < combineFnCount; ++i) { + Object input = extractInputFns.get(i).apply(value); + accumulator[i] = keyedCombineFns.get(i).addInput(key, accumulator[i], input, c); + } + return accumulator; + } + + @Override + public Object[] mergeAccumulators(K key, Iterable<Object[]> accumulators, Context c) { + Iterator<Object[]> iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(key, c); + } else { + // Reuses the first accumulator, and overwrites its values. + // It is safe because {@code accum[i]} only depends on + // the i-th component of each accumulator. + Object[] accum = iter.next(); + for (int i = 0; i < combineFnCount; ++i) { + accum[i] = keyedCombineFns.get(i).mergeAccumulators( + key, new ProjectionIterable(accumulators, i), c); + } + return accum; + } + } + + @Override + public CoCombineResult extractOutput(K key, Object[] accumulator, Context c) { + Map<TupleTag<?>, Object> valuesMap = Maps.newHashMap(); + for (int i = 0; i < combineFnCount; ++i) { + valuesMap.put( + outputTags.get(i), + keyedCombineFns.get(i).extractOutput(key, accumulator[i], c)); + } + return new CoCombineResult(valuesMap); + } + + @Override + public Object[] compact(K key, Object[] accumulator, Context c) { + for (int i = 0; i < combineFnCount; ++i) { + accumulator[i] = keyedCombineFns.get(i).compact(key, accumulator[i], c); + } + return accumulator; + } + + @Override + public Coder<Object[]> getAccumulatorCoder( + CoderRegistry registry, Coder<K> keyCoder, Coder<DataT> dataCoder) + throws CannotProvideCoderException { + List<Coder<Object>> coders = Lists.newArrayList(); + for (int i = 0; i < combineFnCount; ++i) { + Coder<Object> inputCoder = + registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder); + coders.add(keyedCombineFns.get(i).getAccumulatorCoder( + registry, keyCoder, inputCoder)); + } + return new ComposedAccumulatorCoder(coders); + } + } + + ///////////////////////////////////////////////////////////////////////////// + + private static class ProjectionIterable implements Iterable<Object> { + private final Iterable<Object[]> iterable; + private final int column; + + private ProjectionIterable(Iterable<Object[]> iterable, int column) { + this.iterable = iterable; + this.column = column; + } + + @Override + public Iterator<Object> iterator() { + final Iterator<Object[]> iter = iterable.iterator(); + return new Iterator<Object>() { + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Object next() { + return iter.next()[column]; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + } + + private static class ComposedAccumulatorCoder extends StandardCoder<Object[]> { + private List<Coder<Object>> coders; + private int codersCount; + + public ComposedAccumulatorCoder(List<Coder<Object>> coders) { + this.coders = ImmutableList.copyOf(coders); + this.codersCount = coders.size(); + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + @JsonCreator + public static ComposedAccumulatorCoder of( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List<Coder<?>> components) { + return new ComposedAccumulatorCoder((List) components); + } + + @Override + public void encode(Object[] value, OutputStream outStream, Context context) + throws CoderException, IOException { + checkArgument(value.length == codersCount); + Context nestedContext = context.nested(); + for (int i = 0; i < codersCount; ++i) { + coders.get(i).encode(value[i], outStream, nestedContext); + } + } + + @Override + public Object[] decode(InputStream inStream, Context context) + throws CoderException, IOException { + Object[] ret = new Object[codersCount]; + Context nestedContext = context.nested(); + for (int i = 0; i < codersCount; ++i) { + ret[i] = coders.get(i).decode(inStream, nestedContext); + } + return ret; + } + + @Override + public List<? extends Coder<?>> getCoderArguments() { + return coders; + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + for (int i = 0; i < codersCount; ++i) { + coders.get(i).verifyDeterministic(); + } + } + } + + @SuppressWarnings("unchecked") + private static <InputT, AccumT, OutputT> CombineFnWithContext<InputT, AccumT, OutputT> + toFnWithContext(GlobalCombineFn<InputT, AccumT, OutputT> globalCombineFn) { + if (globalCombineFn instanceof CombineFnWithContext) { + return (CombineFnWithContext<InputT, AccumT, OutputT>) globalCombineFn; + } else { + final CombineFn<InputT, AccumT, OutputT> combineFn = + (CombineFn<InputT, AccumT, OutputT>) globalCombineFn; + return new CombineFnWithContext<InputT, AccumT, OutputT>() { + @Override + public AccumT createAccumulator(Context c) { + return combineFn.createAccumulator(); + } + @Override + public AccumT addInput(AccumT accumulator, InputT input, Context c) { + return combineFn.addInput(accumulator, input); + } + @Override + public AccumT mergeAccumulators(Iterable<AccumT> accumulators, Context c) { + return combineFn.mergeAccumulators(accumulators); + } + @Override + public OutputT extractOutput(AccumT accumulator, Context c) { + return combineFn.extractOutput(accumulator); + } + @Override + public AccumT compact(AccumT accumulator, Context c) { + return combineFn.compact(accumulator); + } + @Override + public OutputT defaultValue() { + return combineFn.defaultValue(); + } + @Override + public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry, Coder<InputT> inputCoder) + throws CannotProvideCoderException { + return combineFn.getAccumulatorCoder(registry, inputCoder); + } + @Override + public Coder<OutputT> getDefaultOutputCoder( + CoderRegistry registry, Coder<InputT> inputCoder) throws CannotProvideCoderException { + return combineFn.getDefaultOutputCoder(registry, inputCoder); + } + }; + } + } + + private static <K, InputT, AccumT, OutputT> KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> + toFnWithContext(PerKeyCombineFn<K, InputT, AccumT, OutputT> perKeyCombineFn) { + if (perKeyCombineFn instanceof KeyedCombineFnWithContext) { + @SuppressWarnings("unchecked") + KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> keyedCombineFnWithContext = + (KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>) perKeyCombineFn; + return keyedCombineFnWithContext; + } else { + @SuppressWarnings("unchecked") + final KeyedCombineFn<K, InputT, AccumT, OutputT> keyedCombineFn = + (KeyedCombineFn<K, InputT, AccumT, OutputT>) perKeyCombineFn; + return new KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>() { + @Override + public AccumT createAccumulator(K key, Context c) { + return keyedCombineFn.createAccumulator(key); + } + @Override + public AccumT addInput(K key, AccumT accumulator, InputT value, Context c) { + return keyedCombineFn.addInput(key, accumulator, value); + } + @Override + public AccumT mergeAccumulators(K key, Iterable<AccumT> accumulators, Context c) { + return keyedCombineFn.mergeAccumulators(key, accumulators); + } + @Override + public OutputT extractOutput(K key, AccumT accumulator, Context c) { + return keyedCombineFn.extractOutput(key, accumulator); + } + @Override + public AccumT compact(K key, AccumT accumulator, Context c) { + return keyedCombineFn.compact(key, accumulator); + } + @Override + public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry, Coder<K> keyCoder, + Coder<InputT> inputCoder) throws CannotProvideCoderException { + return keyedCombineFn.getAccumulatorCoder(registry, keyCoder, inputCoder); + } + @Override + public Coder<OutputT> getDefaultOutputCoder(CoderRegistry registry, Coder<K> keyCoder, + Coder<InputT> inputCoder) throws CannotProvideCoderException { + return keyedCombineFn.getDefaultOutputCoder(registry, keyCoder, inputCoder); + } + }; + } + } + + private static <OutputT> void checkUniqueness( + List<TupleTag<?>> registeredTags, TupleTag<OutputT> outputTag) { + checkArgument( + !registeredTags.contains(outputTag), + "Cannot compose with tuple tag %s because it is already present in the composition.", + outputTag); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/23b43780/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineFnsTest.java ---------------------------------------------------------------------- diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineFnsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineFnsTest.java new file mode 100644 index 0000000..ad37708 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineFnsTest.java @@ -0,0 +1,413 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.transforms; + +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.NullableCoder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.testing.RunnableOnService; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Combine.BinaryCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineFns.CoCombineResult; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.Max.MaxIntegerFn; +import com.google.cloud.dataflow.sdk.transforms.Min.MinIntegerFn; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.ImmutableList; + +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +/** + * Unit tests for {@link CombineFns}. + */ +@RunWith(JUnit4.class) +public class CombineFnsTest { + @Rule public ExpectedException expectedException = ExpectedException.none(); + + @Test + public void testDuplicatedTags() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("it is already present in the composition"); + + TupleTag<Integer> tag = new TupleTag<Integer>(); + CombineFns.compose() + .with(new GetIntegerFunction(), new MaxIntegerFn(), tag) + .with(new GetIntegerFunction(), new MinIntegerFn(), tag); + } + + @Test + public void testDuplicatedTagsKeyed() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("it is already present in the composition"); + + TupleTag<Integer> tag = new TupleTag<Integer>(); + CombineFns.composeKeyed() + .with(new GetIntegerFunction(), new MaxIntegerFn(), tag) + .with(new GetIntegerFunction(), new MinIntegerFn(), tag); + } + + @Test + public void testDuplicatedTagsWithContext() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("it is already present in the composition"); + + TupleTag<UserString> tag = new TupleTag<UserString>(); + CombineFns.compose() + .with( + new GetUserStringFunction(), + new ConcatStringWithContext(null /* view */).forKey("G", StringUtf8Coder.of()), + tag) + .with( + new GetUserStringFunction(), + new ConcatStringWithContext(null /* view */).forKey("G", StringUtf8Coder.of()), + tag); + } + + @Test + public void testDuplicatedTagsWithContextKeyed() { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("it is already present in the composition"); + + TupleTag<UserString> tag = new TupleTag<UserString>(); + CombineFns.composeKeyed() + .with( + new GetUserStringFunction(), + new ConcatStringWithContext(null /* view */), + tag) + .with( + new GetUserStringFunction(), + new ConcatStringWithContext(null /* view */), + tag); + } + + @Test + @Category(RunnableOnService.class) + public void testComposedCombine() { + Pipeline p = TestPipeline.create(); + p.getCoderRegistry().registerCoder(UserString.class, UserStringCoder.of()); + + PCollection<KV<String, KV<Integer, UserString>>> perKeyInput = p.apply( + Create.timestamped( + Arrays.asList( + KV.of("a", KV.of(1, UserString.of("1"))), + KV.of("a", KV.of(1, UserString.of("1"))), + KV.of("a", KV.of(4, UserString.of("4"))), + KV.of("b", KV.of(1, UserString.of("1"))), + KV.of("b", KV.of(13, UserString.of("13")))), + Arrays.asList(0L, 4L, 7L, 10L, 16L)) + .withCoder(KvCoder.of( + StringUtf8Coder.of(), + KvCoder.of(BigEndianIntegerCoder.of(), UserStringCoder.of())))); + + TupleTag<Integer> maxIntTag = new TupleTag<Integer>(); + TupleTag<UserString> concatStringTag = new TupleTag<UserString>(); + PCollection<KV<String, KV<Integer, String>>> combineGlobally = perKeyInput + .apply(Values.<KV<Integer, UserString>>create()) + .apply(Combine.globally(CombineFns.compose() + .with( + new GetIntegerFunction(), + new MaxIntegerFn(), + maxIntTag) + .with( + new GetUserStringFunction(), + new ConcatString(), + concatStringTag))) + .apply(WithKeys.<String, CoCombineResult>of("global")) + .apply( + "ExtractGloballyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); + + PCollection<KV<String, KV<Integer, String>>> combinePerKey = perKeyInput + .apply(Combine.perKey(CombineFns.composeKeyed() + .with( + new GetIntegerFunction(), + new MaxIntegerFn().<String>asKeyedFn(), + maxIntTag) + .with( + new GetUserStringFunction(), + new ConcatString().<String>asKeyedFn(), + concatStringTag))) + .apply("ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); + DataflowAssert.that(combineGlobally).containsInAnyOrder( + KV.of("global", KV.of(13, "111134"))); + DataflowAssert.that(combinePerKey).containsInAnyOrder( + KV.of("a", KV.of(4, "114")), + KV.of("b", KV.of(13, "113"))); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testComposedCombineWithContext() { + Pipeline p = TestPipeline.create(); + p.getCoderRegistry().registerCoder(UserString.class, UserStringCoder.of()); + + PCollectionView<String> view = p + .apply(Create.of("I")) + .apply(View.<String>asSingleton()); + + PCollection<KV<String, KV<Integer, UserString>>> perKeyInput = p.apply( + Create.timestamped( + Arrays.asList( + KV.of("a", KV.of(1, UserString.of("1"))), + KV.of("a", KV.of(1, UserString.of("1"))), + KV.of("a", KV.of(4, UserString.of("4"))), + KV.of("b", KV.of(1, UserString.of("1"))), + KV.of("b", KV.of(13, UserString.of("13")))), + Arrays.asList(0L, 4L, 7L, 10L, 16L)) + .withCoder(KvCoder.of( + StringUtf8Coder.of(), + KvCoder.of(BigEndianIntegerCoder.of(), UserStringCoder.of())))); + + TupleTag<Integer> maxIntTag = new TupleTag<Integer>(); + TupleTag<UserString> concatStringTag = new TupleTag<UserString>(); + PCollection<KV<String, KV<Integer, String>>> combineGlobally = perKeyInput + .apply(Values.<KV<Integer, UserString>>create()) + .apply(Combine.globally(CombineFns.compose() + .with( + new GetIntegerFunction(), + new MaxIntegerFn(), + maxIntTag) + .with( + new GetUserStringFunction(), + new ConcatStringWithContext(view).forKey("G", StringUtf8Coder.of()), + concatStringTag)) + .withoutDefaults() + .withSideInputs(ImmutableList.of(view))) + .apply(WithKeys.<String, CoCombineResult>of("global")) + .apply( + "ExtractGloballyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); + + PCollection<KV<String, KV<Integer, String>>> combinePerKey = perKeyInput + .apply(Combine.perKey(CombineFns.composeKeyed() + .with( + new GetIntegerFunction(), + new MaxIntegerFn().<String>asKeyedFn(), + maxIntTag) + .with( + new GetUserStringFunction(), + new ConcatStringWithContext(view), + concatStringTag)) + .withSideInputs(ImmutableList.of(view))) + .apply("ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); + DataflowAssert.that(combineGlobally).containsInAnyOrder( + KV.of("global", KV.of(13, "111134GI"))); + DataflowAssert.that(combinePerKey).containsInAnyOrder( + KV.of("a", KV.of(4, "114Ia")), + KV.of("b", KV.of(13, "113Ib"))); + p.run(); + } + + @Test + @Category(RunnableOnService.class) + public void testComposedCombineNullValues() { + Pipeline p = TestPipeline.create(); + p.getCoderRegistry().registerCoder(UserString.class, NullableCoder.of(UserStringCoder.of())); + p.getCoderRegistry().registerCoder(String.class, NullableCoder.of(StringUtf8Coder.of())); + + PCollection<KV<String, KV<Integer, UserString>>> perKeyInput = p.apply( + Create.timestamped( + Arrays.asList( + KV.of("a", KV.of(1, UserString.of("1"))), + KV.of("a", KV.of(1, UserString.of("1"))), + KV.of("a", KV.of(4, UserString.of("4"))), + KV.of("b", KV.of(1, UserString.of("1"))), + KV.of("b", KV.of(13, UserString.of("13")))), + Arrays.asList(0L, 4L, 7L, 10L, 16L)) + .withCoder(KvCoder.of( + StringUtf8Coder.of(), + KvCoder.of( + BigEndianIntegerCoder.of(), NullableCoder.of(UserStringCoder.of()))))); + + TupleTag<Integer> maxIntTag = new TupleTag<Integer>(); + TupleTag<UserString> concatStringTag = new TupleTag<UserString>(); + + PCollection<KV<String, KV<Integer, String>>> combinePerKey = perKeyInput + .apply(Combine.perKey(CombineFns.composeKeyed() + .with( + new GetIntegerFunction(), + new MaxIntegerFn().<String>asKeyedFn(), + maxIntTag) + .with( + new GetUserStringFunction(), + new OutputNullString().<String>asKeyedFn(), + concatStringTag))) + .apply("ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag))); + DataflowAssert.that(combinePerKey).containsInAnyOrder( + KV.of("a", KV.of(4, (String) null)), + KV.of("b", KV.of(13, (String) null))); + p.run(); + } + + private static class UserString implements Serializable { + private String strValue; + + static UserString of(String strValue) { + UserString ret = new UserString(); + ret.strValue = strValue; + return ret; + } + } + + private static class UserStringCoder extends StandardCoder<UserString> { + public static UserStringCoder of() { + return INSTANCE; + } + + private static final UserStringCoder INSTANCE = new UserStringCoder(); + + @Override + public void encode(UserString value, OutputStream outStream, Context context) + throws CoderException, IOException { + StringUtf8Coder.of().encode(value.strValue, outStream, context); + } + + @Override + public UserString decode(InputStream inStream, Context context) + throws CoderException, IOException { + return UserString.of(StringUtf8Coder.of().decode(inStream, context)); + } + + @Override + public List<? extends Coder<?>> getCoderArguments() { + return null; + } + + @Override + public void verifyDeterministic() throws NonDeterministicException {} + } + + private static class GetIntegerFunction + extends SimpleFunction<KV<Integer, UserString>, Integer> { + @Override + public Integer apply(KV<Integer, UserString> input) { + return input.getKey(); + } + } + + private static class GetUserStringFunction + extends SimpleFunction<KV<Integer, UserString>, UserString> { + @Override + public UserString apply(KV<Integer, UserString> input) { + return input.getValue(); + } + } + + private static class ConcatString extends BinaryCombineFn<UserString> { + @Override + public UserString apply(UserString left, UserString right) { + String retStr = left.strValue + right.strValue; + char[] chars = retStr.toCharArray(); + Arrays.sort(chars); + return UserString.of(new String(chars)); + } + } + + private static class OutputNullString extends BinaryCombineFn<UserString> { + @Override + public UserString apply(UserString left, UserString right) { + return null; + } + } + + private static class ConcatStringWithContext + extends KeyedCombineFnWithContext<String, UserString, UserString, UserString> { + private final PCollectionView<String> view; + + private ConcatStringWithContext(PCollectionView<String> view) { + this.view = view; + } + + @Override + public UserString createAccumulator(String key, CombineWithContext.Context c) { + return UserString.of(key + c.sideInput(view)); + } + + @Override + public UserString addInput( + String key, UserString accumulator, UserString input, CombineWithContext.Context c) { + assertThat(accumulator.strValue, Matchers.startsWith(key + c.sideInput(view))); + accumulator.strValue += input.strValue; + return accumulator; + } + + @Override + public UserString mergeAccumulators( + String key, Iterable<UserString> accumulators, CombineWithContext.Context c) { + String keyPrefix = key + c.sideInput(view); + String all = keyPrefix; + for (UserString accumulator : accumulators) { + assertThat(accumulator.strValue, Matchers.startsWith(keyPrefix)); + all += accumulator.strValue.substring(keyPrefix.length()); + accumulator.strValue = "cleared in mergeAccumulators"; + } + return UserString.of(all); + } + + @Override + public UserString extractOutput( + String key, UserString accumulator, CombineWithContext.Context c) { + assertThat(accumulator.strValue, Matchers.startsWith(key + c.sideInput(view))); + char[] chars = accumulator.strValue.toCharArray(); + Arrays.sort(chars); + return UserString.of(new String(chars)); + } + } + + private static class ExtractResultDoFn + extends DoFn<KV<String, CoCombineResult>, KV<String, KV<Integer, String>>>{ + + private final TupleTag<Integer> maxIntTag; + private final TupleTag<UserString> concatStringTag; + + ExtractResultDoFn(TupleTag<Integer> maxIntTag, TupleTag<UserString> concatStringTag) { + this.maxIntTag = maxIntTag; + this.concatStringTag = concatStringTag; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + UserString userString = c.element().getValue().get(concatStringTag); + KV<Integer, String> value = KV.of( + c.element().getValue().get(maxIntTag), + userString == null ? null : userString.strValue); + c.output(KV.of(c.element().getKey(), value)); + } + } +}