package org.apache.calcite;

import com.google.common.collect.ImmutableList;
import org.apache.calcite.adapter.enumerable.EnumerableConvention;
import org.apache.calcite.avatica.util.Quoting;
import org.apache.calcite.linq4j.Enumerable;
import org.apache.calcite.linq4j.QueryProvider;
import org.apache.calcite.linq4j.Queryable;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.plan.ConventionTraitDef;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelTraitDef;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.prepare.Prepare;
import org.apache.calcite.rel.RelCollationTraitDef;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.TableModify;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.schema.ModifiableTable;
import org.apache.calcite.schema.ProjectableFilterableTable;
import org.apache.calcite.schema.ScannableTable;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.schema.Statistic;
import org.apache.calcite.schema.Statistics;
import org.apache.calcite.schema.impl.AbstractSchema;
import org.apache.calcite.schema.impl.AbstractTable;
import org.apache.calcite.sql.SqlExplainFormat;
import org.apache.calcite.sql.SqlExplainLevel;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlConformanceEnum;
import org.apache.calcite.tools.FrameworkConfig;
import org.apache.calcite.tools.Frameworks;
import org.apache.calcite.tools.Planner;
import org.apache.calcite.tools.Programs;
import org.apache.calcite.util.ImmutableBitSet;
import org.junit.Test;

import java.lang.reflect.Type;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;

public class TestCalcite {

    private static class TableColumn {
        private final String columnName;
        private final SqlTypeName type;

        public TableColumn(String cname, SqlTypeName type) {
            this.columnName = cname;
            this.type = type;
        }

        public String getColumnName() {
            return this.columnName;
        }

        public SqlTypeName getType() {
            return this.type;
        }
    }

    private static class TableImpl extends AbstractTable
            implements ModifiableTable, ScannableTable, ProjectableFilterableTable {


        private  final List<TableColumn> columns;
        private final String tName;

        public TableImpl(List<TableColumn> cols, String name) {
            this.columns =  cols;
            this.tName = name;
        }

        @Override
        public Collection getModifiableCollection() {
            return null;
        }

        @Override
        public TableModify toModificationRel(RelOptCluster cluster, RelOptTable table, Prepare.CatalogReader catalogReader, RelNode child, TableModify.Operation operation, List<String> updateColumnList, List<RexNode> sourceExpressionList, boolean flattened) {
            return null;
        }

        @Override
        public Enumerable<Object[]> scan(DataContext root, List<RexNode> filters, int[] projects) {
            return null;
        }

        @Override
        public <T> Queryable<T> asQueryable(QueryProvider queryProvider, SchemaPlus schema, String tableName) {
            return null;
        }

        @Override
        public Type getElementType() {
            return null;
        }

        @Override
        public Expression getExpression(SchemaPlus schema, String tableName, Class clazz) {
            return null;
        }

        @Override
        public Enumerable<Object[]> scan(DataContext root) {
            return null;
        }



        @Override
        public RelDataType getRowType(RelDataTypeFactory typeFactory) {
            RelDataTypeFactory.Builder builder = new RelDataTypeFactory.Builder(typeFactory);

            for (TableColumn tc : columns) {
                builder.add(tc.getColumnName(),  typeFactory.createSqlType(tc.getType()));
            }
            return builder.build();
        }

        @Override
        public Statistic getStatistic() {
            return Statistics.of(15D,
                    ImmutableList.of(ImmutableBitSet.of(0)),
                    ImmutableList.of());
        }
    }


    private static  FrameworkConfig config;

    private static final SqlParser.Config SQL_PARSER_CONFIG
            = SqlParser.configBuilder(SqlParser.Config.DEFAULT)
            .setCaseSensitive(false)
            .setConformance(SqlConformanceEnum.MYSQL_5)
            .setQuoting(Quoting.BACK_TICK)
            .build();

    private static final List<RelTraitDef> TRAITS = Collections
            .unmodifiableList(java.util.Arrays.asList(ConventionTraitDef.INSTANCE,
                    RelCollationTraitDef.INSTANCE));

    private static final SchemaPlus _rootSchema = Frameworks.createRootSchema(true);


    private static SchemaPlus getSchema(String tableSpace, TableImpl ... tables) {
        final SchemaPlus result = _rootSchema.add(tableSpace, new AbstractSchema());
        if (result != null) {
            Arrays.stream(tables).forEach(t -> {
                result.add(t.tName, t);
            });
        }
        return result;
    }


