Repository: beam Updated Branches: refs/heads/master aebd3a4c5 -> a05455088
Flink runner: support MapState in FlinkStateInternals. Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/dbab052c Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/dbab052c Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/dbab052c Branch: refs/heads/master Commit: dbab052c4456ff51dd4ce44979c77a508acc17e9 Parents: aebd3a4 Author: æ³¢ç¹ <[email protected]> Authored: Thu May 18 12:23:20 2017 +0800 Committer: Pei He <[email protected]> Committed: Tue Jun 6 23:18:33 2017 +0800 ---------------------------------------------------------------------- runners/flink/pom.xml | 1 - .../streaming/state/FlinkStateInternals.java | 205 ++++++++++++++++++- 2 files changed, 202 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/dbab052c/runners/flink/pom.xml ---------------------------------------------------------------------- diff --git a/runners/flink/pom.xml b/runners/flink/pom.xml index 92f95a0..c4c6b55 100644 --- a/runners/flink/pom.xml +++ b/runners/flink/pom.xml @@ -92,7 +92,6 @@ org.apache.beam.sdk.testing.FlattenWithHeterogeneousCoders, org.apache.beam.sdk.testing.LargeKeys$Above100MB, org.apache.beam.sdk.testing.UsesSetState, - org.apache.beam.sdk.testing.UsesMapState, org.apache.beam.sdk.testing.UsesCommittedMetrics, org.apache.beam.sdk.testing.UsesTestStream, org.apache.beam.sdk.testing.UsesSplittableParDo http://git-wip-us.apache.org/repos/asf/beam/blob/dbab052c/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java index b73abe9..f0d3278 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -26,6 +26,7 @@ import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateTag; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.InstantCoder; @@ -46,6 +47,7 @@ import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.CombineContextFactory; import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.runtime.state.KeyedStateBackend; @@ -132,11 +134,11 @@ public class FlinkStateInternals<K> implements StateInternals { @Override public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( - StateTag<MapState<KeyT, ValueT>> spec, + StateTag<MapState<KeyT, ValueT>> address, Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) { - throw new UnsupportedOperationException( - String.format("%s is not supported", MapState.class.getSimpleName())); + return new FlinkMapState<>( + flinkStateBackend, address, namespace, mapKeyCoder, mapValueCoder); } @Override @@ -1029,4 +1031,201 @@ public class FlinkStateInternals<K> implements StateInternals { return result; } } + + private static class FlinkMapState<KeyT, ValueT> implements MapState<KeyT, ValueT> { + + private final StateNamespace namespace; + private final StateTag<MapState<KeyT, ValueT>> address; + private final MapStateDescriptor<KeyT, ValueT> flinkStateDescriptor; + private final KeyedStateBackend<ByteBuffer> flinkStateBackend; + + FlinkMapState( + KeyedStateBackend<ByteBuffer> flinkStateBackend, + StateTag<MapState<KeyT, ValueT>> address, + StateNamespace namespace, + Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) { + this.namespace = namespace; + this.address = address; + this.flinkStateBackend = flinkStateBackend; + this.flinkStateDescriptor = new MapStateDescriptor<>(address.getId(), + new CoderTypeSerializer<>(mapKeyCoder), new CoderTypeSerializer<>(mapValueCoder)); + } + + @Override + public ReadableState<ValueT> get(final KeyT input) { + return new ReadableState<ValueT>() { + @Override + public ValueT read() { + try { + return flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).get(input); + } catch (Exception e) { + throw new RuntimeException("Error get from state.", e); + } + } + + @Override + public ReadableState<ValueT> readLater() { + return this; + } + }; + } + + @Override + public void put(KeyT key, ValueT value) { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).put(key, value); + } catch (Exception e) { + throw new RuntimeException("Error put kv to state.", e); + } + } + + @Override + public ReadableState<ValueT> putIfAbsent(final KeyT key, final ValueT value) { + return new ReadableState<ValueT>() { + @Override + public ValueT read() { + try { + ValueT current = flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).get(key); + + if (current == null) { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).put(key, value); + } + return current; + } catch (Exception e) { + throw new RuntimeException("Error put kv to state.", e); + } + } + + @Override + public ReadableState<ValueT> readLater() { + return this; + } + }; + } + + @Override + public void remove(KeyT key) { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).remove(key); + } catch (Exception e) { + throw new RuntimeException("Error remove map state key.", e); + } + } + + @Override + public ReadableState<Iterable<KeyT>> keys() { + return new ReadableState<Iterable<KeyT>>() { + @Override + public Iterable<KeyT> read() { + try { + return flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).keys(); + } catch (Exception e) { + throw new RuntimeException("Error get map state keys.", e); + } + } + + @Override + public ReadableState<Iterable<KeyT>> readLater() { + return this; + } + }; + } + + @Override + public ReadableState<Iterable<ValueT>> values() { + return new ReadableState<Iterable<ValueT>>() { + @Override + public Iterable<ValueT> read() { + try { + return flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).values(); + } catch (Exception e) { + throw new RuntimeException("Error get map state values.", e); + } + } + + @Override + public ReadableState<Iterable<ValueT>> readLater() { + return this; + } + }; + } + + @Override + public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>> entries() { + return new ReadableState<Iterable<Map.Entry<KeyT, ValueT>>>() { + @Override + public Iterable<Map.Entry<KeyT, ValueT>> read() { + try { + return flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).entries(); + } catch (Exception e) { + throw new RuntimeException("Error get map state entries.", e); + } + } + + @Override + public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>> readLater() { + return this; + } + }; + } + + @Override + public void clear() { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).clear(); + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkMapState<?, ?> that = (FlinkMapState<?, ?>) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + } + }
