Repository: incubator-beam Updated Branches: refs/heads/master e3f2d9564 -> 15a8334f9
Cleanup: move toFnWithContext() to CombineFnUtil Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/0d9dee3e Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/0d9dee3e Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/0d9dee3e Branch: refs/heads/master Commit: 0d9dee3e1028a96678b3108eff9ae9e650424c02 Parents: e3f2d95 Author: Pei He <[email protected]> Authored: Thu Mar 17 15:14:33 2016 -0700 Committer: bchambers <[email protected]> Committed: Wed Apr 6 12:54:04 2016 -0700 ---------------------------------------------------------------------- .../dataflow/sdk/transforms/CombineFns.java | 102 +------------------ .../cloud/dataflow/sdk/util/CombineFnUtil.java | 51 +++++++++- .../dataflow/sdk/util/CombineFnUtilTest.java | 40 ++++++++ 3 files changed, 95 insertions(+), 98 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0d9dee3e/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java index 7af5292..8120de3 100644 --- a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java +++ b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java @@ -31,6 +31,7 @@ import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.PerKeyCombineFn; import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.CombineFnWithContext; import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.Context; import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.util.CombineFnUtil; import com.google.cloud.dataflow.sdk.util.PropertyNames; import com.google.cloud.dataflow.sdk.values.TupleTag; import com.google.common.collect.ImmutableList; @@ -371,7 +372,7 @@ public class CombineFns { checkUniqueness(outputTags, outputTag); List<CombineFnWithContext<Object, Object, Object>> fnsWithContext = Lists.newArrayList(); for (CombineFn<Object, Object, Object> fn : combineFns) { - fnsWithContext.add(toFnWithContext(fn)); + fnsWithContext.add(CombineFnUtil.toFnWithContext(fn)); } return new ComposedCombineFnWithContext<>( ImmutableList.<SerializableFunction<DataT, ?>>builder() @@ -512,7 +513,7 @@ public class CombineFns { .build(), ImmutableList.<CombineFnWithContext<?, ?, ?>>builder() .addAll(combineFnWithContexts) - .add(toFnWithContext(globalCombineFn)) + .add(CombineFnUtil.toFnWithContext(globalCombineFn)) .build(), ImmutableList.<TupleTag<?>>builder() .addAll(outputTags) @@ -662,7 +663,7 @@ public class CombineFns { List<KeyedCombineFnWithContext<K, Object, Object, Object>> fnsWithContext = Lists.newArrayList(); for (KeyedCombineFn<K, Object, Object, Object> fn : keyedCombineFns) { - fnsWithContext.add(toFnWithContext(fn)); + fnsWithContext.add(CombineFnUtil.toFnWithContext(fn)); } return new ComposedKeyedCombineFnWithContext<>( ImmutableList.<SerializableFunction<DataT, ?>>builder() @@ -826,7 +827,7 @@ public class CombineFns { .build(), ImmutableList.<KeyedCombineFnWithContext<K, ?, ?, ?>>builder() .addAll(keyedCombineFns) - .add(toFnWithContext(perKeyCombineFn)) + .add(CombineFnUtil.toFnWithContext(perKeyCombineFn)) .build(), ImmutableList.<TupleTag<?>>builder() .addAll(outputTags) @@ -999,99 +1000,6 @@ public class CombineFns { } } - @SuppressWarnings("unchecked") - private static <InputT, AccumT, OutputT> CombineFnWithContext<InputT, AccumT, OutputT> - toFnWithContext(GlobalCombineFn<InputT, AccumT, OutputT> globalCombineFn) { - if (globalCombineFn instanceof CombineFnWithContext) { - return (CombineFnWithContext<InputT, AccumT, OutputT>) globalCombineFn; - } else { - final CombineFn<InputT, AccumT, OutputT> combineFn = - (CombineFn<InputT, AccumT, OutputT>) globalCombineFn; - return new CombineFnWithContext<InputT, AccumT, OutputT>() { - @Override - public AccumT createAccumulator(Context c) { - return combineFn.createAccumulator(); - } - @Override - public AccumT addInput(AccumT accumulator, InputT input, Context c) { - return combineFn.addInput(accumulator, input); - } - @Override - public AccumT mergeAccumulators(Iterable<AccumT> accumulators, Context c) { - return combineFn.mergeAccumulators(accumulators); - } - @Override - public OutputT extractOutput(AccumT accumulator, Context c) { - return combineFn.extractOutput(accumulator); - } - @Override - public AccumT compact(AccumT accumulator, Context c) { - return combineFn.compact(accumulator); - } - @Override - public OutputT defaultValue() { - return combineFn.defaultValue(); - } - @Override - public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry, Coder<InputT> inputCoder) - throws CannotProvideCoderException { - return combineFn.getAccumulatorCoder(registry, inputCoder); - } - @Override - public Coder<OutputT> getDefaultOutputCoder( - CoderRegistry registry, Coder<InputT> inputCoder) throws CannotProvideCoderException { - return combineFn.getDefaultOutputCoder(registry, inputCoder); - } - }; - } - } - - private static <K, InputT, AccumT, OutputT> KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> - toFnWithContext(PerKeyCombineFn<K, InputT, AccumT, OutputT> perKeyCombineFn) { - if (perKeyCombineFn instanceof KeyedCombineFnWithContext) { - @SuppressWarnings("unchecked") - KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> keyedCombineFnWithContext = - (KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>) perKeyCombineFn; - return keyedCombineFnWithContext; - } else { - @SuppressWarnings("unchecked") - final KeyedCombineFn<K, InputT, AccumT, OutputT> keyedCombineFn = - (KeyedCombineFn<K, InputT, AccumT, OutputT>) perKeyCombineFn; - return new KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>() { - @Override - public AccumT createAccumulator(K key, Context c) { - return keyedCombineFn.createAccumulator(key); - } - @Override - public AccumT addInput(K key, AccumT accumulator, InputT value, Context c) { - return keyedCombineFn.addInput(key, accumulator, value); - } - @Override - public AccumT mergeAccumulators(K key, Iterable<AccumT> accumulators, Context c) { - return keyedCombineFn.mergeAccumulators(key, accumulators); - } - @Override - public OutputT extractOutput(K key, AccumT accumulator, Context c) { - return keyedCombineFn.extractOutput(key, accumulator); - } - @Override - public AccumT compact(K key, AccumT accumulator, Context c) { - return keyedCombineFn.compact(key, accumulator); - } - @Override - public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry, Coder<K> keyCoder, - Coder<InputT> inputCoder) throws CannotProvideCoderException { - return keyedCombineFn.getAccumulatorCoder(registry, keyCoder, inputCoder); - } - @Override - public Coder<OutputT> getDefaultOutputCoder(CoderRegistry registry, Coder<K> keyCoder, - Coder<InputT> inputCoder) throws CannotProvideCoderException { - return keyedCombineFn.getDefaultOutputCoder(registry, keyCoder, inputCoder); - } - }; - } - } - private static <OutputT> void checkUniqueness( List<TupleTag<?>> registeredTags, TupleTag<OutputT> outputTag) { checkArgument( http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0d9dee3e/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/util/CombineFnUtil.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/util/CombineFnUtil.java b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/util/CombineFnUtil.java index 097bae3..ed07efa 100644 --- a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/util/CombineFnUtil.java +++ b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/util/CombineFnUtil.java @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -24,6 +23,7 @@ import com.google.cloud.dataflow.sdk.coders.CoderRegistry; import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.GlobalCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.PerKeyCombineFn; import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.CombineFnWithContext; import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.Context; import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; @@ -105,6 +105,55 @@ public class CombineFnUtil { } } + /** + * Return a {@link KeyedCombineFnWithContext} from the given {@link PerKeyCombineFn}. + */ + public static <K, InputT, AccumT, OutputT> KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> + toFnWithContext(PerKeyCombineFn<K, InputT, AccumT, OutputT> perKeyCombineFn) { + if (perKeyCombineFn instanceof KeyedCombineFnWithContext) { + @SuppressWarnings("unchecked") + KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> keyedCombineFnWithContext = + (KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>) perKeyCombineFn; + return keyedCombineFnWithContext; + } else { + @SuppressWarnings("unchecked") + final KeyedCombineFn<K, InputT, AccumT, OutputT> keyedCombineFn = + (KeyedCombineFn<K, InputT, AccumT, OutputT>) perKeyCombineFn; + return new KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>() { + @Override + public AccumT createAccumulator(K key, Context c) { + return keyedCombineFn.createAccumulator(key); + } + @Override + public AccumT addInput(K key, AccumT accumulator, InputT value, Context c) { + return keyedCombineFn.addInput(key, accumulator, value); + } + @Override + public AccumT mergeAccumulators(K key, Iterable<AccumT> accumulators, Context c) { + return keyedCombineFn.mergeAccumulators(key, accumulators); + } + @Override + public OutputT extractOutput(K key, AccumT accumulator, Context c) { + return keyedCombineFn.extractOutput(key, accumulator); + } + @Override + public AccumT compact(K key, AccumT accumulator, Context c) { + return keyedCombineFn.compact(key, accumulator); + } + @Override + public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry, Coder<K> keyCoder, + Coder<InputT> inputCoder) throws CannotProvideCoderException { + return keyedCombineFn.getAccumulatorCoder(registry, keyCoder, inputCoder); + } + @Override + public Coder<OutputT> getDefaultOutputCoder(CoderRegistry registry, Coder<K> keyCoder, + Coder<InputT> inputCoder) throws CannotProvideCoderException { + return keyedCombineFn.getDefaultOutputCoder(registry, keyCoder, inputCoder); + } + }; + } + } + private static class NonSerializableBoundedKeyedCombineFn<K, InputT, AccumT, OutputT> extends KeyedCombineFn<K, InputT, AccumT, OutputT> { private final KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> combineFn; http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0d9dee3e/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/util/CombineFnUtilTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/util/CombineFnUtilTest.java b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/util/CombineFnUtilTest.java index 40b8900..173c1fa 100644 --- a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/util/CombineFnUtilTest.java +++ b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/util/CombineFnUtilTest.java @@ -17,11 +17,17 @@ */ package com.google.cloud.dataflow.sdk.util; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.withSettings; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.CombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.Context; import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.Sum; import com.google.cloud.dataflow.sdk.util.state.StateContexts; +import com.google.common.collect.ImmutableList; import org.junit.Before; import org.junit.Rule; @@ -33,6 +39,7 @@ import org.junit.runners.JUnit4; import java.io.ByteArrayOutputStream; import java.io.NotSerializableException; import java.io.ObjectOutputStream; +import java.util.List; /** * Unit tests for {@link CombineFnUtil}. @@ -61,4 +68,37 @@ public class CombineFnUtilTest { ObjectOutputStream oos = new ObjectOutputStream(buffer); oos.writeObject(CombineFnUtil.bindContext(mockCombineFn, StateContexts.nullContext())); } + + @Test + public void testToFnWithContextIdempotent() throws Exception { + CombineFnWithContext<Integer, int[], Integer> fnWithContext = + CombineFnUtil.toFnWithContext(new Sum.SumIntegerFn()); + assertTrue(fnWithContext == CombineFnUtil.toFnWithContext(fnWithContext)); + + KeyedCombineFnWithContext<Object, Integer, int[], Integer> keyedFnWithContext = + CombineFnUtil.toFnWithContext(new Sum.SumIntegerFn().asKeyedFn()); + assertTrue(keyedFnWithContext == CombineFnUtil.toFnWithContext(keyedFnWithContext)); + } + + @Test + public void testToFnWithContext() throws Exception { + CombineFnWithContext<Integer, int[], Integer> fnWithContext = + CombineFnUtil.toFnWithContext(new Sum.SumIntegerFn()); + List<Integer> inputs = ImmutableList.of(1, 2, 3, 4); + Context nullContext = CombineContextFactory.nullContext(); + int[] accum = fnWithContext.createAccumulator(nullContext); + for (Integer i : inputs) { + accum = fnWithContext.addInput(accum, i, nullContext); + } + assertEquals(10, fnWithContext.extractOutput(accum, nullContext).intValue()); + + KeyedCombineFnWithContext<String, Integer, int[], Integer> keyedFnWithContext = + CombineFnUtil.toFnWithContext(new Sum.SumIntegerFn().<String>asKeyedFn()); + String key = "key"; + accum = keyedFnWithContext.createAccumulator(key, nullContext); + for (Integer i : inputs) { + accum = keyedFnWithContext.addInput(key, accum, i, nullContext); + } + assertEquals(10, keyedFnWithContext.extractOutput(key, accum, nullContext).intValue()); + } }
