Support CombineFnWithContext in GroupAlsoByWindows This requires plumbing pipeline options and side inputs to State API, including: 1. Adding bindKeyedCombiningValueWithContext() to StateTag.java 2. Adding StateContext to StateInternals.java 3. Plumbing through the remaining files
----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=115408034 Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/6613031b Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/6613031b Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/6613031b Branch: refs/heads/master Commit: 6613031bfed700567bdc2d4eab6b754c4392ce77 Parents: 6926d8e Author: peihe <[email protected]> Authored: Tue Feb 23 18:49:00 2016 -0800 Committer: Davor Bonaci <[email protected]> Committed: Thu Feb 25 23:58:27 2016 -0800 ---------------------------------------------------------------------- .../sdk/util/CombineContextFactory.java | 18 +++ .../cloud/dataflow/sdk/util/CombineFnUtil.java | 97 +++++++++++++ .../cloud/dataflow/sdk/util/DoFnRunnerBase.java | 5 + .../util/GroupAlsoByWindowViaWindowSetDoFn.java | 3 +- .../GroupAlsoByWindowsViaOutputBufferDoFn.java | 3 +- .../sdk/util/ReduceFnContextFactory.java | 64 +++++---- .../cloud/dataflow/sdk/util/ReduceFnRunner.java | 6 +- .../cloud/dataflow/sdk/util/SystemReduceFn.java | 19 +-- .../sdk/util/TriggerContextFactory.java | 4 +- .../cloud/dataflow/sdk/util/TriggerRunner.java | 9 +- .../dataflow/sdk/util/WindowingInternals.java | 5 + .../CopyOnAccessInMemoryStateInternals.java | 72 +++++++--- .../sdk/util/state/InMemoryStateInternals.java | 35 +++-- .../dataflow/sdk/util/state/StateContext.java | 41 ++++++ .../dataflow/sdk/util/state/StateContexts.java | 107 +++++++++++++++ .../dataflow/sdk/util/state/StateInternals.java | 7 + .../dataflow/sdk/util/state/StateTable.java | 8 +- .../cloud/dataflow/sdk/util/state/StateTag.java | 7 + .../dataflow/sdk/util/state/StateTags.java | 83 ++++++++++- .../dataflow/sdk/util/CombineFnUtilTest.java | 62 +++++++++ .../dataflow/sdk/util/ReduceFnRunnerTest.java | 137 ++++++++++++++++++- .../cloud/dataflow/sdk/util/ReduceFnTester.java | 58 +++++++- 22 files changed, 763 insertions(+), 87 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CombineContextFactory.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CombineContextFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CombineContextFactory.java index bf09587..6f2b89b 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CombineContextFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CombineContextFactory.java @@ -19,6 +19,7 @@ import com.google.cloud.dataflow.sdk.options.PipelineOptions; import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.Context; import com.google.cloud.dataflow.sdk.transforms.DoFn; import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.state.StateContext; import com.google.cloud.dataflow.sdk.values.PCollectionView; /** @@ -63,6 +64,23 @@ public class CombineContextFactory { } /** + * Returns a {@code Combine.Context} that wraps a {@link StateContext}. + */ + public static Context createFromStateContext(final StateContext<?> c) { + return new Context() { + @Override + public PipelineOptions getPipelineOptions() { + return c.getPipelineOptions(); + } + + @Override + public <T> T sideInput(PCollectionView<T> view) { + return c.sideInput(view); + } + }; + } + + /** * Returns a {@code Combine.Context} from {@code PipelineOptions}, {@code SideInputReader}, * and the main input window. */ http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CombineFnUtil.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CombineFnUtil.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CombineFnUtil.java new file mode 100644 index 0000000..6201e6e --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/CombineFnUtil.java @@ -0,0 +1,97 @@ + +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderRegistry; +import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.Context; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.util.state.StateContext; + +import java.io.IOException; +import java.io.NotSerializableException; +import java.io.ObjectOutputStream; + +/** + * Static utility methods that create combine function instances. + */ +public class CombineFnUtil { + /** + * Returns the partial application of the {@link KeyedCombineFnWithContext} to a specific + * context to produce a {@link KeyedCombineFn}. + * + * <p>The returned {@link KeyedCombineFn} cannot be serialized. + */ + public static <K, InputT, AccumT, OutputT> KeyedCombineFn<K, InputT, AccumT, OutputT> + bindContext( + KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> combineFn, + StateContext<?> stateContext) { + Context context = CombineContextFactory.createFromStateContext(stateContext); + return new NonSerializableBoundedKeyedCombineFn<>(combineFn, context); + } + + private static class NonSerializableBoundedKeyedCombineFn<K, InputT, AccumT, OutputT> + extends KeyedCombineFn<K, InputT, AccumT, OutputT> { + private final KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> combineFn; + private final Context context; + + private NonSerializableBoundedKeyedCombineFn( + KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> combineFn, + Context context) { + this.combineFn = combineFn; + this.context = context; + } + @Override + public AccumT createAccumulator(K key) { + return combineFn.createAccumulator(key, context); + } + @Override + public AccumT addInput(K key, AccumT accumulator, InputT value) { + return combineFn.addInput(key, accumulator, value, context); + } + @Override + public AccumT mergeAccumulators(K key, Iterable<AccumT> accumulators) { + return combineFn.mergeAccumulators(key, accumulators, context); + } + @Override + public OutputT extractOutput(K key, AccumT accumulator) { + return combineFn.extractOutput(key, accumulator, context); + } + @Override + public AccumT compact(K key, AccumT accumulator) { + return combineFn.compact(key, accumulator, context); + } + @Override + public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry, Coder<K> keyCoder, + Coder<InputT> inputCoder) throws CannotProvideCoderException { + return combineFn.getAccumulatorCoder(registry, keyCoder, inputCoder); + } + @Override + public Coder<OutputT> getDefaultOutputCoder(CoderRegistry registry, Coder<K> keyCoder, + Coder<InputT> inputCoder) throws CannotProvideCoderException { + return combineFn.getDefaultOutputCoder(registry, keyCoder, inputCoder); + } + + private void writeObject(@SuppressWarnings("unused") ObjectOutputStream out) + throws IOException { + throw new NotSerializableException( + "Cannot serialize the CombineFn resulting from CombineFnUtil.bindContext."); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnRunnerBase.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnRunnerBase.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnRunnerBase.java index 25ead03..04ec59f 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnRunnerBase.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/DoFnRunnerBase.java @@ -540,6 +540,11 @@ public abstract class DoFnRunnerBase<InputT, OutputT> implements DoFnRunner<Inpu public StateInternals<?> stateInternals() { return context.stepContext.stateInternals(); } + + @Override + public <T> T sideInput(PCollectionView<T> view, BoundedWindow mainInputWindow) { + return context.sideInput(view, mainInputWindow); + } }; } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowViaWindowSetDoFn.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowViaWindowSetDoFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowViaWindowSetDoFn.java index ac2df24..f6246d1 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowViaWindowSetDoFn.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowViaWindowSetDoFn.java @@ -77,7 +77,8 @@ public class GroupAlsoByWindowViaWindowSetDoFn< timerInternals, c.windowingInternals(), droppedDueToClosedWindow, - reduceFn); + reduceFn, + c.getPipelineOptions()); for (TimerData timer : element.timersIterable()) { reduceFnRunner.onTimer(timer); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsViaOutputBufferDoFn.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsViaOutputBufferDoFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsViaOutputBufferDoFn.java index 1d1afe3..d394e81 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsViaOutputBufferDoFn.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/GroupAlsoByWindowsViaOutputBufferDoFn.java @@ -67,7 +67,8 @@ public class GroupAlsoByWindowsViaOutputBufferDoFn<K, InputT, OutputT, W extends timerInternals, c.windowingInternals(), droppedDueToClosedWindow, - reduceFn); + reduceFn, + c.getPipelineOptions()); Iterable<List<WindowedValue<InputT>>> chunks = Iterables.partition(c.element().getValue(), 1000); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReduceFnContextFactory.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReduceFnContextFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReduceFnContextFactory.java index b2ab752..bdbaf10 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReduceFnContextFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReduceFnContextFactory.java @@ -18,6 +18,7 @@ package com.google.cloud.dataflow.sdk.util; import static com.google.common.base.Preconditions.checkNotNull; import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; import com.google.cloud.dataflow.sdk.util.TimerInternals.TimerData; @@ -25,6 +26,8 @@ import com.google.cloud.dataflow.sdk.util.state.MergingStateAccessor; import com.google.cloud.dataflow.sdk.util.state.ReadableState; import com.google.cloud.dataflow.sdk.util.state.State; import com.google.cloud.dataflow.sdk.util.state.StateAccessor; +import com.google.cloud.dataflow.sdk.util.state.StateContext; +import com.google.cloud.dataflow.sdk.util.state.StateContexts; import com.google.cloud.dataflow.sdk.util.state.StateInternals; import com.google.cloud.dataflow.sdk.util.state.StateNamespace; import com.google.cloud.dataflow.sdk.util.state.StateNamespaces; @@ -51,19 +54,24 @@ class ReduceFnContextFactory<K, InputT, OutputT, W extends BoundedWindow> { private final K key; private final ReduceFn<K, InputT, OutputT, W> reduceFn; private final WindowingStrategy<?, W> windowingStrategy; - private StateInternals<K> stateInternals; - private ActiveWindowSet<W> activeWindows; - private TimerInternals timerInternals; + private final StateInternals<K> stateInternals; + private final ActiveWindowSet<W> activeWindows; + private final TimerInternals timerInternals; + private final WindowingInternals<?, ?> windowingInternals; + private final PipelineOptions options; ReduceFnContextFactory(K key, ReduceFn<K, InputT, OutputT, W> reduceFn, WindowingStrategy<?, W> windowingStrategy, StateInternals<K> stateInternals, - ActiveWindowSet<W> activeWindows, TimerInternals timerInternals) { + ActiveWindowSet<W> activeWindows, TimerInternals timerInternals, + WindowingInternals<?, ?> windowingInternals, PipelineOptions options) { this.key = key; this.reduceFn = reduceFn; this.windowingStrategy = windowingStrategy; this.stateInternals = stateInternals; this.activeWindows = activeWindows; this.timerInternals = timerInternals; + this.windowingInternals = windowingInternals; + this.options = options; } /** Where should we look for state associated with a given window? */ @@ -74,24 +82,25 @@ class ReduceFnContextFactory<K, InputT, OutputT, W extends BoundedWindow> { RENAMED } - private StateAccessorImpl<K, W> stateContext(W window, StateStyle style) { + private StateAccessorImpl<K, W> stateAccessor(W window, StateStyle style) { return new StateAccessorImpl<K, W>( activeWindows, windowingStrategy.getWindowFn().windowCoder(), - stateInternals, window, style); + stateInternals, StateContexts.createFromComponents(options, windowingInternals, window), + style); } public ReduceFn<K, InputT, OutputT, W>.Context base(W window, StateStyle style) { - return new ContextImpl(stateContext(window, style)); + return new ContextImpl(stateAccessor(window, style)); } public ReduceFn<K, InputT, OutputT, W>.ProcessValueContext forValue( W window, InputT value, Instant timestamp, StateStyle style) { - return new ProcessValueContextImpl(stateContext(window, style), value, timestamp); + return new ProcessValueContextImpl(stateAccessor(window, style), value, timestamp); } public ReduceFn<K, InputT, OutputT, W>.OnTriggerContext forTrigger(W window, ReadableState<PaneInfo> pane, StateStyle style, OnTriggerCallbacks<OutputT> callbacks) { - return new OnTriggerContextImpl(stateContext(window, style), pane, callbacks); + return new OnTriggerContextImpl(stateAccessor(window, style), pane, callbacks); } public ReduceFn<K, InputT, OutputT, W>.OnMergeContext forMerge( @@ -150,20 +159,20 @@ class ReduceFnContextFactory<K, InputT, OutputT, W extends BoundedWindow> { protected final ActiveWindowSet<W> activeWindows; - protected final W window; + protected final StateContext<W> context; protected final StateNamespace windowNamespace; protected final Coder<W> windowCoder; protected final StateInternals<K> stateInternals; protected final StateStyle style; public StateAccessorImpl(ActiveWindowSet<W> activeWindows, Coder<W> windowCoder, - StateInternals<K> stateInternals, W window, StateStyle style) { + StateInternals<K> stateInternals, StateContext<W> context, StateStyle style) { this.activeWindows = activeWindows; this.windowCoder = windowCoder; this.stateInternals = stateInternals; - this.window = checkNotNull(window); - this.windowNamespace = namespaceFor(window); + this.context = checkNotNull(context); + this.windowNamespace = namespaceFor(context.window()); this.style = style; } @@ -176,7 +185,7 @@ class ReduceFnContextFactory<K, InputT, OutputT, W extends BoundedWindow> { } W window() { - return window; + return context.window(); } StateNamespace namespace() { @@ -187,10 +196,10 @@ class ReduceFnContextFactory<K, InputT, OutputT, W extends BoundedWindow> { public <StateT extends State> StateT access(StateTag<? super K, StateT> address) { switch (style) { case DIRECT: - return stateInternals.state(windowNamespace(), address); + return stateInternals.state(windowNamespace(), address, context); case RENAMED: return stateInternals.state( - namespaceFor(activeWindows.writeStateAddress(window)), address); + namespaceFor(activeWindows.writeStateAddress(context.window())), address, context); } throw new RuntimeException(); // cases are exhaustive. } @@ -203,7 +212,8 @@ class ReduceFnContextFactory<K, InputT, OutputT, W extends BoundedWindow> { public MergingStateAccessorImpl(ActiveWindowSet<W> activeWindows, Coder<W> windowCoder, StateInternals<K> stateInternals, StateStyle style, Collection<W> activeToBeMerged, W mergeResult) { - super(activeWindows, windowCoder, stateInternals, mergeResult, style); + super(activeWindows, windowCoder, stateInternals, + StateContexts.windowOnly(mergeResult), style); this.activeToBeMerged = activeToBeMerged; } @@ -211,11 +221,13 @@ class ReduceFnContextFactory<K, InputT, OutputT, W extends BoundedWindow> { public <StateT extends State> StateT access(StateTag<? super K, StateT> address) { switch (style) { case DIRECT: - return stateInternals.state(windowNamespace(), address); + return stateInternals.state(windowNamespace(), address, context); case RENAMED: return stateInternals.state( - namespaceFor(activeWindows.mergedWriteStateAddress(activeToBeMerged, window)), - address); + namespaceFor(activeWindows.mergedWriteStateAddress( + activeToBeMerged, context.window())), + address, + context); } throw new RuntimeException(); // cases are exhaustive. } @@ -235,7 +247,7 @@ class ReduceFnContextFactory<K, InputT, OutputT, W extends BoundedWindow> { break; } Preconditions.checkNotNull(namespace); // cases are exhaustive. - builder.put(mergingWindow, stateInternals.state(namespace, address)); + builder.put(mergingWindow, stateInternals.state(namespace, address, context)); } return builder.build(); } @@ -245,19 +257,21 @@ class ReduceFnContextFactory<K, InputT, OutputT, W extends BoundedWindow> { extends StateAccessorImpl<K, W> implements MergingStateAccessor<K, W> { public PremergingStateAccessorImpl(ActiveWindowSet<W> activeWindows, Coder<W> windowCoder, StateInternals<K> stateInternals, W window) { - super(activeWindows, windowCoder, stateInternals, window, StateStyle.RENAMED); + super(activeWindows, windowCoder, stateInternals, + StateContexts.windowOnly(window), StateStyle.RENAMED); } Collection<W> mergingWindows() { - return activeWindows.readStateAddresses(window); + return activeWindows.readStateAddresses(context.window()); } @Override public <StateT extends State> Map<W, StateT> accessInEachMergingWindow( StateTag<? super K, StateT> address) { ImmutableMap.Builder<W, StateT> builder = ImmutableMap.builder(); - for (W stateAddressWindow : activeWindows.readStateAddresses(window)) { - StateT stateForWindow = stateInternals.state(namespaceFor(stateAddressWindow), address); + for (W stateAddressWindow : activeWindows.readStateAddresses(context.window())) { + StateT stateForWindow = + stateInternals.state(namespaceFor(stateAddressWindow), address, context); builder.put(stateAddressWindow, stateForWindow); } return builder.build(); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReduceFnRunner.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReduceFnRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReduceFnRunner.java index ec83688..fe5c474 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReduceFnRunner.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/ReduceFnRunner.java @@ -15,6 +15,7 @@ */ package com.google.cloud.dataflow.sdk.util; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; import com.google.cloud.dataflow.sdk.transforms.Aggregator; import com.google.cloud.dataflow.sdk.transforms.DoFn; import com.google.cloud.dataflow.sdk.transforms.GroupByKey.GroupByKeyOnly; @@ -203,7 +204,8 @@ public class ReduceFnRunner<K, InputT, OutputT, W extends BoundedWindow> { TimerInternals timerInternals, WindowingInternals<?, KV<K, OutputT>> windowingInternals, Aggregator<Long, Long> droppedDueToClosedWindow, - ReduceFn<K, InputT, OutputT, W> reduceFn) { + ReduceFn<K, InputT, OutputT, W> reduceFn, + PipelineOptions options) { this.key = key; this.timerInternals = timerInternals; this.paneInfoTracker = new PaneInfoTracker(timerInternals); @@ -224,7 +226,7 @@ public class ReduceFnRunner<K, InputT, OutputT, W extends BoundedWindow> { this.contextFactory = new ReduceFnContextFactory<K, InputT, OutputT, W>(key, reduceFn, this.windowingStrategy, - stateInternals, this.activeWindows, timerInternals); + stateInternals, this.activeWindows, timerInternals, windowingInternals, options); this.watermarkHold = new WatermarkHold<>(timerInternals, windowingStrategy); this.triggerRunner = http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SystemReduceFn.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SystemReduceFn.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SystemReduceFn.java index d5d9126..1665792 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SystemReduceFn.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/SystemReduceFn.java @@ -15,12 +15,11 @@ */ package com.google.cloud.dataflow.sdk.util; -import static com.google.common.base.Preconditions.checkArgument; import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; -import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.RequiresContextInternal; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; import com.google.cloud.dataflow.sdk.transforms.GroupByKey; import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; import com.google.cloud.dataflow.sdk.util.state.AccumulatorCombiningState; @@ -74,15 +73,19 @@ public abstract class SystemReduceFn<K, InputT, AccumT, OutputT, W extends Bound AccumT, OutputT, W> combining( final Coder<K> keyCoder, final AppliedCombineFn<K, InputT, AccumT, OutputT> combineFn) { - checkArgument( - !(combineFn.getFn() instanceof RequiresContextInternal), - "Combiner lifting is not supported for combine functions with contexts: %s", - combineFn.getFn().getClass().getName()); - final StateTag<K, AccumulatorCombiningState<InputT, AccumT, OutputT>> bufferTag = - StateTags.makeSystemTagInternal( + final StateTag<K, AccumulatorCombiningState<InputT, AccumT, OutputT>> bufferTag; + if (combineFn.getFn() instanceof KeyedCombineFnWithContext) { + bufferTag = StateTags.makeSystemTagInternal( + StateTags.<K, InputT, AccumT, OutputT>keyedCombiningValueWithContext( + BUFFER_NAME, combineFn.getAccumulatorCoder(), + (KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>) combineFn.getFn())); + + } else { + bufferTag = StateTags.makeSystemTagInternal( StateTags.<K, InputT, AccumT, OutputT>keyedCombiningValue( BUFFER_NAME, combineFn.getAccumulatorCoder(), (KeyedCombineFn<K, InputT, AccumT, OutputT>) combineFn.getFn())); + } return new SystemReduceFn<K, InputT, AccumT, OutputT, W>(bufferTag) { @Override public void prefetchOnMerge(MergingStateAccessor<K, W> state) throws Exception { http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TriggerContextFactory.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TriggerContextFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TriggerContextFactory.java index 87e8b00..64ff402 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TriggerContextFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TriggerContextFactory.java @@ -80,11 +80,11 @@ public class TriggerContextFactory<W extends BoundedWindow> { return new OnMergeContextImpl(window, timers, rootTrigger, finishedSet, finishedSets); } - public StateAccessor<?> createStateContext(W window, ExecutableTrigger<W> trigger) { + public StateAccessor<?> createStateAccessor(W window, ExecutableTrigger<W> trigger) { return new StateAccessorImpl(window, trigger); } - public MergingStateAccessor<?, W> createMergingStateContext( + public MergingStateAccessor<?, W> createMergingStateAccessor( W mergeResult, Collection<W> mergingWindows, ExecutableTrigger<W> trigger) { return new MergingStateAccessorImpl(trigger, mergingWindows, mergeResult); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TriggerRunner.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TriggerRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TriggerRunner.java index 1b78ddc..dcfd035 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TriggerRunner.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/TriggerRunner.java @@ -91,14 +91,15 @@ public class TriggerRunner<W extends BoundedWindow> { if (isFinishedSetNeeded()) { state.access(FINISHED_BITS_TAG).readLater(); } - rootTrigger.getSpec().prefetchOnElement(contextFactory.createStateContext(window, rootTrigger)); + rootTrigger.getSpec().prefetchOnElement( + contextFactory.createStateAccessor(window, rootTrigger)); } public void prefetchOnFire(W window, StateAccessor<?> state) { if (isFinishedSetNeeded()) { state.access(FINISHED_BITS_TAG).readLater(); } - rootTrigger.getSpec().prefetchOnFire(contextFactory.createStateContext(window, rootTrigger)); + rootTrigger.getSpec().prefetchOnFire(contextFactory.createStateAccessor(window, rootTrigger)); } public void prefetchShouldFire(W window, StateAccessor<?> state) { @@ -106,7 +107,7 @@ public class TriggerRunner<W extends BoundedWindow> { state.access(FINISHED_BITS_TAG).readLater(); } rootTrigger.getSpec().prefetchShouldFire( - contextFactory.createStateContext(window, rootTrigger)); + contextFactory.createStateAccessor(window, rootTrigger)); } /** @@ -130,7 +131,7 @@ public class TriggerRunner<W extends BoundedWindow> { value.readLater(); } } - rootTrigger.getSpec().prefetchOnMerge(contextFactory.createMergingStateContext( + rootTrigger.getSpec().prefetchOnMerge(contextFactory.createMergingStateAccessor( window, mergingWindows, rootTrigger)); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowingInternals.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowingInternals.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowingInternals.java index 9ffdbee..12fcd53 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowingInternals.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/WindowingInternals.java @@ -74,4 +74,9 @@ public interface WindowingInternals<InputT, OutputT> { TupleTag<?> tag, Iterable<WindowedValue<T>> data, Coder<T> elemCoder) throws IOException; + + /** + * Return the value of the side input for the window of a main input element. + */ + <T> T sideInput(PCollectionView<T> view, BoundedWindow mainInputWindow); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/CopyOnAccessInMemoryStateInternals.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/CopyOnAccessInMemoryStateInternals.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/CopyOnAccessInMemoryStateInternals.java index 19e45d6..3683b74 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/CopyOnAccessInMemoryStateInternals.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/CopyOnAccessInMemoryStateInternals.java @@ -20,8 +20,10 @@ import static com.google.common.base.Preconditions.checkState; import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; +import com.google.cloud.dataflow.sdk.util.CombineFnUtil; import com.google.cloud.dataflow.sdk.util.state.InMemoryStateInternals.InMemoryState; import com.google.cloud.dataflow.sdk.util.state.StateTag.StateBinder; import com.google.common.base.Optional; @@ -97,7 +99,13 @@ public class CopyOnAccessInMemoryStateInternals<K> implements StateInternals<K> @Override public <T extends State> T state(StateNamespace namespace, StateTag<? super K, T> address) { - return table.get(namespace, address); + return state(namespace, address, StateContexts.nullContext()); + } + + @Override + public <T extends State> T state( + StateNamespace namespace, StateTag<? super K, T> address, StateContext<?> c) { + return table.get(namespace, address, c); } @Override @@ -220,12 +228,12 @@ public class CopyOnAccessInMemoryStateInternals<K> implements StateInternals<K> } @Override - protected StateBinder<K> binderForNamespace(final StateNamespace namespace) { - return binderFactory.forNamespace(namespace); + protected StateBinder<K> binderForNamespace(final StateNamespace namespace, StateContext<?> c) { + return binderFactory.forNamespace(namespace, c); } private static interface StateBinderFactory<K> { - StateBinder<K> forNamespace(StateNamespace namespace); + StateBinder<K> forNamespace(StateNamespace namespace, StateContext<?> c); } /** @@ -246,7 +254,7 @@ public class CopyOnAccessInMemoryStateInternals<K> implements StateInternals<K> } @Override - public StateBinder<K> forNamespace(final StateNamespace namespace) { + public StateBinder<K> forNamespace(final StateNamespace namespace, final StateContext<?> c) { return new StateBinder<K>() { @Override public <W extends BoundedWindow> WatermarkHoldState<W> bindWatermark( @@ -256,7 +264,7 @@ public class CopyOnAccessInMemoryStateInternals<K> implements StateInternals<K> @SuppressWarnings("unchecked") InMemoryState<? extends WatermarkHoldState<W>> existingState = (InMemoryStateInternals.InMemoryState<? extends WatermarkHoldState<W>>) - underlying.get().get(namespace, address); + underlying.get().get(namespace, address, c); return existingState.copy(); } else { return new InMemoryStateInternals.InMemoryWatermarkHold<>( @@ -271,7 +279,7 @@ public class CopyOnAccessInMemoryStateInternals<K> implements StateInternals<K> @SuppressWarnings("unchecked") InMemoryState<? extends ValueState<T>> existingState = (InMemoryStateInternals.InMemoryState<? extends ValueState<T>>) - underlying.get().get(namespace, address); + underlying.get().get(namespace, address, c); return existingState.copy(); } else { return new InMemoryStateInternals.InMemoryValue<>(); @@ -289,7 +297,7 @@ public class CopyOnAccessInMemoryStateInternals<K> implements StateInternals<K> existingState = ( InMemoryStateInternals .InMemoryState<? extends AccumulatorCombiningState<InputT, AccumT, - OutputT>>) underlying.get().get(namespace, address); + OutputT>>) underlying.get().get(namespace, address, c); return existingState.copy(); } else { return new InMemoryStateInternals.InMemoryCombiningValue<>( @@ -304,7 +312,7 @@ public class CopyOnAccessInMemoryStateInternals<K> implements StateInternals<K> @SuppressWarnings("unchecked") InMemoryState<? extends BagState<T>> existingState = (InMemoryStateInternals.InMemoryState<? extends BagState<T>>) - underlying.get().get(namespace, address); + underlying.get().get(namespace, address, c); return existingState.copy(); } else { return new InMemoryStateInternals.InMemoryBag<>(); @@ -323,12 +331,22 @@ public class CopyOnAccessInMemoryStateInternals<K> implements StateInternals<K> existingState = ( InMemoryStateInternals .InMemoryState<? extends AccumulatorCombiningState<InputT, AccumT, - OutputT>>) underlying.get().get(namespace, address); + OutputT>>) underlying.get().get(namespace, address, c); return existingState.copy(); } else { return new InMemoryStateInternals.InMemoryCombiningValue<>(key, combineFn); } } + + @Override + public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> + bindKeyedCombiningValueWithContext( + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + KeyedCombineFnWithContext<? super K, InputT, AccumT, OutputT> combineFn) { + return bindKeyedCombiningValue( + address, accumCoder, CombineFnUtil.bindContext(combineFn, c)); + } }; } } @@ -354,7 +372,8 @@ public class CopyOnAccessInMemoryStateInternals<K> implements StateInternals<K> // Only read through non-cleared values to ensure that completed windows are // eventually discarded, and remember the earliest watermark hold from among those // values. - State state = readTo.get(namespace, existingState.getKey()); + State state = + readTo.get(namespace, existingState.getKey(), StateContexts.nullContext()); if (state instanceof WatermarkHoldState) { Instant hold = ((WatermarkHoldState<?>) state).read(); if (hold != null && hold.isBefore(earliestHold)) { @@ -368,19 +387,19 @@ public class CopyOnAccessInMemoryStateInternals<K> implements StateInternals<K> } @Override - public StateBinder<K> forNamespace(final StateNamespace namespace) { + public StateBinder<K> forNamespace(final StateNamespace namespace, final StateContext<?> c) { return new StateBinder<K>() { @Override public <W extends BoundedWindow> WatermarkHoldState<W> bindWatermark( StateTag<? super K, WatermarkHoldState<W>> address, OutputTimeFn<? super W> outputTimeFn) { - return underlying.get(namespace, address); + return underlying.get(namespace, address, c); } @Override public <T> ValueState<T> bindValue( StateTag<? super K, ValueState<T>> address, Coder<T> coder) { - return underlying.get(namespace, address); + return underlying.get(namespace, address, c); } @Override @@ -388,13 +407,13 @@ public class CopyOnAccessInMemoryStateInternals<K> implements StateInternals<K> bindCombiningValue( StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, Coder<AccumT> accumCoder, CombineFn<InputT, AccumT, OutputT> combineFn) { - return underlying.get(namespace, address); + return underlying.get(namespace, address, c); } @Override public <T> BagState<T> bindBag( StateTag<? super K, BagState<T>> address, Coder<T> elemCoder) { - return underlying.get(namespace, address); + return underlying.get(namespace, address, c); } @Override @@ -403,24 +422,33 @@ public class CopyOnAccessInMemoryStateInternals<K> implements StateInternals<K> StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, Coder<AccumT> accumCoder, KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) { - return underlying.get(namespace, address); + return underlying.get(namespace, address, c); + } + + @Override + public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> + bindKeyedCombiningValueWithContext( + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + KeyedCombineFnWithContext<? super K, InputT, AccumT, OutputT> combineFn) { + return bindKeyedCombiningValue( + address, accumCoder, CombineFnUtil.bindContext(combineFn, c)); } }; } } private static class InMemoryStateBinderFactory<K> implements StateBinderFactory<K> { - private final InMemoryStateInternals.InMemoryStateBinder<K> inMemoryStateBinder; + private final K key; public InMemoryStateBinderFactory(K key) { - inMemoryStateBinder = new InMemoryStateInternals.InMemoryStateBinder<>(key); + this.key = key; } @Override - public StateBinder<K> forNamespace(StateNamespace namespace) { - return inMemoryStateBinder; + public StateBinder<K> forNamespace(StateNamespace namespace, StateContext<?> c) { + return new InMemoryStateInternals.InMemoryStateBinder<>(key, c); } } } - } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/InMemoryStateInternals.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/InMemoryStateInternals.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/InMemoryStateInternals.java index 4a2555f..8404801 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/InMemoryStateInternals.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/InMemoryStateInternals.java @@ -20,8 +20,10 @@ import com.google.cloud.dataflow.sdk.annotations.Experimental.Kind; import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; +import com.google.cloud.dataflow.sdk.util.CombineFnUtil; import com.google.cloud.dataflow.sdk.util.state.StateTag.StateBinder; import org.joda.time.Instant; @@ -62,8 +64,8 @@ public class InMemoryStateInternals<K> implements StateInternals<K> { protected final StateTable<K> inMemoryState = new StateTable<K>() { @Override - protected StateBinder<K> binderForNamespace(final StateNamespace namespace) { - return new InMemoryStateBinder<K>(key); + protected StateBinder<K> binderForNamespace(StateNamespace namespace, StateContext<?> c) { + return new InMemoryStateBinder<K>(key, c); } }; @@ -81,7 +83,13 @@ public class InMemoryStateInternals<K> implements StateInternals<K> { @Override public <T extends State> T state(StateNamespace namespace, StateTag<? super K, T> address) { - return inMemoryState.get(namespace, address); + return inMemoryState.get(namespace, address, StateContexts.nullContext()); + } + + @Override + public <T extends State> T state( + StateNamespace namespace, StateTag<? super K, T> address, final StateContext<?> c) { + return inMemoryState.get(namespace, address, c); } /** @@ -89,9 +97,11 @@ public class InMemoryStateInternals<K> implements StateInternals<K> { */ static class InMemoryStateBinder<K> implements StateBinder<K> { private final K key; + private final StateContext<?> c; - InMemoryStateBinder(K key) { + InMemoryStateBinder(K key, StateContext<?> c) { this.key = key; + this.c = c; } @Override @@ -109,8 +119,8 @@ public class InMemoryStateInternals<K> implements StateInternals<K> { @Override public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> bindCombiningValue( - StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> - address, Coder<AccumT> accumCoder, + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, final CombineFn<InputT, AccumT, OutputT> combineFn) { return new InMemoryCombiningValue<K, InputT, AccumT, OutputT>(key, combineFn.<K>asKeyedFn()); } @@ -125,11 +135,20 @@ public class InMemoryStateInternals<K> implements StateInternals<K> { @Override public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValue( - StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> - address, Coder<AccumT> accumCoder, + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) { return new InMemoryCombiningValue<K, InputT, AccumT, OutputT>(key, combineFn); } + + @Override + public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> + bindKeyedCombiningValueWithContext( + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + KeyedCombineFnWithContext<? super K, InputT, AccumT, OutputT> combineFn) { + return bindKeyedCombiningValue(address, accumCoder, CombineFnUtil.bindContext(combineFn, c)); + } } static final class InMemoryValue<T> implements ValueState<T>, InMemoryState<InMemoryValue<T>> { http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateContext.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateContext.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateContext.java new file mode 100644 index 0000000..96387d8 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateContext.java @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.values.PCollectionView; + +/** + * Information accessible the state API. + */ +public interface StateContext<W extends BoundedWindow> { + /** + * Returns the {@code PipelineOptions} specified with the + * {@link com.google.cloud.dataflow.sdk.runners.PipelineRunner}. + */ + public abstract PipelineOptions getPipelineOptions(); + + /** + * Returns the value of the side input for the corresponding state window. + */ + public abstract <T> T sideInput(PCollectionView<T> view); + + /** + * Returns the window corresponding to the state. + */ + public abstract W window(); +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateContexts.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateContexts.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateContexts.java new file mode 100644 index 0000000..e301d43 --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateContexts.java @@ -0,0 +1,107 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util.state; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.util.WindowingInternals; +import com.google.cloud.dataflow.sdk.values.PCollectionView; + +import javax.annotation.Nullable; + +/** + * Factory that produces {@link StateContext} based on different inputs. + */ +public class StateContexts { + private static final StateContext<BoundedWindow> NULL_CONTEXT = + new StateContext<BoundedWindow>() { + @Override + public PipelineOptions getPipelineOptions() { + throw new IllegalArgumentException("cannot call getPipelineOptions() in a null context"); + } + + @Override + public <T> T sideInput(PCollectionView<T> view) { + throw new IllegalArgumentException("cannot call sideInput() in a null context"); + } + + @Override + public BoundedWindow window() { + throw new IllegalArgumentException("cannot call window() in a null context"); + }}; + + /** + * Returns a fake {@link StateContext}. + */ + @SuppressWarnings("unchecked") + public static <W extends BoundedWindow> StateContext<W> nullContext() { + return (StateContext<W>) NULL_CONTEXT; + } + + /** + * Returns a {@link StateContext} that only contains the state window. + */ + public static <W extends BoundedWindow> StateContext<W> windowOnly(final W window) { + return new StateContext<W>() { + @Override + public PipelineOptions getPipelineOptions() { + throw new IllegalArgumentException( + "cannot call getPipelineOptions() in a window only context"); + } + @Override + public <T> T sideInput(PCollectionView<T> view) { + throw new IllegalArgumentException("cannot call sideInput() in a window only context"); + } + @Override + public W window() { + return window; + } + }; + } + + /** + * Returns a {@link StateContext} from {@code PipelineOptions}, {@link WindowingInternals}, + * and the state window. + */ + public static <W extends BoundedWindow> StateContext<W> createFromComponents( + @Nullable final PipelineOptions options, + final WindowingInternals<?, ?> windowingInternals, + final W window) { + @SuppressWarnings("unchecked") + StateContext<W> typedNullContext = (StateContext<W>) NULL_CONTEXT; + if (options == null) { + return typedNullContext; + } else { + return new StateContext<W>() { + + @Override + public PipelineOptions getPipelineOptions() { + return options; + } + + @Override + public <T> T sideInput(PCollectionView<T> view) { + return windowingInternals.sideInput(view, window); + } + + @Override + public W window() { + return window; + } + }; + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateInternals.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateInternals.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateInternals.java index 78aed87..b31afb4 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateInternals.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateInternals.java @@ -45,4 +45,11 @@ public interface StateInternals<K> { * Return the state associated with {@code address} in the specified {@code namespace}. */ <T extends State> T state(StateNamespace namespace, StateTag<? super K, T> address); + + /** + * Return the state associated with {@code address} in the specified {@code namespace} + * with the {@link StateContext}. + */ + <T extends State> T state( + StateNamespace namespace, StateTag<? super K, T> address, StateContext<?> c); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTable.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTable.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTable.java index 0f1209a..edd1dae 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTable.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTable.java @@ -40,11 +40,11 @@ public abstract class StateTable<K> { /** * Gets the {@link State} in the specified {@link StateNamespace} with the specified {@link - * StateTag}, binding it using the {@link #binderForNamespace(StateNamespace)} if it is not + * StateTag}, binding it using the {@link #binderForNamespace} if it is not * already present in this {@link StateTable}. */ public <StateT extends State> StateT get( - StateNamespace namespace, StateTag<? super K, StateT> tag) { + StateNamespace namespace, StateTag<? super K, StateT> tag, StateContext<?> c) { State storage = stateTable.get(namespace, tag); if (storage != null) { @SuppressWarnings("unchecked") @@ -52,7 +52,7 @@ public abstract class StateTable<K> { return typedStorage; } - StateT typedStorage = tag.bind(binderForNamespace(namespace)); + StateT typedStorage = tag.bind(binderForNamespace(namespace, c)); stateTable.put(namespace, tag, typedStorage); return typedStorage; } @@ -85,5 +85,5 @@ public abstract class StateTable<K> { * Provide the {@code StateBinder} to use for creating {@code Storage} instances * in the specified {@code namespace}. */ - protected abstract StateBinder<K> binderForNamespace(StateNamespace namespace); + protected abstract StateBinder<K> binderForNamespace(StateNamespace namespace, StateContext<?> c); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTag.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTag.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTag.java index 2924763..c87bdb7 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTag.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTag.java @@ -20,6 +20,7 @@ import com.google.cloud.dataflow.sdk.annotations.Experimental.Kind; import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; import com.google.cloud.dataflow.sdk.transforms.GroupByKey; import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; @@ -63,6 +64,12 @@ public interface StateTag<K, StateT extends State> extends Serializable { StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, Coder<AccumT> accumCoder, KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn); + <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> + bindKeyedCombiningValueWithContext( + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + KeyedCombineFnWithContext<? super K, InputT, AccumT, OutputT> combineFn); + /** * Bind to a watermark {@link StateTag}. * http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTags.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTags.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTags.java index c1efb60..0cbaa52 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTags.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/state/StateTags.java @@ -22,6 +22,7 @@ import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.coders.CoderRegistry; import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFn; import com.google.common.base.MoreObjects; @@ -90,6 +91,24 @@ public class StateTags { } /** + * Create a state tag for values that use a {@link KeyedCombineFnWithContext} to automatically + * merge multiple {@code InputT}s into a single {@code OutputT}. The key provided to the + * {@link KeyedCombineFn} comes from the keyed {@link StateAccessor}, the context provided comes + * from the {@link StateContext}. + */ + public static <K, InputT, AccumT, OutputT> + StateTag<K, AccumulatorCombiningState<InputT, AccumT, OutputT>> + keyedCombiningValueWithContext( + String id, + Coder<AccumT> accumCoder, + KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> combineFn) { + return new KeyedCombiningValueWithContextStateTag<K, InputT, AccumT, OutputT>( + new StructuredId(id), + accumCoder, + combineFn); + } + + /** * Create a state tag for values that use a {@link CombineFn} to automatically merge * multiple {@code InputT}s into a single {@code OutputT}. * @@ -337,9 +356,67 @@ public class StateTags { } } + /** + * A state cell for values that are combined according to a {@link KeyedCombineFnWithContext}. + * + * @param <K> the type of keys + * @param <InputT> the type of input values + * @param <AccumT> type of mutable accumulator values + * @param <OutputT> type of output values + */ + private static class KeyedCombiningValueWithContextStateTag<K, InputT, AccumT, OutputT> + extends StateTagBase<K, AccumulatorCombiningState<InputT, AccumT, OutputT>> + implements SystemStateTag<K, AccumulatorCombiningState<InputT, AccumT, OutputT>> { + + private final Coder<AccumT> accumCoder; + private final KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> combineFn; + + protected KeyedCombiningValueWithContextStateTag( + StructuredId id, + Coder<AccumT> accumCoder, + KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> combineFn) { + super(id); + this.combineFn = combineFn; + this.accumCoder = accumCoder; + } + + @Override + public AccumulatorCombiningState<InputT, AccumT, OutputT> bind( + StateBinder<? extends K> visitor) { + return visitor.bindKeyedCombiningValueWithContext(this, accumCoder, combineFn); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + + if (!(obj instanceof KeyedCombiningValueWithContextStateTag)) { + return false; + } + + KeyedCombiningValueWithContextStateTag<?, ?, ?, ?> that = + (KeyedCombiningValueWithContextStateTag<?, ?, ?, ?>) obj; + return Objects.equals(this.id, that.id) + && Objects.equals(this.accumCoder, that.accumCoder); + } + + @Override + public int hashCode() { + return Objects.hash(getClass(), id, accumCoder); + } + + @Override + public StateTag<K, AccumulatorCombiningState<InputT, AccumT, OutputT>> asKind( + StateKind kind) { + return new KeyedCombiningValueWithContextStateTag<>( + id.asKind(kind), accumCoder, combineFn); + } + } /** - * A general purpose state cell for values of type {@code T}. + * A state cell for values that are combined according to a {@link KeyedCombineFn}. * * @param <K> the type of keys * @param <InputT> the type of input values @@ -355,9 +432,9 @@ public class StateTags { protected KeyedCombiningValueStateTag( StructuredId id, - Coder<AccumT> accumCoder, KeyedCombineFn<K, InputT, AccumT, OutputT> combineFn) { + Coder<AccumT> accumCoder, KeyedCombineFn<K, InputT, AccumT, OutputT> keyedCombineFn) { super(id); - this.keyedCombineFn = combineFn; + this.keyedCombineFn = keyedCombineFn; this.accumCoder = accumCoder; } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CombineFnUtilTest.java ---------------------------------------------------------------------- diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CombineFnUtilTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CombineFnUtilTest.java new file mode 100644 index 0000000..978ee50 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/CombineFnUtilTest.java @@ -0,0 +1,62 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.util; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.withSettings; + +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import com.google.cloud.dataflow.sdk.util.state.StateContexts; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayOutputStream; +import java.io.NotSerializableException; +import java.io.ObjectOutputStream; + +/** + * Unit tests for {@link CombineFnUtil}. + */ +@RunWith(JUnit4.class) +public class CombineFnUtilTest { + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + KeyedCombineFnWithContext<Integer, Integer, Integer, Integer> mockCombineFn; + + @SuppressWarnings("unchecked") + @Before + public void setUp() { + mockCombineFn = mock(KeyedCombineFnWithContext.class, withSettings().serializable()); + } + + @Test + public void testNonSerializable() throws Exception { + expectedException.expect(NotSerializableException.class); + expectedException.expectMessage( + "Cannot serialize the CombineFn resulting from CombineFnUtil.bindContext."); + + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(buffer); + oos.writeObject(CombineFnUtil.bindContext(mockCombineFn, StateContexts.nullContext())); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReduceFnRunnerTest.java ---------------------------------------------------------------------- diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReduceFnRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReduceFnRunnerTest.java index ddc33b8..c85b1ca 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReduceFnRunnerTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReduceFnRunnerTest.java @@ -25,18 +25,26 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; import com.google.cloud.dataflow.sdk.WindowMatchers; import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.CombineFnWithContext; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.Context; import com.google.cloud.dataflow.sdk.transforms.Sum; import com.google.cloud.dataflow.sdk.transforms.windowing.AfterEach; import com.google.cloud.dataflow.sdk.transforms.windowing.AfterFirst; import com.google.cloud.dataflow.sdk.transforms.windowing.AfterPane; import com.google.cloud.dataflow.sdk.transforms.windowing.AfterProcessingTime; import com.google.cloud.dataflow.sdk.transforms.windowing.AfterWatermark; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; import com.google.cloud.dataflow.sdk.transforms.windowing.IntervalWindow; import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; @@ -47,7 +55,9 @@ import com.google.cloud.dataflow.sdk.transforms.windowing.SlidingWindows; import com.google.cloud.dataflow.sdk.transforms.windowing.Trigger; import com.google.cloud.dataflow.sdk.transforms.windowing.Window.ClosingBehavior; import com.google.cloud.dataflow.sdk.util.WindowingStrategy.AccumulationMode; +import com.google.cloud.dataflow.sdk.values.PCollectionView; import com.google.cloud.dataflow.sdk.values.TimestampedValue; +import com.google.common.base.Preconditions; import org.joda.time.Duration; import org.joda.time.Instant; @@ -55,12 +65,14 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Matchers; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import java.util.Iterator; import java.util.List; /** @@ -71,8 +83,10 @@ import java.util.List; */ @RunWith(JUnit4.class) public class ReduceFnRunnerTest { - @Mock + @Mock private SideInputReader mockSideInputReader; private Trigger<IntervalWindow> mockTrigger; + private PCollectionView<Integer> mockView; + private IntervalWindow firstWindow; private static Trigger<IntervalWindow>.TriggerContext anyTriggerContext() { @@ -85,7 +99,17 @@ public class ReduceFnRunnerTest { @Before public void setUp() { MockitoAnnotations.initMocks(this); + + @SuppressWarnings("unchecked") + Trigger<IntervalWindow> mockTriggerUnchecked = + mock(Trigger.class, withSettings().serializable()); + mockTrigger = mockTriggerUnchecked; when(mockTrigger.buildTrigger()).thenReturn(mockTrigger); + + @SuppressWarnings("unchecked") + PCollectionView<Integer> mockViewUnchecked = + mock(PCollectionView.class, withSettings().serializable()); + mockView = mockViewUnchecked; firstWindow = new IntervalWindow(new Instant(0), new Instant(10)); } @@ -228,6 +252,53 @@ public class ReduceFnRunnerTest { } @Test + public void testOnElementCombiningWithContext() throws Exception { + Integer expectedValue = 5; + WindowingStrategy<?, IntervalWindow> windowingStrategy = WindowingStrategy + .of(FixedWindows.of(Duration.millis(10))) + .withTrigger(mockTrigger) + .withMode(AccumulationMode.DISCARDING_FIRED_PANES) + .withAllowedLateness(Duration.millis(100)); + + TestOptions options = PipelineOptionsFactory.as(TestOptions.class); + options.setValue(5); + + when(mockSideInputReader.contains(Matchers.<PCollectionView<Integer>>any())).thenReturn(true); + when(mockSideInputReader.get( + Matchers.<PCollectionView<Integer>>any(), any(BoundedWindow.class))).thenReturn(5); + + @SuppressWarnings({"rawtypes", "unchecked", "unused"}) + Object suppressWarningsVar = when(mockView.getWindowingStrategyInternal()) + .thenReturn((WindowingStrategy) windowingStrategy); + + SumAndVerifyContextFn combineFn = new SumAndVerifyContextFn(mockView, expectedValue); + // Test basic execution of a trigger using a non-combining window set and discarding mode. + ReduceFnTester<Integer, Integer, IntervalWindow> tester = ReduceFnTester.combining( + windowingStrategy, combineFn.<String>asKeyedFn(), + VarIntCoder.of(), options, mockSideInputReader); + + injectElement(tester, 2); + + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + injectElement(tester, 3); + + when(mockTrigger.shouldFire(anyTriggerContext())).thenReturn(true); + triggerShouldFinish(mockTrigger); + injectElement(tester, 4); + + // This element shouldn't be seen, because the trigger has finished + injectElement(tester, 6); + + assertThat( + tester.extractOutput(), + contains( + isSingleWindowedValue(equalTo(5), 2, 0, 10), + isSingleWindowedValue(equalTo(4), 4, 0, 10))); + assertTrue(tester.isMarkedFinished(firstWindow)); + tester.assertHasOnlyGlobalAndFinishedSetsFor(firstWindow); + } + + @Test public void testWatermarkHoldAndLateData() throws Exception { // Test handling of late data. Specifically, ensure the watermark hold is correct. ReduceFnTester<Integer, Iterable<Integer>, IntervalWindow> tester = @@ -873,4 +944,68 @@ public class ReduceFnRunnerTest { output.get(3), WindowMatchers.valueWithPaneInfo(PaneInfo.createPane(false, true, Timing.LATE, 3, 2))); } + + private static class SumAndVerifyContextFn extends CombineFnWithContext<Integer, int[], Integer> { + + private final PCollectionView<Integer> view; + private final int expectedValue; + + private SumAndVerifyContextFn(PCollectionView<Integer> view, int expectedValue) { + this.view = view; + this.expectedValue = expectedValue; + } + @Override + public int[] createAccumulator(Context c) { + Preconditions.checkArgument( + c.getPipelineOptions().as(TestOptions.class).getValue() == expectedValue); + Preconditions.checkArgument(c.sideInput(view) == expectedValue); + return wrap(0); + } + + @Override + public int[] addInput(int[] accumulator, Integer input, Context c) { + Preconditions.checkArgument( + c.getPipelineOptions().as(TestOptions.class).getValue() == expectedValue); + Preconditions.checkArgument(c.sideInput(view) == expectedValue); + accumulator[0] += input.intValue(); + return accumulator; + } + + @Override + public int[] mergeAccumulators(Iterable<int[]> accumulators, Context c) { + Preconditions.checkArgument( + c.getPipelineOptions().as(TestOptions.class).getValue() == expectedValue); + Preconditions.checkArgument(c.sideInput(view) == expectedValue); + Iterator<int[]> iter = accumulators.iterator(); + if (!iter.hasNext()) { + return createAccumulator(c); + } else { + int[] running = iter.next(); + while (iter.hasNext()) { + running[0] += iter.next()[0]; + } + return running; + } + } + + @Override + public Integer extractOutput(int[] accumulator, Context c) { + Preconditions.checkArgument( + c.getPipelineOptions().as(TestOptions.class).getValue() == expectedValue); + Preconditions.checkArgument(c.sideInput(view) == expectedValue); + return accumulator[0]; + } + + private int[] wrap(int value) { + return new int[] { value }; + } + } + + /** + * A {@link PipelineOptions} to test combining with context. + */ + public interface TestOptions extends PipelineOptions { + Integer getValue(); + void setValue(Integer value); + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6613031b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReduceFnTester.java ---------------------------------------------------------------------- diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReduceFnTester.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReduceFnTester.java index bade9f9..d4620a7 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReduceFnTester.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/util/ReduceFnTester.java @@ -26,9 +26,12 @@ import com.google.cloud.dataflow.sdk.coders.IterableCoder; import com.google.cloud.dataflow.sdk.coders.KvCoder; import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; import com.google.cloud.dataflow.sdk.coders.VarIntCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; import com.google.cloud.dataflow.sdk.transforms.Aggregator; import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn; import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn; +import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; import com.google.cloud.dataflow.sdk.transforms.Sum; import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; @@ -47,6 +50,7 @@ import com.google.cloud.dataflow.sdk.util.state.StateNamespaces; import com.google.cloud.dataflow.sdk.util.state.StateTag; import com.google.cloud.dataflow.sdk.util.state.WatermarkHoldState; import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollectionView; import com.google.cloud.dataflow.sdk.values.TimestampedValue; import com.google.cloud.dataflow.sdk.values.TupleTag; import com.google.common.base.Function; @@ -99,6 +103,7 @@ public class ReduceFnTester<InputT, OutputT, W extends BoundedWindow> { private final Coder<OutputT> outputCoder; private final WindowingStrategy<Object, W> objectStrategy; private final ReduceFn<String, InputT, OutputT, W> reduceFn; + private final PipelineOptions options; /** * If true, the output watermark is automatically advanced to the latest possible @@ -118,7 +123,9 @@ public class ReduceFnTester<InputT, OutputT, W extends BoundedWindow> { return new ReduceFnTester<Integer, Iterable<Integer>, W>( windowingStrategy, SystemReduceFn.<String, Integer, W>buffering(VarIntCoder.of()), - IterableCoder.of(VarIntCoder.of())); + IterableCoder.of(VarIntCoder.of()), + PipelineOptionsFactory.create(), + NullSideInputReader.empty()); } public static <W extends BoundedWindow> ReduceFnTester<Integer, Iterable<Integer>, W> @@ -147,10 +154,31 @@ public class ReduceFnTester<InputT, OutputT, W extends BoundedWindow> { return new ReduceFnTester<Integer, OutputT, W>( strategy, SystemReduceFn.<String, Integer, AccumT, OutputT, W>combining(StringUtf8Coder.of(), fn), - outputCoder); + outputCoder, + PipelineOptionsFactory.create(), + NullSideInputReader.empty()); } public static <W extends BoundedWindow, AccumT, OutputT> ReduceFnTester<Integer, OutputT, W> + combining(WindowingStrategy<?, W> strategy, + KeyedCombineFnWithContext<String, Integer, AccumT, OutputT> combineFn, + Coder<OutputT> outputCoder, + PipelineOptions options, + SideInputReader sideInputReader) throws Exception { + CoderRegistry registry = new CoderRegistry(); + registry.registerStandardCoders(); + AppliedCombineFn<String, Integer, AccumT, OutputT> fn = + AppliedCombineFn.<String, Integer, AccumT, OutputT>withInputCoder( + combineFn, registry, KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of())); + + return new ReduceFnTester<Integer, OutputT, W>( + strategy, + SystemReduceFn.<String, Integer, AccumT, OutputT, W>combining(StringUtf8Coder.of(), fn), + outputCoder, + options, + sideInputReader); + } + public static <W extends BoundedWindow, AccumT, OutputT> ReduceFnTester<Integer, OutputT, W> combining(WindowFn<?, W> windowFn, Trigger<W> trigger, AccumulationMode mode, KeyedCombineFn<String, Integer, AccumT, OutputT> combineFn, Coder<OutputT> outputCoder, Duration allowedDataLateness) throws Exception { @@ -163,17 +191,19 @@ public class ReduceFnTester<InputT, OutputT, W extends BoundedWindow> { } private ReduceFnTester(WindowingStrategy<?, W> wildcardStrategy, - ReduceFn<String, InputT, OutputT, W> reduceFn, Coder<OutputT> outputCoder) throws Exception { + ReduceFn<String, InputT, OutputT, W> reduceFn, Coder<OutputT> outputCoder, + PipelineOptions options, SideInputReader sideInputReader) throws Exception { @SuppressWarnings("unchecked") WindowingStrategy<Object, W> objectStrategy = (WindowingStrategy<Object, W>) wildcardStrategy; this.objectStrategy = objectStrategy; this.reduceFn = reduceFn; this.windowFn = objectStrategy.getWindowFn(); - this.windowingInternals = new TestWindowingInternals(); + this.windowingInternals = new TestWindowingInternals(sideInputReader); this.outputCoder = outputCoder; this.autoAdvanceOutputWatermark = true; - executableTrigger = wildcardStrategy.getTrigger(); + this.executableTrigger = wildcardStrategy.getTrigger(); + this.options = options; } public void setAutoAdvanceOutputWatermark(boolean autoAdvanceOutputWatermark) { @@ -193,7 +223,8 @@ public class ReduceFnTester<InputT, OutputT, W extends BoundedWindow> { timerInternals, windowingInternals, droppedDueToClosedWindow, - reduceFn); + reduceFn, + options); } public ExecutableTrigger<W> getTrigger() { @@ -432,6 +463,11 @@ public class ReduceFnTester<InputT, OutputT, W extends BoundedWindow> { */ private class TestWindowingInternals implements WindowingInternals<InputT, KV<String, OutputT>> { private List<WindowedValue<KV<String, OutputT>>> outputs = new ArrayList<>(); + private SideInputReader sideInputReader; + + private TestWindowingInternals(SideInputReader sideInputReader) { + this.sideInputReader = sideInputReader; + } @Override public void outputWindowedValue(KV<String, OutputT> output, Instant timestamp, @@ -476,6 +512,16 @@ public class ReduceFnTester<InputT, OutputT, W extends BoundedWindow> { (TestInMemoryStateInternals) stateInternals; return untypedStateInternals; } + + @Override + public <T> T sideInput(PCollectionView<T> view, BoundedWindow mainInputWindow) { + if (!sideInputReader.contains(view)) { + throw new IllegalArgumentException("calling sideInput() with unknown view"); + } + BoundedWindow sideInputWindow = + view.getWindowingStrategyInternal().getWindowFn().getSideInputWindow(mainInputWindow); + return sideInputReader.get(view, sideInputWindow); + } } private static class TestAssignContext<W extends BoundedWindow>
