http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java index a0b015b..f0d3278 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -17,7 +17,6 @@ */ package org.apache.beam.runners.flink.translation.wrappers.streaming.state; -import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import java.nio.ByteBuffer; import java.util.Collections; @@ -26,6 +25,7 @@ import java.util.Map; import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateTag; +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; @@ -34,7 +34,6 @@ import org.apache.beam.sdk.state.BagState; import org.apache.beam.sdk.state.CombiningState; import org.apache.beam.sdk.state.MapState; import org.apache.beam.sdk.state.ReadableState; -import org.apache.beam.sdk.state.ReadableStates; import org.apache.beam.sdk.state.SetState; import org.apache.beam.sdk.state.State; import org.apache.beam.sdk.state.StateContext; @@ -50,7 +49,6 @@ import org.apache.beam.sdk.util.CombineContextFactory; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.ValueStateDescriptor; -import org.apache.flink.api.common.typeutils.base.BooleanSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.runtime.state.KeyedStateBackend; import org.joda.time.Instant; @@ -130,8 +128,8 @@ public class FlinkStateInternals<K> implements StateInternals { @Override public <T> SetState<T> bindSet( StateTag<SetState<T>> address, Coder<T> elemCoder) { - return new FlinkSetState<>( - flinkStateBackend, address, namespace, elemCoder); + throw new UnsupportedOperationException( + String.format("%s is not supported", SetState.class.getSimpleName())); } @Override @@ -198,8 +196,9 @@ public class FlinkStateInternals<K> implements StateInternals { this.address = address; this.flinkStateBackend = flinkStateBackend; - flinkStateDescriptor = new ValueStateDescriptor<>( - address.getId(), new CoderTypeSerializer<>(coder)); + CoderTypeInformation<T> typeInfo = new CoderTypeInformation<>(coder); + + flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); } @Override @@ -283,8 +282,9 @@ public class FlinkStateInternals<K> implements StateInternals { this.address = address; this.flinkStateBackend = flinkStateBackend; - flinkStateDescriptor = new ListStateDescriptor<>( - address.getId(), new CoderTypeSerializer<>(coder)); + CoderTypeInformation<T> typeInfo = new CoderTypeInformation<>(coder); + + flinkStateDescriptor = new ListStateDescriptor<>(address.getId(), typeInfo); } @Override @@ -398,8 +398,9 @@ public class FlinkStateInternals<K> implements StateInternals { this.combineFn = combineFn; this.flinkStateBackend = flinkStateBackend; - flinkStateDescriptor = new ValueStateDescriptor<>( - address.getId(), new CoderTypeSerializer<>(accumCoder)); + CoderTypeInformation<AccumT> typeInfo = new CoderTypeInformation<>(accumCoder); + + flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); } @Override @@ -544,6 +545,179 @@ public class FlinkStateInternals<K> implements StateInternals { } } + private static class FlinkKeyedCombiningState<K, InputT, AccumT, OutputT> + implements CombiningState<InputT, AccumT, OutputT> { + + private final StateNamespace namespace; + private final StateTag<CombiningState<InputT, AccumT, OutputT>> address; + private final Combine.CombineFn<InputT, AccumT, OutputT> combineFn; + private final ValueStateDescriptor<AccumT> flinkStateDescriptor; + private final KeyedStateBackend<ByteBuffer> flinkStateBackend; + private final FlinkStateInternals<K> flinkStateInternals; + + FlinkKeyedCombiningState( + KeyedStateBackend<ByteBuffer> flinkStateBackend, + StateTag<CombiningState<InputT, AccumT, OutputT>> address, + Combine.CombineFn<InputT, AccumT, OutputT> combineFn, + StateNamespace namespace, + Coder<AccumT> accumCoder, + FlinkStateInternals<K> flinkStateInternals) { + + this.namespace = namespace; + this.address = address; + this.combineFn = combineFn; + this.flinkStateBackend = flinkStateBackend; + this.flinkStateInternals = flinkStateInternals; + + CoderTypeInformation<AccumT> typeInfo = new CoderTypeInformation<>(accumCoder); + + flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); + } + + @Override + public CombiningState<InputT, AccumT, OutputT> readLater() { + return this; + } + + @Override + public void add(InputT value) { + try { + org.apache.flink.api.common.state.ValueState<AccumT> state = + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor); + + AccumT current = state.value(); + if (current == null) { + current = combineFn.createAccumulator(); + } + current = combineFn.addInput(current, value); + state.update(current); + } catch (RuntimeException re) { + throw re; + } catch (Exception e) { + throw new RuntimeException("Error adding to state." , e); + } + } + + @Override + public void addAccum(AccumT accum) { + try { + org.apache.flink.api.common.state.ValueState<AccumT> state = + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor); + + AccumT current = state.value(); + if (current == null) { + state.update(accum); + } else { + current = combineFn.mergeAccumulators(Lists.newArrayList(current, accum)); + state.update(current); + } + } catch (Exception e) { + throw new RuntimeException("Error adding to state.", e); + } + } + + @Override + public AccumT getAccum() { + try { + return flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).value(); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public AccumT mergeAccumulators(Iterable<AccumT> accumulators) { + return combineFn.mergeAccumulators(accumulators); + } + + @Override + public OutputT read() { + try { + org.apache.flink.api.common.state.ValueState<AccumT> state = + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor); + + AccumT accum = state.value(); + if (accum != null) { + return combineFn.extractOutput(accum); + } else { + return combineFn.extractOutput(combineFn.createAccumulator()); + } + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public Boolean read() { + try { + return flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).value() == null; + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + + } + + @Override + public ReadableState<Boolean> readLater() { + return this; + } + }; + } + + @Override + public void clear() { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).clear(); + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkKeyedCombiningState<?, ?, ?, ?> that = + (FlinkKeyedCombiningState<?, ?, ?, ?>) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + } + private static class FlinkCombiningStateWithContext<K, InputT, AccumT, OutputT> implements CombiningState<InputT, AccumT, OutputT> { @@ -571,8 +745,9 @@ public class FlinkStateInternals<K> implements StateInternals { this.flinkStateInternals = flinkStateInternals; this.context = context; - flinkStateDescriptor = new ValueStateDescriptor<>( - address.getId(), new CoderTypeSerializer<>(accumCoder)); + CoderTypeInformation<AccumT> typeInfo = new CoderTypeInformation<>(accumCoder); + + flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); } @Override @@ -738,8 +913,8 @@ public class FlinkStateInternals<K> implements StateInternals { this.flinkStateBackend = flinkStateBackend; this.flinkStateInternals = flinkStateInternals; - flinkStateDescriptor = new ValueStateDescriptor<>( - address.getId(), new CoderTypeSerializer<>(InstantCoder.of())); + CoderTypeInformation<Instant> typeInfo = new CoderTypeInformation<>(InstantCoder.of()); + flinkStateDescriptor = new ValueStateDescriptor<>(address.getId(), typeInfo, null); } @Override @@ -878,15 +1053,24 @@ public class FlinkStateInternals<K> implements StateInternals { @Override public ReadableState<ValueT> get(final KeyT input) { - try { - return ReadableStates.immediate( - flinkStateBackend.getPartitionedState( + return new ReadableState<ValueT>() { + @Override + public ValueT read() { + try { + return flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, - flinkStateDescriptor).get(input)); - } catch (Exception e) { - throw new RuntimeException("Error get from state.", e); - } + flinkStateDescriptor).get(input); + } catch (Exception e) { + throw new RuntimeException("Error get from state.", e); + } + } + + @Override + public ReadableState<ValueT> readLater() { + return this; + } + }; } @Override @@ -903,22 +1087,32 @@ public class FlinkStateInternals<K> implements StateInternals { @Override public ReadableState<ValueT> putIfAbsent(final KeyT key, final ValueT value) { - try { - ValueT current = flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).get(key); + return new ReadableState<ValueT>() { + @Override + public ValueT read() { + try { + ValueT current = flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).get(key); + + if (current == null) { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).put(key, value); + } + return current; + } catch (Exception e) { + throw new RuntimeException("Error put kv to state.", e); + } + } - if (current == null) { - flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).put(key, value); + @Override + public ReadableState<ValueT> readLater() { + return this; } - return ReadableStates.immediate(current); - } catch (Exception e) { - throw new RuntimeException("Error put kv to state.", e); - } + }; } @Override @@ -939,11 +1133,10 @@ public class FlinkStateInternals<K> implements StateInternals { @Override public Iterable<KeyT> read() { try { - Iterable<KeyT> result = flinkStateBackend.getPartitionedState( + return flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor).keys(); - return result != null ? result : Collections.<KeyT>emptyList(); } catch (Exception e) { throw new RuntimeException("Error get map state keys.", e); } @@ -962,11 +1155,10 @@ public class FlinkStateInternals<K> implements StateInternals { @Override public Iterable<ValueT> read() { try { - Iterable<ValueT> result = flinkStateBackend.getPartitionedState( + return flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor).values(); - return result != null ? result : Collections.<ValueT>emptyList(); } catch (Exception e) { throw new RuntimeException("Error get map state values.", e); } @@ -985,11 +1177,10 @@ public class FlinkStateInternals<K> implements StateInternals { @Override public Iterable<Map.Entry<KeyT, ValueT>> read() { try { - Iterable<Map.Entry<KeyT, ValueT>> result = flinkStateBackend.getPartitionedState( + return flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor).entries(); - return result != null ? result : Collections.<Map.Entry<KeyT, ValueT>>emptyList(); } catch (Exception e) { throw new RuntimeException("Error get map state entries.", e); } @@ -1037,154 +1228,4 @@ public class FlinkStateInternals<K> implements StateInternals { } } - private static class FlinkSetState<T> implements SetState<T> { - - private final StateNamespace namespace; - private final StateTag<SetState<T>> address; - private final MapStateDescriptor<T, Boolean> flinkStateDescriptor; - private final KeyedStateBackend<ByteBuffer> flinkStateBackend; - - FlinkSetState( - KeyedStateBackend<ByteBuffer> flinkStateBackend, - StateTag<SetState<T>> address, - StateNamespace namespace, - Coder<T> coder) { - this.namespace = namespace; - this.address = address; - this.flinkStateBackend = flinkStateBackend; - this.flinkStateDescriptor = new MapStateDescriptor<>(address.getId(), - new CoderTypeSerializer<>(coder), new BooleanSerializer()); - } - - @Override - public ReadableState<Boolean> contains(final T t) { - try { - Boolean result = flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).get(t); - return ReadableStates.immediate(result != null ? result : false); - } catch (Exception e) { - throw new RuntimeException("Error contains value from state.", e); - } - } - - @Override - public ReadableState<Boolean> addIfAbsent(final T t) { - try { - org.apache.flink.api.common.state.MapState<T, Boolean> state = - flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor); - boolean alreadyContained = state.contains(t); - if (!alreadyContained) { - state.put(t, true); - } - return ReadableStates.immediate(!alreadyContained); - } catch (Exception e) { - throw new RuntimeException("Error addIfAbsent value to state.", e); - } - } - - @Override - public void remove(T t) { - try { - flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).remove(t); - } catch (Exception e) { - throw new RuntimeException("Error remove value to state.", e); - } - } - - @Override - public SetState<T> readLater() { - return this; - } - - @Override - public void add(T value) { - try { - flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).put(value, true); - } catch (Exception e) { - throw new RuntimeException("Error add value to state.", e); - } - } - - @Override - public ReadableState<Boolean> isEmpty() { - return new ReadableState<Boolean>() { - @Override - public Boolean read() { - try { - Iterable<T> result = flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).keys(); - return result == null || Iterables.isEmpty(result); - } catch (Exception e) { - throw new RuntimeException("Error isEmpty from state.", e); - } - } - - @Override - public ReadableState<Boolean> readLater() { - return this; - } - }; - } - - @Override - public Iterable<T> read() { - try { - Iterable<T> result = flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).keys(); - return result != null ? result : Collections.<T>emptyList(); - } catch (Exception e) { - throw new RuntimeException("Error read from state.", e); - } - } - - @Override - public void clear() { - try { - flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).clear(); - } catch (Exception e) { - throw new RuntimeException("Error clearing state.", e); - } - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - FlinkSetState<?> that = (FlinkSetState<?>) o; - - return namespace.equals(that.namespace) && address.equals(that.address); - - } - - @Override - public int hashCode() { - int result = namespace.hashCode(); - result = 31 * result + address.hashCode(); - return result; - } - } - }
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java index 3409d27..2b96d91 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java @@ -17,87 +17,229 @@ */ package org.apache.beam.runners.flink.streaming; -import org.apache.beam.runners.core.StateInternals; -import org.apache.beam.runners.core.StateInternalsTest; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThat; + +import java.util.Arrays; +import org.apache.beam.runners.core.StateMerging; +import org.apache.beam.runners.core.StateNamespace; +import org.apache.beam.runners.core.StateNamespaceForTest; +import org.apache.beam.runners.core.StateTag; +import org.apache.beam.runners.core.StateTags; import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkBroadcastStateInternals; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.CombiningState; +import org.apache.beam.sdk.state.GroupingState; +import org.apache.beam.sdk.state.ReadableState; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.transforms.Sum; import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.state.OperatorStateBackend; import org.apache.flink.runtime.state.memory.MemoryStateBackend; -import org.junit.Ignore; +import org.hamcrest.Matchers; +import org.junit.Before; +import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** * Tests for {@link FlinkBroadcastStateInternals}. This is based on the tests for - * {@code StateInternalsTest}. - * - * <p>Just test value, bag and combining. + * {@code InMemoryStateInternals}. */ @RunWith(JUnit4.class) -public class FlinkBroadcastStateInternalsTest extends StateInternalsTest { - - @Override - protected StateInternals createStateInternals() { +public class FlinkBroadcastStateInternalsTest { + private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); + private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2"); + private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3"); + + private static final StateTag<ValueState<String>> STRING_VALUE_ADDR = + StateTags.value("stringValue", StringUtf8Coder.of()); + private static final StateTag<CombiningState<Integer, int[], Integer>> + SUM_INTEGER_ADDR = StateTags.combiningValueFromInputInternal( + "sumInteger", VarIntCoder.of(), Sum.ofIntegers()); + private static final StateTag<BagState<String>> STRING_BAG_ADDR = + StateTags.bag("stringBag", StringUtf8Coder.of()); + + FlinkBroadcastStateInternals<String> underTest; + + @Before + public void initStateInternals() { MemoryStateBackend backend = new MemoryStateBackend(); try { OperatorStateBackend operatorStateBackend = backend.createOperatorStateBackend(new DummyEnvironment("test", 1, 0), ""); - return new FlinkBroadcastStateInternals<>(1, operatorStateBackend); + underTest = new FlinkBroadcastStateInternals<>(1, operatorStateBackend); + } catch (Exception e) { throw new RuntimeException(e); } } - @Override - @Ignore - public void testSet() {} + @Test + public void testValue() throws Exception { + ValueState<String> value = underTest.state(NAMESPACE_1, STRING_VALUE_ADDR); + + assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value); + assertNotEquals( + underTest.state(NAMESPACE_2, STRING_VALUE_ADDR), + value); + + assertThat(value.read(), Matchers.nullValue()); + value.write("hello"); + assertThat(value.read(), Matchers.equalTo("hello")); + value.write("world"); + assertThat(value.read(), Matchers.equalTo("world")); + + value.clear(); + assertThat(value.read(), Matchers.nullValue()); + assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value); + + } + + @Test + public void testBag() throws Exception { + BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - @Override - @Ignore - public void testSetIsEmpty() {} + assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR))); - @Override - @Ignore - public void testMergeSetIntoSource() {} + assertThat(value.read(), Matchers.emptyIterable()); + value.add("hello"); + assertThat(value.read(), Matchers.containsInAnyOrder("hello")); - @Override - @Ignore - public void testMergeSetIntoNewNamespace() {} + value.add("world"); + assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world")); - @Override - @Ignore - public void testMap() {} + value.clear(); + assertThat(value.read(), Matchers.emptyIterable()); + assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value); - @Override - @Ignore - public void testWatermarkEarliestState() {} + } + + @Test + public void testBagIsEmpty() throws Exception { + BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - @Override - @Ignore - public void testWatermarkLatestState() {} + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState<Boolean> readFuture = value.isEmpty(); + value.add("hello"); + assertThat(readFuture.read(), Matchers.is(false)); - @Override - @Ignore - public void testWatermarkEndOfWindowState() {} + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } - @Override - @Ignore - public void testWatermarkStateIsEmpty() {} + @Test + public void testMergeBagIntoSource() throws Exception { + BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); - @Override - @Ignore - public void testMergeEarliestWatermarkIntoSource() {} + bag1.add("Hello"); + bag2.add("World"); + bag1.add("!"); - @Override - @Ignore - public void testMergeLatestWatermarkIntoSource() {} + StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1); + + // Reading the merged bag gets both the contents + assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); + assertThat(bag2.read(), Matchers.emptyIterable()); + } - @Override - @Ignore - public void testSetReadable() {} + @Test + public void testMergeBagIntoNewNamespace() throws Exception { + BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); + BagState<String> bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR); + + bag1.add("Hello"); + bag2.add("World"); + bag1.add("!"); + + StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3); + + // Reading the merged bag gets both the contents + assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); + assertThat(bag1.read(), Matchers.emptyIterable()); + assertThat(bag2.read(), Matchers.emptyIterable()); + } - @Override - @Ignore - public void testMapReadable() {} + @Test + public void testCombiningValue() throws Exception { + GroupingState<Integer, Integer> value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + + // State instances are cached, but depend on the namespace. + assertEquals(value, underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR))); + + assertThat(value.read(), Matchers.equalTo(0)); + value.add(2); + assertThat(value.read(), Matchers.equalTo(2)); + + value.add(3); + assertThat(value.read(), Matchers.equalTo(5)); + + value.clear(); + assertThat(value.read(), Matchers.equalTo(0)); + assertEquals(underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR), value); + } + + @Test + public void testCombiningIsEmpty() throws Exception { + GroupingState<Integer, Integer> value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState<Boolean> readFuture = value.isEmpty(); + value.add(5); + assertThat(readFuture.read(), Matchers.is(false)); + + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } + + @Test + public void testMergeCombiningValueIntoSource() throws Exception { + CombiningState<Integer, int[], Integer> value1 = + underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + CombiningState<Integer, int[], Integer> value2 = + underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); + + value1.add(5); + value2.add(10); + value1.add(6); + + assertThat(value1.read(), Matchers.equalTo(11)); + assertThat(value2.read(), Matchers.equalTo(10)); + + // Merging clears the old values and updates the result value. + StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1); + + assertThat(value1.read(), Matchers.equalTo(21)); + assertThat(value2.read(), Matchers.equalTo(0)); + } + + @Test + public void testMergeCombiningValueIntoNewNamespace() throws Exception { + CombiningState<Integer, int[], Integer> value1 = + underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + CombiningState<Integer, int[], Integer> value2 = + underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); + CombiningState<Integer, int[], Integer> value3 = + underTest.state(NAMESPACE_3, SUM_INTEGER_ADDR); + + value1.add(5); + value2.add(10); + value1.add(6); + + StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3); + + // Merging clears the old values and updates the result value. + assertThat(value1.read(), Matchers.equalTo(0)); + assertThat(value2.read(), Matchers.equalTo(0)); + assertThat(value3.read(), Matchers.equalTo(21)); + } } http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java index aed14f3..4012373 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java @@ -17,6 +17,8 @@ */ package org.apache.beam.runners.flink.streaming; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import java.io.ByteArrayInputStream; @@ -24,8 +26,8 @@ import java.io.ByteArrayOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.nio.ByteBuffer; -import org.apache.beam.runners.core.StateInternals; -import org.apache.beam.runners.core.StateInternalsTest; +import java.util.Arrays; +import org.apache.beam.runners.core.StateMerging; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateNamespaceForTest; import org.apache.beam.runners.core.StateTag; @@ -33,6 +35,7 @@ import org.apache.beam.runners.core.StateTags; import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkKeyGroupStateInternals; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.ReadableState; import org.apache.beam.sdk.util.CoderUtils; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.JobID; @@ -44,219 +47,215 @@ import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.apache.flink.streaming.api.operators.KeyContext; import org.hamcrest.Matchers; -import org.junit.Ignore; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.junit.runners.Suite; /** * Tests for {@link FlinkKeyGroupStateInternals}. This is based on the tests for - * {@code StateInternalsTest}. + * {@code InMemoryStateInternals}. */ -@RunWith(Suite.class) [email protected]({ - FlinkKeyGroupStateInternalsTest.StandardStateInternalsTests.class, - FlinkKeyGroupStateInternalsTest.OtherTests.class -}) +@RunWith(JUnit4.class) public class FlinkKeyGroupStateInternalsTest { + private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); + private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2"); + private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3"); - /** - * A standard StateInternals test. Just test BagState. - */ - @RunWith(JUnit4.class) - public static class StandardStateInternalsTests extends StateInternalsTest { - @Override - protected StateInternals createStateInternals() { - KeyedStateBackend keyedStateBackend = - getKeyedStateBackend(2, new KeyGroupRange(0, 1)); - return new FlinkKeyGroupStateInternals<>(StringUtf8Coder.of(), keyedStateBackend); + private static final StateTag<BagState<String>> STRING_BAG_ADDR = + StateTags.bag("stringBag", StringUtf8Coder.of()); + + FlinkKeyGroupStateInternals<String> underTest; + private KeyedStateBackend keyedStateBackend; + + @Before + public void initStateInternals() { + try { + keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 1)); + underTest = new FlinkKeyGroupStateInternals<>(StringUtf8Coder.of(), keyedStateBackend); + } catch (Exception e) { + throw new RuntimeException(e); } + } - @Override - @Ignore - public void testValue() {} + private KeyedStateBackend getKeyedStateBackend(int numberOfKeyGroups, + KeyGroupRange keyGroupRange) { + MemoryStateBackend backend = new MemoryStateBackend(); + try { + AbstractKeyedStateBackend<ByteBuffer> keyedStateBackend = backend.createKeyedStateBackend( + new DummyEnvironment("test", 1, 0), + new JobID(), + "test_op", + new GenericTypeInfo<>(ByteBuffer.class).createSerializer(new ExecutionConfig()), + numberOfKeyGroups, + keyGroupRange, + new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID())); + keyedStateBackend.setCurrentKey(ByteBuffer.wrap( + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "1"))); + return keyedStateBackend; + } catch (Exception e) { + throw new RuntimeException(e); + } + } - @Override - @Ignore - public void testSet() {} + @Test + public void testBag() throws Exception { + BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - @Override - @Ignore - public void testSetIsEmpty() {} + assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR))); - @Override - @Ignore - public void testMergeSetIntoSource() {} + assertThat(value.read(), Matchers.emptyIterable()); + value.add("hello"); + assertThat(value.read(), Matchers.containsInAnyOrder("hello")); - @Override - @Ignore - public void testMergeSetIntoNewNamespace() {} + value.add("world"); + assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world")); - @Override - @Ignore - public void testMap() {} + value.clear(); + assertThat(value.read(), Matchers.emptyIterable()); + assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value); - @Override - @Ignore - public void testCombiningValue() {} + } - @Override - @Ignore - public void testCombiningIsEmpty() {} + @Test + public void testBagIsEmpty() throws Exception { + BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - @Override - @Ignore - public void testMergeCombiningValueIntoSource() {} + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState<Boolean> readFuture = value.isEmpty(); + value.add("hello"); + assertThat(readFuture.read(), Matchers.is(false)); - @Override - @Ignore - public void testMergeCombiningValueIntoNewNamespace() {} + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } - @Override - @Ignore - public void testWatermarkEarliestState() {} + @Test + public void testMergeBagIntoSource() throws Exception { + BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); - @Override - @Ignore - public void testWatermarkLatestState() {} + bag1.add("Hello"); + bag2.add("World"); + bag1.add("!"); - @Override - @Ignore - public void testWatermarkEndOfWindowState() {} + StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1); - @Override - @Ignore - public void testWatermarkStateIsEmpty() {} + // Reading the merged bag gets both the contents + assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); + assertThat(bag2.read(), Matchers.emptyIterable()); + } - @Override - @Ignore - public void testMergeEarliestWatermarkIntoSource() {} + @Test + public void testMergeBagIntoNewNamespace() throws Exception { + BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); + BagState<String> bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR); - @Override - @Ignore - public void testMergeLatestWatermarkIntoSource() {} + bag1.add("Hello"); + bag2.add("World"); + bag1.add("!"); - @Override - @Ignore - public void testSetReadable() {} + StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3); - @Override - @Ignore - public void testMapReadable() {} + // Reading the merged bag gets both the contents + assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); + assertThat(bag1.read(), Matchers.emptyIterable()); + assertThat(bag2.read(), Matchers.emptyIterable()); } - /** - * A specific test of FlinkKeyGroupStateInternalsTest. - */ - @RunWith(JUnit4.class) - public static class OtherTests { - - private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); - private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2"); - private static final StateTag<BagState<String>> STRING_BAG_ADDR = - StateTags.bag("stringBag", StringUtf8Coder.of()); - - @Test - public void testKeyGroupAndCheckpoint() throws Exception { - // assign to keyGroup 0 - ByteBuffer key0 = ByteBuffer.wrap( - CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "11111111")); - // assign to keyGroup 1 - ByteBuffer key1 = ByteBuffer.wrap( - CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "22222222")); - FlinkKeyGroupStateInternals<String> allState; - { - KeyedStateBackend<ByteBuffer> keyedStateBackend = - getKeyedStateBackend(2, new KeyGroupRange(0, 1)); - allState = new FlinkKeyGroupStateInternals<>( - StringUtf8Coder.of(), keyedStateBackend); - BagState<String> valueForNamespace0 = allState.state(NAMESPACE_1, STRING_BAG_ADDR); - BagState<String> valueForNamespace1 = allState.state(NAMESPACE_2, STRING_BAG_ADDR); - keyedStateBackend.setCurrentKey(key0); - valueForNamespace0.add("0"); - valueForNamespace1.add("2"); - keyedStateBackend.setCurrentKey(key1); - valueForNamespace0.add("1"); - valueForNamespace1.add("3"); - assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0", "1")); - assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2", "3")); - } - - ClassLoader classLoader = FlinkKeyGroupStateInternalsTest.class.getClassLoader(); - - // 1. scale up - ByteArrayOutputStream out0 = new ByteArrayOutputStream(); - allState.snapshotKeyGroupState(0, new DataOutputStream(out0)); - DataInputStream in0 = new DataInputStream( - new ByteArrayInputStream(out0.toByteArray())); - { - KeyedStateBackend<ByteBuffer> keyedStateBackend = - getKeyedStateBackend(2, new KeyGroupRange(0, 0)); - FlinkKeyGroupStateInternals<String> state0 = - new FlinkKeyGroupStateInternals<>( - StringUtf8Coder.of(), keyedStateBackend); - state0.restoreKeyGroupState(0, in0, classLoader); - BagState<String> valueForNamespace0 = state0.state(NAMESPACE_1, STRING_BAG_ADDR); - BagState<String> valueForNamespace1 = state0.state(NAMESPACE_2, STRING_BAG_ADDR); - assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0")); - assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2")); - } - - ByteArrayOutputStream out1 = new ByteArrayOutputStream(); - allState.snapshotKeyGroupState(1, new DataOutputStream(out1)); - DataInputStream in1 = new DataInputStream( - new ByteArrayInputStream(out1.toByteArray())); - { - KeyedStateBackend<ByteBuffer> keyedStateBackend = - getKeyedStateBackend(2, new KeyGroupRange(1, 1)); - FlinkKeyGroupStateInternals<String> state1 = - new FlinkKeyGroupStateInternals<>( - StringUtf8Coder.of(), keyedStateBackend); - state1.restoreKeyGroupState(1, in1, classLoader); - BagState<String> valueForNamespace0 = state1.state(NAMESPACE_1, STRING_BAG_ADDR); - BagState<String> valueForNamespace1 = state1.state(NAMESPACE_2, STRING_BAG_ADDR); - assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("1")); - assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("3")); - } - - // 2. scale down - { - KeyedStateBackend<ByteBuffer> keyedStateBackend = - getKeyedStateBackend(2, new KeyGroupRange(0, 1)); - FlinkKeyGroupStateInternals<String> newAllState = new FlinkKeyGroupStateInternals<>( - StringUtf8Coder.of(), keyedStateBackend); - in0.reset(); - in1.reset(); - newAllState.restoreKeyGroupState(0, in0, classLoader); - newAllState.restoreKeyGroupState(1, in1, classLoader); - BagState<String> valueForNamespace0 = newAllState.state(NAMESPACE_1, STRING_BAG_ADDR); - BagState<String> valueForNamespace1 = newAllState.state(NAMESPACE_2, STRING_BAG_ADDR); - assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0", "1")); - assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2", "3")); - } + @Test + public void testKeyGroupAndCheckpoint() throws Exception { + // assign to keyGroup 0 + ByteBuffer key0 = ByteBuffer.wrap( + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "11111111")); + // assign to keyGroup 1 + ByteBuffer key1 = ByteBuffer.wrap( + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "22222222")); + FlinkKeyGroupStateInternals<String> allState; + { + KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 1)); + allState = new FlinkKeyGroupStateInternals<>( + StringUtf8Coder.of(), keyedStateBackend); + BagState<String> valueForNamespace0 = allState.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> valueForNamespace1 = allState.state(NAMESPACE_2, STRING_BAG_ADDR); + keyedStateBackend.setCurrentKey(key0); + valueForNamespace0.add("0"); + valueForNamespace1.add("2"); + keyedStateBackend.setCurrentKey(key1); + valueForNamespace0.add("1"); + valueForNamespace1.add("3"); + assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0", "1")); + assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2", "3")); + } + + ClassLoader classLoader = FlinkKeyGroupStateInternalsTest.class.getClassLoader(); + + // 1. scale up + ByteArrayOutputStream out0 = new ByteArrayOutputStream(); + allState.snapshotKeyGroupState(0, new DataOutputStream(out0)); + DataInputStream in0 = new DataInputStream( + new ByteArrayInputStream(out0.toByteArray())); + { + KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 0)); + FlinkKeyGroupStateInternals<String> state0 = + new FlinkKeyGroupStateInternals<>( + StringUtf8Coder.of(), keyedStateBackend); + state0.restoreKeyGroupState(0, in0, classLoader); + BagState<String> valueForNamespace0 = state0.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> valueForNamespace1 = state0.state(NAMESPACE_2, STRING_BAG_ADDR); + assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0")); + assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2")); + } + + ByteArrayOutputStream out1 = new ByteArrayOutputStream(); + allState.snapshotKeyGroupState(1, new DataOutputStream(out1)); + DataInputStream in1 = new DataInputStream( + new ByteArrayInputStream(out1.toByteArray())); + { + KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(1, 1)); + FlinkKeyGroupStateInternals<String> state1 = + new FlinkKeyGroupStateInternals<>( + StringUtf8Coder.of(), keyedStateBackend); + state1.restoreKeyGroupState(1, in1, classLoader); + BagState<String> valueForNamespace0 = state1.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> valueForNamespace1 = state1.state(NAMESPACE_2, STRING_BAG_ADDR); + assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("1")); + assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("3")); + } + + // 2. scale down + { + KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 1)); + FlinkKeyGroupStateInternals<String> newAllState = new FlinkKeyGroupStateInternals<>( + StringUtf8Coder.of(), keyedStateBackend); + in0.reset(); + in1.reset(); + newAllState.restoreKeyGroupState(0, in0, classLoader); + newAllState.restoreKeyGroupState(1, in1, classLoader); + BagState<String> valueForNamespace0 = newAllState.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> valueForNamespace1 = newAllState.state(NAMESPACE_2, STRING_BAG_ADDR); + assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0", "1")); + assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2", "3")); } } - private static KeyedStateBackend<ByteBuffer> getKeyedStateBackend(int numberOfKeyGroups, - KeyGroupRange keyGroupRange) { - MemoryStateBackend backend = new MemoryStateBackend(); - try { - AbstractKeyedStateBackend<ByteBuffer> keyedStateBackend = backend.createKeyedStateBackend( - new DummyEnvironment("test", 1, 0), - new JobID(), - "test_op", - new GenericTypeInfo<>(ByteBuffer.class).createSerializer(new ExecutionConfig()), - numberOfKeyGroups, - keyGroupRange, - new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID())); - keyedStateBackend.setCurrentKey(ByteBuffer.wrap( - CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "1"))); - return keyedStateBackend; - } catch (Exception e) { - throw new RuntimeException(e); + private static class TestKeyContext implements KeyContext { + + private Object key; + + @Override + public void setCurrentKey(Object key) { + this.key = key; + } + + @Override + public Object getCurrentKey() { + return key; } } http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java index 667b5ba..17cd3f5 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java @@ -17,115 +17,85 @@ */ package org.apache.beam.runners.flink.streaming; -import org.apache.beam.runners.core.StateInternals; -import org.apache.beam.runners.core.StateInternalsTest; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; + +import org.apache.beam.runners.core.StateNamespace; +import org.apache.beam.runners.core.StateNamespaceForTest; +import org.apache.beam.runners.core.StateTag; +import org.apache.beam.runners.core.StateTags; import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkSplitStateInternals; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.ReadableState; import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.state.OperatorStateBackend; import org.apache.flink.runtime.state.memory.MemoryStateBackend; -import org.junit.Ignore; +import org.hamcrest.Matchers; +import org.junit.Before; +import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** * Tests for {@link FlinkSplitStateInternals}. This is based on the tests for - * {@code StateInternalsTest}. - * - * <p>Just test testBag and testBagIsEmpty. + * {@code InMemoryStateInternals}. */ @RunWith(JUnit4.class) -public class FlinkSplitStateInternalsTest extends StateInternalsTest { +public class FlinkSplitStateInternalsTest { + private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); + private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2"); + + private static final StateTag<BagState<String>> STRING_BAG_ADDR = + StateTags.bag("stringBag", StringUtf8Coder.of()); + + FlinkSplitStateInternals<String> underTest; - @Override - protected StateInternals createStateInternals() { + @Before + public void initStateInternals() { MemoryStateBackend backend = new MemoryStateBackend(); try { OperatorStateBackend operatorStateBackend = backend.createOperatorStateBackend(new DummyEnvironment("test", 1, 0), ""); - return new FlinkSplitStateInternals<>(operatorStateBackend); + underTest = new FlinkSplitStateInternals<>(operatorStateBackend); + } catch (Exception e) { throw new RuntimeException(e); } } - @Override - @Ignore - public void testMergeBagIntoSource() {} - - @Override - @Ignore - public void testMergeBagIntoNewNamespace() {} - - @Override - @Ignore - public void testValue() {} - - @Override - @Ignore - public void testSet() {} - - @Override - @Ignore - public void testSetIsEmpty() {} - - @Override - @Ignore - public void testMergeSetIntoSource() {} + @Test + public void testBag() throws Exception { + BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - @Override - @Ignore - public void testMergeSetIntoNewNamespace() {} + assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR))); - @Override - @Ignore - public void testMap() {} + assertThat(value.read(), Matchers.emptyIterable()); + value.add("hello"); + assertThat(value.read(), Matchers.containsInAnyOrder("hello")); - @Override - @Ignore - public void testCombiningValue() {} + value.add("world"); + assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world")); - @Override - @Ignore - public void testCombiningIsEmpty() {} + value.clear(); + assertThat(value.read(), Matchers.emptyIterable()); + assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value); - @Override - @Ignore - public void testMergeCombiningValueIntoSource() {} - - @Override - @Ignore - public void testMergeCombiningValueIntoNewNamespace() {} - - @Override - @Ignore - public void testWatermarkEarliestState() {} - - @Override - @Ignore - public void testWatermarkLatestState() {} - - @Override - @Ignore - public void testWatermarkEndOfWindowState() {} - - @Override - @Ignore - public void testWatermarkStateIsEmpty() {} - - @Override - @Ignore - public void testMergeEarliestWatermarkIntoSource() {} + } - @Override - @Ignore - public void testMergeLatestWatermarkIntoSource() {} + @Test + public void testBagIsEmpty() throws Exception { + BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); - @Override - @Ignore - public void testSetReadable() {} + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState<Boolean> readFuture = value.isEmpty(); + value.add("hello"); + assertThat(readFuture.read(), Matchers.is(false)); - @Override - @Ignore - public void testMapReadable() {} + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } } http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java index b8d41de..35d2b78 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java @@ -17,11 +17,31 @@ */ package org.apache.beam.runners.flink.streaming; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThat; + import java.nio.ByteBuffer; -import org.apache.beam.runners.core.StateInternals; -import org.apache.beam.runners.core.StateInternalsTest; +import java.util.Arrays; +import org.apache.beam.runners.core.StateMerging; +import org.apache.beam.runners.core.StateNamespace; +import org.apache.beam.runners.core.StateNamespaceForTest; +import org.apache.beam.runners.core.StateTag; +import org.apache.beam.runners.core.StateTags; import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals; import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.CombiningState; +import org.apache.beam.sdk.state.GroupingState; +import org.apache.beam.sdk.state.ReadableState; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.WatermarkHoldState; +import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.util.CoderUtils; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.JobID; @@ -32,17 +52,42 @@ import org.apache.flink.runtime.query.KvStateRegistry; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.hamcrest.Matchers; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** - * Tests for {@link FlinkStateInternals}. This is based on {@link StateInternalsTest}. + * Tests for {@link FlinkStateInternals}. This is based on the tests for + * {@code InMemoryStateInternals}. */ @RunWith(JUnit4.class) -public class FlinkStateInternalsTest extends StateInternalsTest { +public class FlinkStateInternalsTest { + private static final BoundedWindow WINDOW_1 = new IntervalWindow(new Instant(0), new Instant(10)); + private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1"); + private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2"); + private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3"); + + private static final StateTag<ValueState<String>> STRING_VALUE_ADDR = + StateTags.value("stringValue", StringUtf8Coder.of()); + private static final StateTag<CombiningState<Integer, int[], Integer>> + SUM_INTEGER_ADDR = StateTags.combiningValueFromInputInternal( + "sumInteger", VarIntCoder.of(), Sum.ofIntegers()); + private static final StateTag<BagState<String>> STRING_BAG_ADDR = + StateTags.bag("stringBag", StringUtf8Coder.of()); + private static final StateTag<WatermarkHoldState> WATERMARK_EARLIEST_ADDR = + StateTags.watermarkStateInternal("watermark", TimestampCombiner.EARLIEST); + private static final StateTag<WatermarkHoldState> WATERMARK_LATEST_ADDR = + StateTags.watermarkStateInternal("watermark", TimestampCombiner.LATEST); + private static final StateTag<WatermarkHoldState> WATERMARK_EOW_ADDR = + StateTags.watermarkStateInternal("watermark", TimestampCombiner.END_OF_WINDOW); + + FlinkStateInternals<String> underTest; - @Override - protected StateInternals createStateInternals() { + @Before + public void initStateInternals() { MemoryStateBackend backend = new MemoryStateBackend(); try { AbstractKeyedStateBackend<ByteBuffer> keyedStateBackend = backend.createKeyedStateBackend( @@ -53,14 +98,296 @@ public class FlinkStateInternalsTest extends StateInternalsTest { 1, new KeyGroupRange(0, 0), new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID())); + underTest = new FlinkStateInternals<>(keyedStateBackend, StringUtf8Coder.of()); keyedStateBackend.setCurrentKey( ByteBuffer.wrap(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "Hello"))); - - return new FlinkStateInternals<>(keyedStateBackend, StringUtf8Coder.of()); } catch (Exception e) { throw new RuntimeException(e); } } + @Test + public void testValue() throws Exception { + ValueState<String> value = underTest.state(NAMESPACE_1, STRING_VALUE_ADDR); + + assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value); + assertNotEquals( + underTest.state(NAMESPACE_2, STRING_VALUE_ADDR), + value); + + assertThat(value.read(), Matchers.nullValue()); + value.write("hello"); + assertThat(value.read(), Matchers.equalTo("hello")); + value.write("world"); + assertThat(value.read(), Matchers.equalTo("world")); + + value.clear(); + assertThat(value.read(), Matchers.nullValue()); + assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value); + + } + + @Test + public void testBag() throws Exception { + BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + + assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR))); + + assertThat(value.read(), Matchers.emptyIterable()); + value.add("hello"); + assertThat(value.read(), Matchers.containsInAnyOrder("hello")); + + value.add("world"); + assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world")); + + value.clear(); + assertThat(value.read(), Matchers.emptyIterable()); + assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value); + + } + + @Test + public void testBagIsEmpty() throws Exception { + BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState<Boolean> readFuture = value.isEmpty(); + value.add("hello"); + assertThat(readFuture.read(), Matchers.is(false)); + + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } + + @Test + public void testMergeBagIntoSource() throws Exception { + BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); + + bag1.add("Hello"); + bag2.add("World"); + bag1.add("!"); + + StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1); + + // Reading the merged bag gets both the contents + assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); + assertThat(bag2.read(), Matchers.emptyIterable()); + } + + @Test + public void testMergeBagIntoNewNamespace() throws Exception { + BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); + BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR); + BagState<String> bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR); + + bag1.add("Hello"); + bag2.add("World"); + bag1.add("!"); + + StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3); + + // Reading the merged bag gets both the contents + assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); + assertThat(bag1.read(), Matchers.emptyIterable()); + assertThat(bag2.read(), Matchers.emptyIterable()); + } + + @Test + public void testCombiningValue() throws Exception { + GroupingState<Integer, Integer> value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + + // State instances are cached, but depend on the namespace. + assertEquals(value, underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR))); + + assertThat(value.read(), Matchers.equalTo(0)); + value.add(2); + assertThat(value.read(), Matchers.equalTo(2)); + + value.add(3); + assertThat(value.read(), Matchers.equalTo(5)); + + value.clear(); + assertThat(value.read(), Matchers.equalTo(0)); + assertEquals(underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR), value); + } + + @Test + public void testCombiningIsEmpty() throws Exception { + GroupingState<Integer, Integer> value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState<Boolean> readFuture = value.isEmpty(); + value.add(5); + assertThat(readFuture.read(), Matchers.is(false)); + + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } + + @Test + public void testMergeCombiningValueIntoSource() throws Exception { + CombiningState<Integer, int[], Integer> value1 = + underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + CombiningState<Integer, int[], Integer> value2 = + underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); + + value1.add(5); + value2.add(10); + value1.add(6); + + assertThat(value1.read(), Matchers.equalTo(11)); + assertThat(value2.read(), Matchers.equalTo(10)); + + // Merging clears the old values and updates the result value. + StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1); + + assertThat(value1.read(), Matchers.equalTo(21)); + assertThat(value2.read(), Matchers.equalTo(0)); + } + + @Test + public void testMergeCombiningValueIntoNewNamespace() throws Exception { + CombiningState<Integer, int[], Integer> value1 = + underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); + CombiningState<Integer, int[], Integer> value2 = + underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR); + CombiningState<Integer, int[], Integer> value3 = + underTest.state(NAMESPACE_3, SUM_INTEGER_ADDR); + + value1.add(5); + value2.add(10); + value1.add(6); + + StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3); + + // Merging clears the old values and updates the result value. + assertThat(value1.read(), Matchers.equalTo(0)); + assertThat(value2.read(), Matchers.equalTo(0)); + assertThat(value3.read(), Matchers.equalTo(21)); + } + + @Test + public void testWatermarkEarliestState() throws Exception { + WatermarkHoldState value = + underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); + + // State instances are cached, but depend on the namespace. + assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR))); + + assertThat(value.read(), Matchers.nullValue()); + value.add(new Instant(2000)); + assertThat(value.read(), Matchers.equalTo(new Instant(2000))); + + value.add(new Instant(3000)); + assertThat(value.read(), Matchers.equalTo(new Instant(2000))); + + value.add(new Instant(1000)); + assertThat(value.read(), Matchers.equalTo(new Instant(1000))); + + value.clear(); + assertThat(value.read(), Matchers.equalTo(null)); + assertEquals(underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR), value); + } + + @Test + public void testWatermarkLatestState() throws Exception { + WatermarkHoldState value = + underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR); + + // State instances are cached, but depend on the namespace. + assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR))); + + assertThat(value.read(), Matchers.nullValue()); + value.add(new Instant(2000)); + assertThat(value.read(), Matchers.equalTo(new Instant(2000))); + + value.add(new Instant(3000)); + assertThat(value.read(), Matchers.equalTo(new Instant(3000))); + + value.add(new Instant(1000)); + assertThat(value.read(), Matchers.equalTo(new Instant(3000))); + + value.clear(); + assertThat(value.read(), Matchers.equalTo(null)); + assertEquals(underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR), value); + } + + @Test + public void testWatermarkEndOfWindowState() throws Exception { + WatermarkHoldState value = underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR); + + // State instances are cached, but depend on the namespace. + assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR)); + assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EOW_ADDR))); + + assertThat(value.read(), Matchers.nullValue()); + value.add(new Instant(2000)); + assertThat(value.read(), Matchers.equalTo(new Instant(2000))); + + value.clear(); + assertThat(value.read(), Matchers.equalTo(null)); + assertEquals(underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR), value); + } + + @Test + public void testWatermarkStateIsEmpty() throws Exception { + WatermarkHoldState value = + underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); + + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState<Boolean> readFuture = value.isEmpty(); + value.add(new Instant(1000)); + assertThat(readFuture.read(), Matchers.is(false)); + + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } + + @Test + public void testMergeEarliestWatermarkIntoSource() throws Exception { + WatermarkHoldState value1 = + underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR); + WatermarkHoldState value2 = + underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR); + + value1.add(new Instant(3000)); + value2.add(new Instant(5000)); + value1.add(new Instant(4000)); + value2.add(new Instant(2000)); + + // Merging clears the old values and updates the merged value. + StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value1, WINDOW_1); + + assertThat(value1.read(), Matchers.equalTo(new Instant(2000))); + assertThat(value2.read(), Matchers.equalTo(null)); + } + + @Test + public void testMergeLatestWatermarkIntoSource() throws Exception { + WatermarkHoldState value1 = + underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR); + WatermarkHoldState value2 = + underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR); + WatermarkHoldState value3 = + underTest.state(NAMESPACE_3, WATERMARK_LATEST_ADDR); + + value1.add(new Instant(3000)); + value2.add(new Instant(5000)); + value1.add(new Instant(4000)); + value2.add(new Instant(2000)); + + // Merging clears the old values and updates the result value. + StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value3, WINDOW_1); + + // Merging clears the old values and updates the result value. + assertThat(value3.read(), Matchers.equalTo(new Instant(5000))); + assertThat(value1.read(), Matchers.equalTo(null)); + assertThat(value2.read(), Matchers.equalTo(null)); + } } http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/google-cloud-dataflow-java/pom.xml ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/pom.xml b/runners/google-cloud-dataflow-java/pom.xml index c8d63ac..92c94a8 100644 --- a/runners/google-cloud-dataflow-java/pom.xml +++ b/runners/google-cloud-dataflow-java/pom.xml @@ -22,7 +22,7 @@ <parent> <groupId>org.apache.beam</groupId> <artifactId>beam-runners-parent</artifactId> - <version>2.2.0-SNAPSHOT</version> + <version>2.1.0-SNAPSHOT</version> <relativePath>../pom.xml</relativePath> </parent> @@ -33,7 +33,7 @@ <packaging>jar</packaging> <properties> - <dataflow.container_version>beam-master-20170706</dataflow.container_version> + <dataflow.container_version>beam-master-20170530</dataflow.container_version> <dataflow.fnapi_environment_major_version>1</dataflow.fnapi_environment_major_version> <dataflow.legacy_environment_major_version>6</dataflow.legacy_environment_major_version> </properties> @@ -216,17 +216,13 @@ <execution> <id>validates-runner-tests</id> <configuration> - <!-- - UsesSplittableParDoWithWindowedSideInputs because of - https://issues.apache.org/jira/browse/BEAM-2476 - --> <excludedGroups> org.apache.beam.sdk.testing.LargeKeys$Above10MB, org.apache.beam.sdk.testing.UsesDistributionMetrics, org.apache.beam.sdk.testing.UsesGaugeMetrics, org.apache.beam.sdk.testing.UsesSetState, org.apache.beam.sdk.testing.UsesMapState, - org.apache.beam.sdk.testing.UsesSplittableParDoWithWindowedSideInputs, + org.apache.beam.sdk.testing.UsesSplittableParDo, org.apache.beam.sdk.testing.UsesUnboundedPCollections, org.apache.beam.sdk.testing.UsesTestStream, </excludedGroups> http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java index 7309f61..4d9a57f 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java @@ -145,8 +145,6 @@ public class BatchStatefulParDoOverrides { public PCollection<OutputT> expand(PCollection<KV<K, InputT>> input) { DoFn<KV<K, InputT>, OutputT> fn = originalParDo.getFn(); verifyFnIsStateful(fn); - DataflowRunner.verifyStateSupported(fn); - DataflowRunner.verifyStateSupportForWindowingStrategy(input.getWindowingStrategy()); PTransform< PCollection<? extends KV<K, Iterable<KV<Instant, WindowedValue<KV<K, InputT>>>>>>, @@ -171,8 +169,6 @@ public class BatchStatefulParDoOverrides { public PCollectionTuple expand(PCollection<KV<K, InputT>> input) { DoFn<KV<K, InputT>, OutputT> fn = originalParDo.getFn(); verifyFnIsStateful(fn); - DataflowRunner.verifyStateSupported(fn); - DataflowRunner.verifyStateSupportForWindowingStrategy(input.getWindowingStrategy()); PTransform< PCollection<? extends KV<K, Iterable<KV<Instant, WindowedValue<KV<K, InputT>>>>>>, http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java index ad3faed..b4a6e64 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java @@ -39,6 +39,8 @@ import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; +import org.apache.beam.runners.core.construction.PTransformReplacements; +import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory; import org.apache.beam.runners.dataflow.internal.IsmFormat; import org.apache.beam.runners.dataflow.internal.IsmFormat.IsmRecord; import org.apache.beam.runners.dataflow.internal.IsmFormat.IsmRecordCoder; @@ -55,11 +57,17 @@ import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.coders.StructuredCoder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Combine.GloballyAsSingletonView; +import org.apache.beam.sdk.transforms.CombineFnBase.GlobalCombineFn; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.View; +import org.apache.beam.sdk.transforms.View.AsSingleton; import org.apache.beam.sdk.transforms.View.CreatePCollectionView; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; @@ -75,6 +83,7 @@ import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PCollectionViews; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.sdk.values.WindowingStrategy; @@ -183,13 +192,12 @@ class BatchViewOverrides { } private final DataflowRunner runner; - private final PCollectionView<Map<K, V>> view; - /** Builds an instance of this class from the overridden transform. */ + /** + * Builds an instance of this class from the overridden transform. + */ @SuppressWarnings("unused") // used via reflection in DataflowRunner#apply() - public BatchViewAsMap( - DataflowRunner runner, CreatePCollectionView<KV<K, V>, Map<K, V>> transform) { + public BatchViewAsMap(DataflowRunner runner, View.AsMap<K, V> transform) { this.runner = runner; - this.view = transform.getView(); } @Override @@ -199,7 +207,12 @@ class BatchViewOverrides { private <W extends BoundedWindow> PCollectionView<Map<K, V>> applyInternal(PCollection<KV<K, V>> input) { + + @SuppressWarnings({"rawtypes", "unchecked"}) + KvCoder<K, V> inputCoder = (KvCoder) input.getCoder(); try { + PCollectionView<Map<K, V>> view = PCollectionViews.mapView( + input, input.getWindowingStrategy(), inputCoder); return BatchViewAsMultimap.applyForMapLike(runner, input, view, true /* unique keys */); } catch (NonDeterministicException e) { runner.recordViewUsesNonDeterministicKeyCoder(this); @@ -236,14 +249,19 @@ class BatchViewOverrides { inputCoder.getKeyCoder(), FullWindowedValueCoder.of(inputCoder.getValueCoder(), windowCoder))); + TransformedMap<K, WindowedValue<V>, V> defaultValue = new TransformedMap<>( + WindowedValueToValue.<V>of(), + ImmutableMap.<K, WindowedValue<V>>of()); + return BatchViewAsSingleton.<KV<K, V>, TransformedMap<K, WindowedValue<V>, V>, Map<K, V>, W> applyForSingleton( runner, input, new ToMapDoFn<K, V, W>(windowCoder), - finalValueCoder, - view); + true, + defaultValue, + finalValueCoder); } } @@ -662,13 +680,12 @@ class BatchViewOverrides { } private final DataflowRunner runner; - private final PCollectionView<Map<K, Iterable<V>>> view; - /** Builds an instance of this class from the overridden transform. */ + /** + * Builds an instance of this class from the overridden transform. + */ @SuppressWarnings("unused") // used via reflection in DataflowRunner#apply() - public BatchViewAsMultimap( - DataflowRunner runner, CreatePCollectionView<KV<K, V>, Map<K, Iterable<V>>> transform) { + public BatchViewAsMultimap(DataflowRunner runner, View.AsMultimap<K, V> transform) { this.runner = runner; - this.view = transform.getView(); } @Override @@ -678,7 +695,12 @@ class BatchViewOverrides { private <W extends BoundedWindow> PCollectionView<Map<K, Iterable<V>>> applyInternal(PCollection<KV<K, V>> input) { + @SuppressWarnings({"rawtypes", "unchecked"}) + KvCoder<K, V> inputCoder = (KvCoder) input.getCoder(); try { + PCollectionView<Map<K, Iterable<V>>> view = PCollectionViews.multimapView( + input, input.getWindowingStrategy(), inputCoder); + return applyForMapLike(runner, input, view, false /* unique keys not expected */); } catch (NonDeterministicException e) { runner.recordViewUsesNonDeterministicKeyCoder(this); @@ -716,15 +738,16 @@ class BatchViewOverrides { IterableWithWindowedValuesToIterable.<V>of(), ImmutableMap.<K, Iterable<WindowedValue<V>>>of()); - return BatchViewAsSingleton - .<KV<K, V>, TransformedMap<K, Iterable<WindowedValue<V>>, Iterable<V>>, - Map<K, Iterable<V>>, W> - applyForSingleton( - runner, - input, - new ToMultimapDoFn<K, V, W>(windowCoder), - finalValueCoder, - view); + return BatchViewAsSingleton.<KV<K, V>, + TransformedMap<K, Iterable<WindowedValue<V>>, Iterable<V>>, + Map<K, Iterable<V>>, + W> applyForSingleton( + runner, + input, + new ToMultimapDoFn<K, V, W>(windowCoder), + true, + defaultValue, + finalValueCoder); } private static <K, V, W extends BoundedWindow, ViewT> PCollectionView<ViewT> applyForMapLike( @@ -804,9 +827,10 @@ class BatchViewOverrides { PCollectionList.of(ImmutableList.of( perHashWithReifiedWindows, windowMapSizeMetadata, windowMapKeysMetadata)); - Pipeline.applyTransform(outputs, Flatten.<IsmRecord<WindowedValue<V>>>pCollections()) - .apply(CreateDataflowView.<IsmRecord<WindowedValue<V>>, ViewT>of(view)); - return view; + return Pipeline.applyTransform(outputs, + Flatten.<IsmRecord<WindowedValue<V>>>pCollections()) + .apply(CreateDataflowView.<IsmRecord<WindowedValue<V>>, + ViewT>of(view)); } @Override @@ -891,12 +915,14 @@ class BatchViewOverrides { } private final DataflowRunner runner; - private final PCollectionView<T> view; - /** Builds an instance of this class from the overridden transform. */ + private final View.AsSingleton<T> transform; + /** + * Builds an instance of this class from the overridden transform. + */ @SuppressWarnings("unused") // used via reflection in DataflowRunner#apply() - public BatchViewAsSingleton(DataflowRunner runner, CreatePCollectionView<T, T> transform) { + public BatchViewAsSingleton(DataflowRunner runner, View.AsSingleton<T> transform) { this.runner = runner; - this.view = transform.getView(); + this.transform = transform; } @Override @@ -909,8 +935,9 @@ class BatchViewOverrides { runner, input, new IsmRecordForSingularValuePerWindowDoFn<T, BoundedWindow>(windowCoder), - input.getCoder(), - view); + transform.hasDefaultValue(), + transform.defaultValue(), + input.getCoder()); } static <T, FinalT, ViewT, W extends BoundedWindow> PCollectionView<ViewT> @@ -919,13 +946,23 @@ class BatchViewOverrides { PCollection<T> input, DoFn<KV<Integer, Iterable<KV<W, WindowedValue<T>>>>, IsmRecord<WindowedValue<FinalT>>> doFn, - Coder<FinalT> defaultValueCoder, - PCollectionView<ViewT> view) { + boolean hasDefault, + FinalT defaultValue, + Coder<FinalT> defaultValueCoder) { @SuppressWarnings("unchecked") Coder<W> windowCoder = (Coder<W>) input.getWindowingStrategy().getWindowFn().windowCoder(); + @SuppressWarnings({"rawtypes", "unchecked"}) + PCollectionView<ViewT> view = + (PCollectionView<ViewT>) PCollectionViews.<FinalT, W>singletonView( + (PCollection) input, + (WindowingStrategy) input.getWindowingStrategy(), + hasDefault, + defaultValue, + defaultValueCoder); + IsmRecordCoder<WindowedValue<FinalT>> ismCoder = coderForSingleton(windowCoder, defaultValueCoder); @@ -935,9 +972,8 @@ class BatchViewOverrides { reifiedPerWindowAndSorted.setCoder(ismCoder); runner.addPCollectionRequiringIndexedFormat(reifiedPerWindowAndSorted); - reifiedPerWindowAndSorted.apply( + return reifiedPerWindowAndSorted.apply( CreateDataflowView.<IsmRecord<WindowedValue<FinalT>>, ViewT>of(view)); - return view; } @Override @@ -1043,18 +1079,18 @@ class BatchViewOverrides { } private final DataflowRunner runner; - private final PCollectionView<List<T>> view; /** * Builds an instance of this class from the overridden transform. */ @SuppressWarnings("unused") // used via reflection in DataflowRunner#apply() - public BatchViewAsList(DataflowRunner runner, CreatePCollectionView<T, List<T>> transform) { + public BatchViewAsList(DataflowRunner runner, View.AsList<T> transform) { this.runner = runner; - this.view = transform.getView(); } @Override public PCollectionView<List<T>> expand(PCollection<T> input) { + PCollectionView<List<T>> view = PCollectionViews.listView( + input, input.getWindowingStrategy(), input.getCoder()); return applyForIterableLike(runner, input, view); } @@ -1080,9 +1116,8 @@ class BatchViewOverrides { reifiedPerWindowAndSorted.setCoder(ismCoder); runner.addPCollectionRequiringIndexedFormat(reifiedPerWindowAndSorted); - reifiedPerWindowAndSorted.apply( + return reifiedPerWindowAndSorted.apply( CreateDataflowView.<IsmRecord<WindowedValue<T>>, ViewT>of(view)); - return view; } PCollection<IsmRecord<WindowedValue<T>>> reifiedPerWindowAndSorted = input @@ -1091,9 +1126,8 @@ class BatchViewOverrides { reifiedPerWindowAndSorted.setCoder(ismCoder); runner.addPCollectionRequiringIndexedFormat(reifiedPerWindowAndSorted); - reifiedPerWindowAndSorted.apply( + return reifiedPerWindowAndSorted.apply( CreateDataflowView.<IsmRecord<WindowedValue<T>>, ViewT>of(view)); - return view; } @Override @@ -1130,17 +1164,18 @@ class BatchViewOverrides { extends PTransform<PCollection<T>, PCollectionView<Iterable<T>>> { private final DataflowRunner runner; - private final PCollectionView<Iterable<T>> view; - /** Builds an instance of this class from the overridden transform. */ + /** + * Builds an instance of this class from the overridden transform. + */ @SuppressWarnings("unused") // used via reflection in DataflowRunner#apply() - public BatchViewAsIterable( - DataflowRunner runner, CreatePCollectionView<T, Iterable<T>> transform) { + public BatchViewAsIterable(DataflowRunner runner, View.AsIterable<T> transform) { this.runner = runner; - this.view = transform.getView(); } @Override public PCollectionView<Iterable<T>> expand(PCollection<T> input) { + PCollectionView<Iterable<T>> view = PCollectionViews.iterableView( + input, input.getWindowingStrategy(), input.getCoder()); return BatchViewAsList.applyForIterableLike(runner, input, view); } } @@ -1342,4 +1377,59 @@ class BatchViewOverrides { verifyDeterministic(this, "Expected map coder to be deterministic.", originalMapCoder); } } + + static class BatchCombineGloballyAsSingletonViewFactory<ElemT, ViewT> + extends SingleInputOutputOverrideFactory< + PCollection<ElemT>, PCollectionView<ViewT>, + Combine.GloballyAsSingletonView<ElemT, ViewT>> { + private final DataflowRunner runner; + + BatchCombineGloballyAsSingletonViewFactory(DataflowRunner runner) { + this.runner = runner; + } + + @Override + public PTransformReplacement<PCollection<ElemT>, PCollectionView<ViewT>> + getReplacementTransform( + AppliedPTransform< + PCollection<ElemT>, PCollectionView<ViewT>, + GloballyAsSingletonView<ElemT, ViewT>> + transform) { + GloballyAsSingletonView<ElemT, ViewT> combine = transform.getTransform(); + return PTransformReplacement.of( + PTransformReplacements.getSingletonMainInput(transform), + new BatchCombineGloballyAsSingletonView<>( + runner, combine.getCombineFn(), combine.getFanout(), combine.getInsertDefault())); + } + + private static class BatchCombineGloballyAsSingletonView<ElemT, ViewT> + extends PTransform<PCollection<ElemT>, PCollectionView<ViewT>> { + private final DataflowRunner runner; + private final GlobalCombineFn<? super ElemT, ?, ViewT> combineFn; + private final int fanout; + private final boolean insertDefault; + + BatchCombineGloballyAsSingletonView( + DataflowRunner runner, + GlobalCombineFn<? super ElemT, ?, ViewT> combineFn, + int fanout, + boolean insertDefault) { + this.runner = runner; + this.combineFn = combineFn; + this.fanout = fanout; + this.insertDefault = insertDefault; + } + + @Override + public PCollectionView<ViewT> expand(PCollection<ElemT> input) { + PCollection<ViewT> combined = + input.apply(Combine.globally(combineFn).withoutDefaults().withFanout(fanout)); + AsSingleton<ViewT> viewAsSingleton = View.asSingleton(); + if (insertDefault) { + viewAsSingleton.withDefaultValue(combineFn.defaultValue()); + } + return combined.apply(new BatchViewAsSingleton<>(runner, viewAsSingleton)); + } + } + } }
