import org.apache.arrow.gandiva.evaluator.Filter;
import org.apache.arrow.gandiva.evaluator.Projector;
import org.apache.arrow.gandiva.evaluator.SelectionVectorInt32;
import org.apache.arrow.gandiva.exceptions.GandivaException;
import org.apache.arrow.gandiva.expression.Condition;
import org.apache.arrow.gandiva.expression.ExpressionTree;
import org.apache.arrow.gandiva.expression.TreeBuilder;
import org.apache.arrow.gandiva.expression.TreeNode;
import org.apache.arrow.gandiva.ipc.GandivaTypes.SelectionVectorType;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.TimeStampMicroTZVector;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

public class GandivaFilterTest {

    private static final Field time = new Field("time", FieldType.notNullable(new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC")), null);
    private static final Field id = new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null);
    private static final Field value = new Field("value", FieldType.nullable(Types.MinorType.FLOAT8.getType()), null);
    private static final Field type = new Field("type", FieldType.nullable(new ArrowType.Utf8()), null);
    private static final Schema schema = new Schema(asList(time, id, value, type));
    private static final Integer BATCH_SIZE = 10; // 32 * 1024; //32768
    private static final Long START_TIMESTAMP = 1672790400000L;
    private static final String[] TYPES = new String[]{"A", "B", "C"};
    private static final Random random = new Random();

    @Test
    public void testFilterOnIntColumn() throws GandivaException {
        try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) {
            // 30 batches of size 32 * 1024 and the last batch is 16960
            // would make 1 million rows
            int batchSize = BATCH_SIZE;
            long startTime = 0;
            long endTime = 0;

            for (int i = 0; i <= 30; i++) {
                if (i == 30) {
                    batchSize = 10; // 16960;
                }
                VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
                generateData(root, batchSize);
                root.setRowCount(batchSize);
                System.err.println("full table");
                System.err.println(root.contentToTSVString());

                // Compose a filter node
                // id is among [4, 5]
                //

                TreeNode l1 = TreeBuilder.makeLiteral(4);
                List<TreeNode> args1 = ImmutableList.of(TreeBuilder.makeField(schema.findField("id")), l1);
                TreeNode l2 = TreeBuilder.makeLiteral(5);
                List<TreeNode> args2 = ImmutableList.of(TreeBuilder.makeField(schema.findField("id")), l2);
                // Compose the expression
                // See https://github.com/apache/arrow/blob/master/cpp/src/gandiva/function_registry_arithmetic.cc
                TreeNode relationalEqual1 = TreeBuilder.makeFunction("equal", args1, BIT.getType());
                TreeNode relationalEqual2 = TreeBuilder.makeFunction("equal", args2, BIT.getType());
                TreeNode logicalOr = TreeBuilder.makeOr(ImmutableList.of(relationalEqual1, relationalEqual2));


                if (i == 0) {
                    startTime = System.currentTimeMillis();
                }
                try (VectorSchemaRoot filteredRoot = take(allocator, root, TreeBuilder.makeCondition(logicalOr))) {
                    System.err.println(filteredRoot.getRowCount());
                }

                if (i == 30) {
                    endTime = System.currentTimeMillis();
                }
            }
            long duration = (endTime - startTime);  //divide by 1000000 to get milliseconds.
            System.err.println(duration);
        }
    }

    private static void generateData(VectorSchemaRoot root, int batchSize) {
        TimeStampMicroTZVector timeVector = (TimeStampMicroTZVector) root.getVector("time");
        timeVector.allocateNew(batchSize);
        IntVector idVector = (IntVector) root.getVector("id");
        idVector.allocateNew(batchSize);
        Float8Vector valueVector = (Float8Vector) root.getVector("value");
        valueVector.allocateNew(batchSize);
        VarCharVector typeVector = (VarCharVector) root.getVector("type");
        typeVector.allocateNew(batchSize);

        for (int j = 0; j < batchSize; j++) {
            timeVector.set(j, START_TIMESTAMP + j);
            idVector.set(j, random.nextInt(10));
            valueVector.set(j, 1 + ((float) random.nextInt(10)) / 10);
            String type = TYPES[random.nextInt(3)];
            typeVector.set(j, type.getBytes());
        }

        timeVector.setValueCount(batchSize);
        idVector.setValueCount(batchSize);
        valueVector.setValueCount(batchSize);
        typeVector.setValueCount(batchSize);
    }

