package test;

import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple5;
import org.apache.flink.streaming.api.TimeCharacteristic;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.IngestionTimeExtractor;
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.api.java.StreamTableEnvironment;
import org.apache.flink.table.functions.TemporalTableFunction;
import org.apache.flink.types.Row;

public class Test {
  public static void main(String[] args) throws Exception {
    StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
    env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime);
    StreamTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env);

    // Financial instruments data - (InstrumentId, Name, ClosePrice)
    List<Tuple3<Integer, String, Double>> instrumentData = Arrays.asList(
        new Tuple3<>(1, "Apple Inc", 223.58),
        new Tuple3<>(2, "Microsoft Corp", 111.72),
        new Tuple3<>(3, "GlaxoSmithKline", 23.29));

    // Create a temporal table function for the above instrument data
    DataStreamSource<Tuple3<Integer, String, Double>> instrumentStream = env.addSource(new DelayedSource<>(instrumentData, 1L));
    instrumentStream.returns(new TypeHint<Tuple3<Integer, String, Double>>() {});
    DataStream<Tuple3<Integer, String, Double>> instrumentStreamWithTime = instrumentStream.assignTimestampsAndWatermarks(new IngestionTimeExtractor<>());
    Table instruments = tableEnv.fromDataStream(instrumentStreamWithTime, "InstrumentId, Name, ClosePrice, Instrument_EventTime.rowtime");
    TemporalTableFunction instrumentFunction = instruments.createTemporalTableFunction("Instrument_EventTime", "InstrumentId");
    tableEnv.registerFunction("Instrument", instrumentFunction);

    // Trade data - (TradeId, InstrumentId, CounterpartyId, Quantity, UnitPrice)
    List<Tuple5<Integer, Integer, Integer, Double, Double>> tradeData = Arrays.asList(
        new Tuple5<>(1, 1, 2, 4.0, 220.0),
        new Tuple5<>(2, 2, 2, 3.0, 111.3),
        new Tuple5<>(3, 1, 1, 10.0, 222.34),
        new Tuple5<>(4, 3, 1, 6.0, 22.18));

    // Create a trade table based on the above trade data
    DataStreamSource<Tuple5<Integer, Integer, Integer, Double, Double>> tradeStream = env.addSource(new DelayedSource<>(tradeData, 100L));
    tradeStream.returns(new TypeHint<Tuple5<Integer, Integer, Integer, Double, Double>>() {});
    SingleOutputStreamOperator<Tuple5<Integer, Integer, Integer, Double, Double>> tradeStreamWithTime = tradeStream.assignTimestampsAndWatermarks(new IngestionTimeExtractor<>());
    Table trades = tableEnv.fromDataStream(tradeStreamWithTime, "TradeId, t_InstrumentId, t_CounterpartyId, Quantity, UnitPrice, Trade_EventTime.rowtime");

    // Find the number of trades, quantities and costs for each instrument and counterparty combination
    Table groupedTrades = trades.groupBy("t_InstrumentId, t_CounterpartyId")
                                .select("t_InstrumentId, t_CounterpartyId, count(TradeId) as TradeCount, sum(Quantity) as Quantity," +
                                        " sum(Quantity * UnitPrice) as Cost, Max(Trade_EventTime) as LastTrade_EventTime");

    System.out.print("groupedTrades: ");
    groupedTrades.printSchema();

    // Enrich the groupedTrades table with instrument details
    Table instrumentTable = new Table(tableEnv, "Instrument(LastTrade_EventTime)");
    Table tradesByInstr = groupedTrades.join(instrumentTable, "t_InstrumentId = InstrumentId")
                                       .select("InstrumentId, Name, ClosePrice, TradeCount, Quantity, Cost");
    System.out.print("tradesByInstr: ");
    tradesByInstr.printSchema();

    TypeInformation<Row> typeInfo = tradesByInstr.getSchema().toRowType();
    // The following line trigger an exception: java.lang.AssertionError: mismatched type $5 TIMESTAMP(3)
    DataStream<Tuple2<Boolean, Row>> tradesByInstrStream = tableEnv.toRetractStream(tradesByInstr, typeInfo);
    tradesByInstrStream.addSink(new PrintSink<>(typeInfo.toString()));

    env.execute();
    System.out.println("Test completed at " + time());
  }

  public static String time() {
    return LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_TIME);
  }

  /**
   * A source that starts producing data after an initial waiting period
   */
  private static class DelayedSource<T> extends RichSourceFunction<T> {
    private final List<T> data;
    private final long initialDelay;
    private volatile boolean shutdown;

    private DelayedSource(List<T> data, long initialDelay) {
      this.data = data;
      this.initialDelay = initialDelay;
    }

    @Override
    public void run(SourceContext<T> ctx) throws Exception {
      Iterator<T> iterator = data.iterator();
      Thread.sleep(initialDelay);
      while (!shutdown && iterator.hasNext()) {
        T next = iterator.next();
        System.out.println(time() + " - producing " + next);
        ctx.collect(next);
      }
    }

    @Override
    public void cancel() {
      shutdown = true;
    }
  }

  /**
   * A simple sink that just prints out any data it receives
   */
  private static class PrintSink<T> extends RichSinkFunction<T> {
    private String prefix;

    public PrintSink(String prefix) {
      this.prefix = prefix;
    }

    @Override
    public void invoke(T value, Context context) {
      System.out.println(prefix + " = " + value);
    }
  }
}
