BEAM-784 Checkpointing for StateInternals
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/1db4ff63 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/1db4ff63 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/1db4ff63 Branch: refs/heads/master Commit: 1db4ff631736172882976c33316bc089d58483af Parents: 0a1b278 Author: Thomas Weise <[email protected]> Authored: Tue Oct 25 08:32:23 2016 -0700 Committer: Thomas Weise <[email protected]> Committed: Tue Oct 25 10:06:12 2016 -0700 ---------------------------------------------------------------------- .../apex/translators/GroupByKeyTranslator.java | 3 +- .../translators/ParDoBoundMultiTranslator.java | 4 +- .../apex/translators/ParDoBoundTranslator.java | 4 +- .../apex/translators/TranslationContext.java | 10 + .../functions/ApexGroupByKeyOperator.java | 12 +- .../functions/ApexParDoOperator.java | 11 +- .../translators/utils/ApexStateInternals.java | 438 +++++++++++++++++++ .../translators/ParDoBoundTranslatorTest.java | 62 ++- .../utils/ApexStateInternalsTest.java | 361 +++++++++++++++ 9 files changed, 883 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1db4ff63/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/GroupByKeyTranslator.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/GroupByKeyTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/GroupByKeyTranslator.java index d3e7d2d..cb78579 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/GroupByKeyTranslator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/GroupByKeyTranslator.java @@ -33,7 +33,8 @@ public class GroupByKeyTranslator<K, V> implements TransformTranslator<GroupByKe public void translate(GroupByKey<K, V> transform, TranslationContext context) { PCollection<KV<K, V>> input = context.getInput(); ApexGroupByKeyOperator<K, V> group = new ApexGroupByKeyOperator<>(context.getPipelineOptions(), - input); + input, context.<K>stateInternalsFactory() + ); context.addOperator(group, group.output); context.addStream(input, group.input); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1db4ff63/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundMultiTranslator.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundMultiTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundMultiTranslator.java index 13f07c1..2678869 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundMultiTranslator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundMultiTranslator.java @@ -64,7 +64,9 @@ public class ParDoBoundMultiTranslator<InputT, OutputT> ApexParDoOperator<InputT, OutputT> operator = new ApexParDoOperator<>( context.getPipelineOptions(), doFn, transform.getMainOutputTag(), transform.getSideOutputTags().getAll(), - context.<PCollection<?>>getInput().getWindowingStrategy(), sideInputs, wvInputCoder); + context.<PCollection<?>>getInput().getWindowingStrategy(), sideInputs, wvInputCoder, + context.<Void>stateInternalsFactory() + ); Map<TupleTag<?>, PCollection<?>> outputs = output.getAll(); Map<PCollection<?>, OutputPort<?>> ports = Maps.newHashMapWithExpectedSize(outputs.size()); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1db4ff63/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java index bd7115e..92567a6 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java @@ -52,7 +52,9 @@ public class ParDoBoundTranslator<InputT, OutputT> implements ApexParDoOperator<InputT, OutputT> operator = new ApexParDoOperator<>( context.getPipelineOptions(), doFn, new TupleTag<OutputT>(), TupleTagList.empty().getAll() /*sideOutputTags*/, - output.getWindowingStrategy(), sideInputs, wvInputCoder); + output.getWindowingStrategy(), sideInputs, wvInputCoder, + context.<Void>stateInternalsFactory() + ); context.addOperator(operator, operator.output); context.addStream(context.getInput(), operator.input); if (!sideInputs.isEmpty()) { http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1db4ff63/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/TranslationContext.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/TranslationContext.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/TranslationContext.java index ddacc29..07c6494 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/TranslationContext.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/TranslationContext.java @@ -31,6 +31,7 @@ import java.util.List; import java.util.Map; import org.apache.beam.runners.apex.ApexPipelineOptions; +import org.apache.beam.runners.apex.translators.utils.ApexStateInternals; import org.apache.beam.runners.apex.translators.utils.ApexStreamTuple; import org.apache.beam.runners.apex.translators.utils.CoderAdapterStreamCodec; import org.apache.beam.sdk.coders.Coder; @@ -38,6 +39,7 @@ import org.apache.beam.sdk.runners.TransformTreeNode; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; +import org.apache.beam.sdk.util.state.StateInternalsFactory; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PInput; @@ -165,4 +167,12 @@ public class TranslationContext { } } + /** + * Return the {@link StateInternalsFactory} for the pipeline translation. + * @return + */ + public <K> StateInternalsFactory<K> stateInternalsFactory() { + return new ApexStateInternals.ApexStateInternalsFactory(); + } + } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1db4ff63/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexGroupByKeyOperator.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexGroupByKeyOperator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexGroupByKeyOperator.java index 845618d..d69aeab 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexGroupByKeyOperator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexGroupByKeyOperator.java @@ -64,7 +64,6 @@ import org.apache.beam.sdk.util.TimerInternals; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowingInternals; import org.apache.beam.sdk.util.WindowingStrategy; -import org.apache.beam.sdk.util.state.InMemoryStateInternals; import org.apache.beam.sdk.util.state.StateInternals; import org.apache.beam.sdk.util.state.StateInternalsFactory; import org.apache.beam.sdk.values.KV; @@ -97,8 +96,8 @@ public class ApexGroupByKeyOperator<K, V> implements Operator { @Bind(JavaSerializer.class) private final SerializablePipelineOptions serializedOptions; @Bind(JavaSerializer.class) -// TODO: InMemoryStateInternals not serializable - private transient Map<ByteBuffer, StateInternals<K>> perKeyStateInternals = new HashMap<>(); + private final StateInternalsFactory<K> stateInternalsFactory; + private Map<ByteBuffer, StateInternals<K>> perKeyStateInternals = new HashMap<>(); private Map<ByteBuffer, Set<TimerInternals.TimerData>> activeTimers = new HashMap<>(); private transient ProcessContext context; @@ -137,17 +136,20 @@ public class ApexGroupByKeyOperator<K, V> implements Operator { output = new DefaultOutputPort<>(); @SuppressWarnings("unchecked") - public ApexGroupByKeyOperator(ApexPipelineOptions pipelineOptions, PCollection<KV<K, V>> input) { + public ApexGroupByKeyOperator(ApexPipelineOptions pipelineOptions, PCollection<KV<K, V>> input, + StateInternalsFactory<K> stateInternalsFactory) { checkNotNull(pipelineOptions); this.serializedOptions = new SerializablePipelineOptions(pipelineOptions); this.windowingStrategy = (WindowingStrategy<V, BoundedWindow>) input.getWindowingStrategy(); this.keyCoder = ((KvCoder<K, V>) input.getCoder()).getKeyCoder(); this.valueCoder = ((KvCoder<K, V>) input.getCoder()).getValueCoder(); + this.stateInternalsFactory = stateInternalsFactory; } @SuppressWarnings("unused") // for Kryo private ApexGroupByKeyOperator() { this.serializedOptions = null; + this.stateInternalsFactory = null; } @Override @@ -230,7 +232,7 @@ public class ApexGroupByKeyOperator<K, V> implements Operator { } StateInternals<K> stateInternals = perKeyStateInternals.get(keyBytes); if (stateInternals == null) { - stateInternals = InMemoryStateInternals.forKey(key); + stateInternals = stateInternalsFactory.stateInternalsForKey(key); perKeyStateInternals.put(keyBytes, stateInternals); } return stateInternals; http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1db4ff63/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java index 9e8f3dc..43384d6 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java @@ -57,8 +57,8 @@ import org.apache.beam.sdk.util.SideInputReader; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowingStrategy; -import org.apache.beam.sdk.util.state.InMemoryStateInternals; import org.apache.beam.sdk.util.state.StateInternals; +import org.apache.beam.sdk.util.state.StateInternalsFactory; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.slf4j.Logger; @@ -84,9 +84,7 @@ public class ApexParDoOperator<InputT, OutputT> extends BaseOperator implements @Bind(JavaSerializer.class) private final List<PCollectionView<?>> sideInputs; -// TODO: not Kryo serializable, integrate codec - private transient StateInternals<Void> sideInputStateInternals = InMemoryStateInternals - .forKey(null); + private final StateInternals<Void> sideInputStateInternals; private final ValueAndCoderKryoSerializable<List<WindowedValue<InputT>>> pushedBack; private LongMin pushedBackWatermark = new LongMin(); private long currentInputWatermark = Long.MIN_VALUE; @@ -104,7 +102,8 @@ public class ApexParDoOperator<InputT, OutputT> extends BaseOperator implements List<TupleTag<?>> sideOutputTags, WindowingStrategy<?, ?> windowingStrategy, List<PCollectionView<?>> sideInputs, - Coder<WindowedValue<InputT>> inputCoder + Coder<WindowedValue<InputT>> inputCoder, + StateInternalsFactory<Void> stateInternalsFactory ) { this.pipelineOptions = new SerializablePipelineOptions(pipelineOptions); this.doFn = doFn; @@ -112,6 +111,7 @@ public class ApexParDoOperator<InputT, OutputT> extends BaseOperator implements this.sideOutputTags = sideOutputTags; this.windowingStrategy = windowingStrategy; this.sideInputs = sideInputs; + this.sideInputStateInternals = stateInternalsFactory.stateInternalsForKey(null); if (sideOutputTags.size() > sideOutputPorts.length) { String msg = String.format("Too many side outputs (currently only supporting %s).", @@ -134,6 +134,7 @@ public class ApexParDoOperator<InputT, OutputT> extends BaseOperator implements this.windowingStrategy = null; this.sideInputs = null; this.pushedBack = null; + this.sideInputStateInternals = null; } public final transient DefaultInputPort<ApexStreamTuple<WindowedValue<InputT>>> input = http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1db4ff63/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/utils/ApexStateInternals.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/utils/ApexStateInternals.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/utils/ApexStateInternals.java new file mode 100644 index 0000000..edc1220 --- /dev/null +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/utils/ApexStateInternals.java @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.apex.translators.utils; + +import com.esotericsoftware.kryo.DefaultSerializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.serializers.JavaSerializer; +import com.google.common.collect.HashBasedTable; +import com.google.common.collect.Table; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.Coder.Context; +import org.apache.beam.sdk.coders.InstantCoder; +import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.Combine.KeyedCombineFn; +import org.apache.beam.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.OutputTimeFn; +import org.apache.beam.sdk.util.CombineFnUtil; +import org.apache.beam.sdk.util.state.AccumulatorCombiningState; +import org.apache.beam.sdk.util.state.BagState; +import org.apache.beam.sdk.util.state.ReadableState; +import org.apache.beam.sdk.util.state.State; +import org.apache.beam.sdk.util.state.StateContext; +import org.apache.beam.sdk.util.state.StateContexts; +import org.apache.beam.sdk.util.state.StateInternals; +import org.apache.beam.sdk.util.state.StateInternalsFactory; +import org.apache.beam.sdk.util.state.StateNamespace; +import org.apache.beam.sdk.util.state.StateTag; +import org.apache.beam.sdk.util.state.StateTag.StateBinder; +import org.apache.beam.sdk.util.state.ValueState; +import org.apache.beam.sdk.util.state.WatermarkHoldState; +import org.joda.time.Instant; + +/** + * Implementation of {@link StateInternals} that can be serialized and + * checkpointed with the operator. Suitable for small states, in the future this + * should be based on the incremental state saving components in the Apex + * library. + */ +@DefaultSerializer(JavaSerializer.class) +public class ApexStateInternals<K> implements StateInternals<K>, Serializable { + private static final long serialVersionUID = 1L; + public static <K> ApexStateInternals<K> forKey(K key) { + return new ApexStateInternals<>(key); + } + + private final K key; + + protected ApexStateInternals(K key) { + this.key = key; + } + + @Override + public K getKey() { + return key; + } + + /** + * Serializable state for internals (namespace -> state tag -> coded value). + */ + private final Table<String, String, byte[]> stateTable = HashBasedTable.create(); + + @Override + public <T extends State> T state(StateNamespace namespace, StateTag<? super K, T> address) { + return state(namespace, address, StateContexts.nullContext()); + } + + @Override + public <T extends State> T state( + StateNamespace namespace, StateTag<? super K, T> address, final StateContext<?> c) { + return address.bind(new ApexStateBinder(key, namespace, address, c)); + } + + /** + * A {@link StateBinder} that returns {@link State} wrappers for serialized state. + */ + private class ApexStateBinder implements StateBinder<K> { + private final K key; + private final StateNamespace namespace; + private final StateContext<?> c; + + private ApexStateBinder(K key, StateNamespace namespace, StateTag<? super K, ?> address, + StateContext<?> c) { + this.key = key; + this.namespace = namespace; + this.c = c; + } + + @Override + public <T> ValueState<T> bindValue( + StateTag<? super K, ValueState<T>> address, Coder<T> coder) { + return new ApexValueState<T>(namespace, address, coder); + } + + @Override + public <T> BagState<T> bindBag( + final StateTag<? super K, BagState<T>> address, Coder<T> elemCoder) { + return new ApexBagState<T>(namespace, address, elemCoder); + } + + @Override + public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> + bindCombiningValue( + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + final CombineFn<InputT, AccumT, OutputT> combineFn) { + return new ApexAccumulatorCombiningState<K, InputT, AccumT, OutputT>( + namespace, + address, + accumCoder, + key, + combineFn.<K>asKeyedFn() + ); + } + + @Override + public <W extends BoundedWindow> WatermarkHoldState<W> bindWatermark( + StateTag<? super K, WatermarkHoldState<W>> address, + OutputTimeFn<? super W> outputTimeFn) { + return new ApexWatermarkHoldState<W>(namespace, address, outputTimeFn); + } + + @Override + public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> + bindKeyedCombiningValue( + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) { + return new ApexAccumulatorCombiningState<K, InputT, AccumT, OutputT>( + namespace, + address, + accumCoder, + key, combineFn); + } + + @Override + public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> + bindKeyedCombiningValueWithContext( + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + KeyedCombineFnWithContext<? super K, InputT, AccumT, OutputT> combineFn) { + return bindKeyedCombiningValue(address, accumCoder, CombineFnUtil.bindContext(combineFn, c)); + } + } + + private class AbstractState<T> { + protected final StateNamespace namespace; + protected final StateTag<?, ? extends State> address; + protected final Coder<T> coder; + + private AbstractState( + StateNamespace namespace, + StateTag<?, ? extends State> address, + Coder<T> coder) { + this.namespace = namespace; + this.address = address; + this.coder = coder; + } + + protected T readValue() { + T value = null; + byte[] buf = stateTable.get(namespace.stringKey(), address.getId()); + if (buf != null) { + // TODO: reuse input + Input input = new Input(buf); + try { + return coder.decode(input, Context.OUTER); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + return value; + } + + public void writeValue(T input) { + ByteArrayOutputStream output = new ByteArrayOutputStream(); + try { + coder.encode(input, output, Context.OUTER); + stateTable.put(namespace.stringKey(), address.getId(), output.toByteArray()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public void clear() { + stateTable.remove(namespace.stringKey(), address.getId()); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + @SuppressWarnings("unchecked") + AbstractState<?> that = (AbstractState<?>) o; + return namespace.equals(that.namespace) && address.equals(that.address); + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + } + + private class ApexValueState<T> extends AbstractState<T> implements ValueState<T> { + + private ApexValueState( + StateNamespace namespace, + StateTag<?, ValueState<T>> address, + Coder<T> coder) { + super(namespace, address, coder); + } + + @Override + public ApexValueState<T> readLater() { + return this; + } + + @Override + public T read() { + return readValue(); + } + + @Override + public void write(T input) { + writeValue(input); + } + } + + private final class ApexWatermarkHoldState<W extends BoundedWindow> + extends AbstractState<Instant> implements WatermarkHoldState<W> { + + private final OutputTimeFn<? super W> outputTimeFn; + + public ApexWatermarkHoldState( + StateNamespace namespace, + StateTag<?, WatermarkHoldState<W>> address, + OutputTimeFn<? super W> outputTimeFn) { + super(namespace, address, InstantCoder.of()); + this.outputTimeFn = outputTimeFn; + } + + @Override + public ApexWatermarkHoldState<W> readLater() { + return this; + } + + @Override + public Instant read() { + return readValue(); + } + + @Override + public void add(Instant outputTime) { + Instant combined = read(); + combined = (combined == null) ? outputTime : outputTimeFn.combine(combined, outputTime); + writeValue(combined); + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public ReadableState<Boolean> readLater() { + return this; + } + @Override + public Boolean read() { + return stateTable.get(namespace.stringKey(), address.getId()) == null; + } + }; + } + + @Override + public OutputTimeFn<? super W> getOutputTimeFn() { + return outputTimeFn; + } + + } + + private final class ApexAccumulatorCombiningState<K, InputT, AccumT, OutputT> + extends AbstractState<AccumT> + implements AccumulatorCombiningState<InputT, AccumT, OutputT> { + private final K key; + private final KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn; + + private ApexAccumulatorCombiningState(StateNamespace namespace, + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> coder, + K key, KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) { + super(namespace, address, coder); + this.key = key; + this.combineFn = combineFn; + } + + @Override + public ApexAccumulatorCombiningState<K, InputT, AccumT, OutputT> readLater() { + return this; + } + + @Override + public OutputT read() { + return combineFn.extractOutput(key, getAccum()); + } + + @Override + public void add(InputT input) { + AccumT accum = getAccum(); + combineFn.addInput(key, accum, input); + writeValue(accum); + } + + @Override + public AccumT getAccum() { + AccumT accum = readValue(); + if (accum == null) { + accum = combineFn.createAccumulator(key); + } + return accum; + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public ReadableState<Boolean> readLater() { + return this; + } + @Override + public Boolean read() { + return stateTable.get(namespace.stringKey(), address.getId()) == null; + } + }; + } + + @Override + public void addAccum(AccumT accum) { + accum = combineFn.mergeAccumulators(key, Arrays.asList(getAccum(), accum)); + writeValue(accum); + } + + @Override + public AccumT mergeAccumulators(Iterable<AccumT> accumulators) { + return combineFn.mergeAccumulators(key, accumulators); + } + + } + + private final class ApexBagState<T> extends AbstractState<List<T>> implements BagState<T> { + private ApexBagState( + StateNamespace namespace, + StateTag<?, BagState<T>> address, + Coder<T> coder) { + super(namespace, address, ListCoder.of(coder)); + } + + @Override + public ApexBagState<T> readLater() { + return this; + } + + @Override + public List<T> read() { + List<T> value = super.readValue(); + if (value == null) { + value = new ArrayList<T>(); + } + return value; + } + + @Override + public void add(T input) { + List<T> value = read(); + value.add(input); + writeValue(value); + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public ReadableState<Boolean> readLater() { + return this; + } + + @Override + public Boolean read() { + return stateTable.get(namespace.stringKey(), address.getId()) == null; + } + }; + } + } + + /** + * Factory for {@link ApexStateInternals}. + * + * @param <K> + */ + public static class ApexStateInternalsFactory<K> + implements StateInternalsFactory<K>, Serializable { + private static final long serialVersionUID = 1L; + + @Override + public StateInternals<K> stateInternalsForKey(K key) { + return ApexStateInternals.forKey(key); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1db4ff63/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslatorTest.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslatorTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslatorTest.java index ad22acd..9ea4233 100644 --- a/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslatorTest.java +++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslatorTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import com.datatorrent.api.DAG; +import com.datatorrent.api.Sink; import com.datatorrent.lib.util.KryoCloneUtils; import com.google.common.collect.Lists; import com.google.common.collect.Sets; @@ -37,6 +38,7 @@ import org.apache.beam.runners.apex.ApexRunnerResult; import org.apache.beam.runners.apex.TestApexRunner; import org.apache.beam.runners.apex.translators.functions.ApexParDoOperator; import org.apache.beam.runners.apex.translators.io.ApexReadUnboundedInputOperator; +import org.apache.beam.runners.apex.translators.utils.ApexStateInternals; import org.apache.beam.runners.apex.translators.utils.ApexStreamTuple; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; @@ -107,14 +109,22 @@ public class ParDoBoundTranslatorTest { @SuppressWarnings("serial") private static class Add extends OldDoFn<Integer, Integer> { - private final Integer number; + private Integer number; + private PCollectionView<Integer> sideInputView; - public Add(Integer number) { + private Add(Integer number) { this.number = number; } + private Add(PCollectionView<Integer> sideInputView) { + this.sideInputView = sideInputView; + } + @Override public void processElement(ProcessContext c) throws Exception { + if (sideInputView != null) { + number = c.sideInput(sideInputView); + } c.output(c.element() + number); } } @@ -190,17 +200,51 @@ public class ParDoBoundTranslatorTest { .apply(Sum.integersGlobally().asSingletonView()); ApexParDoOperator<Integer, Integer> operator = new ApexParDoOperator<>(options, - new Add(0), new TupleTag<Integer>(), TupleTagList.empty().getAll(), + new Add(singletonView), new TupleTag<Integer>(), TupleTagList.empty().getAll(), WindowingStrategy.globalDefault(), Collections.<PCollectionView<?>>singletonList(singletonView), - coder); + coder, + new ApexStateInternals.ApexStateInternalsFactory<Void>() + ); operator.setup(null); operator.beginWindow(0); - WindowedValue<Integer> wv = WindowedValue.valueInGlobalWindow(0); - operator.input.process(ApexStreamTuple.DataTuple.of(wv)); - operator.input.process(ApexStreamTuple.WatermarkTuple.<WindowedValue<Integer>>of(0)); - operator.endWindow(); - Assert.assertNotNull("Serialization", KryoCloneUtils.cloneObject(operator)); + WindowedValue<Integer> wv1 = WindowedValue.valueInGlobalWindow(1); + WindowedValue<Iterable<?>> sideInput = WindowedValue.<Iterable<?>>valueInGlobalWindow( + Lists.<Integer>newArrayList(22)); + operator.input.process(ApexStreamTuple.DataTuple.of(wv1)); // pushed back input + + final List<Object> results = Lists.newArrayList(); + Sink<Object> sink = new Sink<Object>() { + @Override + public void put(Object tuple) { + results.add(tuple); + } + @Override + public int getCount(boolean reset) { + return 0; + } + }; + // verify pushed back input checkpointing + Assert.assertNotNull("Serialization", operator = KryoCloneUtils.cloneObject(operator)); + operator.output.setSink(sink); + operator.setup(null); + operator.beginWindow(1); + WindowedValue<Integer> wv2 = WindowedValue.valueInGlobalWindow(2); + operator.sideInput1.process(ApexStreamTuple.DataTuple.of(sideInput)); + Assert.assertEquals("number outputs", 1, results.size()); + Assert.assertEquals("result", WindowedValue.valueInGlobalWindow(23), + ((ApexStreamTuple.DataTuple) results.get(0)).getValue()); + + // verify side input checkpointing + results.clear(); + Assert.assertNotNull("Serialization", operator = KryoCloneUtils.cloneObject(operator)); + operator.output.setSink(sink); + operator.setup(null); + operator.beginWindow(2); + operator.input.process(ApexStreamTuple.DataTuple.of(wv2)); + Assert.assertEquals("number outputs", 1, results.size()); + Assert.assertEquals("result", WindowedValue.valueInGlobalWindow(24), + ((ApexStreamTuple.DataTuple) results.get(0)).getValue()); } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1db4ff63/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/utils/ApexStateInternalsTest.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/utils/ApexStateInternalsTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/utils/ApexStateInternalsTest.java new file mode 100644 index 0000000..055d98c --- /dev/null +++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/utils/ApexStateInternalsTest.java @@ -0,0 +1,361 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.apex.translators.utils; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; + +import com.datatorrent.lib.util.KryoCloneUtils; + +import java.util.Arrays; + +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.OutputTimeFns; +import org.apache.beam.sdk.util.state.AccumulatorCombiningState; +import org.apache.beam.sdk.util.state.BagState; +import org.apache.beam.sdk.util.state.CombiningState; +import org.apache.beam.sdk.util.state.ReadableState; +import org.apache.beam.sdk.util.state.StateMerging; +import org.apache.beam.sdk.util.state.StateNamespace; +import org.apache.beam.sdk.util.state.StateNamespaceForTest; +import org.apache.beam.sdk.util.state.StateTag; +import org.apache.beam.sdk.util.state.StateTags; +import org.apache.beam.sdk.util.state.ValueState; +import org.apache.beam.sdk.util.state.WatermarkHoldState; +import org.hamcrest.Matchers; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; + +/** + * Tests for {@link ApexStateInternals}. This is based on the tests for + * {@code InMemoryStateInternals}. + */ +public class ApexStateInternalsTest { + private static final BoundedWindow WINDOW_1 = new IntervalWindow(new Instant(0), new Instant(10)); + private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); + private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2"); + private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3"); + + private static final StateTag<Object, ValueState<String>> STRING_VALUE_ADDR = + StateTags.value("stringValue", StringUtf8Coder.of()); + private static final StateTag<Object, AccumulatorCombiningState<Integer, int[], Integer>> + SUM_INTEGER_ADDR = StateTags.combiningValueFromInputInternal( + "sumInteger", VarIntCoder.of(), new Sum.SumIntegerFn()); + private static final StateTag<Object, BagState<String>> STRING_BAG_ADDR = + StateTags.bag("stringBag", StringUtf8Coder.of()); + private static final StateTag<Object, WatermarkHoldState<BoundedWindow>> + WATERMARK_EARLIEST_ADDR = + StateTags.watermarkStateInternal("watermark", OutputTimeFns.outputAtEarliestInputTimestamp()); + private static final StateTag<Object, WatermarkHoldState<BoundedWindow>> + WATERMARK_LATEST_ADDR = + StateTags.watermarkStateInternal("watermark", OutputTimeFns.outputAtLatestInputTimestamp()); + private static final StateTag<Object, WatermarkHoldState<BoundedWindow>> WATERMARK_EOW_ADDR = + StateTags.watermarkStateInternal("watermark", OutputTimeFns.outputAtEndOfWindow()); + + private ApexStateInternals<String> underTest; + + @Before + public void initStateInternals() { + underTest = new ApexStateInternals<>(null); + } + + @Test + public void testBag() throws Exception { + BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + + assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR))); + + assertThat(value.read(), Matchers.emptyIterable()); + value.add("hello"); + assertThat(value.read(), Matchers.containsInAnyOrder("hello")); + + value.add("world"); + assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world")); + + value.clear(); + assertThat(value.read(), Matchers.emptyIterable()); + assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value); + + } + + @Test + public void testBagIsEmpty() throws Exception { + BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState<Boolean> readFuture = value.isEmpty(); + value.add("hello"); + assertThat(readFuture.read(), Matchers.is(false)); + + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } + + @Test + public void testMergeBagIntoSource() throws Exception { + BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); + + bag1.add("Hello"); + bag2.add("World"); + bag1.add("!"); + + StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1); + + // Reading the merged bag gets both the contents + assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); + assertThat(bag2.read(), Matchers.emptyIterable()); + } + + @Test + public void testMergeBagIntoNewNamespace() throws Exception { + BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); + BagState<String> bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR); + + bag1.add("Hello"); + bag2.add("World"); + bag1.add("!"); + + StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3); + + // Reading the merged bag gets both the contents + assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); + assertThat(bag1.read(), Matchers.emptyIterable()); + assertThat(bag2.read(), Matchers.emptyIterable()); + } + + @Test + public void testCombiningValue() throws Exception { + CombiningState<Integer, Integer> value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + + // State instances are cached, but depend on the namespace. + assertEquals(value, underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR))); + + assertThat(value.read(), Matchers.equalTo(0)); + value.add(2); + assertThat(value.read(), Matchers.equalTo(2)); + + value.add(3); + assertThat(value.read(), Matchers.equalTo(5)); + + value.clear(); + assertThat(value.read(), Matchers.equalTo(0)); + assertEquals(underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR), value); + } + + @Test + public void testCombiningIsEmpty() throws Exception { + CombiningState<Integer, Integer> value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState<Boolean> readFuture = value.isEmpty(); + value.add(5); + assertThat(readFuture.read(), Matchers.is(false)); + + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } + + @Test + public void testMergeCombiningValueIntoSource() throws Exception { + AccumulatorCombiningState<Integer, int[], Integer> value1 = + underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + AccumulatorCombiningState<Integer, int[], Integer> value2 = + underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); + + value1.add(5); + value2.add(10); + value1.add(6); + + assertThat(value1.read(), Matchers.equalTo(11)); + assertThat(value2.read(), Matchers.equalTo(10)); + + // Merging clears the old values and updates the result value. + StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1); + + assertThat(value1.read(), Matchers.equalTo(21)); + assertThat(value2.read(), Matchers.equalTo(0)); + } + + @Test + public void testMergeCombiningValueIntoNewNamespace() throws Exception { + AccumulatorCombiningState<Integer, int[], Integer> value1 = + underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + AccumulatorCombiningState<Integer, int[], Integer> value2 = + underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); + AccumulatorCombiningState<Integer, int[], Integer> value3 = + underTest.state(NAMESPACE_3, SUM_INTEGER_ADDR); + + value1.add(5); + value2.add(10); + value1.add(6); + + StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3); + + // Merging clears the old values and updates the result value. + assertThat(value1.read(), Matchers.equalTo(0)); + assertThat(value2.read(), Matchers.equalTo(0)); + assertThat(value3.read(), Matchers.equalTo(21)); + } + + @Test + public void testWatermarkEarliestState() throws Exception { + WatermarkHoldState<BoundedWindow> value = + underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); + + // State instances are cached, but depend on the namespace. + assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR))); + + assertThat(value.read(), Matchers.nullValue()); + value.add(new Instant(2000)); + assertThat(value.read(), Matchers.equalTo(new Instant(2000))); + + value.add(new Instant(3000)); + assertThat(value.read(), Matchers.equalTo(new Instant(2000))); + + value.add(new Instant(1000)); + assertThat(value.read(), Matchers.equalTo(new Instant(1000))); + + value.clear(); + assertThat(value.read(), Matchers.equalTo(null)); + assertEquals(underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR), value); + } + + @Test + public void testWatermarkLatestState() throws Exception { + WatermarkHoldState<BoundedWindow> value = + underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR); + + // State instances are cached, but depend on the namespace. + assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR))); + + assertThat(value.read(), Matchers.nullValue()); + value.add(new Instant(2000)); + assertThat(value.read(), Matchers.equalTo(new Instant(2000))); + + value.add(new Instant(3000)); + assertThat(value.read(), Matchers.equalTo(new Instant(3000))); + + value.add(new Instant(1000)); + assertThat(value.read(), Matchers.equalTo(new Instant(3000))); + + value.clear(); + assertThat(value.read(), Matchers.equalTo(null)); + assertEquals(underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR), value); + } + + @Test + public void testWatermarkEndOfWindowState() throws Exception { + WatermarkHoldState<BoundedWindow> value = underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR); + + // State instances are cached, but depend on the namespace. + assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EOW_ADDR))); + + assertThat(value.read(), Matchers.nullValue()); + value.add(new Instant(2000)); + assertThat(value.read(), Matchers.equalTo(new Instant(2000))); + + value.clear(); + assertThat(value.read(), Matchers.equalTo(null)); + assertEquals(underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR), value); + } + + @Test + public void testWatermarkStateIsEmpty() throws Exception { + WatermarkHoldState<BoundedWindow> value = + underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); + + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState<Boolean> readFuture = value.isEmpty(); + value.add(new Instant(1000)); + assertThat(readFuture.read(), Matchers.is(false)); + + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } + + @Test + public void testMergeEarliestWatermarkIntoSource() throws Exception { + WatermarkHoldState<BoundedWindow> value1 = + underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); + WatermarkHoldState<BoundedWindow> value2 = + underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR); + + value1.add(new Instant(3000)); + value2.add(new Instant(5000)); + value1.add(new Instant(4000)); + value2.add(new Instant(2000)); + + // Merging clears the old values and updates the merged value. + StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value1, WINDOW_1); + + assertThat(value1.read(), Matchers.equalTo(new Instant(2000))); + assertThat(value2.read(), Matchers.equalTo(null)); + } + + @Test + public void testMergeLatestWatermarkIntoSource() throws Exception { + WatermarkHoldState<BoundedWindow> value1 = + underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR); + WatermarkHoldState<BoundedWindow> value2 = + underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR); + WatermarkHoldState<BoundedWindow> value3 = + underTest.state(NAMESPACE_3, WATERMARK_LATEST_ADDR); + + value1.add(new Instant(3000)); + value2.add(new Instant(5000)); + value1.add(new Instant(4000)); + value2.add(new Instant(2000)); + + // Merging clears the old values and updates the result value. + StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value3, WINDOW_1); + + // Merging clears the old values and updates the result value. + assertThat(value3.read(), Matchers.equalTo(new Instant(5000))); + assertThat(value1.read(), Matchers.equalTo(null)); + assertThat(value2.read(), Matchers.equalTo(null)); + } + + @Test + public void testSerialization() throws Exception { + ApexStateInternals<String> original = new ApexStateInternals<String>(null); + ValueState<String> value = original.state(NAMESPACE_1, STRING_VALUE_ADDR); + assertEquals(original.state(NAMESPACE_1, STRING_VALUE_ADDR), value); + value.write("hello"); + + ApexStateInternals<String> cloned; + assertNotNull("Serialization", cloned = KryoCloneUtils.cloneObject(original)); + ValueState<String> clonedValue = cloned.state(NAMESPACE_1, STRING_VALUE_ADDR); + assertThat(clonedValue.read(), Matchers.equalTo("hello")); + assertEquals(cloned.state(NAMESPACE_1, STRING_VALUE_ADDR), value); + } + +}