    @Test
    public void testFilterOnStringColumn() throws GandivaException {
        try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) {
            // 30 batches of size 32 * 1024 and the last batch is 16960
            // would make 1 million rows
            int batchSize = BATCH_SIZE;
            long startTime = 0;
            long endTime = 0;

            for (int i = 0; i <= 30; i++) {
                if (i == 30) {
                    batchSize = 10 ; //16960;
                }
                VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
                generateData(root, batchSize);
                root.setRowCount(batchSize);
                System.err.println("full table");
                System.err.println(root.contentToTSVString());

                // Compose a filter node
                // type is among ["A", "B", "C"]
                TreeNode l1 = TreeBuilder.makeStringLiteral("A");
                List<TreeNode> args1 = ImmutableList.of(TreeBuilder.makeField(schema.findField("type")), l1);
                TreeNode l2 = TreeBuilder.makeStringLiteral("B");
                List<TreeNode> args2 = ImmutableList.of(TreeBuilder.makeField(schema.findField("type")), l2);
                TreeNode l3 = TreeBuilder.makeStringLiteral("C");
                List<TreeNode> args3 = ImmutableList.of(TreeBuilder.makeField(schema.findField("type")), l3);
                // Compose the expression
                // See https://github.com/apache/arrow/blob/master/cpp/src/gandiva/function_registry_arithmetic.cc
                TreeNode relationalEqual1 = TreeBuilder.makeFunction("equal", args1, BIT.getType());
                TreeNode relationalEqual2 = TreeBuilder.makeFunction("equal", args2, BIT.getType());
                TreeNode relationalEqual3 = TreeBuilder.makeFunction("equal", args3, BIT.getType());
                TreeNode logicalOr = TreeBuilder.makeOr(ImmutableList.of(relationalEqual1, relationalEqual2, relationalEqual3));

                if (i == 0) {
                    startTime = System.currentTimeMillis();
                }
                try (VectorSchemaRoot filteredRoot = take(allocator, root, TreeBuilder.makeCondition(logicalOr))) {
                    System.err.println(filteredRoot.contentToTSVString());
                }

                if (i == 30) {
                    endTime = System.currentTimeMillis();
                }
                root.close();
            }
            long duration = (endTime - startTime);  //divide by 1000000 to get milliseconds.
            System.err.println(duration);
        }
    }

    private static VectorSchemaRoot take(
            BufferAllocator allocator, VectorSchemaRoot root, Condition condition)
            throws GandivaException {
        Schema schema = root.getSchema();
        Filter filter = Filter.make(schema, condition);
        try {
            final int rowCount = root.getRowCount();
            // Convert VSR to ArrowRecordBatch so that we can apply Filter and Projector.
            try (ArrowRecordBatch recordBatch = new VectorUnloader(root).getRecordBatch();
                    // buffer size = row_count * 4 bytes per entry
                    ArrowBuf selectionBuffer = allocator.buffer(rowCount * 4L)) {
                return take(
                        allocator, root, schema, recordBatch, rowCount, filter, selectionBuffer);
            }
        } finally {
            filter.close();
        }
    }

    private static VectorSchemaRoot take(
            BufferAllocator allocator,
            VectorSchemaRoot root,
            Schema schema,
            ArrowRecordBatch recordBatch,
            int rowCount,
            Filter filter,
            ArrowBuf selectionBuffer)
            throws GandivaException {
        // Allocate selection vector.
        // After filter evaluation, the selection buffer will contain
        // the indexes of rows for which the expression is true.
        SelectionVectorInt32 selectionVector = new SelectionVectorInt32(selectionBuffer);

        // Evaluate the filter
        filter.evaluate(recordBatch, selectionVector);

        final int matchedRowCount = selectionVector.getRecordCount();
        if (matchedRowCount == rowCount) {
            System.err.println("All rows match");
            // No need to project
            return root;
        } else if (matchedRowCount == 0) {
            System.err.println("No rows match");
            // Return empty root
            VectorSchemaRoot emptyRoot = VectorSchemaRoot.create(schema, allocator);
            emptyRoot.setRowCount(0);
            root.close();
            return emptyRoot;
        } else {
            System.err.println(matchedRowCount);
            System.err.println("A proper subset of rows match");
            root.close();
        }

        // Create an expression tree for each column.
        // Is there another way of doing projection?
        // Is there a take() function in the Java/Gandiva?
        ImmutableList.Builder<ExpressionTree> fieldNodes = ImmutableList.builder();
        for (Field field : schema.getFields()) {
            fieldNodes.add(TreeBuilder.makeExpression(TreeBuilder.makeField(field), field));
        }

        // Allocate the new root
        VectorSchemaRoot newRoot = VectorSchemaRoot.create(schema, allocator);
        newRoot.allocateNew();
        ((TimeStampMicroTZVector) newRoot.getVector("time")).allocateNew(rowCount);
        ((IntVector) newRoot.getVector("id")).allocateNew(rowCount);
        ((Float8Vector) newRoot.getVector("value")).allocateNew(rowCount);
        ((VarCharVector) newRoot.getVector("type")).allocateNew(rowCount);

        // Output vectors.
        // The projector fills these vectors.
        List<ValueVector> output = new ArrayList<>(newRoot.getFieldVectors());

        // Build the projector.
        Projector projector =
                Projector.make(schema, fieldNodes.build(), SelectionVectorType.SV_INT32);
        try {
            projector.evaluate(recordBatch, selectionVector, output);
        } finally {
            projector.close();
        }

        newRoot.setRowCount(matchedRowCount);
        return newRoot;
    }
}
