package org.example;


import com.google.common.base.Preconditions;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.coders.DefaultCoder;
import org.apache.beam.sdk.io.FileIO;
import org.apache.beam.sdk.io.fs.EmptyMatchTreatment;
import org.apache.beam.sdk.io.range.OffsetRange;
import org.apache.beam.sdk.options.*;
import org.apache.beam.sdk.transforms.*;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.values.*;
import org.apache.commons.lang3.RandomStringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.time.Duration;
import java.util.List;

/**
 * Example of a streaming job that fails on Dataflow. Probably cause: upstream DoFn with side input.
 */
public class TsPointsStreamingFail {
    // The log to output status messages to.
    private static final Logger LOG = LoggerFactory.getLogger(TsPointsStreamingFail.class);

    /* the main dto */
    @DefaultCoder(AvroCoder.class)
    public static class RequestParameters {
        public long start = -1;
        public long end = -1;
        public String payload = "";
        public boolean processedByUpstream = false;

        public RequestParameters copy() {
            RequestParameters output = new RequestParameters();
            output.start = this.start;
            output.end = this.end;
            output.payload = this.payload;
            output.processedByUpstream = this.processedByUpstream;
            return output;
        }

        public String toString() {
            return "start: " + start + " "
                    + "end: " + end + " "
                    + "payload: " + payload + " "
                    + "processedByUpstream: " + processedByUpstream;
        }
    }

    /**
     * An upstream DoFn that
     */
    public static class UpstreamDoFn extends DoFn<RequestParameters, RequestParameters> {
        private final PCollectionView<List<String>> sideInput;

        public UpstreamDoFn(PCollectionView<List<String>> sideInput) {
            this.sideInput = sideInput;
        }

        @ProcessElement
        public void processElement(@Element RequestParameters input,
                                   OutputReceiver<RequestParameters> out,
                                   ProcessContext context) {
            List<String> inputList = context.sideInput(sideInput);
            String theInput = inputList.get(0);

            RequestParameters output = input.copy();
            output.payload = theInput;
            output.processedByUpstream = true;

            out.output(output);
        }
    }

    /**
     * This function generates an unbounded stream of TS datapoints queries.
     */
    @DoFn.UnboundedPerElement
    public static class GenerateRequestsUnboundFn extends DoFn<RequestParameters, RequestParameters> {
        private final String randomIdString = RandomStringUtils.randomAlphanumeric(5);
        private final String loggingPrefix = "Generate TS  request unbound [" + randomIdString + "] -";

        private static final Duration pollOffset = Duration.ofSeconds(30);
        private static final Duration pollInterval = Duration.ofSeconds(10);

        @Setup
        public void setup() {
            LOG.info("Setting up TS point unbound request generator.");
            // do a bit of setup work
        }

        @ProcessElement
        public ProcessContinuation processElement(@Element RequestParameters query,
                                                  RestrictionTracker<OffsetRange, Long> tracker,
                                                  OutputReceiver<RequestParameters> out,
                                                  ProcessContext context) throws Exception {
            final String batchIdentifierPrefix = "Request batch: " + RandomStringUtils.randomAlphanumeric(6) + " - ";
            final String localLoggingPrefix = loggingPrefix + batchIdentifierPrefix;
            LOG.info(localLoggingPrefix + "Input query: {}", query.toString());
            LOG.info(localLoggingPrefix + "Input restriction {}", tracker.currentRestriction());

            long startRange = tracker.currentRestriction().getFrom();
            long endRange = tracker.currentRestriction().getTo();

            while (startRange < (System.currentTimeMillis() - pollOffset.toMillis())) {
                // Set the query's max end to current time - offset.
                if (endRange > (System.currentTimeMillis() - pollOffset.toMillis())) {
                    endRange = (System.currentTimeMillis() - pollOffset.toMillis());
                }

                if (tracker.tryClaim(endRange - 1)) {
                    LOG.info(localLoggingPrefix + "Building RequestParameters with start = {} and end = {}", startRange, endRange);
                    context.updateWatermark(org.joda.time.Instant.ofEpochMilli(startRange));
                    out.outputWithTimestamp(buildRequestParameters(query, startRange, endRange, localLoggingPrefix),
                            org.joda.time.Instant.ofEpochMilli(startRange));
                    // Update the start and end range for the next iteration
                    startRange = endRange;
                    endRange = tracker.currentRestriction().getTo();
                } else {
                    LOG.info(localLoggingPrefix + "Stopping work due to checkpointing or splitting.");
                    return ProcessContinuation.stop();
                }

                if (startRange >= tracker.currentRestriction().getTo()) {
                    LOG.info(localLoggingPrefix + "Completed the request time range. Will stop watching for new datapoints.");
                    return ProcessContinuation.stop();
                }

                LOG.info(localLoggingPrefix + "Pausing for {}", pollInterval.toString());
                return ProcessContinuation.resume().withResumeDelay(org.joda.time.Duration.millis(
                        pollInterval.toMillis()));
            }

            LOG.info(localLoggingPrefix + "Pausing for {}", pollInterval.toString());
            return ProcessContinuation.resume().withResumeDelay(org.joda.time.Duration.millis(
                    pollInterval.toMillis()));
        }

