http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/BoundedSourceWrapper.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/BoundedSourceWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/BoundedSourceWrapper.java new file mode 100644 index 0000000..2ed5024 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/BoundedSourceWrapper.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.io; + +import com.google.common.annotations.VisibleForTesting; +import java.util.ArrayList; +import java.util.List; +import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.flink.api.common.functions.StoppableFunction; +import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Wrapper for executing {@link BoundedSource BoundedSources} as a Flink Source. + */ +public class BoundedSourceWrapper<OutputT> + extends RichParallelSourceFunction<WindowedValue<OutputT>> + implements StoppableFunction { + + private static final Logger LOG = LoggerFactory.getLogger(BoundedSourceWrapper.class); + + /** + * Keep the options so that we can initialize the readers. + */ + private final SerializedPipelineOptions serializedOptions; + + /** + * The split sources. We split them in the constructor to ensure that all parallel + * sources are consistent about the split sources. + */ + private List<? extends BoundedSource<OutputT>> splitSources; + + /** + * Make it a field so that we can access it in {@link #close()}. + */ + private transient List<BoundedSource.BoundedReader<OutputT>> readers; + + /** + * Initialize here and not in run() to prevent races where we cancel a job before run() is + * ever called or run() is called after cancel(). + */ + private volatile boolean isRunning = true; + + @SuppressWarnings("unchecked") + public BoundedSourceWrapper( + PipelineOptions pipelineOptions, + BoundedSource<OutputT> source, + int parallelism) throws Exception { + this.serializedOptions = new SerializedPipelineOptions(pipelineOptions); + + long desiredBundleSize = source.getEstimatedSizeBytes(pipelineOptions) / parallelism; + + // get the splits early. we assume that the generated splits are stable, + // this is necessary so that the mapping of state to source is correct + // when restoring + splitSources = source.split(desiredBundleSize, pipelineOptions); + } + + @Override + public void run(SourceContext<WindowedValue<OutputT>> ctx) throws Exception { + + // figure out which split sources we're responsible for + int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask(); + int numSubtasks = getRuntimeContext().getNumberOfParallelSubtasks(); + + List<BoundedSource<OutputT>> localSources = new ArrayList<>(); + + for (int i = 0; i < splitSources.size(); i++) { + if (i % numSubtasks == subtaskIndex) { + localSources.add(splitSources.get(i)); + } + } + + LOG.info("Bounded Flink Source {}/{} is reading from sources: {}", + subtaskIndex, + numSubtasks, + localSources); + + readers = new ArrayList<>(); + // initialize readers from scratch + for (BoundedSource<OutputT> source : localSources) { + readers.add(source.createReader(serializedOptions.getPipelineOptions())); + } + + if (readers.size() == 1) { + // the easy case, we just read from one reader + BoundedSource.BoundedReader<OutputT> reader = readers.get(0); + + boolean dataAvailable = reader.start(); + if (dataAvailable) { + emitElement(ctx, reader); + } + + while (isRunning) { + dataAvailable = reader.advance(); + + if (dataAvailable) { + emitElement(ctx, reader); + } else { + break; + } + } + } else { + // a bit more complicated, we are responsible for several readers + // loop through them and sleep if none of them had any data + + int currentReader = 0; + + // start each reader and emit data if immediately available + for (BoundedSource.BoundedReader<OutputT> reader : readers) { + boolean dataAvailable = reader.start(); + if (dataAvailable) { + emitElement(ctx, reader); + } + } + + // a flag telling us whether any of the readers had data + // if no reader had data, sleep for bit + boolean hadData = false; + while (isRunning && !readers.isEmpty()) { + BoundedSource.BoundedReader<OutputT> reader = readers.get(currentReader); + boolean dataAvailable = reader.advance(); + + if (dataAvailable) { + emitElement(ctx, reader); + hadData = true; + } else { + readers.remove(currentReader); + currentReader--; + if (readers.isEmpty()) { + break; + } + } + + currentReader = (currentReader + 1) % readers.size(); + if (currentReader == 0 && !hadData) { + Thread.sleep(50); + } else if (currentReader == 0) { + hadData = false; + } + } + + } + + // emit final Long.MAX_VALUE watermark, just to be sure + ctx.emitWatermark(new Watermark(Long.MAX_VALUE)); + } + + /** + * Emit the current element from the given Reader. The reader is guaranteed to have data. + */ + private void emitElement( + SourceContext<WindowedValue<OutputT>> ctx, + BoundedSource.BoundedReader<OutputT> reader) { + // make sure that reader state update and element emission are atomic + // with respect to snapshots + synchronized (ctx.getCheckpointLock()) { + + OutputT item = reader.getCurrent(); + Instant timestamp = reader.getCurrentTimestamp(); + + WindowedValue<OutputT> windowedValue = + WindowedValue.of(item, timestamp, GlobalWindow.INSTANCE, PaneInfo.NO_FIRING); + ctx.collectWithTimestamp(windowedValue, timestamp.getMillis()); + } + } + + @Override + public void close() throws Exception { + super.close(); + if (readers != null) { + for (BoundedSource.BoundedReader<OutputT> reader: readers) { + reader.close(); + } + } + } + + @Override + public void cancel() { + isRunning = false; + } + + @Override + public void stop() { + this.isRunning = false; + } + + /** + * Visible so that we can check this in tests. Must not be used for anything else. + */ + @VisibleForTesting + public List<? extends BoundedSource<OutputT>> getSplitSources() { + return splitSources; + } +}
http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSocketSource.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSocketSource.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSocketSource.java new file mode 100644 index 0000000..910a33f --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSocketSource.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.io; + +import static com.google.common.base.Preconditions.checkArgument; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.util.Collections; +import java.util.List; +import java.util.NoSuchElementException; +import javax.annotation.Nullable; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.options.PipelineOptions; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An example unbounded Beam source that reads input from a socket. + * This is used mainly for testing and debugging. + * */ +public class UnboundedSocketSource<CheckpointMarkT extends UnboundedSource.CheckpointMark> + extends UnboundedSource<String, CheckpointMarkT> { + + private static final Coder<String> DEFAULT_SOCKET_CODER = StringUtf8Coder.of(); + + private static final long serialVersionUID = 1L; + + private static final int DEFAULT_CONNECTION_RETRY_SLEEP = 500; + + private static final int CONNECTION_TIMEOUT_TIME = 0; + + private final String hostname; + private final int port; + private final char delimiter; + private final long maxNumRetries; + private final long delayBetweenRetries; + + public UnboundedSocketSource(String hostname, int port, char delimiter, long maxNumRetries) { + this(hostname, port, delimiter, maxNumRetries, DEFAULT_CONNECTION_RETRY_SLEEP); + } + + public UnboundedSocketSource(String hostname, + int port, + char delimiter, + long maxNumRetries, + long delayBetweenRetries) { + this.hostname = hostname; + this.port = port; + this.delimiter = delimiter; + this.maxNumRetries = maxNumRetries; + this.delayBetweenRetries = delayBetweenRetries; + } + + public String getHostname() { + return this.hostname; + } + + public int getPort() { + return this.port; + } + + public char getDelimiter() { + return this.delimiter; + } + + public long getMaxNumRetries() { + return this.maxNumRetries; + } + + public long getDelayBetweenRetries() { + return this.delayBetweenRetries; + } + + @Override + public List<? extends UnboundedSource<String, CheckpointMarkT>> split( + int desiredNumSplits, + PipelineOptions options) throws Exception { + return Collections.<UnboundedSource<String, CheckpointMarkT>>singletonList(this); + } + + @Override + public UnboundedReader<String> createReader(PipelineOptions options, + @Nullable CheckpointMarkT checkpointMark) { + return new UnboundedSocketReader(this); + } + + @Nullable + @Override + public Coder getCheckpointMarkCoder() { + // Flink and Dataflow have different checkpointing mechanisms. + // In our case we do not need a coder. + return null; + } + + @Override + public void validate() { + checkArgument(port > 0 && port < 65536, "port is out of range"); + checkArgument(maxNumRetries >= -1, "maxNumRetries must be zero or larger (num retries), " + + "or -1 (infinite retries)"); + checkArgument(delayBetweenRetries >= 0, "delayBetweenRetries must be zero or positive"); + } + + @Override + public Coder getDefaultOutputCoder() { + return DEFAULT_SOCKET_CODER; + } + + /** + * Unbounded socket reader. + */ + public static class UnboundedSocketReader extends UnboundedSource.UnboundedReader<String> { + + private static final Logger LOG = LoggerFactory.getLogger(UnboundedSocketReader.class); + + private final UnboundedSocketSource source; + + private Socket socket; + private BufferedReader reader; + + private boolean isRunning; + + private String currentRecord; + + public UnboundedSocketReader(UnboundedSocketSource source) { + this.source = source; + } + + private void openConnection() throws IOException { + this.socket = new Socket(); + this.socket.connect(new InetSocketAddress(this.source.getHostname(), this.source.getPort()), + CONNECTION_TIMEOUT_TIME); + this.reader = new BufferedReader(new InputStreamReader(this.socket.getInputStream())); + this.isRunning = true; + } + + @Override + public boolean start() throws IOException { + int attempt = 0; + while (!isRunning) { + try { + openConnection(); + LOG.info("Connected to server socket " + this.source.getHostname() + ':' + + this.source.getPort()); + + return advance(); + } catch (IOException e) { + LOG.info("Lost connection to server socket " + this.source.getHostname() + ':' + + this.source.getPort() + ". Retrying in " + + this.source.getDelayBetweenRetries() + " msecs..."); + + if (this.source.getMaxNumRetries() == -1 || attempt++ < this.source.getMaxNumRetries()) { + try { + Thread.sleep(this.source.getDelayBetweenRetries()); + } catch (InterruptedException e1) { + e1.printStackTrace(); + } + } else { + this.isRunning = false; + break; + } + } + } + LOG.error("Unable to connect to host " + this.source.getHostname() + + " : " + this.source.getPort()); + return false; + } + + @Override + public boolean advance() throws IOException { + final StringBuilder buffer = new StringBuilder(); + int data; + while (isRunning && (data = reader.read()) != -1) { + // check if the string is complete + if (data != this.source.getDelimiter()) { + buffer.append((char) data); + } else { + if (buffer.length() > 0 && buffer.charAt(buffer.length() - 1) == '\r') { + buffer.setLength(buffer.length() - 1); + } + this.currentRecord = buffer.toString(); + buffer.setLength(0); + return true; + } + } + return false; + } + + @Override + public byte[] getCurrentRecordId() throws NoSuchElementException { + return new byte[0]; + } + + @Override + public String getCurrent() throws NoSuchElementException { + return this.currentRecord; + } + + @Override + public Instant getCurrentTimestamp() throws NoSuchElementException { + return Instant.now(); + } + + @Override + public void close() throws IOException { + this.reader.close(); + this.socket.close(); + this.isRunning = false; + LOG.info("Closed connection to server socket at " + this.source.getHostname() + ":" + + this.source.getPort() + "."); + } + + @Override + public Instant getWatermark() { + return Instant.now(); + } + + @Override + public CheckpointMark getCheckpointMark() { + return null; + } + + @Override + public UnboundedSource<String, ?> getCurrentSource() { + return this.source; + } + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java new file mode 100644 index 0000000..bb9b58a --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java @@ -0,0 +1,476 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.io; + +import com.google.common.annotations.VisibleForTesting; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.functions.StoppableFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.OperatorStateStore; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.state.CheckpointListener; +import org.apache.flink.runtime.state.DefaultOperatorStateBackend; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.tasks.ProcessingTimeCallback; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Wrapper for executing {@link UnboundedSource UnboundedSources} as a Flink Source. + */ +public class UnboundedSourceWrapper< + OutputT, CheckpointMarkT extends UnboundedSource.CheckpointMark> + extends RichParallelSourceFunction<WindowedValue<OutputT>> + implements ProcessingTimeCallback, StoppableFunction, + CheckpointListener, CheckpointedFunction { + + private static final Logger LOG = LoggerFactory.getLogger(UnboundedSourceWrapper.class); + + /** + * Keep the options so that we can initialize the localReaders. + */ + private final SerializedPipelineOptions serializedOptions; + + /** + * For snapshot and restore. + */ + private final KvCoder< + ? extends UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT> checkpointCoder; + + /** + * The split sources. We split them in the constructor to ensure that all parallel + * sources are consistent about the split sources. + */ + private final List<? extends UnboundedSource<OutputT, CheckpointMarkT>> splitSources; + + /** + * The local split sources. Assigned at runtime when the wrapper is executed in parallel. + */ + private transient List<UnboundedSource<OutputT, CheckpointMarkT>> localSplitSources; + + /** + * The local split readers. Assigned at runtime when the wrapper is executed in parallel. + * Make it a field so that we can access it in {@link #onProcessingTime(long)} for + * emitting watermarks. + */ + private transient List<UnboundedSource.UnboundedReader<OutputT>> localReaders; + + /** + * Flag to indicate whether the source is running. + * Initialize here and not in run() to prevent races where we cancel a job before run() is + * ever called or run() is called after cancel(). + */ + private volatile boolean isRunning = true; + + /** + * Make it a field so that we can access it in {@link #onProcessingTime(long)} for registering new + * triggers. + */ + private transient StreamingRuntimeContext runtimeContext; + + /** + * Make it a field so that we can access it in {@link #onProcessingTime(long)} for emitting + * watermarks. + */ + private transient SourceContext<WindowedValue<OutputT>> context; + + /** + * Pending checkpoints which have not been acknowledged yet. + */ + private transient LinkedHashMap<Long, List<CheckpointMarkT>> pendingCheckpoints; + /** + * Keep a maximum of 32 checkpoints for {@code CheckpointMark.finalizeCheckpoint()}. + */ + private static final int MAX_NUMBER_PENDING_CHECKPOINTS = 32; + + private transient ListState<KV<? extends + UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT>> stateForCheckpoint; + + /** + * false if checkpointCoder is null or no restore state by starting first. + */ + private transient boolean isRestored = false; + + @SuppressWarnings("unchecked") + public UnboundedSourceWrapper( + PipelineOptions pipelineOptions, + UnboundedSource<OutputT, CheckpointMarkT> source, + int parallelism) throws Exception { + this.serializedOptions = new SerializedPipelineOptions(pipelineOptions); + + if (source.requiresDeduping()) { + LOG.warn("Source {} requires deduping but Flink runner doesn't support this yet.", source); + } + + Coder<CheckpointMarkT> checkpointMarkCoder = source.getCheckpointMarkCoder(); + if (checkpointMarkCoder == null) { + LOG.info("No CheckpointMarkCoder specified for this source. Won't create snapshots."); + checkpointCoder = null; + } else { + + Coder<? extends UnboundedSource<OutputT, CheckpointMarkT>> sourceCoder = + (Coder) SerializableCoder.of(new TypeDescriptor<UnboundedSource>() { + }); + + checkpointCoder = KvCoder.of(sourceCoder, checkpointMarkCoder); + } + + // get the splits early. we assume that the generated splits are stable, + // this is necessary so that the mapping of state to source is correct + // when restoring + splitSources = source.split(parallelism, pipelineOptions); + } + + + /** + * Initialize and restore state before starting execution of the source. + */ + @Override + public void open(Configuration parameters) throws Exception { + runtimeContext = (StreamingRuntimeContext) getRuntimeContext(); + + // figure out which split sources we're responsible for + int subtaskIndex = runtimeContext.getIndexOfThisSubtask(); + int numSubtasks = runtimeContext.getNumberOfParallelSubtasks(); + + localSplitSources = new ArrayList<>(); + localReaders = new ArrayList<>(); + + pendingCheckpoints = new LinkedHashMap<>(); + + if (isRestored) { + // restore the splitSources from the checkpoint to ensure consistent ordering + for (KV<? extends UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT> restored: + stateForCheckpoint.get()) { + localSplitSources.add(restored.getKey()); + localReaders.add(restored.getKey().createReader( + serializedOptions.getPipelineOptions(), restored.getValue())); + } + } else { + // initialize localReaders and localSources from scratch + for (int i = 0; i < splitSources.size(); i++) { + if (i % numSubtasks == subtaskIndex) { + UnboundedSource<OutputT, CheckpointMarkT> source = + splitSources.get(i); + UnboundedSource.UnboundedReader<OutputT> reader = + source.createReader(serializedOptions.getPipelineOptions(), null); + localSplitSources.add(source); + localReaders.add(reader); + } + } + } + + LOG.info("Unbounded Flink Source {}/{} is reading from sources: {}", + subtaskIndex, + numSubtasks, + localSplitSources); + } + + @Override + public void run(SourceContext<WindowedValue<OutputT>> ctx) throws Exception { + + context = ctx; + + if (localReaders.size() == 0) { + // do nothing, but still look busy ... + // also, output a Long.MAX_VALUE watermark since we know that we're not + // going to emit anything + // we can't return here since Flink requires that all operators stay up, + // otherwise checkpointing would not work correctly anymore + ctx.emitWatermark(new Watermark(Long.MAX_VALUE)); + + // wait until this is canceled + final Object waitLock = new Object(); + while (isRunning) { + try { + // Flink will interrupt us at some point + //noinspection SynchronizationOnLocalVariableOrMethodParameter + synchronized (waitLock) { + // don't wait indefinitely, in case something goes horribly wrong + waitLock.wait(1000); + } + } catch (InterruptedException e) { + if (!isRunning) { + // restore the interrupted state, and fall through the loop + Thread.currentThread().interrupt(); + } + } + } + } else if (localReaders.size() == 1) { + // the easy case, we just read from one reader + UnboundedSource.UnboundedReader<OutputT> reader = localReaders.get(0); + + boolean dataAvailable = reader.start(); + if (dataAvailable) { + emitElement(ctx, reader); + } + + setNextWatermarkTimer(this.runtimeContext); + + while (isRunning) { + dataAvailable = reader.advance(); + + if (dataAvailable) { + emitElement(ctx, reader); + } else { + Thread.sleep(50); + } + } + } else { + // a bit more complicated, we are responsible for several localReaders + // loop through them and sleep if none of them had any data + + int numReaders = localReaders.size(); + int currentReader = 0; + + // start each reader and emit data if immediately available + for (UnboundedSource.UnboundedReader<OutputT> reader : localReaders) { + boolean dataAvailable = reader.start(); + if (dataAvailable) { + emitElement(ctx, reader); + } + } + + // a flag telling us whether any of the localReaders had data + // if no reader had data, sleep for bit + boolean hadData = false; + while (isRunning) { + UnboundedSource.UnboundedReader<OutputT> reader = localReaders.get(currentReader); + boolean dataAvailable = reader.advance(); + + if (dataAvailable) { + emitElement(ctx, reader); + hadData = true; + } + + currentReader = (currentReader + 1) % numReaders; + if (currentReader == 0 && !hadData) { + Thread.sleep(50); + } else if (currentReader == 0) { + hadData = false; + } + } + + } + } + + /** + * Emit the current element from the given Reader. The reader is guaranteed to have data. + */ + private void emitElement( + SourceContext<WindowedValue<OutputT>> ctx, + UnboundedSource.UnboundedReader<OutputT> reader) { + // make sure that reader state update and element emission are atomic + // with respect to snapshots + synchronized (ctx.getCheckpointLock()) { + + OutputT item = reader.getCurrent(); + Instant timestamp = reader.getCurrentTimestamp(); + + WindowedValue<OutputT> windowedValue = + WindowedValue.of(item, timestamp, GlobalWindow.INSTANCE, PaneInfo.NO_FIRING); + ctx.collectWithTimestamp(windowedValue, timestamp.getMillis()); + } + } + + @Override + public void close() throws Exception { + super.close(); + if (localReaders != null) { + for (UnboundedSource.UnboundedReader<OutputT> reader: localReaders) { + reader.close(); + } + } + } + + @Override + public void cancel() { + isRunning = false; + } + + @Override + public void stop() { + isRunning = false; + } + + // ------------------------------------------------------------------------ + // Checkpoint and restore + // ------------------------------------------------------------------------ + + @Override + public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception { + if (!isRunning) { + LOG.debug("snapshotState() called on closed source"); + } else { + + if (checkpointCoder == null) { + // no checkpoint coder available in this source + return; + } + + stateForCheckpoint.clear(); + + long checkpointId = functionSnapshotContext.getCheckpointId(); + + // we checkpoint the sources along with the CheckpointMarkT to ensure + // than we have a correct mapping of checkpoints to sources when + // restoring + List<CheckpointMarkT> checkpointMarks = new ArrayList<>(localSplitSources.size()); + + for (int i = 0; i < localSplitSources.size(); i++) { + UnboundedSource<OutputT, CheckpointMarkT> source = localSplitSources.get(i); + UnboundedSource.UnboundedReader<OutputT> reader = localReaders.get(i); + + @SuppressWarnings("unchecked") + CheckpointMarkT mark = (CheckpointMarkT) reader.getCheckpointMark(); + checkpointMarks.add(mark); + KV<UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT> kv = + KV.of(source, mark); + stateForCheckpoint.add(kv); + } + + // cleanup old pending checkpoints and add new checkpoint + int diff = pendingCheckpoints.size() - MAX_NUMBER_PENDING_CHECKPOINTS; + if (diff >= 0) { + for (Iterator<Long> iterator = pendingCheckpoints.keySet().iterator(); + diff >= 0; + diff--) { + iterator.next(); + iterator.remove(); + } + } + pendingCheckpoints.put(checkpointId, checkpointMarks); + + } + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + if (checkpointCoder == null) { + // no checkpoint coder available in this source + return; + } + + OperatorStateStore stateStore = context.getOperatorStateStore(); + CoderTypeInformation< + KV<? extends UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT>> + typeInformation = (CoderTypeInformation) new CoderTypeInformation<>(checkpointCoder); + stateForCheckpoint = stateStore.getOperatorState( + new ListStateDescriptor<>(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, + typeInformation.createSerializer(new ExecutionConfig()))); + + if (context.isRestored()) { + isRestored = true; + LOG.info("Having restore state in the UnbounedSourceWrapper."); + } else { + LOG.info("No restore state for UnbounedSourceWrapper."); + } + } + + @Override + public void onProcessingTime(long timestamp) throws Exception { + if (this.isRunning) { + synchronized (context.getCheckpointLock()) { + // find minimum watermark over all localReaders + long watermarkMillis = Long.MAX_VALUE; + for (UnboundedSource.UnboundedReader<OutputT> reader: localReaders) { + Instant watermark = reader.getWatermark(); + if (watermark != null) { + watermarkMillis = Math.min(watermark.getMillis(), watermarkMillis); + } + } + context.emitWatermark(new Watermark(watermarkMillis)); + } + setNextWatermarkTimer(this.runtimeContext); + } + } + + private void setNextWatermarkTimer(StreamingRuntimeContext runtime) { + if (this.isRunning) { + long watermarkInterval = runtime.getExecutionConfig().getAutoWatermarkInterval(); + long timeToNextWatermark = getTimeToNextWatermark(watermarkInterval); + runtime.getProcessingTimeService().registerTimer(timeToNextWatermark, this); + } + } + + private long getTimeToNextWatermark(long watermarkInterval) { + return System.currentTimeMillis() + watermarkInterval; + } + + /** + * Visible so that we can check this in tests. Must not be used for anything else. + */ + @VisibleForTesting + public List<? extends UnboundedSource<OutputT, CheckpointMarkT>> getSplitSources() { + return splitSources; + } + + /** + * Visible so that we can check this in tests. Must not be used for anything else. + */ + @VisibleForTesting + public List<? extends UnboundedSource<OutputT, CheckpointMarkT>> getLocalSplitSources() { + return localSplitSources; + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + + List<CheckpointMarkT> checkpointMarks = pendingCheckpoints.get(checkpointId); + + if (checkpointMarks != null) { + + // remove old checkpoints including the current one + Iterator<Long> iterator = pendingCheckpoints.keySet().iterator(); + long currentId; + do { + currentId = iterator.next(); + iterator.remove(); + } while (currentId != checkpointId); + + // confirm all marks + for (CheckpointMarkT mark : checkpointMarks) { + mark.finalizeCheckpoint(); + } + + } + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/package-info.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/package-info.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/package-info.java new file mode 100644 index 0000000..b431ce7 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Internal implementation of the Beam runner for Apache Flink. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.io; http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/package-info.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/package-info.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/package-info.java new file mode 100644 index 0000000..0674871 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Internal implementation of the Beam runner for Apache Flink. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming; http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java new file mode 100644 index 0000000..3203446 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java @@ -0,0 +1,865 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.state; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StateNamespace; +import org.apache.beam.runners.core.StateTag; +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.coders.MapCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.CombineWithContext; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.OutputTimeFn; +import org.apache.beam.sdk.util.CombineContextFactory; +import org.apache.beam.sdk.util.state.BagState; +import org.apache.beam.sdk.util.state.CombiningState; +import org.apache.beam.sdk.util.state.MapState; +import org.apache.beam.sdk.util.state.ReadableState; +import org.apache.beam.sdk.util.state.SetState; +import org.apache.beam.sdk.util.state.State; +import org.apache.beam.sdk.util.state.StateContext; +import org.apache.beam.sdk.util.state.StateContexts; +import org.apache.beam.sdk.util.state.ValueState; +import org.apache.beam.sdk.util.state.WatermarkHoldState; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.runtime.state.DefaultOperatorStateBackend; +import org.apache.flink.runtime.state.OperatorStateBackend; + +/** + * {@link StateInternals} that uses a Flink {@link DefaultOperatorStateBackend} + * to manage the broadcast state. + * The state is the same on all parallel instances of the operator. + * So we just need store state of operator-0 in OperatorStateBackend. + * + * <p>Note: Ignore index of key. + * Mainly for SideInputs. + */ +public class FlinkBroadcastStateInternals<K> implements StateInternals<K> { + + private int indexInSubtaskGroup; + private final DefaultOperatorStateBackend stateBackend; + // stateName -> <namespace, state> + private Map<String, Map<String, ?>> stateForNonZeroOperator; + + public FlinkBroadcastStateInternals(int indexInSubtaskGroup, OperatorStateBackend stateBackend) { + //TODO flink do not yet expose through public API + this.stateBackend = (DefaultOperatorStateBackend) stateBackend; + this.indexInSubtaskGroup = indexInSubtaskGroup; + if (indexInSubtaskGroup != 0) { + stateForNonZeroOperator = new HashMap<>(); + } + } + + @Override + public K getKey() { + return null; + } + + @Override + public <T extends State> T state( + final StateNamespace namespace, + StateTag<? super K, T> address) { + + return state(namespace, address, StateContexts.nullContext()); + } + + @Override + public <T extends State> T state( + final StateNamespace namespace, + StateTag<? super K, T> address, + final StateContext<?> context) { + + return address.bind(new StateTag.StateBinder<K>() { + + @Override + public <T> ValueState<T> bindValue( + StateTag<? super K, ValueState<T>> address, + Coder<T> coder) { + + return new FlinkBroadcastValueState<>(stateBackend, address, namespace, coder); + } + + @Override + public <T> BagState<T> bindBag( + StateTag<? super K, BagState<T>> address, + Coder<T> elemCoder) { + + return new FlinkBroadcastBagState<>(stateBackend, address, namespace, elemCoder); + } + + @Override + public <T> SetState<T> bindSet( + StateTag<? super K, SetState<T>> address, + Coder<T> elemCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", SetState.class.getSimpleName())); + } + + @Override + public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( + StateTag<? super K, MapState<KeyT, ValueT>> spec, + Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", MapState.class.getSimpleName())); + } + + @Override + public <InputT, AccumT, OutputT> + CombiningState<InputT, AccumT, OutputT> + bindCombiningValue( + StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + Combine.CombineFn<InputT, AccumT, OutputT> combineFn) { + + return new FlinkCombiningState<>( + stateBackend, address, combineFn, namespace, accumCoder); + } + + @Override + public <InputT, AccumT, OutputT> + CombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValue( + StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + final Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) { + return new FlinkKeyedCombiningState<>( + stateBackend, + address, + combineFn, + namespace, + accumCoder, + FlinkBroadcastStateInternals.this); + } + + @Override + public <InputT, AccumT, OutputT> + CombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValueWithContext( + StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + CombineWithContext.KeyedCombineFnWithContext< + ? super K, InputT, AccumT, OutputT> combineFn) { + return new FlinkCombiningStateWithContext<>( + stateBackend, + address, + combineFn, + namespace, + accumCoder, + FlinkBroadcastStateInternals.this, + CombineContextFactory.createFromStateContext(context)); + } + + @Override + public <W extends BoundedWindow> WatermarkHoldState<W> bindWatermark( + StateTag<? super K, WatermarkHoldState<W>> address, + OutputTimeFn<? super W> outputTimeFn) { + throw new UnsupportedOperationException( + String.format("%s is not supported", WatermarkHoldState.class.getSimpleName())); + } + }); + } + + /** + * 1. The way we would use it is to only checkpoint anything from the operator + * with subtask index 0 because we assume that the state is the same on all + * parallel instances of the operator. + * + * <p>2. Use map to support namespace. + */ + private abstract class AbstractBroadcastState<T> { + + private String name; + private final StateNamespace namespace; + private final ListStateDescriptor<Map<String, T>> flinkStateDescriptor; + private final DefaultOperatorStateBackend flinkStateBackend; + + AbstractBroadcastState( + DefaultOperatorStateBackend flinkStateBackend, + String name, + StateNamespace namespace, + Coder<T> coder) { + this.name = name; + + this.namespace = namespace; + this.flinkStateBackend = flinkStateBackend; + + CoderTypeInformation<Map<String, T>> typeInfo = + new CoderTypeInformation<>(MapCoder.of(StringUtf8Coder.of(), coder)); + + flinkStateDescriptor = new ListStateDescriptor<>(name, + typeInfo.createSerializer(new ExecutionConfig())); + } + + /** + * Get map(namespce->T) from index 0. + */ + Map<String, T> getMap() throws Exception { + if (indexInSubtaskGroup == 0) { + return getMapFromBroadcastState(); + } else { + Map<String, T> result = (Map<String, T>) stateForNonZeroOperator.get(name); + // maybe restore from BroadcastState of Operator-0 + if (result == null) { + result = getMapFromBroadcastState(); + if (result != null) { + stateForNonZeroOperator.put(name, result); + // we don't need it anymore, must clear it. + flinkStateBackend.getBroadcastOperatorState( + flinkStateDescriptor).clear(); + } + } + return result; + } + } + + Map<String, T> getMapFromBroadcastState() throws Exception { + ListState<Map<String, T>> state = flinkStateBackend.getBroadcastOperatorState( + flinkStateDescriptor); + Iterable<Map<String, T>> iterable = state.get(); + Map<String, T> ret = null; + if (iterable != null) { + // just use index 0 + Iterator<Map<String, T>> iterator = iterable.iterator(); + if (iterator.hasNext()) { + ret = iterator.next(); + } + } + return ret; + } + + /** + * Update map(namespce->T) from index 0. + */ + void updateMap(Map<String, T> map) throws Exception { + if (indexInSubtaskGroup == 0) { + ListState<Map<String, T>> state = flinkStateBackend.getBroadcastOperatorState( + flinkStateDescriptor); + state.clear(); + if (map.size() > 0) { + state.add(map); + } + } else { + if (map.size() == 0) { + stateForNonZeroOperator.remove(name); + // updateMap is always behind getMap, + // getMap will clear map in BroadcastOperatorState, + // we don't need clear here. + } else { + stateForNonZeroOperator.put(name, map); + } + } + } + + void writeInternal(T input) { + try { + Map<String, T> map = getMap(); + if (map == null) { + map = new HashMap<>(); + } + map.put(namespace.stringKey(), input); + updateMap(map); + } catch (Exception e) { + throw new RuntimeException("Error updating state.", e); + } + } + + T readInternal() { + try { + Map<String, T> map = getMap(); + if (map == null) { + return null; + } else { + return map.get(namespace.stringKey()); + } + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + void clearInternal() { + try { + Map<String, T> map = getMap(); + if (map != null) { + map.remove(namespace.stringKey()); + updateMap(map); + } + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + + } + + private class FlinkBroadcastValueState<K, T> + extends AbstractBroadcastState<T> implements ValueState<T> { + + private final StateNamespace namespace; + private final StateTag<? super K, ValueState<T>> address; + + FlinkBroadcastValueState( + DefaultOperatorStateBackend flinkStateBackend, + StateTag<? super K, ValueState<T>> address, + StateNamespace namespace, + Coder<T> coder) { + super(flinkStateBackend, address.getId(), namespace, coder); + + this.namespace = namespace; + this.address = address; + + } + + @Override + public void write(T input) { + writeInternal(input); + } + + @Override + public ValueState<T> readLater() { + return this; + } + + @Override + public T read() { + return readInternal(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkBroadcastValueState<?, ?> that = (FlinkBroadcastValueState<?, ?>) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + + @Override + public void clear() { + clearInternal(); + } + } + + private class FlinkBroadcastBagState<K, T> extends AbstractBroadcastState<List<T>> + implements BagState<T> { + + private final StateNamespace namespace; + private final StateTag<? super K, BagState<T>> address; + + FlinkBroadcastBagState( + DefaultOperatorStateBackend flinkStateBackend, + StateTag<? super K, BagState<T>> address, + StateNamespace namespace, + Coder<T> coder) { + super(flinkStateBackend, address.getId(), namespace, ListCoder.of(coder)); + + this.namespace = namespace; + this.address = address; + } + + @Override + public void add(T input) { + List<T> list = readInternal(); + if (list == null) { + list = new ArrayList<>(); + } + list.add(input); + writeInternal(list); + } + + @Override + public BagState<T> readLater() { + return this; + } + + @Override + public Iterable<T> read() { + List<T> result = readInternal(); + return result != null ? result : Collections.<T>emptyList(); + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public Boolean read() { + try { + List<T> result = readInternal(); + return result == null; + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + + } + + @Override + public ReadableState<Boolean> readLater() { + return this; + } + }; + } + + @Override + public void clear() { + clearInternal(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkBroadcastBagState<?, ?> that = (FlinkBroadcastBagState<?, ?>) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + } + + private class FlinkCombiningState<K, InputT, AccumT, OutputT> + extends AbstractBroadcastState<AccumT> + implements CombiningState<InputT, AccumT, OutputT> { + + private final StateNamespace namespace; + private final StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address; + private final Combine.CombineFn<InputT, AccumT, OutputT> combineFn; + + FlinkCombiningState( + DefaultOperatorStateBackend flinkStateBackend, + StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, + Combine.CombineFn<InputT, AccumT, OutputT> combineFn, + StateNamespace namespace, + Coder<AccumT> accumCoder) { + super(flinkStateBackend, address.getId(), namespace, accumCoder); + + this.namespace = namespace; + this.address = address; + this.combineFn = combineFn; + } + + @Override + public CombiningState<InputT, AccumT, OutputT> readLater() { + return this; + } + + @Override + public void add(InputT value) { + AccumT current = readInternal(); + if (current == null) { + current = combineFn.createAccumulator(); + } + current = combineFn.addInput(current, value); + writeInternal(current); + } + + @Override + public void addAccum(AccumT accum) { + AccumT current = readInternal(); + + if (current == null) { + writeInternal(accum); + } else { + current = combineFn.mergeAccumulators(Arrays.asList(current, accum)); + writeInternal(current); + } + } + + @Override + public AccumT getAccum() { + return readInternal(); + } + + @Override + public AccumT mergeAccumulators(Iterable<AccumT> accumulators) { + return combineFn.mergeAccumulators(accumulators); + } + + @Override + public OutputT read() { + AccumT accum = readInternal(); + if (accum != null) { + return combineFn.extractOutput(accum); + } else { + return combineFn.extractOutput(combineFn.createAccumulator()); + } + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public Boolean read() { + try { + return readInternal() == null; + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + + } + + @Override + public ReadableState<Boolean> readLater() { + return this; + } + }; + } + + @Override + public void clear() { + clearInternal(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkCombiningState<?, ?, ?, ?> that = + (FlinkCombiningState<?, ?, ?, ?>) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + } + + private class FlinkKeyedCombiningState<K, InputT, AccumT, OutputT> + extends AbstractBroadcastState<AccumT> + implements CombiningState<InputT, AccumT, OutputT> { + + private final StateNamespace namespace; + private final StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address; + private final Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn; + private final FlinkBroadcastStateInternals<K> flinkStateInternals; + + FlinkKeyedCombiningState( + DefaultOperatorStateBackend flinkStateBackend, + StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, + Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn, + StateNamespace namespace, + Coder<AccumT> accumCoder, + FlinkBroadcastStateInternals<K> flinkStateInternals) { + super(flinkStateBackend, address.getId(), namespace, accumCoder); + + this.namespace = namespace; + this.address = address; + this.combineFn = combineFn; + this.flinkStateInternals = flinkStateInternals; + + } + + @Override + public CombiningState<InputT, AccumT, OutputT> readLater() { + return this; + } + + @Override + public void add(InputT value) { + try { + AccumT current = readInternal(); + if (current == null) { + current = combineFn.createAccumulator(flinkStateInternals.getKey()); + } + current = combineFn.addInput(flinkStateInternals.getKey(), current, value); + writeInternal(current); + } catch (Exception e) { + throw new RuntimeException("Error adding to state." , e); + } + } + + @Override + public void addAccum(AccumT accum) { + try { + AccumT current = readInternal(); + if (current == null) { + writeInternal(accum); + } else { + current = combineFn.mergeAccumulators( + flinkStateInternals.getKey(), + Arrays.asList(current, accum)); + writeInternal(current); + } + } catch (Exception e) { + throw new RuntimeException("Error adding to state.", e); + } + } + + @Override + public AccumT getAccum() { + try { + return readInternal(); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public AccumT mergeAccumulators(Iterable<AccumT> accumulators) { + return combineFn.mergeAccumulators(flinkStateInternals.getKey(), accumulators); + } + + @Override + public OutputT read() { + try { + AccumT accum = readInternal(); + if (accum != null) { + return combineFn.extractOutput(flinkStateInternals.getKey(), accum); + } else { + return combineFn.extractOutput( + flinkStateInternals.getKey(), + combineFn.createAccumulator(flinkStateInternals.getKey())); + } + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public Boolean read() { + try { + return readInternal() == null; + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + + } + + @Override + public ReadableState<Boolean> readLater() { + return this; + } + }; + } + + @Override + public void clear() { + clearInternal(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkKeyedCombiningState<?, ?, ?, ?> that = + (FlinkKeyedCombiningState<?, ?, ?, ?>) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + } + + private class FlinkCombiningStateWithContext<K, InputT, AccumT, OutputT> + extends AbstractBroadcastState<AccumT> + implements CombiningState<InputT, AccumT, OutputT> { + + private final StateNamespace namespace; + private final StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address; + private final CombineWithContext.KeyedCombineFnWithContext< + ? super K, InputT, AccumT, OutputT> combineFn; + private final FlinkBroadcastStateInternals<K> flinkStateInternals; + private final CombineWithContext.Context context; + + FlinkCombiningStateWithContext( + DefaultOperatorStateBackend flinkStateBackend, + StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, + CombineWithContext.KeyedCombineFnWithContext< + ? super K, InputT, AccumT, OutputT> combineFn, + StateNamespace namespace, + Coder<AccumT> accumCoder, + FlinkBroadcastStateInternals<K> flinkStateInternals, + CombineWithContext.Context context) { + super(flinkStateBackend, address.getId(), namespace, accumCoder); + + this.namespace = namespace; + this.address = address; + this.combineFn = combineFn; + this.flinkStateInternals = flinkStateInternals; + this.context = context; + + } + + @Override + public CombiningState<InputT, AccumT, OutputT> readLater() { + return this; + } + + @Override + public void add(InputT value) { + try { + AccumT current = readInternal(); + if (current == null) { + current = combineFn.createAccumulator(flinkStateInternals.getKey(), context); + } + current = combineFn.addInput(flinkStateInternals.getKey(), current, value, context); + writeInternal(current); + } catch (Exception e) { + throw new RuntimeException("Error adding to state." , e); + } + } + + @Override + public void addAccum(AccumT accum) { + try { + + AccumT current = readInternal(); + if (current == null) { + writeInternal(accum); + } else { + current = combineFn.mergeAccumulators( + flinkStateInternals.getKey(), + Arrays.asList(current, accum), + context); + writeInternal(current); + } + } catch (Exception e) { + throw new RuntimeException("Error adding to state.", e); + } + } + + @Override + public AccumT getAccum() { + try { + return readInternal(); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public AccumT mergeAccumulators(Iterable<AccumT> accumulators) { + return combineFn.mergeAccumulators(flinkStateInternals.getKey(), accumulators, context); + } + + @Override + public OutputT read() { + try { + AccumT accum = readInternal(); + return combineFn.extractOutput(flinkStateInternals.getKey(), accum, context); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public Boolean read() { + try { + return readInternal() == null; + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + + } + + @Override + public ReadableState<Boolean> readLater() { + return this; + } + }; + } + + @Override + public void clear() { + clearInternal(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkCombiningStateWithContext<?, ?, ?, ?> that = + (FlinkCombiningStateWithContext<?, ?, ?, ?>) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + } + +} http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkKeyGroupStateInternals.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkKeyGroupStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkKeyGroupStateInternals.java new file mode 100644 index 0000000..24b340e --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkKeyGroupStateInternals.java @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers.streaming.state; + +import static org.apache.flink.util.Preconditions.checkArgument; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StateNamespace; +import org.apache.beam.runners.core.StateTag; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.Coder.Context; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.CombineWithContext; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.OutputTimeFn; +import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.util.state.BagState; +import org.apache.beam.sdk.util.state.CombiningState; +import org.apache.beam.sdk.util.state.MapState; +import org.apache.beam.sdk.util.state.ReadableState; +import org.apache.beam.sdk.util.state.SetState; +import org.apache.beam.sdk.util.state.State; +import org.apache.beam.sdk.util.state.StateContext; +import org.apache.beam.sdk.util.state.StateContexts; +import org.apache.beam.sdk.util.state.ValueState; +import org.apache.beam.sdk.util.state.WatermarkHoldState; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.state.KeyGroupsList; +import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.streaming.api.operators.HeapInternalTimerService; +import org.apache.flink.util.InstantiationUtil; +import org.apache.flink.util.Preconditions; + +/** + * {@link StateInternals} that uses {@link KeyGroupCheckpointedOperator} + * to checkpoint state. + * + * <p>Note: + * Ignore index of key. + * Just implement BagState. + * + * <p>Reference from {@link HeapInternalTimerService} to the local key-group range. + */ +public class FlinkKeyGroupStateInternals<K> implements StateInternals<K> { + + private final Coder<K> keyCoder; + private final KeyGroupsList localKeyGroupRange; + private KeyedStateBackend keyedStateBackend; + private final int localKeyGroupRangeStartIdx; + + // stateName -> namespace -> (valueCoder, value) + private final Map<String, Tuple2<Coder<?>, Map<String, ?>>>[] stateTables; + + public FlinkKeyGroupStateInternals( + Coder<K> keyCoder, + KeyedStateBackend keyedStateBackend) { + this.keyCoder = keyCoder; + this.keyedStateBackend = keyedStateBackend; + this.localKeyGroupRange = keyedStateBackend.getKeyGroupRange(); + // find the starting index of the local key-group range + int startIdx = Integer.MAX_VALUE; + for (Integer keyGroupIdx : localKeyGroupRange) { + startIdx = Math.min(keyGroupIdx, startIdx); + } + this.localKeyGroupRangeStartIdx = startIdx; + stateTables = (Map<String, Tuple2<Coder<?>, Map<String, ?>>>[]) + new Map[localKeyGroupRange.getNumberOfKeyGroups()]; + for (int i = 0; i < stateTables.length; i++) { + stateTables[i] = new HashMap<>(); + } + } + + @Override + public K getKey() { + ByteBuffer keyBytes = (ByteBuffer) keyedStateBackend.getCurrentKey(); + try { + return CoderUtils.decodeFromByteArray(keyCoder, keyBytes.array()); + } catch (CoderException e) { + throw new RuntimeException("Error decoding key.", e); + } + } + + @Override + public <T extends State> T state( + final StateNamespace namespace, + StateTag<? super K, T> address) { + + return state(namespace, address, StateContexts.nullContext()); + } + + @Override + public <T extends State> T state( + final StateNamespace namespace, + StateTag<? super K, T> address, + final StateContext<?> context) { + + return address.bind(new StateTag.StateBinder<K>() { + + @Override + public <T> ValueState<T> bindValue( + StateTag<? super K, ValueState<T>> address, + Coder<T> coder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", ValueState.class.getSimpleName())); + } + + @Override + public <T> BagState<T> bindBag( + StateTag<? super K, BagState<T>> address, + Coder<T> elemCoder) { + + return new FlinkKeyGroupBagState<>(address, namespace, elemCoder); + } + + @Override + public <T> SetState<T> bindSet( + StateTag<? super K, SetState<T>> address, + Coder<T> elemCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", SetState.class.getSimpleName())); + } + + @Override + public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( + StateTag<? super K, MapState<KeyT, ValueT>> spec, + Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", MapState.class.getSimpleName())); + } + + @Override + public <InputT, AccumT, OutputT> + CombiningState<InputT, AccumT, OutputT> + bindCombiningValue( + StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + Combine.CombineFn<InputT, AccumT, OutputT> combineFn) { + throw new UnsupportedOperationException("bindCombiningValue is not supported."); + } + + @Override + public <InputT, AccumT, OutputT> + CombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValue( + StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + final Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) { + throw new UnsupportedOperationException("bindKeyedCombiningValue is not supported."); + + } + + @Override + public <InputT, AccumT, OutputT> + CombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValueWithContext( + StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + CombineWithContext.KeyedCombineFnWithContext< + ? super K, InputT, AccumT, OutputT> combineFn) { + throw new UnsupportedOperationException( + "bindKeyedCombiningValueWithContext is not supported."); + } + + @Override + public <W extends BoundedWindow> WatermarkHoldState<W> bindWatermark( + StateTag<? super K, WatermarkHoldState<W>> address, + OutputTimeFn<? super W> outputTimeFn) { + throw new UnsupportedOperationException( + String.format("%s is not supported", CombiningState.class.getSimpleName())); + } + }); + } + + /** + * Reference from {@link Combine.CombineFn}. + * + * <p>Accumulators are stored in each KeyGroup, call addInput() when a element comes, + * call extractOutput() to produce the desired value when need to read data. + */ + interface KeyGroupCombiner<InputT, AccumT, OutputT> { + + /** + * Returns a new, mutable accumulator value, representing the accumulation + * of zero input values. + */ + AccumT createAccumulator(); + + /** + * Adds the given input value to the given accumulator, returning the + * new accumulator value. + */ + AccumT addInput(AccumT accumulator, InputT input); + + /** + * Returns the output value that is the result of all accumulators from KeyGroups + * that are assigned to this operator. + */ + OutputT extractOutput(Iterable<AccumT> accumulators); + } + + private abstract class AbstractKeyGroupState<InputT, AccumT, OutputT> { + + private String stateName; + private String namespace; + private Coder<AccumT> coder; + private KeyGroupCombiner<InputT, AccumT, OutputT> keyGroupCombiner; + + AbstractKeyGroupState( + String stateName, + String namespace, + Coder<AccumT> coder, + KeyGroupCombiner<InputT, AccumT, OutputT> keyGroupCombiner) { + this.stateName = stateName; + this.namespace = namespace; + this.coder = coder; + this.keyGroupCombiner = keyGroupCombiner; + } + + /** + * Choose keyGroup of input and addInput to accumulator. + */ + void addInput(InputT input) { + int keyGroupIdx = keyedStateBackend.getCurrentKeyGroupIndex(); + int localIdx = getIndexForKeyGroup(keyGroupIdx); + Map<String, Tuple2<Coder<?>, Map<String, ?>>> stateTable = stateTables[localIdx]; + Tuple2<Coder<?>, Map<String, ?>> tuple2 = stateTable.get(stateName); + if (tuple2 == null) { + tuple2 = new Tuple2<>(); + tuple2.f0 = coder; + tuple2.f1 = new HashMap<>(); + stateTable.put(stateName, tuple2); + } + Map<String, AccumT> map = (Map<String, AccumT>) tuple2.f1; + AccumT accumulator = map.get(namespace); + if (accumulator == null) { + accumulator = keyGroupCombiner.createAccumulator(); + } + accumulator = keyGroupCombiner.addInput(accumulator, input); + map.put(namespace, accumulator); + } + + /** + * Get all accumulators and invoke extractOutput(). + */ + OutputT extractOutput() { + List<AccumT> accumulators = new ArrayList<>(stateTables.length); + for (Map<String, Tuple2<Coder<?>, Map<String, ?>>> stateTable : stateTables) { + Tuple2<Coder<?>, Map<String, ?>> tuple2 = stateTable.get(stateName); + if (tuple2 != null) { + AccumT accumulator = (AccumT) tuple2.f1.get(namespace); + if (accumulator != null) { + accumulators.add(accumulator); + } + } + } + return keyGroupCombiner.extractOutput(accumulators); + } + + /** + * Find the first accumulator and return immediately. + */ + boolean isEmptyInternal() { + for (Map<String, Tuple2<Coder<?>, Map<String, ?>>> stateTable : stateTables) { + Tuple2<Coder<?>, Map<String, ?>> tuple2 = stateTable.get(stateName); + if (tuple2 != null) { + AccumT accumulator = (AccumT) tuple2.f1.get(namespace); + if (accumulator != null) { + return false; + } + } + } + return true; + } + + /** + * Clear accumulators and clean empty map. + */ + void clearInternal() { + for (Map<String, Tuple2<Coder<?>, Map<String, ?>>> stateTable : stateTables) { + Tuple2<Coder<?>, Map<String, ?>> tuple2 = stateTable.get(stateName); + if (tuple2 != null) { + tuple2.f1.remove(namespace); + if (tuple2.f1.size() == 0) { + stateTable.remove(stateName); + } + } + } + } + + } + + private int getIndexForKeyGroup(int keyGroupIdx) { + checkArgument(localKeyGroupRange.contains(keyGroupIdx), + "Key Group " + keyGroupIdx + " does not belong to the local range."); + return keyGroupIdx - this.localKeyGroupRangeStartIdx; + } + + private class KeyGroupBagCombiner<T> implements KeyGroupCombiner<T, List<T>, Iterable<T>> { + + @Override + public List<T> createAccumulator() { + return new ArrayList<>(); + } + + @Override + public List<T> addInput(List<T> accumulator, T input) { + accumulator.add(input); + return accumulator; + } + + @Override + public Iterable<T> extractOutput(Iterable<List<T>> accumulators) { + List<T> result = new ArrayList<>(); + // maybe can return an unmodifiable view. + for (List<T> list : accumulators) { + result.addAll(list); + } + return result; + } + } + + private class FlinkKeyGroupBagState<T> extends AbstractKeyGroupState<T, List<T>, Iterable<T>> + implements BagState<T> { + + private final StateNamespace namespace; + private final StateTag<? super K, BagState<T>> address; + + FlinkKeyGroupBagState( + StateTag<? super K, BagState<T>> address, + StateNamespace namespace, + Coder<T> coder) { + super(address.getId(), namespace.stringKey(), ListCoder.of(coder), + new KeyGroupBagCombiner<T>()); + this.namespace = namespace; + this.address = address; + } + + @Override + public void add(T input) { + addInput(input); + } + + @Override + public BagState<T> readLater() { + return this; + } + + @Override + public Iterable<T> read() { + Iterable<T> result = extractOutput(); + return result != null ? result : Collections.<T>emptyList(); + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public Boolean read() { + try { + return isEmptyInternal(); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + + } + + @Override + public ReadableState<Boolean> readLater() { + return this; + } + }; + } + + @Override + public void clear() { + clearInternal(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkKeyGroupBagState<?> that = (FlinkKeyGroupBagState<?>) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + } + + /** + * Snapshots the state {@code (stateName -> (valueCoder && (namespace -> value)))} for a given + * {@code keyGroupIdx}. + * + * @param keyGroupIdx the id of the key-group to be put in the snapshot. + * @param out the stream to write to. + */ + public void snapshotKeyGroupState(int keyGroupIdx, DataOutputStream out) throws Exception { + int localIdx = getIndexForKeyGroup(keyGroupIdx); + Map<String, Tuple2<Coder<?>, Map<String, ?>>> stateTable = stateTables[localIdx]; + Preconditions.checkState(stateTable.size() <= Short.MAX_VALUE, + "Too many States: " + stateTable.size() + ". Currently at most " + + Short.MAX_VALUE + " states are supported"); + out.writeShort(stateTable.size()); + for (Map.Entry<String, Tuple2<Coder<?>, Map<String, ?>>> entry : stateTable.entrySet()) { + out.writeUTF(entry.getKey()); + Coder coder = entry.getValue().f0; + InstantiationUtil.serializeObject(out, coder); + Map<String, ?> map = entry.getValue().f1; + out.writeInt(map.size()); + for (Map.Entry<String, ?> entry1 : map.entrySet()) { + StringUtf8Coder.of().encode(entry1.getKey(), out, Context.NESTED); + coder.encode(entry1.getValue(), out, Context.NESTED); + } + } + } + + /** + * Restore the state {@code (stateName -> (valueCoder && (namespace -> value)))} + * for a given {@code keyGroupIdx}. + * + * @param keyGroupIdx the id of the key-group to be put in the snapshot. + * @param in the stream to read from. + * @param userCodeClassLoader the class loader that will be used to deserialize + * the valueCoder. + */ + public void restoreKeyGroupState(int keyGroupIdx, DataInputStream in, + ClassLoader userCodeClassLoader) throws Exception { + int localIdx = getIndexForKeyGroup(keyGroupIdx); + Map<String, Tuple2<Coder<?>, Map<String, ?>>> stateTable = stateTables[localIdx]; + int numStates = in.readShort(); + for (int i = 0; i < numStates; ++i) { + String stateName = in.readUTF(); + Coder coder = InstantiationUtil.deserializeObject(in, userCodeClassLoader); + Tuple2<Coder<?>, Map<String, ?>> tuple2 = stateTable.get(stateName); + if (tuple2 == null) { + tuple2 = new Tuple2<>(); + tuple2.f0 = coder; + tuple2.f1 = new HashMap<>(); + stateTable.put(stateName, tuple2); + } + Map<String, Object> map = (Map<String, Object>) tuple2.f1; + int mapSize = in.readInt(); + for (int j = 0; j < mapSize; j++) { + String namespace = StringUtf8Coder.of().decode(in, Context.NESTED); + Object value = coder.decode(in, Context.NESTED); + map.put(namespace, value); + } + } + } + +}