package test;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;

import java.io.Serializable;
import java.util.*;

import static org.apache.spark.sql.functions.*;

/**
 * Created by asrivast on 1/25/17.
 */
public class EdgeAggregator implements Serializable {
    public void execute(SparkSession sparkSession, Dataset<Row> adadvisorLogs) {
        sparkSession.udf().register("freqAgg", new FrequencyAggregator());

        adadvisorLogs.groupBy("id")
                .agg(max("ts")).show();

        adadvisorLogs.groupBy("id")
                .agg(callUDF("freqAgg", col("ts"))).show();
    }

    public static void main(String[] args) {
        List<Row> inputData = new ArrayList<>();
        inputData.add(RowFactory.create("34109042", "20160108114934"));
        inputData.add(RowFactory.create("97805672", "20151217113409"));
        inputData.add(RowFactory.create("97805672", "20151222154655"));
        inputData.add(RowFactory.create("28616951", "20160106062918"));
        inputData.add(RowFactory.create("148578626", "20160108180003"));
        inputData.add(RowFactory.create("1024666", "20151230000725"));
        inputData.add(RowFactory.create("1024666", "20151230000725"));
        inputData.add(RowFactory.create("1024666", "20151230102025"));
        StructType schema = new StructType().add("id", DataTypes.StringType)
                .add("ts", DataTypes.StringType);

        SparkSession sparkSession = SparkSession
                .builder()
                .appName("IPCountFilterTest")
                .master("local")
                .getOrCreate();

        Dataset<Row> rows = sparkSession.createDataFrame(inputData, schema);
        new EdgeAggregator().execute(sparkSession, rows);

    }

    class FrequencyAggregator extends UserDefinedAggregateFunction {

        @Override
        public StructType inputSchema() {
            return new StructType().add("ts", DataTypes.StringType);
        }

        @Override
        public StructType bufferSchema() {
            return new StructType()
                    .add("freq", DataTypes.createMapType(DataTypes.StringType, DataTypes.IntegerType));
        }

        @Override
        public DataType dataType() {
            return new StructType()
                    .add("freq", DataTypes.createMapType(DataTypes.StringType, DataTypes.IntegerType));
        }

        @Override
        public boolean deterministic() {
            return true;
        }

        @Override
        public void initialize(MutableAggregationBuffer buffer) {
            buffer.update(0, Collections.EMPTY_MAP);
        }

        @Override
        public void update(MutableAggregationBuffer buffer, Row input) {
            String hr = input.getString(0).substring(0, 10);
            Map<String, Integer> freqMap = new HashMap<>(buffer.getJavaMap(0));
            Integer freq = freqMap.containsKey(hr) ? freqMap
                    .get(hr) + 1 : 1;
            freqMap.put(hr, freq);
            buffer.update(0, freqMap);
        }

        @Override
        public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
            Map<String, Integer> freqMap1 = new HashMap<>(buffer1.getJavaMap(0));
            Map<String, Integer> freqMap2 = buffer2.getJavaMap(0);
            for (Map.Entry<String, Integer> e : freqMap2
                    .entrySet()) {
                String hr = e.getKey();
                Integer freq = freqMap1
                        .containsKey(hr) ? freqMap1
                        .get(hr) + e.getValue() : e
                        .getValue();
                freqMap1.put(hr, freq);
            }
            buffer1.update(0, freqMap1);
        }


        @Override
        public Object evaluate(Row buffer) {
            return buffer.getJavaMap(0);
        }
    }
}