        private RequestParameters buildRequestParameters(RequestParameters requestParameters,
                                                         long start,
                                                         long end,
                                                         String loggingPrefix) {
            Preconditions.checkArgument(start < end, "Trying to build request with start >= end.");
            LOG.debug(loggingPrefix + "Building RequestParameters with start = {} and end = {}", start, end);
            RequestParameters output = new RequestParameters();
            output.end = end;
            output.start = start;
            return output;
        }

        @GetInitialRestriction
        public OffsetRange getInitialRestriction(RequestParameters requestParameters) throws Exception {
            long startTimestamp = 0L;
            long endTimestamp = Long.MAX_VALUE;

            return new OffsetRange(startTimestamp, endTimestamp);
        }


    }

    /**
     * Setup the main pipeline structure and run it.
     * @param options
     */
    private static PipelineResult runCdfTsPointsBQ(PipelineOptions options) throws IOException {
        Pipeline p = Pipeline.create(options);

        // side input
        PCollectionView<List<String>> sideView = p
                .apply("Some data", Create.of("Nice string"))
                .apply("to view", View.asList());

        PCollectionView<List<String>> sideFromFile = p
                .apply("Find file", FileIO.match()
                        .filepattern("gs://add-some-text-file/from-gs.txt")
                        .withEmptyMatchTreatment(EmptyMatchTreatment.ALLOW))
                .apply("Read file metadata", FileIO.readMatches()
                        .withDirectoryTreatment(FileIO.ReadMatches.DirectoryTreatment.SKIP))
                .apply("Read file", ParDo.of(new DoFn<FileIO.ReadableFile, KV<String, String>>() {
                    @ProcessElement
                    public void processElement(@Element FileIO.ReadableFile file,
                                               OutputReceiver<KV<String, String>> out) throws Exception {
                        out.output(KV.of("key", file.readFullyAsUTF8String()));
                    }
                }))
                .apply("remove key", Values.create())
                .apply("view", View.asList());

        // main input
        PCollection<RequestParameters> requestParametersPCollection;

        // streaming mode
        LOG.info("Setting up streaming mode");
        requestParametersPCollection = p
                .apply(Create.of(new RequestParameters()))
                .apply("Trublemaker", ParDo.of(new UpstreamDoFn(sideFromFile))
                        .withSideInputs(sideFromFile))
                .apply("Watch for new items", ParDo.of(new GenerateRequestsUnboundFn()));

        requestParametersPCollection
                .apply("Log data", MapElements.into(TypeDescriptor.of(RequestParameters.class))
                        .via(request -> {
                            LOG.info("Request parameters: {}", request.toString());
                            return request;
                        }));

        return p.run();
    }

    /**
     * Read the pipeline options from args and run the pipeline.
     * @param args
     */
    public static void main(String[] args) throws IOException{
        PipelineOptions options = PipelineOptionsFactory.fromArgs(args).withValidation().as(PipelineOptions.class);
        runCdfTsPointsBQ(options);
    }
}
