http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8852eb15/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/io/UnboundedSocketSource.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/io/UnboundedSocketSource.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/io/UnboundedSocketSource.java index dd14f68..0b1a5da 100644 --- a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/io/UnboundedSocketSource.java +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/io/UnboundedSocketSource.java @@ -41,191 +41,191 @@ import static com.google.common.base.Preconditions.checkArgument; * */ public class UnboundedSocketSource<C extends UnboundedSource.CheckpointMark> extends UnboundedSource<String, C> { - private static final Coder<String> DEFAULT_SOCKET_CODER = StringUtf8Coder.of(); + private static final Coder<String> DEFAULT_SOCKET_CODER = StringUtf8Coder.of(); - private static final long serialVersionUID = 1L; - - private static final int DEFAULT_CONNECTION_RETRY_SLEEP = 500; - - private static final int CONNECTION_TIMEOUT_TIME = 0; - - private final String hostname; - private final int port; - private final char delimiter; - private final long maxNumRetries; - private final long delayBetweenRetries; - - public UnboundedSocketSource(String hostname, int port, char delimiter, long maxNumRetries) { - this(hostname, port, delimiter, maxNumRetries, DEFAULT_CONNECTION_RETRY_SLEEP); - } - - public UnboundedSocketSource(String hostname, int port, char delimiter, long maxNumRetries, long delayBetweenRetries) { - this.hostname = hostname; - this.port = port; - this.delimiter = delimiter; - this.maxNumRetries = maxNumRetries; - this.delayBetweenRetries = delayBetweenRetries; - } - - public String getHostname() { - return this.hostname; - } - - public int getPort() { - return this.port; - } - - public char getDelimiter() { - return this.delimiter; - } - - public long getMaxNumRetries() { - return this.maxNumRetries; - } - - public long getDelayBetweenRetries() { - return this.delayBetweenRetries; - } - - @Override - public List<? extends UnboundedSource<String, C>> generateInitialSplits(int desiredNumSplits, PipelineOptions options) throws Exception { - return Collections.<UnboundedSource<String, C>>singletonList(this); - } - - @Override - public UnboundedReader<String> createReader(PipelineOptions options, @Nullable C checkpointMark) { - return new UnboundedSocketReader(this); - } - - @Nullable - @Override - public Coder getCheckpointMarkCoder() { - // Flink and Dataflow have different checkpointing mechanisms. - // In our case we do not need a coder. - return null; - } - - @Override - public void validate() { - checkArgument(port > 0 && port < 65536, "port is out of range"); - checkArgument(maxNumRetries >= -1, "maxNumRetries must be zero or larger (num retries), or -1 (infinite retries)"); - checkArgument(delayBetweenRetries >= 0, "delayBetweenRetries must be zero or positive"); - } - - @Override - public Coder getDefaultOutputCoder() { - return DEFAULT_SOCKET_CODER; - } - - public static class UnboundedSocketReader extends UnboundedSource.UnboundedReader<String> implements Serializable { - - private static final long serialVersionUID = 7526472295622776147L; - private static final Logger LOG = LoggerFactory.getLogger(UnboundedSocketReader.class); - - private final UnboundedSocketSource source; - - private Socket socket; - private BufferedReader reader; - - private boolean isRunning; - - private String currentRecord; - - public UnboundedSocketReader(UnboundedSocketSource source) { - this.source = source; - } - - private void openConnection() throws IOException { - this.socket = new Socket(); - this.socket.connect(new InetSocketAddress(this.source.getHostname(), this.source.getPort()), CONNECTION_TIMEOUT_TIME); - this.reader = new BufferedReader(new InputStreamReader(this.socket.getInputStream())); - this.isRunning = true; - } - - @Override - public boolean start() throws IOException { - int attempt = 0; - while (!isRunning) { - try { - openConnection(); - LOG.info("Connected to server socket " + this.source.getHostname() + ':' + this.source.getPort()); - - return advance(); - } catch (IOException e) { - LOG.info("Lost connection to server socket " + this.source.getHostname() + ':' + this.source.getPort() + ". Retrying in " + this.source.getDelayBetweenRetries() + " msecs..."); - - if (this.source.getMaxNumRetries() == -1 || attempt++ < this.source.getMaxNumRetries()) { - try { - Thread.sleep(this.source.getDelayBetweenRetries()); - } catch (InterruptedException e1) { - e1.printStackTrace(); - } - } else { - this.isRunning = false; - break; - } - } - } - LOG.error("Unable to connect to host " + this.source.getHostname() + " : " + this.source.getPort()); - return false; - } - - @Override - public boolean advance() throws IOException { - final StringBuilder buffer = new StringBuilder(); - int data; - while (isRunning && (data = reader.read()) != -1) { - // check if the string is complete - if (data != this.source.getDelimiter()) { - buffer.append((char) data); - } else { - if (buffer.length() > 0 && buffer.charAt(buffer.length() - 1) == '\r') { - buffer.setLength(buffer.length() - 1); - } - this.currentRecord = buffer.toString(); - buffer.setLength(0); - return true; - } - } - return false; - } - - @Override - public byte[] getCurrentRecordId() throws NoSuchElementException { - return new byte[0]; - } - - @Override - public String getCurrent() throws NoSuchElementException { - return this.currentRecord; - } - - @Override - public Instant getCurrentTimestamp() throws NoSuchElementException { - return Instant.now(); - } - - @Override - public void close() throws IOException { - this.reader.close(); - this.socket.close(); - this.isRunning = false; - LOG.info("Closed connection to server socket at " + this.source.getHostname() + ":" + this.source.getPort() + "."); - } - - @Override - public Instant getWatermark() { - return Instant.now(); - } - - @Override - public CheckpointMark getCheckpointMark() { - return null; - } - - @Override - public UnboundedSource<String, ?> getCurrentSource() { - return this.source; - } - } + private static final long serialVersionUID = 1L; + + private static final int DEFAULT_CONNECTION_RETRY_SLEEP = 500; + + private static final int CONNECTION_TIMEOUT_TIME = 0; + + private final String hostname; + private final int port; + private final char delimiter; + private final long maxNumRetries; + private final long delayBetweenRetries; + + public UnboundedSocketSource(String hostname, int port, char delimiter, long maxNumRetries) { + this(hostname, port, delimiter, maxNumRetries, DEFAULT_CONNECTION_RETRY_SLEEP); + } + + public UnboundedSocketSource(String hostname, int port, char delimiter, long maxNumRetries, long delayBetweenRetries) { + this.hostname = hostname; + this.port = port; + this.delimiter = delimiter; + this.maxNumRetries = maxNumRetries; + this.delayBetweenRetries = delayBetweenRetries; + } + + public String getHostname() { + return this.hostname; + } + + public int getPort() { + return this.port; + } + + public char getDelimiter() { + return this.delimiter; + } + + public long getMaxNumRetries() { + return this.maxNumRetries; + } + + public long getDelayBetweenRetries() { + return this.delayBetweenRetries; + } + + @Override + public List<? extends UnboundedSource<String, C>> generateInitialSplits(int desiredNumSplits, PipelineOptions options) throws Exception { + return Collections.<UnboundedSource<String, C>>singletonList(this); + } + + @Override + public UnboundedReader<String> createReader(PipelineOptions options, @Nullable C checkpointMark) { + return new UnboundedSocketReader(this); + } + + @Nullable + @Override + public Coder getCheckpointMarkCoder() { + // Flink and Dataflow have different checkpointing mechanisms. + // In our case we do not need a coder. + return null; + } + + @Override + public void validate() { + checkArgument(port > 0 && port < 65536, "port is out of range"); + checkArgument(maxNumRetries >= -1, "maxNumRetries must be zero or larger (num retries), or -1 (infinite retries)"); + checkArgument(delayBetweenRetries >= 0, "delayBetweenRetries must be zero or positive"); + } + + @Override + public Coder getDefaultOutputCoder() { + return DEFAULT_SOCKET_CODER; + } + + public static class UnboundedSocketReader extends UnboundedSource.UnboundedReader<String> implements Serializable { + + private static final long serialVersionUID = 7526472295622776147L; + private static final Logger LOG = LoggerFactory.getLogger(UnboundedSocketReader.class); + + private final UnboundedSocketSource source; + + private Socket socket; + private BufferedReader reader; + + private boolean isRunning; + + private String currentRecord; + + public UnboundedSocketReader(UnboundedSocketSource source) { + this.source = source; + } + + private void openConnection() throws IOException { + this.socket = new Socket(); + this.socket.connect(new InetSocketAddress(this.source.getHostname(), this.source.getPort()), CONNECTION_TIMEOUT_TIME); + this.reader = new BufferedReader(new InputStreamReader(this.socket.getInputStream())); + this.isRunning = true; + } + + @Override + public boolean start() throws IOException { + int attempt = 0; + while (!isRunning) { + try { + openConnection(); + LOG.info("Connected to server socket " + this.source.getHostname() + ':' + this.source.getPort()); + + return advance(); + } catch (IOException e) { + LOG.info("Lost connection to server socket " + this.source.getHostname() + ':' + this.source.getPort() + ". Retrying in " + this.source.getDelayBetweenRetries() + " msecs..."); + + if (this.source.getMaxNumRetries() == -1 || attempt++ < this.source.getMaxNumRetries()) { + try { + Thread.sleep(this.source.getDelayBetweenRetries()); + } catch (InterruptedException e1) { + e1.printStackTrace(); + } + } else { + this.isRunning = false; + break; + } + } + } + LOG.error("Unable to connect to host " + this.source.getHostname() + " : " + this.source.getPort()); + return false; + } + + @Override + public boolean advance() throws IOException { + final StringBuilder buffer = new StringBuilder(); + int data; + while (isRunning && (data = reader.read()) != -1) { + // check if the string is complete + if (data != this.source.getDelimiter()) { + buffer.append((char) data); + } else { + if (buffer.length() > 0 && buffer.charAt(buffer.length() - 1) == '\r') { + buffer.setLength(buffer.length() - 1); + } + this.currentRecord = buffer.toString(); + buffer.setLength(0); + return true; + } + } + return false; + } + + @Override + public byte[] getCurrentRecordId() throws NoSuchElementException { + return new byte[0]; + } + + @Override + public String getCurrent() throws NoSuchElementException { + return this.currentRecord; + } + + @Override + public Instant getCurrentTimestamp() throws NoSuchElementException { + return Instant.now(); + } + + @Override + public void close() throws IOException { + this.reader.close(); + this.socket.close(); + this.isRunning = false; + LOG.info("Closed connection to server socket at " + this.source.getHostname() + ":" + this.source.getPort() + "."); + } + + @Override + public Instant getWatermark() { + return Instant.now(); + } + + @Override + public CheckpointMark getCheckpointMark() { + return null; + } + + @Override + public UnboundedSource<String, ?> getCurrentSource() { + return this.source; + } + } }
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8852eb15/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/io/UnboundedSourceWrapper.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/io/UnboundedSourceWrapper.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/io/UnboundedSourceWrapper.java index e065f87..5a89894 100644 --- a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/io/UnboundedSourceWrapper.java +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/io/UnboundedSourceWrapper.java @@ -38,95 +38,95 @@ import org.joda.time.Instant; * */ public class UnboundedSourceWrapper<T> extends RichSourceFunction<WindowedValue<T>> implements Triggerable { - private final String name; - private final UnboundedSource.UnboundedReader<T> reader; - - private StreamingRuntimeContext runtime = null; - private StreamSource.ManualWatermarkContext<WindowedValue<T>> context = null; - - private volatile boolean isRunning = false; - - public UnboundedSourceWrapper(PipelineOptions options, Read.Unbounded<T> transform) { - this.name = transform.getName(); - this.reader = transform.getSource().createReader(options, null); - } - - public String getName() { - return this.name; - } - - WindowedValue<T> makeWindowedValue(T output, Instant timestamp) { - if (timestamp == null) { - timestamp = BoundedWindow.TIMESTAMP_MIN_VALUE; - } - return WindowedValue.of(output, timestamp, GlobalWindow.INSTANCE, PaneInfo.NO_FIRING); - } - - @Override - public void run(SourceContext<WindowedValue<T>> ctx) throws Exception { - if (!(ctx instanceof StreamSource.ManualWatermarkContext)) { - throw new RuntimeException("We assume that all sources in Dataflow are EventTimeSourceFunction. " + - "Apparently " + this.name + " is not. Probably you should consider writing your own Wrapper for this source."); - } - - context = (StreamSource.ManualWatermarkContext<WindowedValue<T>>) ctx; - runtime = (StreamingRuntimeContext) getRuntimeContext(); - - this.isRunning = true; - boolean inputAvailable = reader.start(); - - setNextWatermarkTimer(this.runtime); - - while (isRunning) { - - while (!inputAvailable && isRunning) { - // wait a bit until we retry to pull more records - Thread.sleep(50); - inputAvailable = reader.advance(); - } - - if (inputAvailable) { - - // get it and its timestamp from the source - T item = reader.getCurrent(); - Instant timestamp = reader.getCurrentTimestamp(); - - // write it to the output collector - synchronized (ctx.getCheckpointLock()) { - context.collectWithTimestamp(makeWindowedValue(item, timestamp), timestamp.getMillis()); - } - - inputAvailable = reader.advance(); - } - - } - } - - @Override - public void cancel() { - isRunning = false; - } - - @Override - public void trigger(long timestamp) throws Exception { - if (this.isRunning) { - synchronized (context.getCheckpointLock()) { - long watermarkMillis = this.reader.getWatermark().getMillis(); - context.emitWatermark(new Watermark(watermarkMillis)); - } - setNextWatermarkTimer(this.runtime); - } - } - - private void setNextWatermarkTimer(StreamingRuntimeContext runtime) { - if (this.isRunning) { - long watermarkInterval = runtime.getExecutionConfig().getAutoWatermarkInterval(); - long timeToNextWatermark = getTimeToNextWaternark(watermarkInterval); - runtime.registerTimer(timeToNextWatermark, this); - } - } - - private long getTimeToNextWaternark(long watermarkInterval) { - return System.currentTimeMillis() + watermarkInterval; - } + private final String name; + private final UnboundedSource.UnboundedReader<T> reader; + + private StreamingRuntimeContext runtime = null; + private StreamSource.ManualWatermarkContext<WindowedValue<T>> context = null; + + private volatile boolean isRunning = false; + + public UnboundedSourceWrapper(PipelineOptions options, Read.Unbounded<T> transform) { + this.name = transform.getName(); + this.reader = transform.getSource().createReader(options, null); + } + + public String getName() { + return this.name; + } + + WindowedValue<T> makeWindowedValue(T output, Instant timestamp) { + if (timestamp == null) { + timestamp = BoundedWindow.TIMESTAMP_MIN_VALUE; + } + return WindowedValue.of(output, timestamp, GlobalWindow.INSTANCE, PaneInfo.NO_FIRING); + } + + @Override + public void run(SourceContext<WindowedValue<T>> ctx) throws Exception { + if (!(ctx instanceof StreamSource.ManualWatermarkContext)) { + throw new RuntimeException("We assume that all sources in Dataflow are EventTimeSourceFunction. " + + "Apparently " + this.name + " is not. Probably you should consider writing your own Wrapper for this source."); + } + + context = (StreamSource.ManualWatermarkContext<WindowedValue<T>>) ctx; + runtime = (StreamingRuntimeContext) getRuntimeContext(); + + this.isRunning = true; + boolean inputAvailable = reader.start(); + + setNextWatermarkTimer(this.runtime); + + while (isRunning) { + + while (!inputAvailable && isRunning) { + // wait a bit until we retry to pull more records + Thread.sleep(50); + inputAvailable = reader.advance(); + } + + if (inputAvailable) { + + // get it and its timestamp from the source + T item = reader.getCurrent(); + Instant timestamp = reader.getCurrentTimestamp(); + + // write it to the output collector + synchronized (ctx.getCheckpointLock()) { + context.collectWithTimestamp(makeWindowedValue(item, timestamp), timestamp.getMillis()); + } + + inputAvailable = reader.advance(); + } + + } + } + + @Override + public void cancel() { + isRunning = false; + } + + @Override + public void trigger(long timestamp) throws Exception { + if (this.isRunning) { + synchronized (context.getCheckpointLock()) { + long watermarkMillis = this.reader.getWatermark().getMillis(); + context.emitWatermark(new Watermark(watermarkMillis)); + } + setNextWatermarkTimer(this.runtime); + } + } + + private void setNextWatermarkTimer(StreamingRuntimeContext runtime) { + if (this.isRunning) { + long watermarkInterval = runtime.getExecutionConfig().getAutoWatermarkInterval(); + long timeToNextWatermark = getTimeToNextWaternark(watermarkInterval); + runtime.registerTimer(timeToNextWatermark, this); + } + } + + private long getTimeToNextWaternark(long watermarkInterval) { + return System.currentTimeMillis() + watermarkInterval; + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8852eb15/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/AbstractFlinkTimerInternals.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/AbstractFlinkTimerInternals.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/AbstractFlinkTimerInternals.java index 84a322f..75c8ac6 100644 --- a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/AbstractFlinkTimerInternals.java +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/AbstractFlinkTimerInternals.java @@ -34,93 +34,93 @@ import java.io.Serializable; * The latter is used when snapshots of the current state are taken, for fault-tolerance. * */ public abstract class AbstractFlinkTimerInternals<K, VIN> implements TimerInternals, Serializable { - private Instant currentInputWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; - private Instant currentOutputWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; - - public void setCurrentInputWatermark(Instant watermark) { - checkIfValidInputWatermark(watermark); - this.currentInputWatermark = watermark; - } - - public void setCurrentOutputWatermark(Instant watermark) { - checkIfValidOutputWatermark(watermark); - this.currentOutputWatermark = watermark; - } - - private void setCurrentInputWatermarkAfterRecovery(Instant watermark) { - if (!currentInputWatermark.isEqual(BoundedWindow.TIMESTAMP_MIN_VALUE)) { - throw new RuntimeException("Explicitly setting the input watermark is only allowed on " + - "initialization after recovery from a node failure. Apparently this is not " + - "the case here as the watermark is already set."); - } - this.currentInputWatermark = watermark; - } - - private void setCurrentOutputWatermarkAfterRecovery(Instant watermark) { - if (!currentOutputWatermark.isEqual(BoundedWindow.TIMESTAMP_MIN_VALUE)) { - throw new RuntimeException("Explicitly setting the output watermark is only allowed on " + - "initialization after recovery from a node failure. Apparently this is not " + - "the case here as the watermark is already set."); - } - this.currentOutputWatermark = watermark; - } - - @Override - public Instant currentProcessingTime() { - return Instant.now(); - } - - @Override - public Instant currentInputWatermarkTime() { - return currentInputWatermark; - } - - @Nullable - @Override - public Instant currentSynchronizedProcessingTime() { - // TODO - return null; - } - - @Override - public Instant currentOutputWatermarkTime() { - return currentOutputWatermark; - } - - private void checkIfValidInputWatermark(Instant newWatermark) { - if (currentInputWatermark.isAfter(newWatermark)) { - throw new IllegalArgumentException(String.format( - "Cannot set current input watermark to %s. Newer watermarks " + - "must be no earlier than the current one (%s).", - newWatermark, currentInputWatermark)); - } - } - - private void checkIfValidOutputWatermark(Instant newWatermark) { - if (currentOutputWatermark.isAfter(newWatermark)) { - throw new IllegalArgumentException(String.format( - "Cannot set current output watermark to %s. Newer watermarks " + - "must be no earlier than the current one (%s).", - newWatermark, currentOutputWatermark)); - } - } - - public void encodeTimerInternals(DoFn.ProcessContext context, - StateCheckpointWriter writer, - KvCoder<K, VIN> kvCoder, - Coder<? extends BoundedWindow> windowCoder) throws IOException { - if (context == null) { - throw new RuntimeException("The Context has not been initialized."); - } - - writer.setTimestamp(currentInputWatermark); - writer.setTimestamp(currentOutputWatermark); - } - - public void restoreTimerInternals(StateCheckpointReader reader, - KvCoder<K, VIN> kvCoder, - Coder<? extends BoundedWindow> windowCoder) throws IOException { - setCurrentInputWatermarkAfterRecovery(reader.getTimestamp()); - setCurrentOutputWatermarkAfterRecovery(reader.getTimestamp()); - } + private Instant currentInputWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; + private Instant currentOutputWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; + + public void setCurrentInputWatermark(Instant watermark) { + checkIfValidInputWatermark(watermark); + this.currentInputWatermark = watermark; + } + + public void setCurrentOutputWatermark(Instant watermark) { + checkIfValidOutputWatermark(watermark); + this.currentOutputWatermark = watermark; + } + + private void setCurrentInputWatermarkAfterRecovery(Instant watermark) { + if (!currentInputWatermark.isEqual(BoundedWindow.TIMESTAMP_MIN_VALUE)) { + throw new RuntimeException("Explicitly setting the input watermark is only allowed on " + + "initialization after recovery from a node failure. Apparently this is not " + + "the case here as the watermark is already set."); + } + this.currentInputWatermark = watermark; + } + + private void setCurrentOutputWatermarkAfterRecovery(Instant watermark) { + if (!currentOutputWatermark.isEqual(BoundedWindow.TIMESTAMP_MIN_VALUE)) { + throw new RuntimeException("Explicitly setting the output watermark is only allowed on " + + "initialization after recovery from a node failure. Apparently this is not " + + "the case here as the watermark is already set."); + } + this.currentOutputWatermark = watermark; + } + + @Override + public Instant currentProcessingTime() { + return Instant.now(); + } + + @Override + public Instant currentInputWatermarkTime() { + return currentInputWatermark; + } + + @Nullable + @Override + public Instant currentSynchronizedProcessingTime() { + // TODO + return null; + } + + @Override + public Instant currentOutputWatermarkTime() { + return currentOutputWatermark; + } + + private void checkIfValidInputWatermark(Instant newWatermark) { + if (currentInputWatermark.isAfter(newWatermark)) { + throw new IllegalArgumentException(String.format( + "Cannot set current input watermark to %s. Newer watermarks " + + "must be no earlier than the current one (%s).", + newWatermark, currentInputWatermark)); + } + } + + private void checkIfValidOutputWatermark(Instant newWatermark) { + if (currentOutputWatermark.isAfter(newWatermark)) { + throw new IllegalArgumentException(String.format( + "Cannot set current output watermark to %s. Newer watermarks " + + "must be no earlier than the current one (%s).", + newWatermark, currentOutputWatermark)); + } + } + + public void encodeTimerInternals(DoFn.ProcessContext context, + StateCheckpointWriter writer, + KvCoder<K, VIN> kvCoder, + Coder<? extends BoundedWindow> windowCoder) throws IOException { + if (context == null) { + throw new RuntimeException("The Context has not been initialized."); + } + + writer.setTimestamp(currentInputWatermark); + writer.setTimestamp(currentOutputWatermark); + } + + public void restoreTimerInternals(StateCheckpointReader reader, + KvCoder<K, VIN> kvCoder, + Coder<? extends BoundedWindow> windowCoder) throws IOException { + setCurrentInputWatermarkAfterRecovery(reader.getTimestamp()); + setCurrentOutputWatermarkAfterRecovery(reader.getTimestamp()); + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8852eb15/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/FlinkStateInternals.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/FlinkStateInternals.java index 41ab5f0..39fec14 100644 --- a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/FlinkStateInternals.java +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -41,673 +41,673 @@ import java.util.*; */ public class FlinkStateInternals<K> implements StateInternals<K> { - private final K key; - - private final Coder<K> keyCoder; - - private final Coder<? extends BoundedWindow> windowCoder; - - private final OutputTimeFn<? super BoundedWindow> outputTimeFn; - - private Instant watermarkHoldAccessor; - - public FlinkStateInternals(K key, - Coder<K> keyCoder, - Coder<? extends BoundedWindow> windowCoder, - OutputTimeFn<? super BoundedWindow> outputTimeFn) { - this.key = key; - this.keyCoder = keyCoder; - this.windowCoder = windowCoder; - this.outputTimeFn = outputTimeFn; - } - - public Instant getWatermarkHold() { - return watermarkHoldAccessor; - } - - /** - * This is the interface state has to implement in order for it to be fault tolerant when - * executed by the FlinkPipelineRunner. - */ - private interface CheckpointableIF { - - boolean shouldPersist(); - - void persistState(StateCheckpointWriter checkpointBuilder) throws IOException; - } - - protected final StateTable<K> inMemoryState = new StateTable<K>() { - @Override - protected StateTag.StateBinder binderForNamespace(final StateNamespace namespace, final StateContext<?> c) { - return new StateTag.StateBinder<K>() { - - @Override - public <T> ValueState<T> bindValue(StateTag<? super K, ValueState<T>> address, Coder<T> coder) { - return new FlinkInMemoryValue<>(encodeKey(namespace, address), coder); - } - - @Override - public <T> BagState<T> bindBag(StateTag<? super K, BagState<T>> address, Coder<T> elemCoder) { - return new FlinkInMemoryBag<>(encodeKey(namespace, address), elemCoder); - } - - @Override - public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> bindCombiningValue( - StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, - Coder<AccumT> accumCoder, Combine.CombineFn<InputT, AccumT, OutputT> combineFn) { - return new FlinkInMemoryKeyedCombiningValue<>(encodeKey(namespace, address), combineFn, accumCoder, c); - } - - @Override - public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValue( - StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, - Coder<AccumT> accumCoder, - Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) { - return new FlinkInMemoryKeyedCombiningValue<>(encodeKey(namespace, address), combineFn, accumCoder, c); - } - - @Override - public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValueWithContext( - StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, - Coder<AccumT> accumCoder, - CombineWithContext.KeyedCombineFnWithContext<? super K, InputT, AccumT, OutputT> combineFn) { - return new FlinkInMemoryKeyedCombiningValue<>(encodeKey(namespace, address), combineFn, accumCoder, c); - } - - @Override - public <W extends BoundedWindow> WatermarkHoldState<W> bindWatermark(StateTag<? super K, WatermarkHoldState<W>> address, OutputTimeFn<? super W> outputTimeFn) { - return new FlinkWatermarkHoldStateImpl<>(encodeKey(namespace, address), outputTimeFn); - } - }; - } - }; - - @Override - public K getKey() { - return key; - } - - @Override - public <StateT extends State> StateT state(StateNamespace namespace, StateTag<? super K, StateT> address) { - return inMemoryState.get(namespace, address, null); - } - - @Override - public <T extends State> T state(StateNamespace namespace, StateTag<? super K, T> address, StateContext<?> c) { - return inMemoryState.get(namespace, address, c); - } - - public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { - checkpointBuilder.writeInt(getNoOfElements()); - - for (State location : inMemoryState.values()) { - if (!(location instanceof CheckpointableIF)) { - throw new IllegalStateException(String.format( - "%s wasn't created by %s -- unable to persist it", - location.getClass().getSimpleName(), - getClass().getSimpleName())); - } - ((CheckpointableIF) location).persistState(checkpointBuilder); - } - } - - public void restoreState(StateCheckpointReader checkpointReader, ClassLoader loader) - throws IOException, ClassNotFoundException { - - // the number of elements to read. - int noOfElements = checkpointReader.getInt(); - for (int i = 0; i < noOfElements; i++) { - decodeState(checkpointReader, loader); - } - } - - /** - * We remove the first character which encodes the type of the stateTag ('s' for system - * and 'u' for user). For more details check out the source of - * {@link StateTags.StateTagBase#getId()}. - */ - private void decodeState(StateCheckpointReader reader, ClassLoader loader) - throws IOException, ClassNotFoundException { - - StateType stateItemType = StateType.deserialize(reader); - ByteString stateKey = reader.getTag(); - - // first decode the namespace and the tagId... - String[] namespaceAndTag = stateKey.toStringUtf8().split("\\+"); - if (namespaceAndTag.length != 2) { - throw new IllegalArgumentException("Invalid stateKey " + stateKey.toString() + "."); - } - StateNamespace namespace = StateNamespaces.fromString(namespaceAndTag[0], windowCoder); - - // ... decide if it is a system or user stateTag... - char ownerTag = namespaceAndTag[1].charAt(0); - if (ownerTag != 's' && ownerTag != 'u') { - throw new RuntimeException("Invalid StateTag name."); - } - boolean isSystemTag = ownerTag == 's'; - String tagId = namespaceAndTag[1].substring(1); - - // ...then decode the coder (if there is one)... - Coder<?> coder = null; - switch (stateItemType) { - case VALUE: - case LIST: - case ACCUMULATOR: - ByteString coderBytes = reader.getData(); - coder = InstantiationUtil.deserializeObject(coderBytes.toByteArray(), loader); - break; - case WATERMARK: - break; - } - - // ...then decode the combiner function (if there is one)... - CombineWithContext.KeyedCombineFnWithContext<? super K, ?, ?, ?> combineFn = null; - switch (stateItemType) { - case ACCUMULATOR: - ByteString combinerBytes = reader.getData(); - combineFn = InstantiationUtil.deserializeObject(combinerBytes.toByteArray(), loader); - break; - case VALUE: - case LIST: - case WATERMARK: - break; - } - - //... and finally, depending on the type of the state being decoded, - // 1) create the adequate stateTag, - // 2) create the state container, - // 3) restore the actual content. - switch (stateItemType) { - case VALUE: { - StateTag stateTag = StateTags.value(tagId, coder); - stateTag = isSystemTag ? StateTags.makeSystemTagInternal(stateTag) : stateTag; - @SuppressWarnings("unchecked") - FlinkInMemoryValue<?> value = (FlinkInMemoryValue<?>) inMemoryState.get(namespace, stateTag, null); - value.restoreState(reader); - break; - } - case WATERMARK: { - @SuppressWarnings("unchecked") - StateTag<Object, WatermarkHoldState<BoundedWindow>> stateTag = StateTags.watermarkStateInternal(tagId, outputTimeFn); - stateTag = isSystemTag ? StateTags.makeSystemTagInternal(stateTag) : stateTag; - @SuppressWarnings("unchecked") - FlinkWatermarkHoldStateImpl<?> watermark = (FlinkWatermarkHoldStateImpl<?>) inMemoryState.get(namespace, stateTag, null); - watermark.restoreState(reader); - break; - } - case LIST: { - StateTag stateTag = StateTags.bag(tagId, coder); - stateTag = isSystemTag ? StateTags.makeSystemTagInternal(stateTag) : stateTag; - FlinkInMemoryBag<?> bag = (FlinkInMemoryBag<?>) inMemoryState.get(namespace, stateTag, null); - bag.restoreState(reader); - break; - } - case ACCUMULATOR: { - @SuppressWarnings("unchecked") - StateTag<K, AccumulatorCombiningState<?, ?, ?>> stateTag = StateTags.keyedCombiningValueWithContext(tagId, (Coder) coder, combineFn); - stateTag = isSystemTag ? StateTags.makeSystemTagInternal(stateTag) : stateTag; - @SuppressWarnings("unchecked") - FlinkInMemoryKeyedCombiningValue<?, ?, ?> combiningValue = - (FlinkInMemoryKeyedCombiningValue<?, ?, ?>) inMemoryState.get(namespace, stateTag, null); - combiningValue.restoreState(reader); - break; - } - default: - throw new RuntimeException("Unknown State Type " + stateItemType + "."); - } - } - - private ByteString encodeKey(StateNamespace namespace, StateTag<? super K, ?> address) { - StringBuilder sb = new StringBuilder(); - try { - namespace.appendTo(sb); - sb.append('+'); - address.appendTo(sb); - } catch (IOException e) { - throw new RuntimeException(e); - } - return ByteString.copyFromUtf8(sb.toString()); - } - - private int getNoOfElements() { - int noOfElements = 0; - for (State state : inMemoryState.values()) { - if (!(state instanceof CheckpointableIF)) { - throw new RuntimeException("State Implementations used by the " + - "Flink Dataflow Runner should implement the CheckpointableIF interface."); - } - - if (((CheckpointableIF) state).shouldPersist()) { - noOfElements++; - } - } - return noOfElements; - } - - private final class FlinkInMemoryValue<T> implements ValueState<T>, CheckpointableIF { - - private final ByteString stateKey; - private final Coder<T> elemCoder; - - private T value = null; - - public FlinkInMemoryValue(ByteString stateKey, Coder<T> elemCoder) { - this.stateKey = stateKey; - this.elemCoder = elemCoder; - } - - @Override - public void clear() { - value = null; - } - - @Override - public void write(T input) { - this.value = input; - } - - @Override - public T read() { - return value; - } - - @Override - public ValueState<T> readLater() { - // Ignore - return this; - } - - @Override - public boolean shouldPersist() { - return value != null; - } - - @Override - public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { - if (value != null) { - // serialize the coder. - byte[] coder = InstantiationUtil.serializeObject(elemCoder); - - // encode the value into a ByteString - ByteString.Output stream = ByteString.newOutput(); - elemCoder.encode(value, stream, Coder.Context.OUTER); - ByteString data = stream.toByteString(); - - checkpointBuilder.addValueBuilder() - .setTag(stateKey) - .setData(coder) - .setData(data); - } - } - - public void restoreState(StateCheckpointReader checkpointReader) throws IOException { - ByteString valueContent = checkpointReader.getData(); - T outValue = elemCoder.decode(new ByteArrayInputStream(valueContent.toByteArray()), Coder.Context.OUTER); - write(outValue); - } - } - - private final class FlinkWatermarkHoldStateImpl<W extends BoundedWindow> - implements WatermarkHoldState<W>, CheckpointableIF { - - private final ByteString stateKey; - - private Instant minimumHold = null; - - private OutputTimeFn<? super W> outputTimeFn; - - public FlinkWatermarkHoldStateImpl(ByteString stateKey, OutputTimeFn<? super W> outputTimeFn) { - this.stateKey = stateKey; - this.outputTimeFn = outputTimeFn; - } - - @Override - public void clear() { - // Even though we're clearing we can't remove this from the in-memory state map, since - // other users may already have a handle on this WatermarkBagInternal. - minimumHold = null; - watermarkHoldAccessor = null; - } - - @Override - public void add(Instant watermarkHold) { - if (minimumHold == null || minimumHold.isAfter(watermarkHold)) { - watermarkHoldAccessor = watermarkHold; - minimumHold = watermarkHold; - } - } - - @Override - public ReadableState<Boolean> isEmpty() { - return new ReadableState<Boolean>() { - @Override - public Boolean read() { - return minimumHold == null; - } - - @Override - public ReadableState<Boolean> readLater() { - // Ignore - return this; - } - }; - } - - @Override - public OutputTimeFn<? super W> getOutputTimeFn() { - return outputTimeFn; - } - - @Override - public Instant read() { - return minimumHold; - } - - @Override - public WatermarkHoldState<W> readLater() { - // Ignore - return this; - } - - @Override - public String toString() { - return Objects.toString(minimumHold); - } - - @Override - public boolean shouldPersist() { - return minimumHold != null; - } - - @Override - public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { - if (minimumHold != null) { - checkpointBuilder.addWatermarkHoldsBuilder() - .setTag(stateKey) - .setTimestamp(minimumHold); - } - } - - public void restoreState(StateCheckpointReader checkpointReader) throws IOException { - Instant watermark = checkpointReader.getTimestamp(); - add(watermark); - } - } - - - private static <K, InputT, AccumT, OutputT> CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> withContext( - final Combine.KeyedCombineFn<K, InputT, AccumT, OutputT> combineFn) { - return new CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>() { - @Override - public AccumT createAccumulator(K key, CombineWithContext.Context c) { - return combineFn.createAccumulator(key); - } - - @Override - public AccumT addInput(K key, AccumT accumulator, InputT value, CombineWithContext.Context c) { - return combineFn.addInput(key, accumulator, value); - } - - @Override - public AccumT mergeAccumulators(K key, Iterable<AccumT> accumulators, CombineWithContext.Context c) { - return combineFn.mergeAccumulators(key, accumulators); - } - - @Override - public OutputT extractOutput(K key, AccumT accumulator, CombineWithContext.Context c) { - return combineFn.extractOutput(key, accumulator); - } - }; - } - - private static <K, InputT, AccumT, OutputT> CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> withKeyAndContext( - final Combine.CombineFn<InputT, AccumT, OutputT> combineFn) { - return new CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>() { - @Override - public AccumT createAccumulator(K key, CombineWithContext.Context c) { - return combineFn.createAccumulator(); - } - - @Override - public AccumT addInput(K key, AccumT accumulator, InputT value, CombineWithContext.Context c) { - return combineFn.addInput(accumulator, value); - } - - @Override - public AccumT mergeAccumulators(K key, Iterable<AccumT> accumulators, CombineWithContext.Context c) { - return combineFn.mergeAccumulators(accumulators); - } - - @Override - public OutputT extractOutput(K key, AccumT accumulator, CombineWithContext.Context c) { - return combineFn.extractOutput(accumulator); - } - }; - } - - private final class FlinkInMemoryKeyedCombiningValue<InputT, AccumT, OutputT> - implements AccumulatorCombiningState<InputT, AccumT, OutputT>, CheckpointableIF { - - private final ByteString stateKey; - private final CombineWithContext.KeyedCombineFnWithContext<? super K, InputT, AccumT, OutputT> combineFn; - private final Coder<AccumT> accumCoder; - private final CombineWithContext.Context context; - - private AccumT accum = null; - private boolean isClear = true; - - private FlinkInMemoryKeyedCombiningValue(ByteString stateKey, - Combine.CombineFn<InputT, AccumT, OutputT> combineFn, - Coder<AccumT> accumCoder, - final StateContext<?> stateContext) { - this(stateKey, withKeyAndContext(combineFn), accumCoder, stateContext); - } - - - private FlinkInMemoryKeyedCombiningValue(ByteString stateKey, - Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn, - Coder<AccumT> accumCoder, - final StateContext<?> stateContext) { - this(stateKey, withContext(combineFn), accumCoder, stateContext); - } - - private FlinkInMemoryKeyedCombiningValue(ByteString stateKey, - CombineWithContext.KeyedCombineFnWithContext<? super K, InputT, AccumT, OutputT> combineFn, - Coder<AccumT> accumCoder, - final StateContext<?> stateContext) { - Preconditions.checkNotNull(combineFn); - Preconditions.checkNotNull(accumCoder); - - this.stateKey = stateKey; - this.combineFn = combineFn; - this.accumCoder = accumCoder; - this.context = new CombineWithContext.Context() { - @Override - public PipelineOptions getPipelineOptions() { - return stateContext.getPipelineOptions(); - } - - @Override - public <T> T sideInput(PCollectionView<T> view) { - return stateContext.sideInput(view); - } - }; - accum = combineFn.createAccumulator(key, context); - } - - @Override - public void clear() { - accum = combineFn.createAccumulator(key, context); - isClear = true; - } - - @Override - public void add(InputT input) { - isClear = false; - accum = combineFn.addInput(key, accum, input, context); - } - - @Override - public AccumT getAccum() { - return accum; - } - - @Override - public ReadableState<Boolean> isEmpty() { - return new ReadableState<Boolean>() { - @Override - public ReadableState<Boolean> readLater() { - // Ignore - return this; - } - - @Override - public Boolean read() { - return isClear; - } - }; - } - - @Override - public void addAccum(AccumT accum) { - isClear = false; - this.accum = combineFn.mergeAccumulators(key, Arrays.asList(this.accum, accum), context); - } - - @Override - public AccumT mergeAccumulators(Iterable<AccumT> accumulators) { - return combineFn.mergeAccumulators(key, accumulators, context); - } - - @Override - public OutputT read() { - return combineFn.extractOutput(key, accum, context); - } - - @Override - public AccumulatorCombiningState<InputT, AccumT, OutputT> readLater() { - // Ignore - return this; - } - - @Override - public boolean shouldPersist() { - return !isClear; - } - - @Override - public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { - if (!isClear) { - // serialize the coder. - byte[] coder = InstantiationUtil.serializeObject(accumCoder); - - // serialize the combiner. - byte[] combiner = InstantiationUtil.serializeObject(combineFn); - - // encode the accumulator into a ByteString - ByteString.Output stream = ByteString.newOutput(); - accumCoder.encode(accum, stream, Coder.Context.OUTER); - ByteString data = stream.toByteString(); - - // put the flag that the next serialized element is an accumulator - checkpointBuilder.addAccumulatorBuilder() - .setTag(stateKey) - .setData(coder) - .setData(combiner) - .setData(data); - } - } - - public void restoreState(StateCheckpointReader checkpointReader) throws IOException { - ByteString valueContent = checkpointReader.getData(); - AccumT accum = this.accumCoder.decode(new ByteArrayInputStream(valueContent.toByteArray()), Coder.Context.OUTER); - addAccum(accum); - } - } - - private static final class FlinkInMemoryBag<T> implements BagState<T>, CheckpointableIF { - private final List<T> contents = new ArrayList<>(); - - private final ByteString stateKey; - private final Coder<T> elemCoder; - - public FlinkInMemoryBag(ByteString stateKey, Coder<T> elemCoder) { - this.stateKey = stateKey; - this.elemCoder = elemCoder; - } - - @Override - public void clear() { - contents.clear(); - } - - @Override - public Iterable<T> read() { - return contents; - } - - @Override - public BagState<T> readLater() { - // Ignore - return this; - } - - @Override - public void add(T input) { - contents.add(input); - } - - @Override - public ReadableState<Boolean> isEmpty() { - return new ReadableState<Boolean>() { - @Override - public ReadableState<Boolean> readLater() { - // Ignore - return this; - } - - @Override - public Boolean read() { - return contents.isEmpty(); - } - }; - } - - @Override - public boolean shouldPersist() { - return !contents.isEmpty(); - } - - @Override - public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { - if (!contents.isEmpty()) { - // serialize the coder. - byte[] coder = InstantiationUtil.serializeObject(elemCoder); - - checkpointBuilder.addListUpdatesBuilder() - .setTag(stateKey) - .setData(coder) - .writeInt(contents.size()); - - for (T item : contents) { - // encode the element - ByteString.Output stream = ByteString.newOutput(); - elemCoder.encode(item, stream, Coder.Context.OUTER); - ByteString data = stream.toByteString(); - - // add the data to the checkpoint. - checkpointBuilder.setData(data); - } - } - } - - public void restoreState(StateCheckpointReader checkpointReader) throws IOException { - int noOfValues = checkpointReader.getInt(); - for (int j = 0; j < noOfValues; j++) { - ByteString valueContent = checkpointReader.getData(); - T outValue = elemCoder.decode(new ByteArrayInputStream(valueContent.toByteArray()), Coder.Context.OUTER); - add(outValue); - } - } - } + private final K key; + + private final Coder<K> keyCoder; + + private final Coder<? extends BoundedWindow> windowCoder; + + private final OutputTimeFn<? super BoundedWindow> outputTimeFn; + + private Instant watermarkHoldAccessor; + + public FlinkStateInternals(K key, + Coder<K> keyCoder, + Coder<? extends BoundedWindow> windowCoder, + OutputTimeFn<? super BoundedWindow> outputTimeFn) { + this.key = key; + this.keyCoder = keyCoder; + this.windowCoder = windowCoder; + this.outputTimeFn = outputTimeFn; + } + + public Instant getWatermarkHold() { + return watermarkHoldAccessor; + } + + /** + * This is the interface state has to implement in order for it to be fault tolerant when + * executed by the FlinkPipelineRunner. + */ + private interface CheckpointableIF { + + boolean shouldPersist(); + + void persistState(StateCheckpointWriter checkpointBuilder) throws IOException; + } + + protected final StateTable<K> inMemoryState = new StateTable<K>() { + @Override + protected StateTag.StateBinder binderForNamespace(final StateNamespace namespace, final StateContext<?> c) { + return new StateTag.StateBinder<K>() { + + @Override + public <T> ValueState<T> bindValue(StateTag<? super K, ValueState<T>> address, Coder<T> coder) { + return new FlinkInMemoryValue<>(encodeKey(namespace, address), coder); + } + + @Override + public <T> BagState<T> bindBag(StateTag<? super K, BagState<T>> address, Coder<T> elemCoder) { + return new FlinkInMemoryBag<>(encodeKey(namespace, address), elemCoder); + } + + @Override + public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> bindCombiningValue( + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, Combine.CombineFn<InputT, AccumT, OutputT> combineFn) { + return new FlinkInMemoryKeyedCombiningValue<>(encodeKey(namespace, address), combineFn, accumCoder, c); + } + + @Override + public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValue( + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) { + return new FlinkInMemoryKeyedCombiningValue<>(encodeKey(namespace, address), combineFn, accumCoder, c); + } + + @Override + public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValueWithContext( + StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + CombineWithContext.KeyedCombineFnWithContext<? super K, InputT, AccumT, OutputT> combineFn) { + return new FlinkInMemoryKeyedCombiningValue<>(encodeKey(namespace, address), combineFn, accumCoder, c); + } + + @Override + public <W extends BoundedWindow> WatermarkHoldState<W> bindWatermark(StateTag<? super K, WatermarkHoldState<W>> address, OutputTimeFn<? super W> outputTimeFn) { + return new FlinkWatermarkHoldStateImpl<>(encodeKey(namespace, address), outputTimeFn); + } + }; + } + }; + + @Override + public K getKey() { + return key; + } + + @Override + public <StateT extends State> StateT state(StateNamespace namespace, StateTag<? super K, StateT> address) { + return inMemoryState.get(namespace, address, null); + } + + @Override + public <T extends State> T state(StateNamespace namespace, StateTag<? super K, T> address, StateContext<?> c) { + return inMemoryState.get(namespace, address, c); + } + + public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { + checkpointBuilder.writeInt(getNoOfElements()); + + for (State location : inMemoryState.values()) { + if (!(location instanceof CheckpointableIF)) { + throw new IllegalStateException(String.format( + "%s wasn't created by %s -- unable to persist it", + location.getClass().getSimpleName(), + getClass().getSimpleName())); + } + ((CheckpointableIF) location).persistState(checkpointBuilder); + } + } + + public void restoreState(StateCheckpointReader checkpointReader, ClassLoader loader) + throws IOException, ClassNotFoundException { + + // the number of elements to read. + int noOfElements = checkpointReader.getInt(); + for (int i = 0; i < noOfElements; i++) { + decodeState(checkpointReader, loader); + } + } + + /** + * We remove the first character which encodes the type of the stateTag ('s' for system + * and 'u' for user). For more details check out the source of + * {@link StateTags.StateTagBase#getId()}. + */ + private void decodeState(StateCheckpointReader reader, ClassLoader loader) + throws IOException, ClassNotFoundException { + + StateType stateItemType = StateType.deserialize(reader); + ByteString stateKey = reader.getTag(); + + // first decode the namespace and the tagId... + String[] namespaceAndTag = stateKey.toStringUtf8().split("\\+"); + if (namespaceAndTag.length != 2) { + throw new IllegalArgumentException("Invalid stateKey " + stateKey.toString() + "."); + } + StateNamespace namespace = StateNamespaces.fromString(namespaceAndTag[0], windowCoder); + + // ... decide if it is a system or user stateTag... + char ownerTag = namespaceAndTag[1].charAt(0); + if (ownerTag != 's' && ownerTag != 'u') { + throw new RuntimeException("Invalid StateTag name."); + } + boolean isSystemTag = ownerTag == 's'; + String tagId = namespaceAndTag[1].substring(1); + + // ...then decode the coder (if there is one)... + Coder<?> coder = null; + switch (stateItemType) { + case VALUE: + case LIST: + case ACCUMULATOR: + ByteString coderBytes = reader.getData(); + coder = InstantiationUtil.deserializeObject(coderBytes.toByteArray(), loader); + break; + case WATERMARK: + break; + } + + // ...then decode the combiner function (if there is one)... + CombineWithContext.KeyedCombineFnWithContext<? super K, ?, ?, ?> combineFn = null; + switch (stateItemType) { + case ACCUMULATOR: + ByteString combinerBytes = reader.getData(); + combineFn = InstantiationUtil.deserializeObject(combinerBytes.toByteArray(), loader); + break; + case VALUE: + case LIST: + case WATERMARK: + break; + } + + //... and finally, depending on the type of the state being decoded, + // 1) create the adequate stateTag, + // 2) create the state container, + // 3) restore the actual content. + switch (stateItemType) { + case VALUE: { + StateTag stateTag = StateTags.value(tagId, coder); + stateTag = isSystemTag ? StateTags.makeSystemTagInternal(stateTag) : stateTag; + @SuppressWarnings("unchecked") + FlinkInMemoryValue<?> value = (FlinkInMemoryValue<?>) inMemoryState.get(namespace, stateTag, null); + value.restoreState(reader); + break; + } + case WATERMARK: { + @SuppressWarnings("unchecked") + StateTag<Object, WatermarkHoldState<BoundedWindow>> stateTag = StateTags.watermarkStateInternal(tagId, outputTimeFn); + stateTag = isSystemTag ? StateTags.makeSystemTagInternal(stateTag) : stateTag; + @SuppressWarnings("unchecked") + FlinkWatermarkHoldStateImpl<?> watermark = (FlinkWatermarkHoldStateImpl<?>) inMemoryState.get(namespace, stateTag, null); + watermark.restoreState(reader); + break; + } + case LIST: { + StateTag stateTag = StateTags.bag(tagId, coder); + stateTag = isSystemTag ? StateTags.makeSystemTagInternal(stateTag) : stateTag; + FlinkInMemoryBag<?> bag = (FlinkInMemoryBag<?>) inMemoryState.get(namespace, stateTag, null); + bag.restoreState(reader); + break; + } + case ACCUMULATOR: { + @SuppressWarnings("unchecked") + StateTag<K, AccumulatorCombiningState<?, ?, ?>> stateTag = StateTags.keyedCombiningValueWithContext(tagId, (Coder) coder, combineFn); + stateTag = isSystemTag ? StateTags.makeSystemTagInternal(stateTag) : stateTag; + @SuppressWarnings("unchecked") + FlinkInMemoryKeyedCombiningValue<?, ?, ?> combiningValue = + (FlinkInMemoryKeyedCombiningValue<?, ?, ?>) inMemoryState.get(namespace, stateTag, null); + combiningValue.restoreState(reader); + break; + } + default: + throw new RuntimeException("Unknown State Type " + stateItemType + "."); + } + } + + private ByteString encodeKey(StateNamespace namespace, StateTag<? super K, ?> address) { + StringBuilder sb = new StringBuilder(); + try { + namespace.appendTo(sb); + sb.append('+'); + address.appendTo(sb); + } catch (IOException e) { + throw new RuntimeException(e); + } + return ByteString.copyFromUtf8(sb.toString()); + } + + private int getNoOfElements() { + int noOfElements = 0; + for (State state : inMemoryState.values()) { + if (!(state instanceof CheckpointableIF)) { + throw new RuntimeException("State Implementations used by the " + + "Flink Dataflow Runner should implement the CheckpointableIF interface."); + } + + if (((CheckpointableIF) state).shouldPersist()) { + noOfElements++; + } + } + return noOfElements; + } + + private final class FlinkInMemoryValue<T> implements ValueState<T>, CheckpointableIF { + + private final ByteString stateKey; + private final Coder<T> elemCoder; + + private T value = null; + + public FlinkInMemoryValue(ByteString stateKey, Coder<T> elemCoder) { + this.stateKey = stateKey; + this.elemCoder = elemCoder; + } + + @Override + public void clear() { + value = null; + } + + @Override + public void write(T input) { + this.value = input; + } + + @Override + public T read() { + return value; + } + + @Override + public ValueState<T> readLater() { + // Ignore + return this; + } + + @Override + public boolean shouldPersist() { + return value != null; + } + + @Override + public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { + if (value != null) { + // serialize the coder. + byte[] coder = InstantiationUtil.serializeObject(elemCoder); + + // encode the value into a ByteString + ByteString.Output stream = ByteString.newOutput(); + elemCoder.encode(value, stream, Coder.Context.OUTER); + ByteString data = stream.toByteString(); + + checkpointBuilder.addValueBuilder() + .setTag(stateKey) + .setData(coder) + .setData(data); + } + } + + public void restoreState(StateCheckpointReader checkpointReader) throws IOException { + ByteString valueContent = checkpointReader.getData(); + T outValue = elemCoder.decode(new ByteArrayInputStream(valueContent.toByteArray()), Coder.Context.OUTER); + write(outValue); + } + } + + private final class FlinkWatermarkHoldStateImpl<W extends BoundedWindow> + implements WatermarkHoldState<W>, CheckpointableIF { + + private final ByteString stateKey; + + private Instant minimumHold = null; + + private OutputTimeFn<? super W> outputTimeFn; + + public FlinkWatermarkHoldStateImpl(ByteString stateKey, OutputTimeFn<? super W> outputTimeFn) { + this.stateKey = stateKey; + this.outputTimeFn = outputTimeFn; + } + + @Override + public void clear() { + // Even though we're clearing we can't remove this from the in-memory state map, since + // other users may already have a handle on this WatermarkBagInternal. + minimumHold = null; + watermarkHoldAccessor = null; + } + + @Override + public void add(Instant watermarkHold) { + if (minimumHold == null || minimumHold.isAfter(watermarkHold)) { + watermarkHoldAccessor = watermarkHold; + minimumHold = watermarkHold; + } + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public Boolean read() { + return minimumHold == null; + } + + @Override + public ReadableState<Boolean> readLater() { + // Ignore + return this; + } + }; + } + + @Override + public OutputTimeFn<? super W> getOutputTimeFn() { + return outputTimeFn; + } + + @Override + public Instant read() { + return minimumHold; + } + + @Override + public WatermarkHoldState<W> readLater() { + // Ignore + return this; + } + + @Override + public String toString() { + return Objects.toString(minimumHold); + } + + @Override + public boolean shouldPersist() { + return minimumHold != null; + } + + @Override + public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { + if (minimumHold != null) { + checkpointBuilder.addWatermarkHoldsBuilder() + .setTag(stateKey) + .setTimestamp(minimumHold); + } + } + + public void restoreState(StateCheckpointReader checkpointReader) throws IOException { + Instant watermark = checkpointReader.getTimestamp(); + add(watermark); + } + } + + + private static <K, InputT, AccumT, OutputT> CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> withContext( + final Combine.KeyedCombineFn<K, InputT, AccumT, OutputT> combineFn) { + return new CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>() { + @Override + public AccumT createAccumulator(K key, CombineWithContext.Context c) { + return combineFn.createAccumulator(key); + } + + @Override + public AccumT addInput(K key, AccumT accumulator, InputT value, CombineWithContext.Context c) { + return combineFn.addInput(key, accumulator, value); + } + + @Override + public AccumT mergeAccumulators(K key, Iterable<AccumT> accumulators, CombineWithContext.Context c) { + return combineFn.mergeAccumulators(key, accumulators); + } + + @Override + public OutputT extractOutput(K key, AccumT accumulator, CombineWithContext.Context c) { + return combineFn.extractOutput(key, accumulator); + } + }; + } + + private static <K, InputT, AccumT, OutputT> CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> withKeyAndContext( + final Combine.CombineFn<InputT, AccumT, OutputT> combineFn) { + return new CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>() { + @Override + public AccumT createAccumulator(K key, CombineWithContext.Context c) { + return combineFn.createAccumulator(); + } + + @Override + public AccumT addInput(K key, AccumT accumulator, InputT value, CombineWithContext.Context c) { + return combineFn.addInput(accumulator, value); + } + + @Override + public AccumT mergeAccumulators(K key, Iterable<AccumT> accumulators, CombineWithContext.Context c) { + return combineFn.mergeAccumulators(accumulators); + } + + @Override + public OutputT extractOutput(K key, AccumT accumulator, CombineWithContext.Context c) { + return combineFn.extractOutput(accumulator); + } + }; + } + + private final class FlinkInMemoryKeyedCombiningValue<InputT, AccumT, OutputT> + implements AccumulatorCombiningState<InputT, AccumT, OutputT>, CheckpointableIF { + + private final ByteString stateKey; + private final CombineWithContext.KeyedCombineFnWithContext<? super K, InputT, AccumT, OutputT> combineFn; + private final Coder<AccumT> accumCoder; + private final CombineWithContext.Context context; + + private AccumT accum = null; + private boolean isClear = true; + + private FlinkInMemoryKeyedCombiningValue(ByteString stateKey, + Combine.CombineFn<InputT, AccumT, OutputT> combineFn, + Coder<AccumT> accumCoder, + final StateContext<?> stateContext) { + this(stateKey, withKeyAndContext(combineFn), accumCoder, stateContext); + } + + + private FlinkInMemoryKeyedCombiningValue(ByteString stateKey, + Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn, + Coder<AccumT> accumCoder, + final StateContext<?> stateContext) { + this(stateKey, withContext(combineFn), accumCoder, stateContext); + } + + private FlinkInMemoryKeyedCombiningValue(ByteString stateKey, + CombineWithContext.KeyedCombineFnWithContext<? super K, InputT, AccumT, OutputT> combineFn, + Coder<AccumT> accumCoder, + final StateContext<?> stateContext) { + Preconditions.checkNotNull(combineFn); + Preconditions.checkNotNull(accumCoder); + + this.stateKey = stateKey; + this.combineFn = combineFn; + this.accumCoder = accumCoder; + this.context = new CombineWithContext.Context() { + @Override + public PipelineOptions getPipelineOptions() { + return stateContext.getPipelineOptions(); + } + + @Override + public <T> T sideInput(PCollectionView<T> view) { + return stateContext.sideInput(view); + } + }; + accum = combineFn.createAccumulator(key, context); + } + + @Override + public void clear() { + accum = combineFn.createAccumulator(key, context); + isClear = true; + } + + @Override + public void add(InputT input) { + isClear = false; + accum = combineFn.addInput(key, accum, input, context); + } + + @Override + public AccumT getAccum() { + return accum; + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public ReadableState<Boolean> readLater() { + // Ignore + return this; + } + + @Override + public Boolean read() { + return isClear; + } + }; + } + + @Override + public void addAccum(AccumT accum) { + isClear = false; + this.accum = combineFn.mergeAccumulators(key, Arrays.asList(this.accum, accum), context); + } + + @Override + public AccumT mergeAccumulators(Iterable<AccumT> accumulators) { + return combineFn.mergeAccumulators(key, accumulators, context); + } + + @Override + public OutputT read() { + return combineFn.extractOutput(key, accum, context); + } + + @Override + public AccumulatorCombiningState<InputT, AccumT, OutputT> readLater() { + // Ignore + return this; + } + + @Override + public boolean shouldPersist() { + return !isClear; + } + + @Override + public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { + if (!isClear) { + // serialize the coder. + byte[] coder = InstantiationUtil.serializeObject(accumCoder); + + // serialize the combiner. + byte[] combiner = InstantiationUtil.serializeObject(combineFn); + + // encode the accumulator into a ByteString + ByteString.Output stream = ByteString.newOutput(); + accumCoder.encode(accum, stream, Coder.Context.OUTER); + ByteString data = stream.toByteString(); + + // put the flag that the next serialized element is an accumulator + checkpointBuilder.addAccumulatorBuilder() + .setTag(stateKey) + .setData(coder) + .setData(combiner) + .setData(data); + } + } + + public void restoreState(StateCheckpointReader checkpointReader) throws IOException { + ByteString valueContent = checkpointReader.getData(); + AccumT accum = this.accumCoder.decode(new ByteArrayInputStream(valueContent.toByteArray()), Coder.Context.OUTER); + addAccum(accum); + } + } + + private static final class FlinkInMemoryBag<T> implements BagState<T>, CheckpointableIF { + private final List<T> contents = new ArrayList<>(); + + private final ByteString stateKey; + private final Coder<T> elemCoder; + + public FlinkInMemoryBag(ByteString stateKey, Coder<T> elemCoder) { + this.stateKey = stateKey; + this.elemCoder = elemCoder; + } + + @Override + public void clear() { + contents.clear(); + } + + @Override + public Iterable<T> read() { + return contents; + } + + @Override + public BagState<T> readLater() { + // Ignore + return this; + } + + @Override + public void add(T input) { + contents.add(input); + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public ReadableState<Boolean> readLater() { + // Ignore + return this; + } + + @Override + public Boolean read() { + return contents.isEmpty(); + } + }; + } + + @Override + public boolean shouldPersist() { + return !contents.isEmpty(); + } + + @Override + public void persistState(StateCheckpointWriter checkpointBuilder) throws IOException { + if (!contents.isEmpty()) { + // serialize the coder. + byte[] coder = InstantiationUtil.serializeObject(elemCoder); + + checkpointBuilder.addListUpdatesBuilder() + .setTag(stateKey) + .setData(coder) + .writeInt(contents.size()); + + for (T item : contents) { + // encode the element + ByteString.Output stream = ByteString.newOutput(); + elemCoder.encode(item, stream, Coder.Context.OUTER); + ByteString data = stream.toByteString(); + + // add the data to the checkpoint. + checkpointBuilder.setData(data); + } + } + } + + public void restoreState(StateCheckpointReader checkpointReader) throws IOException { + int noOfValues = checkpointReader.getInt(); + for (int j = 0; j < noOfValues; j++) { + ByteString valueContent = checkpointReader.getData(); + T outValue = elemCoder.decode(new ByteArrayInputStream(valueContent.toByteArray()), Coder.Context.OUTER); + add(outValue); + } + } + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8852eb15/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/StateCheckpointReader.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/StateCheckpointReader.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/StateCheckpointReader.java index ba8ef89..753309e 100644 --- a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/StateCheckpointReader.java +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/StateCheckpointReader.java @@ -25,65 +25,65 @@ import java.util.concurrent.TimeUnit; public class StateCheckpointReader { - private final DataInputView input; - - public StateCheckpointReader(DataInputView in) { - this.input = in; - } - - public ByteString getTag() throws IOException { - return ByteString.copyFrom(readRawData()); - } - - public String getTagToString() throws IOException { - return input.readUTF(); - } - - public ByteString getData() throws IOException { - return ByteString.copyFrom(readRawData()); - } - - public int getInt() throws IOException { - validate(); - return input.readInt(); - } - - public byte getByte() throws IOException { - validate(); - return input.readByte(); - } - - public Instant getTimestamp() throws IOException { - validate(); - Long watermarkMillis = input.readLong(); - return new Instant(TimeUnit.MICROSECONDS.toMillis(watermarkMillis)); - } - - public <K> K deserializeKey(CoderTypeSerializer<K> keySerializer) throws IOException { - return deserializeObject(keySerializer); - } - - public <T> T deserializeObject(CoderTypeSerializer<T> objectSerializer) throws IOException { - return objectSerializer.deserialize(input); - } - - ///////// Helper Methods /////// - - private byte[] readRawData() throws IOException { - validate(); - int size = input.readInt(); - - byte[] serData = new byte[size]; - int bytesRead = input.read(serData); - if (bytesRead != size) { - throw new RuntimeException("Error while deserializing checkpoint. Not enough bytes in the input stream."); - } - return serData; - } - - private void validate() { - if (this.input == null) { - throw new RuntimeException("StateBackend not initialized yet."); - } - } + private final DataInputView input; + + public StateCheckpointReader(DataInputView in) { + this.input = in; + } + + public ByteString getTag() throws IOException { + return ByteString.copyFrom(readRawData()); + } + + public String getTagToString() throws IOException { + return input.readUTF(); + } + + public ByteString getData() throws IOException { + return ByteString.copyFrom(readRawData()); + } + + public int getInt() throws IOException { + validate(); + return input.readInt(); + } + + public byte getByte() throws IOException { + validate(); + return input.readByte(); + } + + public Instant getTimestamp() throws IOException { + validate(); + Long watermarkMillis = input.readLong(); + return new Instant(TimeUnit.MICROSECONDS.toMillis(watermarkMillis)); + } + + public <K> K deserializeKey(CoderTypeSerializer<K> keySerializer) throws IOException { + return deserializeObject(keySerializer); + } + + public <T> T deserializeObject(CoderTypeSerializer<T> objectSerializer) throws IOException { + return objectSerializer.deserialize(input); + } + + ///////// Helper Methods /////// + + private byte[] readRawData() throws IOException { + validate(); + int size = input.readInt(); + + byte[] serData = new byte[size]; + int bytesRead = input.read(serData); + if (bytesRead != size) { + throw new RuntimeException("Error while deserializing checkpoint. Not enough bytes in the input stream."); + } + return serData; + } + + private void validate() { + if (this.input == null) { + throw new RuntimeException("StateBackend not initialized yet."); + } + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8852eb15/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/StateCheckpointUtils.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/StateCheckpointUtils.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/StateCheckpointUtils.java index cd85163..1741829 100644 --- a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/StateCheckpointUtils.java +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/state/StateCheckpointUtils.java @@ -34,120 +34,120 @@ import java.util.Set; public class StateCheckpointUtils { - public static <K> void encodeState(Map<K, FlinkStateInternals<K>> perKeyStateInternals, - StateCheckpointWriter writer, Coder<K> keyCoder) throws IOException { - CoderTypeSerializer<K> keySerializer = new CoderTypeSerializer<>(keyCoder); - - int noOfKeys = perKeyStateInternals.size(); - writer.writeInt(noOfKeys); - for (Map.Entry<K, FlinkStateInternals<K>> keyStatePair : perKeyStateInternals.entrySet()) { - K key = keyStatePair.getKey(); - FlinkStateInternals<K> state = keyStatePair.getValue(); - - // encode the key - writer.serializeKey(key, keySerializer); - - // write the associated state - state.persistState(writer); - } - } - - public static <K> Map<K, FlinkStateInternals<K>> decodeState( - StateCheckpointReader reader, - OutputTimeFn<? super BoundedWindow> outputTimeFn, - Coder<K> keyCoder, - Coder<? extends BoundedWindow> windowCoder, - ClassLoader classLoader) throws IOException, ClassNotFoundException { - - int noOfKeys = reader.getInt(); - Map<K, FlinkStateInternals<K>> perKeyStateInternals = new HashMap<>(noOfKeys); - perKeyStateInternals.clear(); - - CoderTypeSerializer<K> keySerializer = new CoderTypeSerializer<>(keyCoder); - for (int i = 0; i < noOfKeys; i++) { - - // decode the key. - K key = reader.deserializeKey(keySerializer); - - //decode the state associated to the key. - FlinkStateInternals<K> stateForKey = - new FlinkStateInternals<>(key, keyCoder, windowCoder, outputTimeFn); - stateForKey.restoreState(reader, classLoader); - perKeyStateInternals.put(key, stateForKey); - } - return perKeyStateInternals; - } - - ////////////// Encoding/Decoding the Timers //////////////// - - - public static <K> void encodeTimers(Map<K, Set<TimerInternals.TimerData>> allTimers, - StateCheckpointWriter writer, - Coder<K> keyCoder) throws IOException { - CoderTypeSerializer<K> keySerializer = new CoderTypeSerializer<>(keyCoder); - - int noOfKeys = allTimers.size(); - writer.writeInt(noOfKeys); - for (Map.Entry<K, Set<TimerInternals.TimerData>> timersPerKey : allTimers.entrySet()) { - K key = timersPerKey.getKey(); - - // encode the key - writer.serializeKey(key, keySerializer); - - // write the associated timers - Set<TimerInternals.TimerData> timers = timersPerKey.getValue(); - encodeTimerDataForKey(writer, timers); - } - } - - public static <K> Map<K, Set<TimerInternals.TimerData>> decodeTimers( - StateCheckpointReader reader, - Coder<? extends BoundedWindow> windowCoder, - Coder<K> keyCoder) throws IOException { - - int noOfKeys = reader.getInt(); - Map<K, Set<TimerInternals.TimerData>> activeTimers = new HashMap<>(noOfKeys); - activeTimers.clear(); - - CoderTypeSerializer<K> keySerializer = new CoderTypeSerializer<>(keyCoder); - for (int i = 0; i < noOfKeys; i++) { - - // decode the key. - K key = reader.deserializeKey(keySerializer); - - // decode the associated timers. - Set<TimerInternals.TimerData> timers = decodeTimerDataForKey(reader, windowCoder); - activeTimers.put(key, timers); - } - return activeTimers; - } - - private static void encodeTimerDataForKey(StateCheckpointWriter writer, Set<TimerInternals.TimerData> timers) throws IOException { - // encode timers - writer.writeInt(timers.size()); - for (TimerInternals.TimerData timer : timers) { - String stringKey = timer.getNamespace().stringKey(); - - writer.setTag(stringKey); - writer.setTimestamp(timer.getTimestamp()); - writer.writeInt(timer.getDomain().ordinal()); - } - } - - private static Set<TimerInternals.TimerData> decodeTimerDataForKey( - StateCheckpointReader reader, Coder<? extends BoundedWindow> windowCoder) throws IOException { - - // decode the timers: first their number and then the content itself. - int noOfTimers = reader.getInt(); - Set<TimerInternals.TimerData> timers = new HashSet<>(noOfTimers); - for (int i = 0; i < noOfTimers; i++) { - String stringKey = reader.getTagToString(); - Instant instant = reader.getTimestamp(); - TimeDomain domain = TimeDomain.values()[reader.getInt()]; - - StateNamespace namespace = StateNamespaces.fromString(stringKey, windowCoder); - timers.add(TimerInternals.TimerData.of(namespace, instant, domain)); - } - return timers; - } + public static <K> void encodeState(Map<K, FlinkStateInternals<K>> perKeyStateInternals, + StateCheckpointWriter writer, Coder<K> keyCoder) throws IOException { + CoderTypeSerializer<K> keySerializer = new CoderTypeSerializer<>(keyCoder); + + int noOfKeys = perKeyStateInternals.size(); + writer.writeInt(noOfKeys); + for (Map.Entry<K, FlinkStateInternals<K>> keyStatePair : perKeyStateInternals.entrySet()) { + K key = keyStatePair.getKey(); + FlinkStateInternals<K> state = keyStatePair.getValue(); + + // encode the key + writer.serializeKey(key, keySerializer); + + // write the associated state + state.persistState(writer); + } + } + + public static <K> Map<K, FlinkStateInternals<K>> decodeState( + StateCheckpointReader reader, + OutputTimeFn<? super BoundedWindow> outputTimeFn, + Coder<K> keyCoder, + Coder<? extends BoundedWindow> windowCoder, + ClassLoader classLoader) throws IOException, ClassNotFoundException { + + int noOfKeys = reader.getInt(); + Map<K, FlinkStateInternals<K>> perKeyStateInternals = new HashMap<>(noOfKeys); + perKeyStateInternals.clear(); + + CoderTypeSerializer<K> keySerializer = new CoderTypeSerializer<>(keyCoder); + for (int i = 0; i < noOfKeys; i++) { + + // decode the key. + K key = reader.deserializeKey(keySerializer); + + //decode the state associated to the key. + FlinkStateInternals<K> stateForKey = + new FlinkStateInternals<>(key, keyCoder, windowCoder, outputTimeFn); + stateForKey.restoreState(reader, classLoader); + perKeyStateInternals.put(key, stateForKey); + } + return perKeyStateInternals; + } + + ////////////// Encoding/Decoding the Timers //////////////// + + + public static <K> void encodeTimers(Map<K, Set<TimerInternals.TimerData>> allTimers, + StateCheckpointWriter writer, + Coder<K> keyCoder) throws IOException { + CoderTypeSerializer<K> keySerializer = new CoderTypeSerializer<>(keyCoder); + + int noOfKeys = allTimers.size(); + writer.writeInt(noOfKeys); + for (Map.Entry<K, Set<TimerInternals.TimerData>> timersPerKey : allTimers.entrySet()) { + K key = timersPerKey.getKey(); + + // encode the key + writer.serializeKey(key, keySerializer); + + // write the associated timers + Set<TimerInternals.TimerData> timers = timersPerKey.getValue(); + encodeTimerDataForKey(writer, timers); + } + } + + public static <K> Map<K, Set<TimerInternals.TimerData>> decodeTimers( + StateCheckpointReader reader, + Coder<? extends BoundedWindow> windowCoder, + Coder<K> keyCoder) throws IOException { + + int noOfKeys = reader.getInt(); + Map<K, Set<TimerInternals.TimerData>> activeTimers = new HashMap<>(noOfKeys); + activeTimers.clear(); + + CoderTypeSerializer<K> keySerializer = new CoderTypeSerializer<>(keyCoder); + for (int i = 0; i < noOfKeys; i++) { + + // decode the key. + K key = reader.deserializeKey(keySerializer); + + // decode the associated timers. + Set<TimerInternals.TimerData> timers = decodeTimerDataForKey(reader, windowCoder); + activeTimers.put(key, timers); + } + return activeTimers; + } + + private static void encodeTimerDataForKey(StateCheckpointWriter writer, Set<TimerInternals.TimerData> timers) throws IOException { + // encode timers + writer.writeInt(timers.size()); + for (TimerInternals.TimerData timer : timers) { + String stringKey = timer.getNamespace().stringKey(); + + writer.setTag(stringKey); + writer.setTimestamp(timer.getTimestamp()); + writer.writeInt(timer.getDomain().ordinal()); + } + } + + private static Set<TimerInternals.TimerData> decodeTimerDataForKey( + StateCheckpointReader reader, Coder<? extends BoundedWindow> windowCoder) throws IOException { + + // decode the timers: first their number and then the content itself. + int noOfTimers = reader.getInt(); + Set<TimerInternals.TimerData> timers = new HashSet<>(noOfTimers); + for (int i = 0; i < noOfTimers; i++) { + String stringKey = reader.getTagToString(); + Instant instant = reader.getTimestamp(); + TimeDomain domain = TimeDomain.values()[reader.getInt()]; + + StateNamespace namespace = StateNamespaces.fromString(stringKey, windowCoder); + timers.add(TimerInternals.TimerData.of(namespace, instant, domain)); + } + return timers; + } }