    public static void init(String defaultTableSpace, TableImpl ... tables) {
        config = Frameworks.newConfigBuilder()
                .parserConfig(SQL_PARSER_CONFIG)
                .defaultSchema(getSchema(defaultTableSpace, tables))
                .traitDefs(TRAITS)
                // define the rules you want to apply

                .programs(Programs.ofRules(Programs.RULE_SET))
                .build();

    }

    private static Planner getPlanner() {
        Planner planner = Frameworks.getPlanner(config);
        return planner;
    }

    public static void testQuery(String query) {
        try {
            Planner planner = getPlanner();
            SqlNode n = planner.parse(query);
            n = planner.validate(n);
            RelNode root = planner.rel(n).project();
            RelOptCluster cluster = root.getCluster();
            final RelOptPlanner optPlanner = cluster.getPlanner();

            RelTraitSet desiredTraits  =
                    cluster.traitSet().replace(EnumerableConvention.INSTANCE);
            final RelNode newRoot = optPlanner.changeTraits(root, desiredTraits);
                System.out.println(
                        RelOptUtil.dumpPlan("-- Mid Plan", newRoot, SqlExplainFormat.TEXT,
                                SqlExplainLevel.DIGEST_ATTRIBUTES));
            optPlanner.setRoot(newRoot);
            RelNode bestExp = optPlanner.findBestExp();
                System.out.println(
                        RelOptUtil.dumpPlan("-- Best Plan", bestExp, SqlExplainFormat.TEXT,
                                SqlExplainLevel.DIGEST_ATTRIBUTES));
        } catch(Exception ex) {
            ex.printStackTrace();
        }
    }

//    @Test
//    public void updateQueryFailure() {
//        TableColumn tck1 = new TableColumn("k1", SqlTypeName.VARCHAR);
//        TableColumn tck2 = new TableColumn("k2", SqlTypeName.VARCHAR);
//        TableColumn tcfk = new TableColumn("fk", SqlTypeName.VARCHAR);
//
//        TableColumn tcn1 = new TableColumn("n1", SqlTypeName.INTEGER);
//
//
//
//        TableImpl table1 = new TableImpl(Arrays.asList(tck1, tcn1), "table1");
//
//        TableImpl table2 = new TableImpl(Arrays.asList(tck2,tcfk), "table2");
//
//        init("tblspace1",table1, table2);
//        testQuery("UPDATE tblspace1.table1 set n1=1000"
//                + "WHERE k1 in (SELECT fk FROM tblspace1.table2 WHERE k2=10)");
//    }
//
//    @Test
//    public void minMaxQueryFailure() {
//        TableColumn tck1 =  new TableColumn("k1", SqlTypeName.VARCHAR);
//        TableColumn tcn1 =  new TableColumn("n1", SqlTypeName.INTEGER);
//        TableColumn tcs1 =  new TableColumn("s1", SqlTypeName.VARCHAR);
//
//        TableImpl table1 = new TableImpl(Arrays.asList(tck1, tcn1,tcs1), "tsql");
//
//        init("tblspace1", table1);
//
//        //CREATE TABLE tblspace1.tsql (k1 string primary key,n1 int,s1 string)
//        String query = "SELECT MIN(n1) as mi, MAX(n1) as ma FROM tblspace1.tsql WHERE k1='no_results' GROUP BY k1";
//        testQuery(query);
//
//    }

    @Test
    public void problem_with_1_21() {
        //            execute(manager, "CREATE TABLE tblspace1.tsql (k1 string primary key,n1 int,s1 string)", Collections.emptyList());

        TableColumn tck1 =  new TableColumn("k1", SqlTypeName.VARCHAR);
        TableColumn tcn1 =  new TableColumn("n1", SqlTypeName.INTEGER);
        TableColumn tcs1 =  new TableColumn("s1", SqlTypeName.VARCHAR);

        TableImpl table1 = new TableImpl(Arrays.asList(tck1, tcn1,tcs1), "tsql");

        init("tblspace1", table1);

        String query = "SELECT * FROM tblspace1.tsql where n1=? and k1 in (SELECT k1 FROM tblspace1.tsql where n1=?)";
        testQuery(query);

    }
}
