package streamtest;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

import org.apache.avro.generic.GenericRecord;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.contrib.streaming.state.EmbeddedRocksDBStateBackend;
import org.apache.flink.runtime.state.storage.FileSystemCheckpointStorage;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.CheckpointConfig.ExternalizedCheckpointCleanup;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.util.Collector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FlinkGenericRecordTypeSafeJob {

	private static Logger LOG = LoggerFactory.getLogger(FlinkGenericRecordTypeSafeJob.class);

	static Random random = new Random();

	public static void main(String[] args) throws Exception {
		LOG.debug("Entering FlinkGenericRecordTypeSafeJob...");

		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
		env.disableOperatorChaining();
		// env.getConfig().disableGenericTypes();
		// env.getConfig().disableForceKryo();

		EmbeddedRocksDBStateBackend embedded = new EmbeddedRocksDBStateBackend();
		embedded.getMemoryConfiguration().setUseManagedMemory(false);
		// embedded.getMemoryConfiguration().setWriteBufferRatio(.8);
		env.setStateBackend(embedded);

		env.getCheckpointConfig().setCheckpointStorage("file:///mnt/c/Users/ezmitka/checkpointDir");
		env.getCheckpointConfig().setCheckpointInterval(5000);
		env.getCheckpointConfig()
				.setCheckpointStorage(new FileSystemCheckpointStorage("file:///mnt/c/Users/ezmitka/checkpointDir"));
		env.getCheckpointConfig()
				.setExternalizedCheckpointCleanup(ExternalizedCheckpointCleanup.RETAIN_ON_CANCELLATION);

		// Simulated Kafka byte[] input
		
		  List<byte[]> kafkaBytes = new ArrayList<>();
		  
			for (int i = 0; i < 10; i++) {

				byte[] msg = KafkaAvroMessageGenerator.createAvroMessageSingleObjectEncoding(1, 1001, "Alice");
				System.out.println("Input record size: " + msg.length + " bytes");
				kafkaBytes.add(msg);
				kafkaBytes.add(KafkaAvroMessageGenerator.createAvroMessageSingleObjectEncoding(1, 1002, "Bob"));

				kafkaBytes
						.add(KafkaAvroMessageGenerator.createProductAvroMessageSingleObjectEncoding(2, 501, "Gadget"));
				kafkaBytes
						.add(KafkaAvroMessageGenerator.createProductAvroMessageSingleObjectEncoding(2, 502, "Widget"));
			}

			//DataStream<byte[]> byteStream = env.fromCollection(kafkaBytes);
		 

		
		DataStream<byte[]> byteStream = env.addSource(new SourceFunction<byte[]>() {
			private volatile boolean running = true;

			@Override
			public void run(SourceContext<byte[]> ctx) throws Exception {
				int count = 0;
				while (running) {

					int index = random.nextInt(kafkaBytes.size());
					byte[] msg = kafkaBytes.get(index);
					System.out.println("Input record size: " + msg.length + " bytes");
					ctx.collect(msg);
					Thread.sleep(1000);

				}
			}
		  
		  @Override public void cancel() { running = false; } });
		 

		// Deserialize byte[] -> GenericRecord
		DataStream<GenericRecord> avroStream = byteStream.map(new MapFunction<byte[], GenericRecord>() {
			private final CustomAvroDeserializationSchema deserializer = new CustomAvroDeserializationSchema();

			@Override
			public GenericRecord map(byte[] value) throws Exception {
				return deserializer.deserialize(value);
			}
		}).setParallelism(2)
				// .returns(GenericRecord.class));
				.returns(new DynamicGenericRecordTypeInfo()); // Avoid Kryo fallback

		// === Delay each record by 5 seconds using timer

		DataStream<GenericRecord> delayedStream = avroStream.keyBy(record -> {
			if (record.getSchema().getField("id") != null) {
				return record.get("id").toString();
			} else if (record.getSchema().getField("productId") != null) {
				return record.get("productId").toString();
			} else {
				return "unknown";
			}
		}).process(new DelayFunction(5000)).returns(new DynamicGenericRecordTypeInfo()); 

		// === Intermediate Operator 1: Add a suffix to 'name' field ===
		DataStream<GenericRecord> modifiedStream = delayedStream.map(new MapFunction<GenericRecord, GenericRecord>() {
			@Override
			public GenericRecord map(GenericRecord record) {
				String currentName = record.get("name").toString();
				record.put("name", currentName + "_processed");
				return record;
			}
		}).returns(new DynamicGenericRecordTypeInfo());

		// === Intermediate Operator 2: Log contents ===
		modifiedStream.map(new MapFunction<GenericRecord, String>() {
			@Override
			public String map(GenericRecord record) {
				String log = "Modified Record: " + record.toString();
				System.out.println(log);
				return log;
			}
		}).returns(TypeInformation.of(String.class)); // no Kryo needed for strings

		// === Serialize GenericRecord -> byte[] ===
		DataStream<byte[]> serializedStream = modifiedStream.map(new MapFunction<GenericRecord, byte[]>() {
			private final CustomAvroSerializationSchema serializer = new CustomAvroSerializationSchema();

			@Override
			public byte[] map(GenericRecord record) {
				return serializer.serialize(record);
			}
		});

		// === Simulated Kafka Sink (print bytes length) ===
		serializedStream.map(new MapFunction<byte[], String>() {
			@Override
			public String map(byte[] bytes) {

				return "Serialized record size: " + bytes.length + " bytes";
			}
		}).print();

		env.execute("Flink Avro Serialization Pipeline with Custom TypeInfo");
	}

	// === Non-blocking delay using timer
	/*
	 * public static class DelayFunction extends KeyedProcessFunction<String,
	 * GenericRecord, GenericRecord> {
	 * 
	 * private final long delayMs; private transient ListState<byte[]> recordState;
	 * private transient CustomAvroDeserializationSchema deserializer;
	 * 
	 * public DelayFunction(long delayMs) { this.delayMs = delayMs; }
	 * 
	 * @Override public void open(org.apache.flink.configuration.Configuration
	 * parameters) throws Exception { ListStateDescriptor<byte[]> descriptor = new
	 * ListStateDescriptor<>("recordState", byte[].class); recordState =
	 * getRuntimeContext().getListState(descriptor); deserializer = new
	 * CustomAvroDeserializationSchema(); }
	 * 
	 * @Override public void processElement(GenericRecord value, Context ctx,
	 * Collector<GenericRecord> out) throws Exception { // Serialize the record and
	 * store the bytes in state byte[] serialized = new
	 * CustomAvroSerializationSchema().serialize(value);
	 * recordState.add(serialized);
	 * 
	 * long currentTime = ctx.timerService().currentProcessingTime(); long
	 * triggerTime = currentTime + delayMs;
	 * 
	 * System.out.println("processElement: record=" + value + " currentTime=" +
	 * currentTime + " triggerTime=" + triggerTime + " key=" + ctx.getCurrentKey());
	 * 
	 * ctx.timerService().registerProcessingTimeTimer(triggerTime); }
	 * 
	 * @Override public void onTimer(long timestamp, OnTimerContext ctx,
	 * Collector<GenericRecord> out) throws Exception {
	 * System.out.println("onTimer triggered at timestamp: " + timestamp +
	 * " for key: " + ctx.getCurrentKey());
	 * 
	 * for (byte[] bytes : recordState.get()) { GenericRecord record =
	 * deserializer.deserialize(bytes);
	 * System.out.println("Emitting delayed record: " + record);
	 * out.collect(record); } recordState.clear();
	 * System.out.println("Cleared state after emitting records for key: " +
	 * ctx.getCurrentKey()); } }
	 */

	public static class DelayFunction extends KeyedProcessFunction<String, GenericRecord, GenericRecord> {

		private final long delayMs;

		private transient ListState<byte[]> recordState;
		private transient ValueState<Long> timerState;
		private transient CustomAvroDeserializationSchema deserializer;

		public DelayFunction(long delayMs) {
			this.delayMs = delayMs;
		}

		@Override
		public void open(Configuration parameters) throws Exception {
			ListStateDescriptor<byte[]> recordStateDescriptor = new ListStateDescriptor<>("recordState", byte[].class);
			recordState = getRuntimeContext().getListState(recordStateDescriptor);

			ValueStateDescriptor<Long> timerStateDescriptor = new ValueStateDescriptor<>("timerState", Long.class);
			timerState = getRuntimeContext().getState(timerStateDescriptor);

			deserializer = new CustomAvroDeserializationSchema();
		}

		@Override
		public void processElement(GenericRecord value, Context ctx, Collector<GenericRecord> out) throws Exception {
			byte[] serialized = new CustomAvroSerializationSchema().serialize(value);
			recordState.add(serialized);

			long currentTime = ctx.timerService().currentProcessingTime();
			long triggerTime = currentTime + delayMs;

			Long registeredTimer = timerState.value();

			// Register timer only if none exists or new trigger time is later
			if (registeredTimer == null || triggerTime > registeredTimer) {
				if (registeredTimer != null) {
					ctx.timerService().deleteProcessingTimeTimer(registeredTimer);
				}
				ctx.timerService().registerProcessingTimeTimer(triggerTime);
				timerState.update(triggerTime);
			}

			// Optional debug log
			System.out.printf("processElement: key=%s, currentTime=%d, triggerTime=%d, registeredTimer=%s%n",
					ctx.getCurrentKey(), currentTime, triggerTime, timerState.value());
		}

		@Override
		public void onTimer(long timestamp, OnTimerContext ctx, Collector<GenericRecord> out) throws Exception {
			System.out.printf("onTimer triggered at timestamp: %d for key: %s%n", timestamp, ctx.getCurrentKey());

			for (byte[] bytes : recordState.get()) {
				GenericRecord record = deserializer.deserialize(bytes);
				System.out.println("Emitting delayed record: " + record);
				out.collect(record);
			}

			recordState.clear();
			timerState.clear();

			System.out.println("Cleared state after emitting records for key: " + ctx.getCurrentKey());
		}
	}
}
