package com.company;

import com.google.common.base.MoreObjects;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.extensions.joinlibrary.Join;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.*;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;

import java.io.Serializable;
import java.util.UUID;
import java.util.stream.StreamSupport;

public class MainWithSplit implements Serializable {


    public static final double THRESHOLD = 10.0;

    public static void main(String[] args) {

        new MainWithSplit().runPipeline(args);
    }

    private void runPipeline(String[] args) {

        PipelineOptions options =
                PipelineOptionsFactory.fromArgs(args).withValidation().create();
        Pipeline pipeline = Pipeline.create(options);

        PCollection<KV<String, Trade>> elements = pipeline
                .apply(Create.of(
                        "trader1,1.0",
                        "trader1,1.0",
                        "trader1,2.0",
                        "trader2,10.0",
                        "trader2,20.0",
                        "trader2,10.0",
                        "trader3,5.0")).apply(ParDo.of(extractKey()));

        PCollection<KV<String, Iterable<Trade>>> grouped = elements.apply(GroupByKey.<String, Trade>create());
        PCollection<KV<String, Double>> calculation = grouped.apply(ParDo.of(calculateMax()));
        PCollection<KV<String, KV<Double, Iterable<Trade>>>> joined = Join.innerJoin(calculation, grouped);
        PCollection<KV<String, KV<Double, Iterable<Trade>>>> filtered = joined.apply("filtering", ParDo.of(filterJoined()));

        filtered
                .apply(ToString.elements())
                .apply(TextIO.write().to("largeTrades" + UUID.randomUUID().toString() + ".txt"));

        pipeline.run();
    }

    private DoFn<KV<String, KV<Double, Iterable<Trade>>>, KV<String, KV<Double, Iterable<Trade>>>> filterJoined() {
        return new DoFn<KV<String, KV<Double, Iterable<Trade>>>, KV<String, KV<Double, Iterable<Trade>>>>() {
            @ProcessElement
            public void processElement(ProcessContext context) {

                KV<String, KV<Double, Iterable<Trade>>> element = context.element();
                String key = element.getKey();
                KV<Double, Iterable<Trade>> value = element.getValue();
                Double max = value.getKey();
                Iterable<Trade> trades = value.getValue();

                if (filterSmallTraders(trades) && filterSmallTraders(trades)) {
                    context.output(KV.of(key, KV.of(max, trades)));
                }
            }
        };
    }

    private boolean filterSmallTraders(Iterable<Trade> trades) {
        return StreamSupport.stream(trades.spliterator(), false).mapToDouble(Trade::getTransactionAmount).sum() > THRESHOLD;
    }

    private SerializableFunction<KV<String, Iterable<Trade>>, Boolean> filterAlwaysTrue() {
        return new SerializableFunction<KV<String, Iterable<Trade>>, Boolean>() {
            @Override
            public Boolean apply(KV<String, Iterable<Trade>> trades) {
                return true;
            }
        };
    }

    private static DoFn<KV<String, Iterable<Trade>>, KV<String, Iterable<Trade>>> filterSmallTradersWithDoFn() {
        return new DoFn<KV<String, Iterable<Trade>>, KV<String, Iterable<Trade>>>() {
            @ProcessElement
            public void processElement(@Element KV<String, Iterable<Trade>> element, ProcessContext context) {
                String key = element.getKey();
                Iterable<Trade> iterable = element.getValue();
                Double sum = StreamSupport.stream(iterable.spliterator(), false).mapToDouble(Trade::getTransactionAmount).sum();

                if (sum > THRESHOLD) {
                    context.output(KV.of(key, iterable));
                }
            }

        };
    }


    private static DoFn<KV<String, Iterable<Trade>>, KV<String, Double>> calculateMax() {
        return new DoFn<KV<String, Iterable<Trade>>, KV<String, Double>>() {
            @ProcessElement
            public void processElement(@Element KV<String, Iterable<Trade>> element, ProcessContext context) {
                String key = element.getKey();
                Iterable<Trade> iterable = element.getValue();
                Double max = StreamSupport.stream(iterable.spliterator(), false).mapToDouble(Trade::getTransactionAmount).max().getAsDouble();

                context.output(KV.of(key, max));
            }

        };
    }

    public static DoFn<String, KV<String, Trade>> extractKey() {
        return new DoFn<String, KV<String, Trade>>() {
            @ProcessElement
            public void processElement(@Element String element, ProcessContext context) {
                String[] row = element.split(",");
                Trade trade = new Trade(row[0], Double.valueOf(row[1]));
                context.output(KV.of(trade.traderId, trade));
            }
        };
    }

    private static class Trade implements Serializable {
        private String traderId;
        private Double transactionAmount;

        public Trade(String traderId, Double transactionAmount) {
            this.traderId = traderId;
            this.transactionAmount = transactionAmount;
        }

        public String getTraderId() {
            return traderId;
        }

        public void setTraderId(String traderId) {
            this.traderId = traderId;
        }

        public Double getTransactionAmount() {
            return transactionAmount;
        }

        public void setTransactionAmount(Double transactionAmount) {
            this.transactionAmount = transactionAmount;
        }


        @Override
        public String toString() {
            return MoreObjects.toStringHelper(this)
                    .add("traderId", traderId)
                    .add("transactionAmount", transactionAmount)
                    .toString();
        }
    }
}