Repository: beam Updated Branches: refs/heads/master f03f6ac19 -> 137fee95b
http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java new file mode 100644 index 0000000..9033ba7 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -0,0 +1,1053 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.state; + +import com.google.common.collect.Lists; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StateNamespace; +import org.apache.beam.runners.core.StateTag; +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.InstantCoder; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.CombineWithContext; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.OutputTimeFn; +import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.util.CombineContextFactory; +import org.apache.beam.sdk.util.state.AccumulatorCombiningState; +import org.apache.beam.sdk.util.state.BagState; +import org.apache.beam.sdk.util.state.MapState; +import org.apache.beam.sdk.util.state.ReadableState; +import org.apache.beam.sdk.util.state.SetState; +import org.apache.beam.sdk.util.state.State; +import org.apache.beam.sdk.util.state.StateContext; +import org.apache.beam.sdk.util.state.StateContexts; +import org.apache.beam.sdk.util.state.ValueState; +import org.apache.beam.sdk.util.state.WatermarkHoldState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.runtime.state.KeyedStateBackend; +import org.joda.time.Instant; + +/** + * {@link StateInternals} that uses a Flink {@link KeyedStateBackend} to manage state. + * + * <p>Note: In the Flink streaming runner the key is always encoded + * using an {@link Coder} and stored in a {@link ByteBuffer}. + */ +public class FlinkStateInternals<K> implements StateInternals<K> { + + private final KeyedStateBackend<ByteBuffer> flinkStateBackend; + private Coder<K> keyCoder; + + // on recovery, these will no be properly set because we don't + // know which watermark hold states there are in the Flink State Backend + private final Map<String, Instant> watermarkHolds = new HashMap<>(); + + public FlinkStateInternals(KeyedStateBackend<ByteBuffer> flinkStateBackend, Coder<K> keyCoder) { + this.flinkStateBackend = flinkStateBackend; + this.keyCoder = keyCoder; + } + + /** + * Returns the minimum over all watermark holds. + */ + public Instant watermarkHold() { + long min = Long.MAX_VALUE; + for (Instant hold: watermarkHolds.values()) { + min = Math.min(min, hold.getMillis()); + } + return new Instant(min); + } + + @Override + public K getKey() { + ByteBuffer keyBytes = flinkStateBackend.getCurrentKey(); + try { + return CoderUtils.decodeFromByteArray(keyCoder, keyBytes.array()); + } catch (CoderException e) { + throw new RuntimeException("Error decoding key.", e); + } + } + + @Override + public <T extends State> T state( + final StateNamespace namespace, + StateTag<? super K, T> address) { + + return state(namespace, address, StateContexts.nullContext()); + } + + @Override + public <T extends State> T state( + final StateNamespace namespace, + StateTag<? super K, T> address, + final StateContext<?> context) { + + return address.bind(new StateTag.StateBinder<K>() { + + @Override + public <T> ValueState<T> bindValue( + StateTag<? super K, ValueState<T>> address, + Coder<T> coder) { + + return new FlinkValueState<>(flinkStateBackend, address, namespace, coder); + } + + @Override + public <T> BagState<T> bindBag( + StateTag<? super K, BagState<T>> address, + Coder<T> elemCoder) { + + return new FlinkBagState<>(flinkStateBackend, address, namespace, elemCoder); + } + + @Override + public <T> SetState<T> bindSet( + StateTag<? super K, SetState<T>> address, + Coder<T> elemCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", SetState.class.getSimpleName())); + } + + @Override + public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( + StateTag<? super K, MapState<KeyT, ValueT>> spec, + Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", MapState.class.getSimpleName())); + } + + @Override + public <InputT, AccumT, OutputT> + AccumulatorCombiningState<InputT, AccumT, OutputT> + bindCombiningValue( + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + Combine.CombineFn<InputT, AccumT, OutputT> combineFn) { + + return new FlinkAccumulatorCombiningState<>( + flinkStateBackend, address, combineFn, namespace, accumCoder); + } + + @Override + public <InputT, AccumT, OutputT> + AccumulatorCombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValue( + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + final Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) { + return new FlinkKeyedAccumulatorCombiningState<>( + flinkStateBackend, + address, + combineFn, + namespace, + accumCoder, + FlinkStateInternals.this); + } + + @Override + public <InputT, AccumT, OutputT> + AccumulatorCombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValueWithContext( + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + CombineWithContext.KeyedCombineFnWithContext< + ? super K, InputT, AccumT, OutputT> combineFn) { + return new FlinkAccumulatorCombiningStateWithContext<>( + flinkStateBackend, + address, + combineFn, + namespace, + accumCoder, + FlinkStateInternals.this, + CombineContextFactory.createFromStateContext(context)); + } + + @Override + public <W extends BoundedWindow> WatermarkHoldState<W> bindWatermark( + StateTag<? super K, WatermarkHoldState<W>> address, + OutputTimeFn<? super W> outputTimeFn) { + + return new FlinkWatermarkHoldState<>( + flinkStateBackend, FlinkStateInternals.this, address, namespace, outputTimeFn); + } + }); + } + + private static class FlinkValueState<K, T> implements ValueState<T> { + + private final StateNamespace namespace; + private final StateTag<? super K, ValueState<T>> address; + private final ValueStateDescriptor<T> flinkStateDescriptor; + private final KeyedStateBackend<ByteBuffer> flinkStateBackend; + + FlinkValueState( + KeyedStateBackend<ByteBuffer> flinkStateBackend, + StateTag<? super K, ValueState<T>> address, + StateNamespace namespace, + Coder<T> coder) { + + this.namespace = namespace; + this.address = address; + this.flinkStateBackend = flinkStateBackend; + + CoderTypeInformation<T> typeInfo = new CoderTypeInformation<>(coder); + + flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); + } + + @Override + public void write(T input) { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).update(input); + } catch (Exception e) { + throw new RuntimeException("Error updating state.", e); + } + } + + @Override + public ValueState<T> readLater() { + return this; + } + + @Override + public T read() { + try { + return flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).value(); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public void clear() { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).clear(); + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkValueState<?, ?> that = (FlinkValueState<?, ?>) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + } + + private static class FlinkBagState<K, T> implements BagState<T> { + + private final StateNamespace namespace; + private final StateTag<? super K, BagState<T>> address; + private final ListStateDescriptor<T> flinkStateDescriptor; + private final KeyedStateBackend<ByteBuffer> flinkStateBackend; + + FlinkBagState( + KeyedStateBackend<ByteBuffer> flinkStateBackend, + StateTag<? super K, BagState<T>> address, + StateNamespace namespace, + Coder<T> coder) { + + this.namespace = namespace; + this.address = address; + this.flinkStateBackend = flinkStateBackend; + + CoderTypeInformation<T> typeInfo = new CoderTypeInformation<>(coder); + + flinkStateDescriptor = new ListStateDescriptor<>(address.getId(), typeInfo); + } + + @Override + public void add(T input) { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).add(input); + } catch (Exception e) { + throw new RuntimeException("Error adding to bag state.", e); + } + } + + @Override + public BagState<T> readLater() { + return this; + } + + @Override + public Iterable<T> read() { + try { + Iterable<T> result = flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).get(); + + return result != null ? result : Collections.<T>emptyList(); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public Boolean read() { + try { + Iterable<T> result = flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).get(); + return result == null; + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + + } + + @Override + public ReadableState<Boolean> readLater() { + return this; + } + }; + } + + @Override + public void clear() { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).clear(); + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkBagState<?, ?> that = (FlinkBagState<?, ?>) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + } + + private static class FlinkAccumulatorCombiningState<K, InputT, AccumT, OutputT> + implements AccumulatorCombiningState<InputT, AccumT, OutputT> { + + private final StateNamespace namespace; + private final StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address; + private final Combine.CombineFn<InputT, AccumT, OutputT> combineFn; + private final ValueStateDescriptor<AccumT> flinkStateDescriptor; + private final KeyedStateBackend<ByteBuffer> flinkStateBackend; + + FlinkAccumulatorCombiningState( + KeyedStateBackend<ByteBuffer> flinkStateBackend, + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Combine.CombineFn<InputT, AccumT, OutputT> combineFn, + StateNamespace namespace, + Coder<AccumT> accumCoder) { + + this.namespace = namespace; + this.address = address; + this.combineFn = combineFn; + this.flinkStateBackend = flinkStateBackend; + + CoderTypeInformation<AccumT> typeInfo = new CoderTypeInformation<>(accumCoder); + + flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); + } + + @Override + public AccumulatorCombiningState<InputT, AccumT, OutputT> readLater() { + return this; + } + + @Override + public void add(InputT value) { + try { + org.apache.flink.api.common.state.ValueState<AccumT> state = + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor); + + AccumT current = state.value(); + if (current == null) { + current = combineFn.createAccumulator(); + } + current = combineFn.addInput(current, value); + state.update(current); + } catch (Exception e) { + throw new RuntimeException("Error adding to state." , e); + } + } + + @Override + public void addAccum(AccumT accum) { + try { + org.apache.flink.api.common.state.ValueState<AccumT> state = + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor); + + AccumT current = state.value(); + if (current == null) { + state.update(accum); + } else { + current = combineFn.mergeAccumulators(Lists.newArrayList(current, accum)); + state.update(current); + } + } catch (Exception e) { + throw new RuntimeException("Error adding to state.", e); + } + } + + @Override + public AccumT getAccum() { + try { + return flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).value(); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public AccumT mergeAccumulators(Iterable<AccumT> accumulators) { + return combineFn.mergeAccumulators(accumulators); + } + + @Override + public OutputT read() { + try { + org.apache.flink.api.common.state.ValueState<AccumT> state = + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor); + + AccumT accum = state.value(); + if (accum != null) { + return combineFn.extractOutput(accum); + } else { + return combineFn.extractOutput(combineFn.createAccumulator()); + } + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public Boolean read() { + try { + return flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).value() == null; + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + + } + + @Override + public ReadableState<Boolean> readLater() { + return this; + } + }; + } + + @Override + public void clear() { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).clear(); + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkAccumulatorCombiningState<?, ?, ?, ?> that = + (FlinkAccumulatorCombiningState<?, ?, ?, ?>) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + } + + private static class FlinkKeyedAccumulatorCombiningState<K, InputT, AccumT, OutputT> + implements AccumulatorCombiningState<InputT, AccumT, OutputT> { + + private final StateNamespace namespace; + private final StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address; + private final Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn; + private final ValueStateDescriptor<AccumT> flinkStateDescriptor; + private final KeyedStateBackend<ByteBuffer> flinkStateBackend; + private final FlinkStateInternals<K> flinkStateInternals; + + FlinkKeyedAccumulatorCombiningState( + KeyedStateBackend<ByteBuffer> flinkStateBackend, + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn, + StateNamespace namespace, + Coder<AccumT> accumCoder, + FlinkStateInternals<K> flinkStateInternals) { + + this.namespace = namespace; + this.address = address; + this.combineFn = combineFn; + this.flinkStateBackend = flinkStateBackend; + this.flinkStateInternals = flinkStateInternals; + + CoderTypeInformation<AccumT> typeInfo = new CoderTypeInformation<>(accumCoder); + + flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); + } + + @Override + public AccumulatorCombiningState<InputT, AccumT, OutputT> readLater() { + return this; + } + + @Override + public void add(InputT value) { + try { + org.apache.flink.api.common.state.ValueState<AccumT> state = + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor); + + AccumT current = state.value(); + if (current == null) { + current = combineFn.createAccumulator(flinkStateInternals.getKey()); + } + current = combineFn.addInput(flinkStateInternals.getKey(), current, value); + state.update(current); + } catch (Exception e) { + throw new RuntimeException("Error adding to state." , e); + } + } + + @Override + public void addAccum(AccumT accum) { + try { + org.apache.flink.api.common.state.ValueState<AccumT> state = + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor); + + AccumT current = state.value(); + if (current == null) { + state.update(accum); + } else { + current = combineFn.mergeAccumulators( + flinkStateInternals.getKey(), + Lists.newArrayList(current, accum)); + state.update(current); + } + } catch (Exception e) { + throw new RuntimeException("Error adding to state.", e); + } + } + + @Override + public AccumT getAccum() { + try { + return flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).value(); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public AccumT mergeAccumulators(Iterable<AccumT> accumulators) { + return combineFn.mergeAccumulators(flinkStateInternals.getKey(), accumulators); + } + + @Override + public OutputT read() { + try { + org.apache.flink.api.common.state.ValueState<AccumT> state = + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor); + + AccumT accum = state.value(); + if (accum != null) { + return combineFn.extractOutput(flinkStateInternals.getKey(), accum); + } else { + return combineFn.extractOutput( + flinkStateInternals.getKey(), + combineFn.createAccumulator(flinkStateInternals.getKey())); + } + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public Boolean read() { + try { + return flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).value() == null; + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + + } + + @Override + public ReadableState<Boolean> readLater() { + return this; + } + }; + } + + @Override + public void clear() { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).clear(); + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkKeyedAccumulatorCombiningState<?, ?, ?, ?> that = + (FlinkKeyedAccumulatorCombiningState<?, ?, ?, ?>) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + } + + private static class FlinkAccumulatorCombiningStateWithContext<K, InputT, AccumT, OutputT> + implements AccumulatorCombiningState<InputT, AccumT, OutputT> { + + private final StateNamespace namespace; + private final StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address; + private final CombineWithContext.KeyedCombineFnWithContext< + ? super K, InputT, AccumT, OutputT> combineFn; + private final ValueStateDescriptor<AccumT> flinkStateDescriptor; + private final KeyedStateBackend<ByteBuffer> flinkStateBackend; + private final FlinkStateInternals<K> flinkStateInternals; + private final CombineWithContext.Context context; + + FlinkAccumulatorCombiningStateWithContext( + KeyedStateBackend<ByteBuffer> flinkStateBackend, + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + CombineWithContext.KeyedCombineFnWithContext< + ? super K, InputT, AccumT, OutputT> combineFn, + StateNamespace namespace, + Coder<AccumT> accumCoder, + FlinkStateInternals<K> flinkStateInternals, + CombineWithContext.Context context) { + + this.namespace = namespace; + this.address = address; + this.combineFn = combineFn; + this.flinkStateBackend = flinkStateBackend; + this.flinkStateInternals = flinkStateInternals; + this.context = context; + + CoderTypeInformation<AccumT> typeInfo = new CoderTypeInformation<>(accumCoder); + + flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); + } + + @Override + public AccumulatorCombiningState<InputT, AccumT, OutputT> readLater() { + return this; + } + + @Override + public void add(InputT value) { + try { + org.apache.flink.api.common.state.ValueState<AccumT> state = + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor); + + AccumT current = state.value(); + if (current == null) { + current = combineFn.createAccumulator(flinkStateInternals.getKey(), context); + } + current = combineFn.addInput(flinkStateInternals.getKey(), current, value, context); + state.update(current); + } catch (Exception e) { + throw new RuntimeException("Error adding to state." , e); + } + } + + @Override + public void addAccum(AccumT accum) { + try { + org.apache.flink.api.common.state.ValueState<AccumT> state = + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor); + + AccumT current = state.value(); + if (current == null) { + state.update(accum); + } else { + current = combineFn.mergeAccumulators( + flinkStateInternals.getKey(), + Lists.newArrayList(current, accum), + context); + state.update(current); + } + } catch (Exception e) { + throw new RuntimeException("Error adding to state.", e); + } + } + + @Override + public AccumT getAccum() { + try { + return flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).value(); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public AccumT mergeAccumulators(Iterable<AccumT> accumulators) { + return combineFn.mergeAccumulators(flinkStateInternals.getKey(), accumulators, context); + } + + @Override + public OutputT read() { + try { + org.apache.flink.api.common.state.ValueState<AccumT> state = + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor); + + AccumT accum = state.value(); + return combineFn.extractOutput(flinkStateInternals.getKey(), accum, context); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public Boolean read() { + try { + return flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).value() == null; + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + + } + + @Override + public ReadableState<Boolean> readLater() { + return this; + } + }; + } + + @Override + public void clear() { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).clear(); + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkAccumulatorCombiningStateWithContext<?, ?, ?, ?> that = + (FlinkAccumulatorCombiningStateWithContext<?, ?, ?, ?>) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + } + + private static class FlinkWatermarkHoldState<K, W extends BoundedWindow> + implements WatermarkHoldState<W> { + private final StateTag<? super K, WatermarkHoldState<W>> address; + private final OutputTimeFn<? super W> outputTimeFn; + private final StateNamespace namespace; + private final KeyedStateBackend<ByteBuffer> flinkStateBackend; + private final FlinkStateInternals<K> flinkStateInternals; + private final ValueStateDescriptor<Instant> flinkStateDescriptor; + + public FlinkWatermarkHoldState( + KeyedStateBackend<ByteBuffer> flinkStateBackend, + FlinkStateInternals<K> flinkStateInternals, + StateTag<? super K, WatermarkHoldState<W>> address, + StateNamespace namespace, + OutputTimeFn<? super W> outputTimeFn) { + this.address = address; + this.outputTimeFn = outputTimeFn; + this.namespace = namespace; + this.flinkStateBackend = flinkStateBackend; + this.flinkStateInternals = flinkStateInternals; + + CoderTypeInformation<Instant> typeInfo = new CoderTypeInformation<>(InstantCoder.of()); + flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); + } + + @Override + public OutputTimeFn<? super W> getOutputTimeFn() { + return outputTimeFn; + } + + @Override + public WatermarkHoldState<W> readLater() { + return this; + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public Boolean read() { + try { + return flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).value() == null; + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState<Boolean> readLater() { + return this; + } + }; + + } + + @Override + public void add(Instant value) { + try { + org.apache.flink.api.common.state.ValueState<Instant> state = + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor); + + Instant current = state.value(); + if (current == null) { + state.update(value); + flinkStateInternals.watermarkHolds.put(namespace.stringKey(), value); + } else { + Instant combined = outputTimeFn.combine(current, value); + state.update(combined); + flinkStateInternals.watermarkHolds.put(namespace.stringKey(), combined); + } + } catch (Exception e) { + throw new RuntimeException("Error updating state.", e); + } + } + + @Override + public Instant read() { + try { + org.apache.flink.api.common.state.ValueState<Instant> state = + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor); + return state.value(); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public void clear() { + flinkStateInternals.watermarkHolds.remove(namespace.stringKey()); + try { + org.apache.flink.api.common.state.ValueState<Instant> state = + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor); + state.clear(); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkWatermarkHoldState<?, ?> that = (FlinkWatermarkHoldState<?, ?>) o; + + if (!address.equals(that.address)) { + return false; + } + if (!outputTimeFn.equals(that.outputTimeFn)) { + return false; + } + return namespace.equals(that.namespace); + + } + + @Override + public int hashCode() { + int result = address.hashCode(); + result = 31 * result + outputTimeFn.hashCode(); + result = 31 * result + namespace.hashCode(); + return result; + } + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/KeyGroupCheckpointedOperator.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/KeyGroupCheckpointedOperator.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/KeyGroupCheckpointedOperator.java new file mode 100644 index 0000000..b38a520 --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/KeyGroupCheckpointedOperator.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.state; + +import java.io.DataOutputStream; + +/** + * This interface is used to checkpoint key-groups state. + */ +public interface KeyGroupCheckpointedOperator extends KeyGroupRestoringOperator{ + /** + * Snapshots the state for a given {@code keyGroupIdx}. + * + * <p>AbstractStreamOperator would call this hook in + * AbstractStreamOperator.snapshotState() while iterating over the key groups. + * @param keyGroupIndex the id of the key-group to be put in the snapshot. + * @param out the stream to write to. + */ + void snapshotKeyGroupState(int keyGroupIndex, DataOutputStream out) throws Exception; +} http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/KeyGroupRestoringOperator.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/KeyGroupRestoringOperator.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/KeyGroupRestoringOperator.java new file mode 100644 index 0000000..2bdfc6e --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/KeyGroupRestoringOperator.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.state; + +import java.io.DataInputStream; + +/** + * This interface is used to restore key-groups state. + */ +public interface KeyGroupRestoringOperator { + /** + * Restore the state for a given {@code keyGroupIndex}. + * @param keyGroupIndex the id of the key-group to be put in the snapshot. + * @param in the stream to read from. + */ + void restoreKeyGroupState(int keyGroupIndex, DataInputStream in) throws Exception; +} http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/package-info.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/package-info.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/package-info.java new file mode 100644 index 0000000..0004e9e --- /dev/null +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Internal state implementation of the Beam runner for Apache Flink. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.state; http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java index d07861c..2cb3dd3 100644 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/PipelineOptionsTest.java @@ -26,6 +26,7 @@ import java.util.Collections; import java.util.HashMap; import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions; import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator; +import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.options.Default; import org.apache.beam.sdk.options.Description; import org.apache.beam.sdk.options.PipelineOptions; @@ -110,12 +111,12 @@ public class PipelineOptionsTest { @Test(expected = Exception.class) public void parDoBaseClassPipelineOptionsNullTest() { - DoFnOperator<Object, Object, Object> doFnOperator = new DoFnOperator<>( + DoFnOperator<String, String, String> doFnOperator = new DoFnOperator<>( new TestDoFn(), - TypeInformation.of(new TypeHint<WindowedValue<Object>>() {}), - new TupleTag<>("main-output"), + WindowedValue.getValueOnlyCoder(StringUtf8Coder.of()), + new TupleTag<String>("main-output"), Collections.<TupleTag<?>>emptyList(), - new DoFnOperator.DefaultOutputManagerFactory<>(), + new DoFnOperator.DefaultOutputManagerFactory<String>(), WindowingStrategy.globalDefault(), new HashMap<Integer, PCollectionView<?>>(), Collections.<PCollectionView<?>>emptyList(), @@ -130,12 +131,12 @@ public class PipelineOptionsTest { @Test public void parDoBaseClassPipelineOptionsSerializationTest() throws Exception { - DoFnOperator<Object, Object, Object> doFnOperator = new DoFnOperator<>( + DoFnOperator<String, String, String> doFnOperator = new DoFnOperator<>( new TestDoFn(), - TypeInformation.of(new TypeHint<WindowedValue<Object>>() {}), - new TupleTag<>("main-output"), + WindowedValue.getValueOnlyCoder(StringUtf8Coder.of()), + new TupleTag<String>("main-output"), Collections.<TupleTag<?>>emptyList(), - new DoFnOperator.DefaultOutputManagerFactory<>(), + new DoFnOperator.DefaultOutputManagerFactory<String>(), WindowingStrategy.globalDefault(), new HashMap<Integer, PCollectionView<?>>(), Collections.<PCollectionView<?>>emptyList(), @@ -148,8 +149,12 @@ public class PipelineOptionsTest { DoFnOperator<Object, Object, Object> deserialized = (DoFnOperator<Object, Object, Object>) SerializationUtils.deserialize(serialized); + TypeInformation<WindowedValue<Object>> typeInformation = TypeInformation.of( + new TypeHint<WindowedValue<Object>>() {}); + OneInputStreamOperatorTestHarness<WindowedValue<Object>, Object> testHarness = - new OneInputStreamOperatorTestHarness<>(deserialized, new ExecutionConfig()); + new OneInputStreamOperatorTestHarness<>(deserialized, + typeInformation.createSerializer(new ExecutionConfig())); testHarness.open(); @@ -166,7 +171,7 @@ public class PipelineOptionsTest { } - private static class TestDoFn extends DoFn<Object, Object> { + private static class TestDoFn extends DoFn<String, String> { @ProcessElement public void processElement(ProcessContext c) throws Exception { http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java index 3598d10..7d14a87 100644 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java @@ -29,8 +29,8 @@ import java.util.Collections; import java.util.HashMap; import javax.annotation.Nullable; import org.apache.beam.runners.flink.FlinkPipelineOptions; -import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator; +import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PCollectionViewTesting; @@ -44,12 +44,15 @@ import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness; import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; import org.apache.flink.streaming.util.TwoInputStreamOperatorTestHarness; import org.joda.time.Duration; import org.joda.time.Instant; -import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -89,14 +92,11 @@ public class DoFnOperatorTest { WindowedValue.ValueOnlyWindowedValueCoder<String> windowedValueCoder = WindowedValue.getValueOnlyCoder(StringUtf8Coder.of()); - CoderTypeInformation<WindowedValue<String>> coderTypeInfo = - new CoderTypeInformation<>(windowedValueCoder); - TupleTag<String> outputTag = new TupleTag<>("main-output"); DoFnOperator<String, String, String> doFnOperator = new DoFnOperator<>( new IdentityDoFn<String>(), - coderTypeInfo, + windowedValueCoder, outputTag, Collections.<TupleTag<?>>emptyList(), new DoFnOperator.DefaultOutputManagerFactory(), @@ -127,9 +127,6 @@ public class DoFnOperatorTest { WindowedValue.ValueOnlyWindowedValueCoder<String> windowedValueCoder = WindowedValue.getValueOnlyCoder(StringUtf8Coder.of()); - CoderTypeInformation<WindowedValue<String>> coderTypeInfo = - new CoderTypeInformation<>(windowedValueCoder); - TupleTag<String> mainOutput = new TupleTag<>("main-output"); TupleTag<String> sideOutput1 = new TupleTag<>("side-output-1"); TupleTag<String> sideOutput2 = new TupleTag<>("side-output-2"); @@ -141,7 +138,7 @@ public class DoFnOperatorTest { DoFnOperator<String, String, RawUnionValue> doFnOperator = new DoFnOperator<>( new MultiOutputDoFn(sideOutput1, sideOutput2), - coderTypeInfo, + windowedValueCoder, mainOutput, ImmutableList.<TupleTag<?>>of(sideOutput1, sideOutput2), new DoFnOperator.MultiOutputOutputManagerFactory(outputMapping), @@ -172,26 +169,11 @@ public class DoFnOperatorTest { testHarness.close(); } - /** - * For now, this test doesn't work because {@link TwoInputStreamOperatorTestHarness} is not - * sufficiently well equipped to handle more complex operators that require a state backend. - * We have to revisit this once we update to a newer version of Flink and also add some more - * tests that verify pushback behaviour and watermark hold behaviour. - * - * <p>The behaviour that we would test here is also exercised by the - * {@link org.apache.beam.sdk.testing.RunnableOnService} tests, so the code is not untested. - */ - @Test - @Ignore - @SuppressWarnings("unchecked") - public void testSideInputs() throws Exception { + public void testSideInputs(boolean keyed) throws Exception { WindowedValue.ValueOnlyWindowedValueCoder<String> windowedValueCoder = WindowedValue.getValueOnlyCoder(StringUtf8Coder.of()); - CoderTypeInformation<WindowedValue<String>> coderTypeInfo = - new CoderTypeInformation<>(windowedValueCoder); - TupleTag<String> outputTag = new TupleTag<>("main-output"); ImmutableMap<Integer, PCollectionView<?>> sideInputMapping = @@ -200,46 +182,92 @@ public class DoFnOperatorTest { .put(2, view2) .build(); + Coder<String> keyCoder = null; + if (keyed) { + keyCoder = StringUtf8Coder.of(); + } + DoFnOperator<String, String, String> doFnOperator = new DoFnOperator<>( new IdentityDoFn<String>(), - coderTypeInfo, + windowedValueCoder, outputTag, Collections.<TupleTag<?>>emptyList(), - new DoFnOperator.DefaultOutputManagerFactory(), + new DoFnOperator.DefaultOutputManagerFactory<String>(), WindowingStrategy.globalDefault(), sideInputMapping, /* side-input mapping */ ImmutableList.<PCollectionView<?>>of(view1, view2), /* side inputs */ PipelineOptionsFactory.as(FlinkPipelineOptions.class), - null); + keyCoder); TwoInputStreamOperatorTestHarness<WindowedValue<String>, RawUnionValue, String> testHarness = new TwoInputStreamOperatorTestHarness<>(doFnOperator); + if (keyed) { + // we use a dummy key for the second input since it is considered to be broadcast + testHarness = new KeyedTwoInputStreamOperatorTestHarness<>( + doFnOperator, + new StringKeySelector(), + new DummyKeySelector(), + BasicTypeInfo.STRING_TYPE_INFO); + } + testHarness.open(); IntervalWindow firstWindow = new IntervalWindow(new Instant(0), new Instant(100)); + IntervalWindow secondWindow = new IntervalWindow(new Instant(0), new Instant(500)); - // push in some side-input elements + // test the keep of sideInputs events testHarness.processElement2( new StreamRecord<>( new RawUnionValue( 1, valuesInWindow(ImmutableList.of("hello", "ciao"), new Instant(0), firstWindow)))); - testHarness.processElement2( new StreamRecord<>( new RawUnionValue( 2, - valuesInWindow(ImmutableList.of("foo", "bar"), new Instant(0), firstWindow)))); + valuesInWindow(ImmutableList.of("foo", "bar"), new Instant(0), secondWindow)))); // push in a regular elements - testHarness.processElement1(new StreamRecord<>(WindowedValue.valueInGlobalWindow("Hello"))); + WindowedValue<String> helloElement = valueInWindow("Hello", new Instant(0), firstWindow); + WindowedValue<String> worldElement = valueInWindow("World", new Instant(1000), firstWindow); + testHarness.processElement1(new StreamRecord<>(helloElement)); + testHarness.processElement1(new StreamRecord<>(worldElement)); + + // test the keep of pushed-back events + testHarness.processElement2( + new StreamRecord<>( + new RawUnionValue( + 1, + valuesInWindow(ImmutableList.of("hello", "ciao"), + new Instant(1000), firstWindow)))); + testHarness.processElement2( + new StreamRecord<>( + new RawUnionValue( + 2, + valuesInWindow(ImmutableList.of("foo", "bar"), new Instant(1000), secondWindow)))); assertThat( this.<String>stripStreamRecordFromWindowedValue(testHarness.getOutput()), - contains(WindowedValue.valueInGlobalWindow("Hello"))); + contains(helloElement, worldElement)); testHarness.close(); + + } + + /** + * {@link TwoInputStreamOperatorTestHarness} support OperatorStateBackend, + * but don't support KeyedStateBackend. So we just test sideInput of normal ParDo. + */ + @Test + @SuppressWarnings("unchecked") + public void testNormalParDoSideInputs() throws Exception { + testSideInputs(false); + } + + @Test + public void testKeyedSideInputs() throws Exception { + testSideInputs(true); } private <T> Iterable<WindowedValue<T>> stripStreamRecordFromWindowedValue( @@ -325,4 +353,17 @@ public class DoFnOperatorTest { } + private static class DummyKeySelector implements KeySelector<RawUnionValue, String> { + @Override + public String getKey(RawUnionValue stringWindowedValue) throws Exception { + return "dummy_key"; + } + } + + private static class StringKeySelector implements KeySelector<WindowedValue<String>, String> { + @Override + public String getKey(WindowedValue<String> stringWindowedValue) throws Exception { + return stringWindowedValue.getValue(); + } + } } http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java new file mode 100644 index 0000000..db02cb3 --- /dev/null +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.streaming; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThat; + +import java.util.Arrays; +import org.apache.beam.runners.core.StateMerging; +import org.apache.beam.runners.core.StateNamespace; +import org.apache.beam.runners.core.StateNamespaceForTest; +import org.apache.beam.runners.core.StateTag; +import org.apache.beam.runners.core.StateTags; +import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkBroadcastStateInternals; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.util.state.AccumulatorCombiningState; +import org.apache.beam.sdk.util.state.BagState; +import org.apache.beam.sdk.util.state.CombiningState; +import org.apache.beam.sdk.util.state.ReadableState; +import org.apache.beam.sdk.util.state.ValueState; +import org.apache.flink.runtime.operators.testutils.DummyEnvironment; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.hamcrest.Matchers; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link FlinkBroadcastStateInternals}. This is based on the tests for + * {@code InMemoryStateInternals}. + */ +@RunWith(JUnit4.class) +public class FlinkBroadcastStateInternalsTest { + private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); + private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2"); + private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3"); + + private static final StateTag<Object, ValueState<String>> STRING_VALUE_ADDR = + StateTags.value("stringValue", StringUtf8Coder.of()); + private static final StateTag<Object, AccumulatorCombiningState<Integer, int[], Integer>> + SUM_INTEGER_ADDR = StateTags.combiningValueFromInputInternal( + "sumInteger", VarIntCoder.of(), Sum.ofIntegers()); + private static final StateTag<Object, BagState<String>> STRING_BAG_ADDR = + StateTags.bag("stringBag", StringUtf8Coder.of()); + + FlinkBroadcastStateInternals<String> underTest; + + @Before + public void initStateInternals() { + MemoryStateBackend backend = new MemoryStateBackend(); + try { + OperatorStateBackend operatorStateBackend = + backend.createOperatorStateBackend(new DummyEnvironment("test", 1, 0), ""); + underTest = new FlinkBroadcastStateInternals<>(1, operatorStateBackend); + + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Test + public void testValue() throws Exception { + ValueState<String> value = underTest.state(NAMESPACE_1, STRING_VALUE_ADDR); + + assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value); + assertNotEquals( + underTest.state(NAMESPACE_2, STRING_VALUE_ADDR), + value); + + assertThat(value.read(), Matchers.nullValue()); + value.write("hello"); + assertThat(value.read(), Matchers.equalTo("hello")); + value.write("world"); + assertThat(value.read(), Matchers.equalTo("world")); + + value.clear(); + assertThat(value.read(), Matchers.nullValue()); + assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value); + + } + + @Test + public void testBag() throws Exception { + BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + + assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR))); + + assertThat(value.read(), Matchers.emptyIterable()); + value.add("hello"); + assertThat(value.read(), Matchers.containsInAnyOrder("hello")); + + value.add("world"); + assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world")); + + value.clear(); + assertThat(value.read(), Matchers.emptyIterable()); + assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value); + + } + + @Test + public void testBagIsEmpty() throws Exception { + BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState<Boolean> readFuture = value.isEmpty(); + value.add("hello"); + assertThat(readFuture.read(), Matchers.is(false)); + + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } + + @Test + public void testMergeBagIntoSource() throws Exception { + BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); + + bag1.add("Hello"); + bag2.add("World"); + bag1.add("!"); + + StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1); + + // Reading the merged bag gets both the contents + assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); + assertThat(bag2.read(), Matchers.emptyIterable()); + } + + @Test + public void testMergeBagIntoNewNamespace() throws Exception { + BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); + BagState<String> bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR); + + bag1.add("Hello"); + bag2.add("World"); + bag1.add("!"); + + StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3); + + // Reading the merged bag gets both the contents + assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); + assertThat(bag1.read(), Matchers.emptyIterable()); + assertThat(bag2.read(), Matchers.emptyIterable()); + } + + @Test + public void testCombiningValue() throws Exception { + CombiningState<Integer, Integer> value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + + // State instances are cached, but depend on the namespace. + assertEquals(value, underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR))); + + assertThat(value.read(), Matchers.equalTo(0)); + value.add(2); + assertThat(value.read(), Matchers.equalTo(2)); + + value.add(3); + assertThat(value.read(), Matchers.equalTo(5)); + + value.clear(); + assertThat(value.read(), Matchers.equalTo(0)); + assertEquals(underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR), value); + } + + @Test + public void testCombiningIsEmpty() throws Exception { + CombiningState<Integer, Integer> value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState<Boolean> readFuture = value.isEmpty(); + value.add(5); + assertThat(readFuture.read(), Matchers.is(false)); + + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } + + @Test + public void testMergeCombiningValueIntoSource() throws Exception { + AccumulatorCombiningState<Integer, int[], Integer> value1 = + underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + AccumulatorCombiningState<Integer, int[], Integer> value2 = + underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); + + value1.add(5); + value2.add(10); + value1.add(6); + + assertThat(value1.read(), Matchers.equalTo(11)); + assertThat(value2.read(), Matchers.equalTo(10)); + + // Merging clears the old values and updates the result value. + StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1); + + assertThat(value1.read(), Matchers.equalTo(21)); + assertThat(value2.read(), Matchers.equalTo(0)); + } + + @Test + public void testMergeCombiningValueIntoNewNamespace() throws Exception { + AccumulatorCombiningState<Integer, int[], Integer> value1 = + underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + AccumulatorCombiningState<Integer, int[], Integer> value2 = + underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); + AccumulatorCombiningState<Integer, int[], Integer> value3 = + underTest.state(NAMESPACE_3, SUM_INTEGER_ADDR); + + value1.add(5); + value2.add(10); + value1.add(6); + + StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3); + + // Merging clears the old values and updates the result value. + assertThat(value1.read(), Matchers.equalTo(0)); + assertThat(value2.read(), Matchers.equalTo(0)); + assertThat(value3.read(), Matchers.equalTo(21)); + } + +} http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java new file mode 100644 index 0000000..5433d07 --- /dev/null +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.streaming; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.nio.ByteBuffer; +import java.util.Arrays; +import org.apache.beam.runners.core.StateMerging; +import org.apache.beam.runners.core.StateNamespace; +import org.apache.beam.runners.core.StateNamespaceForTest; +import org.apache.beam.runners.core.StateTag; +import org.apache.beam.runners.core.StateTags; +import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkKeyGroupStateInternals; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.util.state.BagState; +import org.apache.beam.sdk.util.state.ReadableState; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.java.typeutils.GenericTypeInfo; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.operators.testutils.DummyEnvironment; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.apache.flink.streaming.api.operators.KeyContext; +import org.hamcrest.Matchers; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link FlinkKeyGroupStateInternals}. This is based on the tests for + * {@code InMemoryStateInternals}. + */ +@RunWith(JUnit4.class) +public class FlinkKeyGroupStateInternalsTest { + private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); + private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2"); + private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3"); + + private static final StateTag<Object, BagState<String>> STRING_BAG_ADDR = + StateTags.bag("stringBag", StringUtf8Coder.of()); + + FlinkKeyGroupStateInternals<String> underTest; + private KeyedStateBackend keyedStateBackend; + + @Before + public void initStateInternals() { + try { + keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 1)); + underTest = new FlinkKeyGroupStateInternals<>(StringUtf8Coder.of(), keyedStateBackend); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private KeyedStateBackend getKeyedStateBackend(int numberOfKeyGroups, + KeyGroupRange keyGroupRange) { + MemoryStateBackend backend = new MemoryStateBackend(); + try { + AbstractKeyedStateBackend<ByteBuffer> keyedStateBackend = backend.createKeyedStateBackend( + new DummyEnvironment("test", 1, 0), + new JobID(), + "test_op", + new GenericTypeInfo<>(ByteBuffer.class).createSerializer(new ExecutionConfig()), + numberOfKeyGroups, + keyGroupRange, + new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID())); + keyedStateBackend.setCurrentKey(ByteBuffer.wrap( + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "1"))); + return keyedStateBackend; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Test + public void testBag() throws Exception { + BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + + assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR))); + + assertThat(value.read(), Matchers.emptyIterable()); + value.add("hello"); + assertThat(value.read(), Matchers.containsInAnyOrder("hello")); + + value.add("world"); + assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world")); + + value.clear(); + assertThat(value.read(), Matchers.emptyIterable()); + assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value); + + } + + @Test + public void testBagIsEmpty() throws Exception { + BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState<Boolean> readFuture = value.isEmpty(); + value.add("hello"); + assertThat(readFuture.read(), Matchers.is(false)); + + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } + + @Test + public void testMergeBagIntoSource() throws Exception { + BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); + + bag1.add("Hello"); + bag2.add("World"); + bag1.add("!"); + + StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1); + + // Reading the merged bag gets both the contents + assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); + assertThat(bag2.read(), Matchers.emptyIterable()); + } + + @Test + public void testMergeBagIntoNewNamespace() throws Exception { + BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); + BagState<String> bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR); + + bag1.add("Hello"); + bag2.add("World"); + bag1.add("!"); + + StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3); + + // Reading the merged bag gets both the contents + assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); + assertThat(bag1.read(), Matchers.emptyIterable()); + assertThat(bag2.read(), Matchers.emptyIterable()); + } + + @Test + public void testKeyGroupAndCheckpoint() throws Exception { + // assign to keyGroup 0 + ByteBuffer key0 = ByteBuffer.wrap( + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "11111111")); + // assign to keyGroup 1 + ByteBuffer key1 = ByteBuffer.wrap( + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "22222222")); + FlinkKeyGroupStateInternals<String> allState; + { + KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 1)); + allState = new FlinkKeyGroupStateInternals<>( + StringUtf8Coder.of(), keyedStateBackend); + BagState<String> valueForNamespace0 = allState.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> valueForNamespace1 = allState.state(NAMESPACE_2, STRING_BAG_ADDR); + keyedStateBackend.setCurrentKey(key0); + valueForNamespace0.add("0"); + valueForNamespace1.add("2"); + keyedStateBackend.setCurrentKey(key1); + valueForNamespace0.add("1"); + valueForNamespace1.add("3"); + assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0", "1")); + assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2", "3")); + } + + ClassLoader classLoader = FlinkKeyGroupStateInternalsTest.class.getClassLoader(); + + // 1. scale up + ByteArrayOutputStream out0 = new ByteArrayOutputStream(); + allState.snapshotKeyGroupState(0, new DataOutputStream(out0)); + DataInputStream in0 = new DataInputStream( + new ByteArrayInputStream(out0.toByteArray())); + { + KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 0)); + FlinkKeyGroupStateInternals<String> state0 = + new FlinkKeyGroupStateInternals<>( + StringUtf8Coder.of(), keyedStateBackend); + state0.restoreKeyGroupState(0, in0, classLoader); + BagState<String> valueForNamespace0 = state0.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> valueForNamespace1 = state0.state(NAMESPACE_2, STRING_BAG_ADDR); + assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0")); + assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2")); + } + + ByteArrayOutputStream out1 = new ByteArrayOutputStream(); + allState.snapshotKeyGroupState(1, new DataOutputStream(out1)); + DataInputStream in1 = new DataInputStream( + new ByteArrayInputStream(out1.toByteArray())); + { + KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(1, 1)); + FlinkKeyGroupStateInternals<String> state1 = + new FlinkKeyGroupStateInternals<>( + StringUtf8Coder.of(), keyedStateBackend); + state1.restoreKeyGroupState(1, in1, classLoader); + BagState<String> valueForNamespace0 = state1.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> valueForNamespace1 = state1.state(NAMESPACE_2, STRING_BAG_ADDR); + assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("1")); + assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("3")); + } + + // 2. scale down + { + KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 1)); + FlinkKeyGroupStateInternals<String> newAllState = new FlinkKeyGroupStateInternals<>( + StringUtf8Coder.of(), keyedStateBackend); + in0.reset(); + in1.reset(); + newAllState.restoreKeyGroupState(0, in0, classLoader); + newAllState.restoreKeyGroupState(1, in1, classLoader); + BagState<String> valueForNamespace0 = newAllState.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> valueForNamespace1 = newAllState.state(NAMESPACE_2, STRING_BAG_ADDR); + assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0", "1")); + assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2", "3")); + } + + } + + private static class TestKeyContext implements KeyContext { + + private Object key; + + @Override + public void setCurrentKey(Object key) { + this.key = key; + } + + @Override + public Object getCurrentKey() { + return key; + } + } + +} http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java new file mode 100644 index 0000000..08ae0c4 --- /dev/null +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.streaming; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; + +import org.apache.beam.runners.core.StateNamespace; +import org.apache.beam.runners.core.StateNamespaceForTest; +import org.apache.beam.runners.core.StateTag; +import org.apache.beam.runners.core.StateTags; +import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkSplitStateInternals; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.util.state.BagState; +import org.apache.beam.sdk.util.state.ReadableState; +import org.apache.flink.runtime.operators.testutils.DummyEnvironment; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.hamcrest.Matchers; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link FlinkSplitStateInternals}. This is based on the tests for + * {@code InMemoryStateInternals}. + */ +@RunWith(JUnit4.class) +public class FlinkSplitStateInternalsTest { + private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); + private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2"); + + private static final StateTag<Object, BagState<String>> STRING_BAG_ADDR = + StateTags.bag("stringBag", StringUtf8Coder.of()); + + FlinkSplitStateInternals<String> underTest; + + @Before + public void initStateInternals() { + MemoryStateBackend backend = new MemoryStateBackend(); + try { + OperatorStateBackend operatorStateBackend = + backend.createOperatorStateBackend(new DummyEnvironment("test", 1, 0), ""); + underTest = new FlinkSplitStateInternals<>(operatorStateBackend); + + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Test + public void testBag() throws Exception { + BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + + assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR))); + + assertThat(value.read(), Matchers.emptyIterable()); + value.add("hello"); + assertThat(value.read(), Matchers.containsInAnyOrder("hello")); + + value.add("world"); + assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world")); + + value.clear(); + assertThat(value.read(), Matchers.emptyIterable()); + assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value); + + } + + @Test + public void testBagIsEmpty() throws Exception { + BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState<Boolean> readFuture = value.isEmpty(); + value.add("hello"); + assertThat(readFuture.read(), Matchers.is(false)); + + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } + +} http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java index 465dad3..7839cf3 100644 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java @@ -29,8 +29,7 @@ import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateNamespaceForTest; import org.apache.beam.runners.core.StateTag; import org.apache.beam.runners.core.StateTags; -import org.apache.beam.runners.flink.translation.wrappers.streaming.FlinkStateInternals; -import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.transforms.Sum; @@ -45,8 +44,13 @@ import org.apache.beam.sdk.util.state.ReadableState; import org.apache.beam.sdk.util.state.ValueState; import org.apache.beam.sdk.util.state.WatermarkHoldState; import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.JobID; import org.apache.flink.api.java.typeutils.GenericTypeInfo; +import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.operators.testutils.DummyEnvironment; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.hamcrest.Matchers; import org.joda.time.Instant; @@ -88,18 +92,19 @@ public class FlinkStateInternalsTest { public void initStateInternals() { MemoryStateBackend backend = new MemoryStateBackend(); try { - backend.initializeForJob( + AbstractKeyedStateBackend<ByteBuffer> keyedStateBackend = backend.createKeyedStateBackend( new DummyEnvironment("test", 1, 0), + new JobID(), "test_op", - new GenericTypeInfo<>(ByteBuffer.class).createSerializer(new ExecutionConfig())); - } catch (Exception e) { - throw new RuntimeException(e); - } - underTest = new FlinkStateInternals<>(backend, StringUtf8Coder.of()); - try { - backend.setCurrentKey( + new GenericTypeInfo<>(ByteBuffer.class).createSerializer(new ExecutionConfig()), + 1, + new KeyGroupRange(0, 0), + new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID())); + underTest = new FlinkStateInternals<>(keyedStateBackend, StringUtf8Coder.of()); + + keyedStateBackend.setCurrentKey( ByteBuffer.wrap(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "Hello"))); - } catch (CoderException e) { + } catch (Exception e) { throw new RuntimeException(e); } } http://git-wip-us.apache.org/repos/asf/beam/blob/7d32b93e/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java index b0be98b..5b3d088 100644 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java @@ -44,8 +44,10 @@ import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.Output; import org.apache.flink.streaming.api.operators.StreamSource; import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.streaming.runtime.tasks.TestProcessingTimeService; import org.apache.flink.util.InstantiationUtil; import org.junit.Test; import org.junit.experimental.runners.Enclosed; @@ -123,6 +125,10 @@ public class UnboundedSourceWrapperTest { } @Override + public void emitLatencyMarker(LatencyMarker latencyMarker) { + } + + @Override public void collect( StreamRecord<WindowedValue<KV<Integer, Integer>>> windowedValueStreamRecord) { @@ -191,6 +197,10 @@ public class UnboundedSourceWrapperTest { } @Override + public void emitLatencyMarker(LatencyMarker latencyMarker) { + } + + @Override public void collect( StreamRecord<WindowedValue<KV<Integer, Integer>>> windowedValueStreamRecord) { @@ -256,6 +266,10 @@ public class UnboundedSourceWrapperTest { } @Override + public void emitLatencyMarker(LatencyMarker latencyMarker) { + } + + @Override public void collect( StreamRecord<WindowedValue<KV<Integer, Integer>>> windowedValueStreamRecord) { emittedElements.add(windowedValueStreamRecord.getValue().getValue()); @@ -300,6 +314,8 @@ public class UnboundedSourceWrapperTest { when(mockTask.getExecutionConfig()).thenReturn(executionConfig); when(mockTask.getAccumulatorMap()) .thenReturn(Collections.<String, Accumulator<?, ?>>emptyMap()); + TestProcessingTimeService testProcessingTimeService = new TestProcessingTimeService(); + when(mockTask.getProcessingTimeService()).thenReturn(testProcessingTimeService); operator.setup(mockTask, cfg, (Output<StreamRecord<T>>) mock(Output.class)); }