package main.java.chapter2;

import static java.util.Arrays.asList;

import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.Serializable;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import org.apache.arrow.flight.Action;
import org.apache.arrow.flight.ActionType;
import org.apache.arrow.flight.Criteria;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.PutResult;
import org.apache.arrow.flight.Result;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.log4j.Logger;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF1;

//import buffers.readers.ArrowBufferReader;
public class OrdersFlightProducerSparkGeneral implements FlightProducer, Serializable {
  private int count;
  private String inpath;
  private SparkSession spark;
  private String mnmfile = "data1/mnm_dataset_small.csv";
  private ObjectMapper mapper = new ObjectMapper();
  private static final Logger LOGGER = Logger.getLogger(OrdersFlightProducerSparkGeneral.class);
  private Dataset<Row> mnmDF = null;

  public OrdersFlightProducerSparkGeneral(int count, String inpath) {
    this.count = count;
    this.inpath = inpath;
    spark = SparkSession
        .builder()
        .appName("MnMCount")
        .getOrCreate();
    mnmDF = spark.read().format("csv")
        .option("header", "true")
        .option("inferSchema", "true")
        .load(mnmfile);
    mnmDF.show();
  }

  //private static final Logger LOGGER = Logger.getLogger(OrdersFlightProducerSpark.class);

  private UDF1<String, byte[]> bytesUdf() {
    return (s1) -> {
      return mapper.writeValueAsBytes(s1);
    };
  }

  @Override
  public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) {
    System.out.println("started data exchanging");

    while (reader.next()) {
      System.out.println("recieved " + reader.getRoot().getRowCount() + " rows");
    }
    System.out.println("finished");
  }

  @Override
  public void getStream(CallContext context, Ticket ticket, ServerStreamListener clientStreamListener) {
    //ArrowBufferReader reader = new ArrowBufferReader(inpath);
    System.out.println("Called getStream");
    //mapper.registerModule(DefaultScalaModule);
    BufferAllocator rootAllocator = new RootAllocator();
    Path path = Paths.get(inpath);
//    Dataset<Row> mnmDF = spark.read().format("csv")
//        .option("header", "true")
//        .option("inferSchema", "true")
//        .load(mnmfile);
//    mnmDF.show();
    // byte[] bytes = mapper.writeValueAsBytes(name);
    Function<Row, String> func = new Function<Row, String>() {
      public String call(Row s) {
        return s.getString(0);
      }
    };
    Function<Row, Integer> funcInt = new Function<Row, Integer>() {
      public Integer call(Row s) {
        return new Integer(s.getInt(0));
      }
    };
    //.map(f=>f.getString(0)
    List<List<String>> statesGroups = new ArrayList<>();
    List<List<String>> colorsGroups = new ArrayList<>();
    List<List<Integer>> countsGroups = new ArrayList<>();
    try {
      try (BufferAllocator allocator = new RootAllocator()) {
        Field state = new Field("state", FieldType.nullable(new ArrowType.Utf8()), null);
        Field color = new Field("color", FieldType.nullable(new ArrowType.Utf8()), null);
        Field count = new Field("count", FieldType.nullable(new ArrowType.Int(32, true)), null);
        Schema schemaPerson = new Schema(asList(state, color, count));
        try (
            VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schemaPerson, allocator)
        ) {
          mnmDF.foreachPartition(iterator -> {
              int i = 0;
              List<String> listValuesState = new ArrayList<>();
              List<String> listValuesColor = new ArrayList<>();
              List<Integer> listValuesCount = new ArrayList<>();
              while (iterator.hasNext()) {
                Row row = iterator.next();
                String stateStr = row.getAs("State");
                String colorStr = row.getAs("Color");
                int countInt = row.getAs("Count");
                listValuesState.add(stateStr);
                listValuesColor.add(colorStr);
                listValuesCount.add(countInt);
                VarCharVector stateVector = (VarCharVector) vectorSchemaRoot.getVector("state");
                stateVector.allocateNew(1);
                stateVector.set(i, stateStr.getBytes());

                VarCharVector colorVector = (VarCharVector) vectorSchemaRoot.getVector("color");
                colorVector.allocateNew(1);
                colorVector.set(i, colorStr.getBytes());

                IntVector countVector = (IntVector) vectorSchemaRoot.getVector("count");
                countVector.allocateNew(1);
                countVector.set(i, countInt);
                i++;
              }
              statesGroups.add(listValuesState);
              colorsGroups.add(listValuesColor);
              countsGroups.add(listValuesCount);
              vectorSchemaRoot.setRowCount(i);
              clientStreamListener.start(vectorSchemaRoot);
              clientStreamListener.putNext();
              clientStreamListener.completed();
              System.out.println(vectorSchemaRoot.getRowCount());
          });
        } catch (Exception ex) {
          ex.printStackTrace();
        }
      }
    } catch (Exception ex) {
      ex.printStackTrace();
    }

    try (BufferAllocator allocator = new RootAllocator()) {
      Field state = new Field("state", FieldType.nullable(new ArrowType.Utf8()), null);
      Field color = new Field("color", FieldType.nullable(new ArrowType.Utf8()), null);
      Field count = new Field("count", FieldType.nullable(new ArrowType.Int(32, true)), null);
      Schema schemaPerson = new Schema(asList(state, color, count));

      try (
          VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schemaPerson, allocator)
      ) {

      }
    }

    System.out.println("sending done");
  }

  @Override
  public void listFlights(CallContext context, Criteria criteria, StreamListener<FlightInfo> listener) {
    System.out.println("listflights");
  }

  @Override
  public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
    throw new UnsupportedOperationException();
  }

  @Override
  public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) {
    return null;
  }

  @Override
  public void doAction(CallContext context, Action action, StreamListener<Result> listener) {
    throw new UnsupportedOperationException();
  }

  @Override
  public void listActions(CallContext context, StreamListener<ActionType> listener) {
    throw new UnsupportedOperationException();
  }
}
