/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.streaming.examples.wordcount;

import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.filesystem.FsStateBackend;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

import java.io.Serializable;
import java.util.Iterator;

public class Test {

    public static void main(String[] args) throws Exception {
        final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.enableCheckpointing(50);
        env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 0L));
        FsStateBackend stateBackend =
                new FsStateBackend(
                        "file:///" + System.getProperty("java.io.tmpdir") + "/flink/backend",
                        false);
        env.setStateBackend(stateBackend);

        env.addSource(new RandomLongSource(10000, 1))
                .keyBy(new IdentityKeySelector<>())
                .process(new TestFunction())
                .addSink(new DiscardingSink<>());

        env.execute("Test");
    }

    public static class TestFunction extends KeyedProcessFunction<Long, Long, Integer>
            implements CheckpointedFunction, Serializable {

        private transient ValueState<MultiStorePacketState> state;

        @Override
        public void open(Configuration config) {
            ValueStateDescriptor<MultiStorePacketState> descriptor =
                    new ValueStateDescriptor<MultiStorePacketState>(
                            "MultiStorePacketState", // the state name
                            TypeInformation.of(
                                    new TypeHint<MultiStorePacketState>() {}), // type information
                            new MultiStorePacketState()); // default value of the state, if nothing
            state = getRuntimeContext().getState(descriptor);
        }

        @Override
        public void processElement(Long value, Context ctx, Collector<Integer> out)
                throws Exception {
            MultiStorePacketState so = state.value();
            so.exportedFileName = "aaaa" + value;
            so.fileName = so.exportedFileName;
            state.update(so);
            out.collect(value.intValue());
        }

        @Override
        public void snapshotState(FunctionSnapshotContext context) throws Exception {
            System.out.println("execute checkpoint");
            MultiStorePacketState so = state.value();
            System.out.println(so.exportedFileName);
        }

        @Override
        public void initializeState(FunctionInitializationContext context) throws Exception {}
    }

    private static class IdentityKeySelector<T> implements KeySelector<T, T> {
        private static final long serialVersionUID = 1L;

        @Override
        public T getKey(T value) throws Exception {
            return value;
        }
    }

    public static class MultiStorePacketState implements Serializable {
        public transient String fileName;
        public String exportedFileName;
    }

    /** Source function that produces a long sequence. */
    private static final class RandomLongSource extends RichParallelSourceFunction<Long>
            implements CheckpointedFunction {

        private static final long serialVersionUID = 1L;

        /** Generator delay between two events. */
        final long delay;

        /** Maximum restarts before shutting down this source. */
        final int maxAttempts;

        /** State that holds the current key for recovery. */
        transient ListState<Long> sourceCurrentKeyState;

        /** Generator's current key. */
        long currentKey;

        /** Generator runs while this is true. */
        volatile boolean running;

        RandomLongSource(int maxAttempts, long delay) {
            this.delay = delay;
            this.maxAttempts = maxAttempts;
            this.running = true;
        }

        @Override
        public void run(SourceContext<Long> sourceContext) throws Exception {

            int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
            int subtaskIdx = getRuntimeContext().getIndexOfThisSubtask();

            // the source emits one final event and shuts down once we have reached max attempts.
            if (getRuntimeContext().getAttemptNumber() > maxAttempts) {
                synchronized (sourceContext.getCheckpointLock()) {
                    sourceContext.collect(Long.valueOf(Long.MAX_VALUE - subtaskIdx) % 100);
                }
                return;
            }

            while (running) {

                synchronized (sourceContext.getCheckpointLock()) {
                    sourceContext.collect(Long.valueOf(currentKey % 100));
                    currentKey += numberOfParallelSubtasks;
                }

                if (delay > 0) {
                    Thread.sleep(delay);
                }
            }
        }

        @Override
        public void cancel() {
            running = false;
        }

        @Override
        public void snapshotState(FunctionSnapshotContext context) throws Exception {
            sourceCurrentKeyState.clear();
            sourceCurrentKeyState.add(currentKey);
        }

        @Override
        public void initializeState(FunctionInitializationContext context) throws Exception {

            ListStateDescriptor<Long> currentKeyDescriptor =
                    new ListStateDescriptor<>("currentKey", Long.class);
            sourceCurrentKeyState =
                    context.getOperatorStateStore().getListState(currentKeyDescriptor);

            currentKey = getRuntimeContext().getIndexOfThisSubtask();
            Iterable<Long> iterable = sourceCurrentKeyState.get();
            if (iterable != null) {
                Iterator<Long> iterator = iterable.iterator();
                if (iterator.hasNext()) {
                    currentKey = iterator.next();
                    Preconditions.checkState(!iterator.hasNext());
                }
            }
        }
    }
}
