http://git-wip-us.apache.org/repos/asf/calcite/blob/5acf84f3/druid/src/main/java/org/apache/calcite/adapter/druid/DruidQuery.java ---------------------------------------------------------------------- diff --git a/druid/src/main/java/org/apache/calcite/adapter/druid/DruidQuery.java b/druid/src/main/java/org/apache/calcite/adapter/druid/DruidQuery.java index 4056412..aa9a02a 100644 --- a/druid/src/main/java/org/apache/calcite/adapter/druid/DruidQuery.java +++ b/druid/src/main/java/org/apache/calcite/adapter/druid/DruidQuery.java @@ -19,7 +19,6 @@ package org.apache.calcite.adapter.druid; import org.apache.calcite.DataContext; import org.apache.calcite.avatica.ColumnMetaData; import org.apache.calcite.config.CalciteConnectionConfig; -import org.apache.calcite.config.CalciteConnectionProperty; import org.apache.calcite.interpreter.BindableRel; import org.apache.calcite.interpreter.Bindables; import org.apache.calcite.interpreter.Compiler; @@ -32,7 +31,6 @@ import org.apache.calcite.plan.RelOptCost; import org.apache.calcite.plan.RelOptPlanner; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptTable; -import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.AbstractRelNode; import org.apache.calcite.rel.RelFieldCollation; @@ -53,59 +51,118 @@ import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; -import org.apache.calcite.rex.RexUtil; import org.apache.calcite.runtime.Hook; import org.apache.calcite.schema.ScannableTable; import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.validate.SqlValidatorUtil; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.Litmus; import org.apache.calcite.util.Pair; -import org.apache.calcite.util.TimestampString; import org.apache.calcite.util.Util; import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.core.JsonGenerator; -import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Function; import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.base.Strings; +import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; -import com.google.common.collect.Sets; +import com.google.common.collect.Maps; import org.joda.time.Interval; import java.io.IOException; import java.io.StringWriter; -import java.math.BigDecimal; -import java.text.SimpleDateFormat; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Objects; -import java.util.Properties; import java.util.Set; import java.util.TimeZone; import java.util.regex.Pattern; -import static org.apache.calcite.sql.SqlKind.INPUT_REF; +import javax.annotation.Nullable; /** * Relational expression representing a scan of a Druid data set. */ public class DruidQuery extends AbstractRelNode implements BindableRel { + /** + * Provides a standard list of supported Calcite operators that can be converted to + * Druid Expressions. This can be used as is or re-adapted based on underline + * engine operator syntax. + */ + public static final List<DruidSqlOperatorConverter> DEFAULT_OPERATORS_LIST = + ImmutableList.<DruidSqlOperatorConverter>builder() + .add(new DirectOperatorConversion(SqlStdOperatorTable.EXP, "exp")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.CONCAT, "concat")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.DIVIDE_INTEGER, "div")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.LIKE, "like")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.LN, "log")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.SQRT, "sqrt")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.LOWER, "lower")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.LOG10, "log10")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.REPLACE, "replace")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.UPPER, "upper")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.POWER, "pow")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.ABS, "abs")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.SIN, "sin")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.COS, "cos")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.TAN, "tan")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.CASE, "case_searched")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.CHAR_LENGTH, "strlen")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.CHARACTER_LENGTH, "strlen")) + .add(new BinaryOperatorConversion(SqlStdOperatorTable.EQUALS, "==")) + .add(new BinaryOperatorConversion(SqlStdOperatorTable.NOT_EQUALS, "!=")) + .add(new NaryOperatorConverter(SqlStdOperatorTable.OR, "||")) + .add(new NaryOperatorConverter(SqlStdOperatorTable.AND, "&&")) + .add(new BinaryOperatorConversion(SqlStdOperatorTable.LESS_THAN, "<")) + .add(new BinaryOperatorConversion(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, "<=")) + .add(new BinaryOperatorConversion(SqlStdOperatorTable.GREATER_THAN, ">")) + .add(new BinaryOperatorConversion(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, ">=")) + .add(new BinaryOperatorConversion(SqlStdOperatorTable.PLUS, "+")) + .add(new BinaryOperatorConversion(SqlStdOperatorTable.MINUS, "-")) + .add(new BinaryOperatorConversion(SqlStdOperatorTable.MULTIPLY, "*")) + .add(new BinaryOperatorConversion(SqlStdOperatorTable.DIVIDE, "/")) + .add(new BinaryOperatorConversion(SqlStdOperatorTable.MOD, "%")) + .add(new DruidSqlCastConverter()) + .add(new ExtractOperatorConversion()) + .add(new UnaryPrefixOperatorConversion(SqlStdOperatorTable.NOT, "!")) + .add(new UnaryPrefixOperatorConversion(SqlStdOperatorTable.UNARY_MINUS, "-")) + .add(new UnarySuffixOperatorConversion(SqlStdOperatorTable.IS_FALSE, "<= 0")) + .add(new UnarySuffixOperatorConversion(SqlStdOperatorTable.IS_NOT_TRUE, "<= 0")) + .add(new UnarySuffixOperatorConversion(SqlStdOperatorTable.IS_TRUE, "> 0")) + .add(new UnarySuffixOperatorConversion(SqlStdOperatorTable.IS_NOT_FALSE, "> 0")) + .add(new UnarySuffixOperatorConversion(SqlStdOperatorTable.IS_NULL, "== null")) + .add(new UnarySuffixOperatorConversion(SqlStdOperatorTable.IS_NOT_NULL, "!= null")) + .add(new FloorOperatorConversion()) + .add(new CeilOperatorConversion()) + .add(new SubstringOperatorConversion()) + .build(); protected QuerySpec querySpec; final RelOptTable table; final DruidTable druidTable; final ImmutableList<Interval> intervals; final ImmutableList<RelNode> rels; + /** + * This operator map provides DruidSqlOperatorConverter instance to convert a Calcite RexNode to + * Druid Expression when possible. + */ + final Map<SqlOperator, DruidSqlOperatorConverter> converterOperatorMap; - private static final Pattern VALID_SIG = Pattern.compile("sf?p?(a?|ao)l?"); + private static final Pattern VALID_SIG = Pattern.compile("sf?p?(a?|ah|ah?o)l?"); private static final String EXTRACT_COLUMN_NAME_PREFIX = "extract"; private static final String FLOOR_COLUMN_NAME_PREFIX = "floor"; protected static final String DRUID_QUERY_FETCH = "druid.query.fetch"; @@ -120,25 +177,236 @@ public class DruidQuery extends AbstractRelNode implements BindableRel { * @param druidTable Druid table * @param intervals Intervals for the query * @param rels Internal relational expressions + * @param converterOperatorMap mapping of Calcite Sql Operator to Druid Expression API. */ protected DruidQuery(RelOptCluster cluster, RelTraitSet traitSet, RelOptTable table, DruidTable druidTable, - List<Interval> intervals, List<RelNode> rels) { + List<Interval> intervals, List<RelNode> rels, + Map<SqlOperator, DruidSqlOperatorConverter> converterOperatorMap) { super(cluster, traitSet); this.table = table; this.druidTable = druidTable; this.intervals = ImmutableList.copyOf(intervals); this.rels = ImmutableList.copyOf(rels); - + this.converterOperatorMap = Preconditions.checkNotNull(converterOperatorMap, "Operator map " + + "can not be null"); assert isValid(Litmus.THROW, null); } + /** Returns whether a signature represents an sequence of relational operators + * that can be translated into a valid Druid query. */ + static boolean isValidSignature(String signature) { + return VALID_SIG.matcher(signature).matches(); + } + + /** Creates a DruidQuery. */ + public static DruidQuery create(RelOptCluster cluster, RelTraitSet traitSet, + RelOptTable table, DruidTable druidTable, List<RelNode> rels) { + final ImmutableMap converterOperatorMap = ImmutableMap.<SqlOperator, + DruidSqlOperatorConverter>builder().putAll( + Lists.transform(DEFAULT_OPERATORS_LIST, new Function<DruidSqlOperatorConverter, + Map.Entry<SqlOperator, DruidSqlOperatorConverter>>() { + @Nullable @Override public Map.Entry<SqlOperator, DruidSqlOperatorConverter> apply( + final DruidSqlOperatorConverter input) { + return Maps.immutableEntry(input.calciteOperator(), input); + } + })).build(); + return create(cluster, traitSet, table, druidTable, druidTable.intervals, rels, + converterOperatorMap); + } + + /** Creates a DruidQuery. */ + public static DruidQuery create(RelOptCluster cluster, RelTraitSet traitSet, + RelOptTable table, DruidTable druidTable, List<RelNode> rels, + Map<SqlOperator, DruidSqlOperatorConverter> converterOperatorMap) { + return create(cluster, traitSet, table, druidTable, druidTable.intervals, rels, + converterOperatorMap); + } + + /** + * Creates a DruidQuery. + */ + private static DruidQuery create(RelOptCluster cluster, RelTraitSet traitSet, + RelOptTable table, DruidTable druidTable, List<Interval> intervals, + List<RelNode> rels, Map<SqlOperator, DruidSqlOperatorConverter> converterOperatorMap) { + return new DruidQuery(cluster, traitSet, table, druidTable, intervals, rels, + converterOperatorMap); + } + + /** Extends a DruidQuery. */ + public static DruidQuery extendQuery(DruidQuery query, RelNode r) { + final ImmutableList.Builder<RelNode> builder = ImmutableList.builder(); + return DruidQuery.create(query.getCluster(), r.getTraitSet().replace(query.getConvention()), + query.getTable(), query.druidTable, query.intervals, + builder.addAll(query.rels).add(r).build(), query.getOperatorConversionMap()); + } + + /** Extends a DruidQuery. */ + public static DruidQuery extendQuery(DruidQuery query, + List<Interval> intervals) { + return DruidQuery.create(query.getCluster(), query.getTraitSet(), query.getTable(), + query.druidTable, intervals, query.rels, query.getOperatorConversionMap()); + } + + /** + * @param rexNode leaf Input Ref to Druid Column + * @param rowType row type + * @param druidQuery druid query + * + * @return {@link Pair} of Column name and Extraction Function on the top of the input ref or + * {@link Pair of(null, null)} when can not translate to valid Druid column + */ + protected static Pair<String, ExtractionFunction> toDruidColumn(RexNode rexNode, + RelDataType rowType, DruidQuery druidQuery) { + final String columnName; + final ExtractionFunction extractionFunction; + final Granularity granularity; + switch (rexNode.getKind()) { + case INPUT_REF: + columnName = extractColumnName(rexNode, rowType, druidQuery); + //@TODO we can remove this ugly check by treating druid time columns as LONG + if (rexNode.getType().getFamily() == SqlTypeFamily.DATE + || rexNode.getType().getFamily() == SqlTypeFamily.TIMESTAMP) { + extractionFunction = TimeExtractionFunction + .createDefault(druidQuery.getConnectionConfig().timeZone()); + } else { + extractionFunction = null; + } + break; + case EXTRACT: + granularity = DruidDateTimeUtils + .extractGranularity(rexNode, druidQuery.getConnectionConfig().timeZone()); + if (granularity == null) { + // unknown Granularity + return Pair.of(null, null); + } + if (!TimeExtractionFunction.isValidTimeExtract((RexCall) rexNode)) { + return Pair.of(null, null); + } + extractionFunction = + TimeExtractionFunction.createExtractFromGranularity(granularity, + druidQuery.getConnectionConfig().timeZone()); + columnName = + extractColumnName(((RexCall) rexNode).getOperands().get(1), rowType, druidQuery); + + break; + case FLOOR: + granularity = DruidDateTimeUtils + .extractGranularity(rexNode, druidQuery.getConnectionConfig().timeZone()); + if (granularity == null) { + // unknown Granularity + return Pair.of(null, null); + } + if (!TimeExtractionFunction.isValidTimeFloor((RexCall) rexNode)) { + return Pair.of(null, null); + } + extractionFunction = + TimeExtractionFunction + .createFloorFromGranularity(granularity, druidQuery.getConnectionConfig().timeZone()); + columnName = + extractColumnName(((RexCall) rexNode).getOperands().get(0), rowType, druidQuery); + break; + case CAST: + // CASE we have a cast over InputRef. Check that cast is valid + if (!isValidLeafCast(rexNode)) { + return Pair.of(null, null); + } + columnName = + extractColumnName(((RexCall) rexNode).getOperands().get(0), rowType, druidQuery); + // CASE CAST to TIME/DATE need to make sure that we have valid extraction fn + final SqlTypeName toTypeName = rexNode.getType().getSqlTypeName(); + if (toTypeName.getFamily() == SqlTypeFamily.TIMESTAMP + || toTypeName.getFamily() == SqlTypeFamily.DATETIME) { + extractionFunction = TimeExtractionFunction.translateCastToTimeExtract(rexNode, + TimeZone.getTimeZone(druidQuery.getConnectionConfig().timeZone())); + if (extractionFunction == null) { + // no extraction Function means cast is not valid thus bail out + return Pair.of(null, null); + } + } else { + extractionFunction = null; + } + break; + default: + return Pair.of(null, null); + } + return Pair.of(columnName, extractionFunction); + } + + /** + * @param rexNode rexNode + * + * @return true if the operand is an inputRef and it is a valid Druid Cast operation + */ + private static boolean isValidLeafCast(RexNode rexNode) { + assert rexNode.isA(SqlKind.CAST); + final RexNode input = ((RexCall) rexNode).getOperands().get(0); + if (!input.isA(SqlKind.INPUT_REF)) { + // it is not a leaf cast don't bother going further. + return false; + } + final SqlTypeName toTypeName = rexNode.getType().getSqlTypeName(); + if (toTypeName.getFamily() == SqlTypeFamily.CHARACTER) { + // CAST of input to character type + return true; + } + if (toTypeName.getFamily() == SqlTypeFamily.NUMERIC) { + // CAST of input to numeric type, it is part of a bounded comparison + return true; + } + if (toTypeName.getFamily() == SqlTypeFamily.TIMESTAMP + || toTypeName.getFamily() == SqlTypeFamily.DATETIME) { + // CAST of literal to timestamp type + return true; + } + if (toTypeName.getFamily().contains(input.getType())) { + //same type it is okay to push it + return true; + } + // Currently other CAST operations cannot be pushed to Druid + return false; + + } + + /** + * @param rexNode Druid input ref node + * @param rowType rowType + * @param query Druid Query + * + * @return Druid column name or null when not possible to translate. + */ + @Nullable + protected static String extractColumnName(RexNode rexNode, RelDataType rowType, + DruidQuery query) { + if (rexNode.getKind() == SqlKind.INPUT_REF) { + final RexInputRef ref = (RexInputRef) rexNode; + final String columnName = rowType.getFieldNames().get(ref.getIndex()); + if (columnName == null) { + return null; + } + //calcite has this un-direct renaming of timestampFieldName to native druid `__time` + if (query.getDruidTable().timestampFieldName.equals(columnName)) { + return DruidTable.DEFAULT_TIMESTAMP_COLUMN; + } + return columnName; + } + return null; + } + + /** + * Equivalent of String.format(Locale.ENGLISH, message, formatArgs). + */ + public static String format(String message, Object... formatArgs) { + return String.format(Locale.ENGLISH, message, formatArgs); + } + /** Returns a string describing the operations inside this query. * - * <p>For example, "sfpaol" means {@link TableScan} (s) + * <p>For example, "sfpahol" means {@link TableScan} (s) * followed by {@link Filter} (f) * followed by {@link Project} (p) * followed by {@link Aggregate} (a) + * followed by {@link Filter} (h) * followed by {@link Project} (o) * followed by {@link Sort} (l). * @@ -150,11 +418,12 @@ public class DruidQuery extends AbstractRelNode implements BindableRel { for (RelNode rel : rels) { b.append(rel instanceof TableScan ? 's' : (rel instanceof Project && flag) ? 'o' - : rel instanceof Filter ? 'f' - : rel instanceof Aggregate ? 'a' - : rel instanceof Sort ? 'l' - : rel instanceof Project ? 'p' - : '!'); + : (rel instanceof Filter && flag) ? 'h' + : rel instanceof Aggregate ? 'a' + : rel instanceof Filter ? 'f' + : rel instanceof Sort ? 'l' + : rel instanceof Project ? 'p' + : '!'); flag = flag || rel instanceof Aggregate; } return b.toString(); @@ -194,7 +463,9 @@ public class DruidQuery extends AbstractRelNode implements BindableRel { } if (r instanceof Filter) { final Filter filter = (Filter) r; - if (!isValidFilter(filter.getCondition())) { + final DruidJsonFilter druidJsonFilter = DruidJsonFilter + .toDruidFilters(filter.getCondition(), filter.getInput().getRowType(), this); + if (druidJsonFilter == null) { return litmus.fail("invalid filter [{}]", filter.getCondition()); } } @@ -209,109 +480,8 @@ public class DruidQuery extends AbstractRelNode implements BindableRel { return true; } - public boolean isValidFilter(RexNode e) { - return isValidFilter(e, false); - } - - public boolean isValidFilter(RexNode e, boolean boundedComparator) { - switch (e.getKind()) { - case INPUT_REF: - return true; - case LITERAL: - return ((RexLiteral) e).getValue() != null; - case AND: - case OR: - case NOT: - case IN: - case IS_NULL: - case IS_NOT_NULL: - return areValidFilters(((RexCall) e).getOperands(), false); - case EQUALS: - case NOT_EQUALS: - case LESS_THAN: - case LESS_THAN_OR_EQUAL: - case GREATER_THAN: - case GREATER_THAN_OR_EQUAL: - case BETWEEN: - return areValidFilters(((RexCall) e).getOperands(), true); - case CAST: - return isValidCast((RexCall) e, boundedComparator); - case EXTRACT: - return TimeExtractionFunction.isValidTimeExtract((RexCall) e); - case FLOOR: - return TimeExtractionFunction.isValidTimeFloor((RexCall) e); - case IS_TRUE: - return isValidFilter(((RexCall) e).getOperands().get(0), boundedComparator); - default: - return false; - } - } - - private boolean areValidFilters(List<RexNode> es, boolean boundedComparator) { - for (RexNode e : es) { - if (!isValidFilter(e, boundedComparator)) { - return false; - } - } - return true; - } - - private static boolean isValidCast(RexCall e, boolean boundedComparator) { - assert e.isA(SqlKind.CAST); - if (e.getOperands().get(0).isA(INPUT_REF) - && e.getType().getFamily() == SqlTypeFamily.CHARACTER) { - // CAST of input to character type - return true; - } - if (e.getOperands().get(0).isA(INPUT_REF) - && e.getType().getFamily() == SqlTypeFamily.NUMERIC - && boundedComparator) { - // CAST of input to numeric type, it is part of a bounded comparison - return true; - } - if (e.getOperands().get(0).isA(SqlKind.LITERAL) - && (e.getType().getSqlTypeName() == SqlTypeName.DATE - || e.getType().getSqlTypeName() == SqlTypeName.TIMESTAMP - || e.getType().getSqlTypeName() == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE)) { - // CAST of literal to timestamp type - return true; - } - // Currently other CAST operations cannot be pushed to Druid - return false; - } - - /** Returns whether a signature represents an sequence of relational operators - * that can be translated into a valid Druid query. */ - static boolean isValidSignature(String signature) { - return VALID_SIG.matcher(signature).matches(); - } - - /** Creates a DruidQuery. */ - public static DruidQuery create(RelOptCluster cluster, RelTraitSet traitSet, - RelOptTable table, DruidTable druidTable, List<RelNode> rels) { - return new DruidQuery(cluster, traitSet, table, druidTable, druidTable.intervals, rels); - } - - /** Creates a DruidQuery. */ - private static DruidQuery create(RelOptCluster cluster, RelTraitSet traitSet, - RelOptTable table, DruidTable druidTable, List<Interval> intervals, - List<RelNode> rels) { - return new DruidQuery(cluster, traitSet, table, druidTable, intervals, rels); - } - - /** Extends a DruidQuery. */ - public static DruidQuery extendQuery(DruidQuery query, RelNode r) { - final ImmutableList.Builder<RelNode> builder = ImmutableList.builder(); - return DruidQuery.create(query.getCluster(), r.getTraitSet().replace(query.getConvention()), - query.getTable(), query.druidTable, query.intervals, - builder.addAll(query.rels).add(r).build()); - } - - /** Extends a DruidQuery. */ - public static DruidQuery extendQuery(DruidQuery query, - List<Interval> intervals) { - return DruidQuery.create(query.getCluster(), query.getTraitSet(), query.getTable(), - query.druidTable, intervals, query.rels); + protected Map<SqlOperator, DruidSqlOperatorConverter> getOperatorConversionMap() { + return converterOperatorMap; } @Override public RelNode copy(RelTraitSet traitSet, List<RelNode> inputs) { @@ -389,6 +559,8 @@ public class DruidQuery extends AbstractRelNode implements BindableRel { .multiplyBy( RelMdUtil.linear(querySpec.fieldNames.size(), 2, 100, 1d, 2d)) .multiplyBy(getQueryTypeCostMultiplier()) + //A Scan leaf filter is better than having filter spec if possible. + .multiplyBy(rels.size() > 1 && rels.get(1) instanceof Filter ? 0.5 : 1.0) // a plan with sort pushed to druid is better than doing sort outside of druid .multiplyBy(Util.last(rels) instanceof Sort ? 0.1 : 1.0) .multiplyBy(getIntervalCostMultiplier()); @@ -455,16 +627,14 @@ public class DruidQuery extends AbstractRelNode implements BindableRel { final RelDataType rowType = table.getRowType(); int i = 1; - RexNode filter = null; + Filter filterRel = null; if (i < rels.size() && rels.get(i) instanceof Filter) { - final Filter filterRel = (Filter) rels.get(i++); - filter = filterRel.getCondition(); + filterRel = (Filter) rels.get(i++); } - List<RexNode> projects = null; + Project project = null; if (i < rels.size() && rels.get(i) instanceof Project) { - final Project project = (Project) rels.get(i++); - projects = project.getProjects(); + project = (Project) rels.get(i++); } ImmutableBitSet groupSet = null; @@ -478,6 +648,11 @@ public class DruidQuery extends AbstractRelNode implements BindableRel { groupSet.cardinality()); } + Filter havingFilter = null; + if (i < rels.size() && rels.get(i) instanceof Filter) { + havingFilter = (Filter) rels.get(i++); + } + Project postProject = null; if (i < rels.size() && rels.get(i) instanceof Project) { postProject = (Project) rels.get(i++); @@ -506,9 +681,9 @@ public class DruidQuery extends AbstractRelNode implements BindableRel { throw new AssertionError("could not implement all rels"); } - return getQuery(rowType, filter, projects, groupSet, aggCalls, aggNames, + return getQuery(rowType, filterRel, project, groupSet, aggCalls, aggNames, collationIndexes, collationDirections, numericCollationBitSetBuilder.build(), fetch, - postProject); + postProject, havingFilter); } public QueryType getQueryType() { @@ -523,355 +698,668 @@ public class DruidQuery extends AbstractRelNode implements BindableRel { return getCluster().getPlanner().getContext().unwrap(CalciteConnectionConfig.class); } - protected QuerySpec getQuery(RelDataType rowType, RexNode filter, List<RexNode> projects, - ImmutableBitSet groupSet, List<AggregateCall> aggCalls, List<String> aggNames, - List<Integer> collationIndexes, List<Direction> collationDirections, - ImmutableBitSet numericCollationIndexes, Integer fetch, Project postProject) { - final CalciteConnectionConfig config = getConnectionConfig(); - QueryType queryType = QueryType.SCAN; - final Translator translator = new Translator(druidTable, rowType, config.timeZone()); - List<String> fieldNames = rowType.getFieldNames(); - Set<String> usedFieldNames = Sets.newHashSet(fieldNames); - - // Handle filter - Json jsonFilter = null; + /** + * Translates Filter rel to Druid Filter Json object if possible. + * Currently Filter rel input has to be Druid Table scan + * + * @param filterRel input filter rel + * @param druidQuery Druid query + * + * @return DruidJson Filter or null if can not translate one of filters + */ + @Nullable + private static DruidJsonFilter computeFilter(@Nullable Filter filterRel, + DruidQuery druidQuery) { + if (filterRel == null) { + return null; + } + final RexNode filter = filterRel.getCondition(); + final RelDataType inputRowType = filterRel.getInput().getRowType(); if (filter != null) { - jsonFilter = translator.translateFilter(filter); + return DruidJsonFilter.toDruidFilters(filter, inputRowType, druidQuery); } + return null; + } - // Then we handle project - if (projects != null) { - translator.clearFieldNameLists(); - final ImmutableList.Builder<String> builder = ImmutableList.builder(); - for (RexNode project : projects) { - builder.add(translator.translate(project, true, false)); + /** + * Translates list of projects to Druid Column names and Virtual Columns if any + * We can not use {@link Pair#zip(Object[], Object[])}, since size can be different + * + * @param projectRel Project Rel + * + * @param druidQuery Druid query + * + * @return Pair of list of Druid Columns and Expression Virtual Columns or null when can not + * translate one of the projects. + */ + @Nullable + protected static Pair<List<String>, List<VirtualColumn>> computeProjectAsScan( + @Nullable Project projectRel, RelDataType inputRowType, DruidQuery druidQuery) { + if (projectRel == null) { + return null; + } + final Set<String> usedFieldNames = new HashSet<>(); + final ImmutableList.Builder<VirtualColumn> virtualColumnsBuilder = ImmutableList.builder(); + final ImmutableList.Builder<String> projectedColumnsBuilder = ImmutableList.builder(); + final List<RexNode> projects = projectRel.getProjects(); + for (RexNode project : projects) { + Pair<String, ExtractionFunction> druidColumn = + toDruidColumn(project, inputRowType, druidQuery); + if (druidColumn.left == null || druidColumn.right != null) { + // It is a complex project pushed as expression + final String expression = DruidExpressions + .toDruidExpression(project, inputRowType, druidQuery); + if (expression == null) { + return null; + } + final String virColName = SqlValidatorUtil.uniquify("vc", + usedFieldNames, SqlValidatorUtil.EXPR_SUGGESTER); + virtualColumnsBuilder.add(VirtualColumn.builder() + .withName(virColName) + .withExpression(expression).withType( + DruidExpressions.EXPRESSION_TYPES.get(project.getType().getSqlTypeName())) + .build()); + usedFieldNames.add(virColName); + projectedColumnsBuilder.add(virColName); + } else { + // simple inputRef or extractable function + if (usedFieldNames.contains(druidColumn.left)) { + final String virColName = SqlValidatorUtil.uniquify("vc", + usedFieldNames, SqlValidatorUtil.EXPR_SUGGESTER); + virtualColumnsBuilder.add(VirtualColumn.builder() + .withName(virColName) + .withExpression(DruidExpressions.fromColumn(druidColumn.left)).withType( + DruidExpressions.EXPRESSION_TYPES.get(project.getType().getSqlTypeName())) + .build()); + usedFieldNames.add(virColName); + projectedColumnsBuilder.add(virColName); + } else { + projectedColumnsBuilder.add(druidColumn.left); + usedFieldNames.add(druidColumn.left); + } } - fieldNames = builder.build(); } + return Pair.<List<String>, List<VirtualColumn>>of(projectedColumnsBuilder.build(), + virtualColumnsBuilder.build()); + } - // Finally we handle aggregate and sort. Handling of these - // operators is more complex, since we need to extract - // the conditions to know whether the query will be - // executed as a Timeseries, TopN, or GroupBy in Druid - final List<DimensionSpec> dimensions = new ArrayList<>(); - final List<JsonAggregation> aggregations = new ArrayList<>(); - final List<JsonPostAggregation> postAggs = new ArrayList<>(); - Granularity finalGranularity = Granularities.all(); - Direction timeSeriesDirection = null; - JsonLimit limit = null; - TimeExtractionDimensionSpec timeExtractionDimensionSpec = null; - if (groupSet != null) { - assert aggCalls != null; - assert aggNames != null; - assert aggCalls.size() == aggNames.size(); - - int timePositionIdx = -1; - ImmutableList.Builder<String> builder = ImmutableList.builder(); - if (projects != null) { - for (int groupKey : groupSet) { - final String fieldName = fieldNames.get(groupKey); - final RexNode project = projects.get(groupKey); - if (project instanceof RexInputRef) { - // Reference could be to the timestamp or druid dimension but no druid metric - final RexInputRef ref = (RexInputRef) project; - final String originalFieldName = druidTable.getRowType(getCluster().getTypeFactory()) - .getFieldList().get(ref.getIndex()).getName(); - if (originalFieldName.equals(druidTable.timestampFieldName)) { - finalGranularity = Granularities.all(); - String extractColumnName = SqlValidatorUtil.uniquify(EXTRACT_COLUMN_NAME_PREFIX, - usedFieldNames, SqlValidatorUtil.EXPR_SUGGESTER); - timeExtractionDimensionSpec = TimeExtractionDimensionSpec.makeFullTimeExtract( - extractColumnName, config.timeZone()); - dimensions.add(timeExtractionDimensionSpec); - builder.add(extractColumnName); - assert timePositionIdx == -1; - timePositionIdx = groupKey; - } else { - dimensions.add(new DefaultDimensionSpec(fieldName)); - builder.add(fieldName); - } - } else if (project instanceof RexCall) { - // Call, check if we should infer granularity - final RexCall call = (RexCall) project; - final Granularity funcGranularity = - DruidDateTimeUtils.extractGranularity(call, config.timeZone()); - if (funcGranularity != null) { - final String extractColumnName; - switch (call.getKind()) { - case EXTRACT: - // case extract field from time column - finalGranularity = Granularities.all(); - extractColumnName = - SqlValidatorUtil.uniquify(EXTRACT_COLUMN_NAME_PREFIX + "_" - + funcGranularity.getType().lowerName, usedFieldNames, - SqlValidatorUtil.EXPR_SUGGESTER); - timeExtractionDimensionSpec = TimeExtractionDimensionSpec.makeTimeExtract( - funcGranularity, extractColumnName, config.timeZone()); - dimensions.add(timeExtractionDimensionSpec); - builder.add(extractColumnName); - break; - case FLOOR: - // case floor time column - if (groupSet.cardinality() > 1) { - // case we have more than 1 group by key -> then will have druid group by - extractColumnName = - SqlValidatorUtil.uniquify(FLOOR_COLUMN_NAME_PREFIX - + "_" + funcGranularity.getType().lowerName, - usedFieldNames, SqlValidatorUtil.EXPR_SUGGESTER); - dimensions.add( - TimeExtractionDimensionSpec.makeTimeFloor(funcGranularity, - extractColumnName, config.timeZone())); - finalGranularity = Granularities.all(); - builder.add(extractColumnName); - } else { - // case timeseries we can not use extraction function - finalGranularity = funcGranularity; - builder.add(fieldName); - } - assert timePositionIdx == -1; - timePositionIdx = groupKey; - break; - default: - throw new AssertionError(); - } + /** + * @param projectNode Project under the Aggregates if any + * @param groupSet ids of grouping keys as they are listed in {@code projects} list + * @param inputRowType Input row type under the project + * @param druidQuery Druid Query + * + * @return Pair of: Ordered {@link List<DimensionSpec>} containing the group by dimensions + * and {@link List<VirtualColumn>} containing Druid virtual column projections or Null, + * if translation is not possible. Note that the size of lists can be different. + */ + @Nullable + protected static Pair<List<DimensionSpec>, List<VirtualColumn>> computeProjectGroupSet( + @Nullable Project projectNode, ImmutableBitSet groupSet, + RelDataType inputRowType, + DruidQuery druidQuery) { + final List<DimensionSpec> dimensionSpecList = new ArrayList<>(); + final List<VirtualColumn> virtualColumnList = new ArrayList<>(); + final Set<String> usedFieldNames = new HashSet<>(); + for (int groupKey : groupSet) { + final DimensionSpec dimensionSpec; + final RexNode project; + if (projectNode == null) { + project = RexInputRef.of(groupKey, inputRowType); + } else { + project = projectNode.getProjects().get(groupKey); + } - } else { - dimensions.add(new DefaultDimensionSpec(fieldName)); - builder.add(fieldName); - } - } else { - throw new AssertionError("incompatible project expression: " + project); - } + Pair<String, ExtractionFunction> druidColumn = + toDruidColumn(project, inputRowType, druidQuery); + if (druidColumn.left != null && druidColumn.right == null) { + //SIMPLE INPUT REF + dimensionSpec = new DefaultDimensionSpec(druidColumn.left, druidColumn.left, + DruidExpressions.EXPRESSION_TYPES.get(project.getType().getSqlTypeName())); + usedFieldNames.add(druidColumn.left); + } else if (druidColumn.left != null && druidColumn.right != null) { + // CASE it is an extraction Dimension + final String columnPrefix; + //@TODO Remove it! if else statement is not really needed it is here to make tests pass. + if (project.getKind() == SqlKind.EXTRACT) { + columnPrefix = + EXTRACT_COLUMN_NAME_PREFIX + "_" + Objects + .requireNonNull(DruidDateTimeUtils + .extractGranularity(project, druidQuery.getConnectionConfig().timeZone()) + .getType().lowerName); + } else if (project.getKind() == SqlKind.FLOOR) { + columnPrefix = + FLOOR_COLUMN_NAME_PREFIX + "_" + Objects + .requireNonNull(DruidDateTimeUtils + .extractGranularity(project, druidQuery.getConnectionConfig().timeZone()) + .getType().lowerName); + } else { + columnPrefix = "extract"; + } + final String uniqueExtractColumnName = SqlValidatorUtil + .uniquify(columnPrefix, usedFieldNames, + SqlValidatorUtil.EXPR_SUGGESTER); + dimensionSpec = new ExtractionDimensionSpec(druidColumn.left, + druidColumn.right, uniqueExtractColumnName); + usedFieldNames.add(uniqueExtractColumnName); + } else { + // CASE it is Expression + final String expression = DruidExpressions + .toDruidExpression(project, inputRowType, druidQuery); + if (Strings.isNullOrEmpty(expression)) { + return null; } + final String name = SqlValidatorUtil + .uniquify("vc", usedFieldNames, + SqlValidatorUtil.EXPR_SUGGESTER); + VirtualColumn vc = new VirtualColumn(name, expression, + DruidExpressions.EXPRESSION_TYPES.get(project.getType().getSqlTypeName())); + virtualColumnList.add(vc); + dimensionSpec = new DefaultDimensionSpec(name, name, + DruidExpressions.EXPRESSION_TYPES.get(project.getType().getSqlTypeName())); + usedFieldNames.add(name); + + } + + dimensionSpecList.add(dimensionSpec); + } + return Pair.of(dimensionSpecList, virtualColumnList); + } + + /** + * Translates Aggregators Calls to Druid Json Aggregators when possible. + * + * @param aggCalls List of Agg Calls to translate + * @param aggNames Lit of Agg names + * @param project Input project under the Agg Calls, if null means we have TableScan->Agg + * @param druidQuery Druid Query Rel + * + * @return List of Valid Druid Json Aggregate or null if any of the aggregates is not supported + */ + @Nullable + protected static List<JsonAggregation> computeDruidJsonAgg(List<AggregateCall> aggCalls, + List<String> aggNames, @Nullable Project project, + DruidQuery druidQuery) { + final List<JsonAggregation> aggregations = new ArrayList<>(); + for (Pair<AggregateCall, String> agg : Pair.zip(aggCalls, aggNames)) { + final String fieldName; + final String expression; + final AggregateCall aggCall = agg.left; + final RexNode filterNode; + // Type check First + final RelDataType type = aggCall.getType(); + final SqlTypeName sqlTypeName = type.getSqlTypeName(); + final boolean isNotAcceptedType; + if (SqlTypeFamily.APPROXIMATE_NUMERIC.getTypeNames().contains(sqlTypeName) + || SqlTypeFamily.INTEGER.getTypeNames().contains(sqlTypeName)) { + isNotAcceptedType = false; + } else if (SqlTypeFamily.EXACT_NUMERIC.getTypeNames().contains(sqlTypeName) && ( + type.getScale() == 0 || druidQuery.getConnectionConfig().approximateDecimal())) { + // Decimal, If scale is zero or we allow approximating decimal, we can proceed + isNotAcceptedType = false; } else { - for (int groupKey : groupSet) { - final String s = fieldNames.get(groupKey); - if (s.equals(druidTable.timestampFieldName)) { - finalGranularity = Granularities.all(); - // Generate unique name as timestampFieldName is taken - String extractColumnName = SqlValidatorUtil.uniquify(EXTRACT_COLUMN_NAME_PREFIX, - usedFieldNames, SqlValidatorUtil.EXPR_SUGGESTER); - timeExtractionDimensionSpec = TimeExtractionDimensionSpec.makeFullTimeExtract( - extractColumnName, config.timeZone()); - dimensions.add(timeExtractionDimensionSpec); - builder.add(extractColumnName); - assert timePositionIdx == -1; - timePositionIdx = groupKey; + isNotAcceptedType = true; + } + if (isNotAcceptedType) { + return null; + } + + // Extract filters + if (project != null && aggCall.hasFilter()) { + filterNode = project.getProjects().get(aggCall.filterArg); + } else { + filterNode = null; + } + if (aggCall.getArgList().size() == 0) { + fieldName = null; + expression = null; + } else { + int index = Iterables.getOnlyElement(aggCall.getArgList()); + if (project == null) { + fieldName = druidQuery.table.getRowType().getFieldNames().get(index); + expression = null; + } else { + final RexNode rexNode = project.getProjects().get(index); + final RelDataType inputRowType = project.getInput().getRowType(); + if (rexNode.isA(SqlKind.INPUT_REF)) { + expression = null; + fieldName = + extractColumnName(rexNode, inputRowType, druidQuery); } else { - dimensions.add(new DefaultDimensionSpec(s)); - builder.add(s); + expression = DruidExpressions + .toDruidExpression(rexNode, inputRowType, druidQuery); + if (Strings.isNullOrEmpty(expression)) { + return null; + } + fieldName = null; } } + //One should be not null and the other should be null. + assert expression == null ^ fieldName == null; } + final JsonAggregation jsonAggregation = getJsonAggregation(agg.right, agg.left, filterNode, + fieldName, expression, + druidQuery); + if (jsonAggregation == null) { + return null; + } + aggregations.add(jsonAggregation); + } + return aggregations; + } + + protected QuerySpec getQuery(RelDataType rowType, Filter filter, Project project, + ImmutableBitSet groupSet, List<AggregateCall> aggCalls, List<String> aggNames, + List<Integer> collationIndexes, List<Direction> collationDirections, + ImmutableBitSet numericCollationIndexes, Integer fetch, Project postProject, + Filter havingFilter) { + // Handle filter + final DruidJsonFilter jsonFilter = computeFilter(filter, this); - for (Pair<AggregateCall, String> agg : Pair.zip(aggCalls, aggNames)) { - final JsonAggregation jsonAggregation = - getJsonAggregation(fieldNames, agg.right, agg.left, projects, translator); - aggregations.add(jsonAggregation); - builder.add(jsonAggregation.name); + if (groupSet == null) { + //It is Scan Query since no Grouping + assert aggCalls == null; + assert aggNames == null; + assert collationIndexes == null || collationIndexes.isEmpty(); + assert collationDirections == null || collationDirections.isEmpty(); + final List<String> scanColumnNames; + final List<VirtualColumn> virtualColumnList = new ArrayList<>(); + if (project != null) { + //project some fields only + Pair<List<String>, List<VirtualColumn>> projectResult = computeProjectAsScan( + project, project.getInput().getRowType(), this); + scanColumnNames = projectResult.left; + virtualColumnList.addAll(projectResult.right); + } else { + //Scan all the fields + scanColumnNames = rowType.getFieldNames(); } + final ScanQuery scanQuery = new ScanQuery(druidTable.dataSource, intervals, jsonFilter, + virtualColumnList, scanColumnNames, fetch); + return new QuerySpec(QueryType.SCAN, + Preconditions.checkNotNull(scanQuery.toQuery(), "Can not plan Scan Druid Query"), + scanColumnNames); + } - fieldNames = builder.build(); - - if (postProject != null) { - builder = ImmutableList.builder(); - for (Pair<RexNode, String> pair : postProject.getNamedProjects()) { - String fieldName = pair.right; - RexNode rex = pair.left; - builder.add(fieldName); - // Render Post JSON object when PostProject exists. In DruidPostAggregationProjectRule - // all check has been done to ensure all RexCall rexNode can be pushed in. - if (rex instanceof RexCall) { - DruidQuery.JsonPostAggregation jsonPost = getJsonPostAggregation(fieldName, rex, - postProject.getInput()); - postAggs.add(jsonPost); - } + // At this Stage we have a valid Aggregate thus Query is one of Timeseries, TopN, or GroupBy + // Handling aggregate and sort is more complex, since + // we need to extract the conditions to know whether the query will be executed as a + // Timeseries, TopN, or GroupBy in Druid + assert aggCalls != null; + assert aggNames != null; + assert aggCalls.size() == aggNames.size(); + + final List<JsonExpressionPostAgg> postAggs = new ArrayList<>(); + final JsonLimit limit; + final RelDataType aggInputRowType = table.getRowType(); + final List<String> aggregateStageFieldNames = new ArrayList<>(); + + Pair<List<DimensionSpec>, List<VirtualColumn>> projectGroupSet = computeProjectGroupSet( + project, groupSet, aggInputRowType, this); + + final List<DimensionSpec> groupByKeyDims = projectGroupSet.left; + final List<VirtualColumn> virtualColumnList = projectGroupSet.right; + for (DimensionSpec dim : groupByKeyDims) { + aggregateStageFieldNames.add(dim.getOutputName()); + } + final List<JsonAggregation> aggregations = computeDruidJsonAgg(aggCalls, aggNames, project, + this); + for (JsonAggregation jsonAgg : aggregations) { + aggregateStageFieldNames.add(jsonAgg.name); + } + + + final DruidJsonFilter havingJsonFilter; + if (havingFilter != null) { + havingJsonFilter = DruidJsonFilter + .toDruidFilters(havingFilter.getCondition(), havingFilter.getInput().getRowType(), this); + } else { + havingJsonFilter = null; + } + + //Then we handle projects after aggregates as Druid Post Aggregates + final List<String> postAggregateStageFieldNames; + if (postProject != null) { + final List<String> postProjectDimListBuilder = new ArrayList<>(); + final RelDataType postAggInputRowType = getCluster().getTypeFactory() + .createStructType(Pair.right(postProject.getInput().getRowType().getFieldList()), + aggregateStageFieldNames); + // this is an index of existing columns coming out aggregate layer. Will use this index to: + // filter out any project down the road that doesn't change values e.g inputRef/identity cast + Map<String, String> existingProjects = Maps + .uniqueIndex(aggregateStageFieldNames, new Function<String, String>() { + @Override public String apply(@Nullable String input) { + return DruidExpressions.fromColumn(input); + } + }); + for (Pair<RexNode, String> pair : postProject.getNamedProjects()) { + final RexNode postProjectRexNode = pair.left; + final String postProjectFieldName = pair.right; + String expression = DruidExpressions + .toDruidExpression(postProjectRexNode, postAggInputRowType, this); + final String existingFieldName = existingProjects.get(expression); + if (existingFieldName != null) { + //simple input ref or Druid runtime identity cast will skip it, since it is here already + postProjectDimListBuilder.add(existingFieldName); + } else { + postAggs.add(new JsonExpressionPostAgg(postProjectFieldName, expression, null)); + postProjectDimListBuilder.add(postProjectFieldName); } - fieldNames = builder.build(); } + postAggregateStageFieldNames = postProjectDimListBuilder; + } else { + postAggregateStageFieldNames = null; + } + + // final Query output row field names. + final List<String> queryOutputFieldNames = postAggregateStageFieldNames == null + ? aggregateStageFieldNames + : postAggregateStageFieldNames; + + //handle sort all together + limit = computeSort(fetch, collationIndexes, collationDirections, numericCollationIndexes, + queryOutputFieldNames); + + final String timeSeriesQueryString = planAsTimeSeries(groupByKeyDims, jsonFilter, + virtualColumnList, aggregations, postAggs, limit, havingJsonFilter); + if (timeSeriesQueryString != null) { + final String timeExtractColumn = groupByKeyDims.isEmpty() + ? null + : groupByKeyDims.get(0).getOutputName(); + if (timeExtractColumn != null) { + //Case we have transformed the group by time to druid timeseries with Granularity + //Need to replace the name of the column with druid timestamp field name + final List<String> timeseriesFieldNames = Lists + .transform(queryOutputFieldNames, new Function<String, String>() { + @Override public String apply(@Nullable String input) { + if (timeExtractColumn.equals(input)) { + return "timestamp"; + } + return input; + } + }); + return new QuerySpec(QueryType.TIMESERIES, timeSeriesQueryString, timeseriesFieldNames); + } + return new QuerySpec(QueryType.TIMESERIES, timeSeriesQueryString, queryOutputFieldNames); + } + final String topNQuery = planAsTopN(groupByKeyDims, jsonFilter, + virtualColumnList, aggregations, postAggs, limit, havingJsonFilter); + if (topNQuery != null) { + return new QuerySpec(QueryType.TOP_N, topNQuery, queryOutputFieldNames); + } + + final String groupByQuery = planAsGroupBy(groupByKeyDims, jsonFilter, + virtualColumnList, aggregations, postAggs, limit, havingJsonFilter); + + if (groupByQuery == null) { + throw new IllegalStateException("Can not plan Druid Query"); + } + return new QuerySpec(QueryType.GROUP_BY, groupByQuery, queryOutputFieldNames); + } - ImmutableList<JsonCollation> collations = null; - boolean sortsMetric = false; - if (collationIndexes != null) { - assert collationDirections != null; - ImmutableList.Builder<JsonCollation> colBuilder = - ImmutableList.builder(); - for (Pair<Integer, Direction> p : Pair.zip(collationIndexes, collationDirections)) { - final String dimensionOrder = numericCollationIndexes.get(p.left) ? "numeric" - : "alphanumeric"; - colBuilder.add( - new JsonCollation(fieldNames.get(p.left), - p.right == Direction.DESCENDING ? "descending" : "ascending", dimensionOrder)); - if (p.left >= groupSet.cardinality() && p.right == Direction.DESCENDING) { - // Currently only support for DESC in TopN - sortsMetric = true; - } else if (p.left == timePositionIdx) { - assert timeSeriesDirection == null; - timeSeriesDirection = p.right; + /** + * @param fetch limit to fetch + * @param collationIndexes index of fields as listed in query row output + * @param collationDirections direction of sort + * @param numericCollationIndexes flag of to determine sort comparator + * @param queryOutputFieldNames query output fields + * + * @return always an non null Json Limit object + */ + private JsonLimit computeSort(@Nullable Integer fetch, List<Integer> collationIndexes, + List<Direction> collationDirections, ImmutableBitSet numericCollationIndexes, + List<String> queryOutputFieldNames) { + final List<JsonCollation> collations; + if (collationIndexes != null) { + assert collationDirections != null; + ImmutableList.Builder<JsonCollation> colBuilder = ImmutableList.builder(); + for (Pair<Integer, Direction> p : Pair.zip(collationIndexes, collationDirections)) { + final String dimensionOrder = numericCollationIndexes.get(p.left) + ? "numeric" + : "lexicographic"; + colBuilder.add( + new JsonCollation(queryOutputFieldNames.get(p.left), + p.right == Direction.DESCENDING ? "descending" : "ascending", dimensionOrder)); + } + collations = colBuilder.build(); + } else { + collations = null; + } + return new JsonLimit("default", fetch, collations); + } + + @Nullable + private String planAsTimeSeries(List<DimensionSpec> groupByKeyDims, DruidJsonFilter jsonFilter, + List<VirtualColumn> virtualColumnList, List<JsonAggregation> aggregations, + List<JsonExpressionPostAgg> postAggregations, JsonLimit limit, DruidJsonFilter havingFilter) { + if (havingFilter != null) { + return null; + } + if (groupByKeyDims.size() > 1) { + return null; + } + if (limit.limit != null) { + // it has a limit not supported by time series + return null; + } + if (limit.collations != null && limit.collations.size() > 1) { + //it has multiple sort columns + return null; + } + final String sortDirection; + if (limit.collations != null && limit.collations.size() == 1) { + if (groupByKeyDims.isEmpty() + || !(limit.collations.get(0).dimension.equals(groupByKeyDims.get(0).getOutputName()))) { + //sort column is not time column + return null; + } + sortDirection = limit.collations.get(0).direction; + } else { + sortDirection = null; + } + + final Granularity timeseriesGranularity; + if (groupByKeyDims.size() == 1) { + DimensionSpec dimensionSpec = Iterables.getOnlyElement(groupByKeyDims); + Granularity granularity = ExtractionDimensionSpec.toQueryGranularity(dimensionSpec); + //case we have project expression on the top of the time extract then can not use timeseries + boolean hasExpressionOnTopOfTimeExtract = false; + for (JsonExpressionPostAgg postAgg : postAggregations) { + if (postAgg instanceof JsonExpressionPostAgg) { + if (postAgg.expression.contains(groupByKeyDims.get(0).getOutputName())) { + hasExpressionOnTopOfTimeExtract = true; } } - collations = colBuilder.build(); } - - limit = new JsonLimit("default", fetch, collations); - - if (dimensions.isEmpty() && (collations == null || timeSeriesDirection != null)) { - queryType = QueryType.TIMESERIES; - assert fetch == null; - } else if (dimensions.size() == 1 - && finalGranularity.equals(Granularities.all()) - && sortsMetric - && collations.size() == 1 - && fetch != null - && config.approximateTopN()) { - queryType = QueryType.TOP_N; - } else { - queryType = QueryType.GROUP_BY; + timeseriesGranularity = hasExpressionOnTopOfTimeExtract ? null : granularity; + if (timeseriesGranularity == null) { + // can not extract granularity bailout + return null; } } else { - assert aggCalls == null; - assert aggNames == null; - assert collationIndexes == null || collationIndexes.isEmpty(); - assert collationDirections == null || collationDirections.isEmpty(); + timeseriesGranularity = Granularities.all(); } + final boolean isCountStar = Granularities.all() == timeseriesGranularity + && aggregations.size() == 1 + && aggregations.get(0).type.equals("count"); + final StringWriter sw = new StringWriter(); final JsonFactory factory = new JsonFactory(); try { final JsonGenerator generator = factory.createGenerator(sw); + generator.writeStartObject(); + generator.writeStringField("queryType", "timeseries"); + generator.writeStringField("dataSource", druidTable.dataSource); + generator.writeBooleanField("descending", sortDirection != null + && sortDirection.equals("descending")); + writeField(generator, "granularity", timeseriesGranularity); + writeFieldIf(generator, "filter", jsonFilter); + writeField(generator, "aggregations", aggregations); + writeFieldIf(generator, "virtualColumns", + virtualColumnList.size() > 0 ? virtualColumnList : null); + writeFieldIf(generator, "postAggregations", + postAggregations.size() > 0 ? postAggregations : null); + writeField(generator, "intervals", intervals); + generator.writeFieldName("context"); + // The following field is necessary to conform with SQL semantics (CALCITE-1589) + generator.writeStartObject(); + //Count(*) returns 0 if result set is empty thus need to set skipEmptyBuckets to false + generator.writeBooleanField("skipEmptyBuckets", !isCountStar); + generator.writeEndObject(); + generator.close(); + } catch (IOException e) { + Throwables.propagate(e); + } + return sw.toString(); + } - switch (queryType) { - case TIMESERIES: - generator.writeStartObject(); - - generator.writeStringField("queryType", "timeseries"); - generator.writeStringField("dataSource", druidTable.dataSource); - generator.writeBooleanField("descending", timeSeriesDirection != null - && timeSeriesDirection == Direction.DESCENDING); - writeField(generator, "granularity", finalGranularity); - writeFieldIf(generator, "filter", jsonFilter); - writeField(generator, "aggregations", aggregations); - writeFieldIf(generator, "postAggregations", postAggs.size() > 0 ? postAggs : null); - writeField(generator, "intervals", intervals); + @Nullable + private String planAsTopN(List<DimensionSpec> groupByKeyDims, DruidJsonFilter jsonFilter, + List<VirtualColumn> virtualColumnList, List<JsonAggregation> aggregations, + List<JsonExpressionPostAgg> postAggregations, JsonLimit limit, DruidJsonFilter havingFilter) { + if (havingFilter != null) { + return null; + } + if (!getConnectionConfig().approximateTopN() || groupByKeyDims.size() != 1 + || limit.limit == null || limit.collations == null || limit.collations.size() != 1) { + return null; + } + if (limit.collations.get(0).dimension.equals(groupByKeyDims.get(0).getOutputName())) { + return null; + } + if (limit.collations.get(0).direction.equals("ascending")) { + //Only DESC is allowed + return null; + } - generator.writeFieldName("context"); - // The following field is necessary to conform with SQL semantics (CALCITE-1589) - generator.writeStartObject(); - final boolean isCountStar = finalGranularity.equals(Granularities.all()) - && aggregations.size() == 1 - && aggregations.get(0).type.equals("count"); - //Count(*) returns 0 if result set is empty thus need to set skipEmptyBuckets to false - generator.writeBooleanField("skipEmptyBuckets", !isCountStar); - generator.writeEndObject(); + final String topNMetricColumnName = limit.collations.get(0).dimension; + final StringWriter sw = new StringWriter(); + final JsonFactory factory = new JsonFactory(); + try { + final JsonGenerator generator = factory.createGenerator(sw); + generator.writeStartObject(); - generator.writeEndObject(); - break; + generator.writeStringField("queryType", "topN"); + generator.writeStringField("dataSource", druidTable.dataSource); + writeField(generator, "granularity", Granularities.all()); + writeField(generator, "dimension", groupByKeyDims.get(0)); + writeFieldIf(generator, "virtualColumns", + virtualColumnList.size() > 0 ? virtualColumnList : null); + generator.writeStringField("metric", topNMetricColumnName); + writeFieldIf(generator, "filter", jsonFilter); + writeField(generator, "aggregations", aggregations); + writeFieldIf(generator, "postAggregations", + postAggregations.size() > 0 ? postAggregations : null); + writeField(generator, "intervals", intervals); + generator.writeNumberField("threshold", limit.limit); + generator.writeEndObject(); + generator.close(); + } catch (IOException e) { + Throwables.propagate(e); + } + return sw.toString(); + } - case TOP_N: - generator.writeStartObject(); + @Nullable + private String planAsGroupBy(List<DimensionSpec> groupByKeyDims, DruidJsonFilter jsonFilter, + List<VirtualColumn> virtualColumnList, List<JsonAggregation> aggregations, + List<JsonExpressionPostAgg> postAggregations, JsonLimit limit, DruidJsonFilter havingFilter) { + final StringWriter sw = new StringWriter(); + final JsonFactory factory = new JsonFactory(); + try { + final JsonGenerator generator = factory.createGenerator(sw); - generator.writeStringField("queryType", "topN"); - generator.writeStringField("dataSource", druidTable.dataSource); - writeField(generator, "granularity", finalGranularity); - writeField(generator, "dimension", dimensions.get(0)); - generator.writeStringField("metric", fieldNames.get(collationIndexes.get(0))); - writeFieldIf(generator, "filter", jsonFilter); - writeField(generator, "aggregations", aggregations); - writeFieldIf(generator, "postAggregations", postAggs.size() > 0 ? postAggs : null); - writeField(generator, "intervals", intervals); - generator.writeNumberField("threshold", fetch); + generator.writeStartObject(); + generator.writeStringField("queryType", "groupBy"); + generator.writeStringField("dataSource", druidTable.dataSource); + writeField(generator, "granularity", Granularities.all()); + writeField(generator, "dimensions", groupByKeyDims); + writeFieldIf(generator, "virtualColumns", + virtualColumnList.size() > 0 ? virtualColumnList : null); + writeFieldIf(generator, "limitSpec", limit); + writeFieldIf(generator, "filter", jsonFilter); + writeField(generator, "aggregations", aggregations); + writeFieldIf(generator, "postAggregations", + postAggregations.size() > 0 ? postAggregations : null); + writeField(generator, "intervals", intervals); + writeFieldIf(generator, "having", + havingFilter == null ? null : new DruidJsonFilter.JsonDimHavingFilter(havingFilter)); + generator.writeEndObject(); + generator.close(); + } catch (IOException e) { + Throwables.propagate(e); + } + return sw.toString(); + } - generator.writeEndObject(); - break; + /** + * Druid Scan Query Body + */ + private static class ScanQuery { - case GROUP_BY: - generator.writeStartObject(); - generator.writeStringField("queryType", "groupBy"); - generator.writeStringField("dataSource", druidTable.dataSource); - writeField(generator, "granularity", finalGranularity); - writeField(generator, "dimensions", dimensions); - writeFieldIf(generator, "limitSpec", limit); - writeFieldIf(generator, "filter", jsonFilter); - writeField(generator, "aggregations", aggregations); - writeFieldIf(generator, "postAggregations", postAggs.size() > 0 ? postAggs : null); - writeField(generator, "intervals", intervals); - writeFieldIf(generator, "having", null); + private String dataSource; - generator.writeEndObject(); - break; + private List<Interval> intervals; - case SELECT: - generator.writeStartObject(); + private DruidJsonFilter jsonFilter; - generator.writeStringField("queryType", "select"); - generator.writeStringField("dataSource", druidTable.dataSource); - generator.writeBooleanField("descending", false); - writeField(generator, "intervals", intervals); - writeFieldIf(generator, "filter", jsonFilter); - writeField(generator, "dimensions", translator.dimensions); - writeField(generator, "metrics", translator.metrics); - writeField(generator, "granularity", finalGranularity); + private List<VirtualColumn> virtualColumnList; - generator.writeFieldName("pagingSpec"); - generator.writeStartObject(); - generator.writeNumberField("threshold", fetch != null ? fetch - : CalciteConnectionProperty.DRUID_FETCH.wrap(new Properties()).getInt()); - generator.writeBooleanField("fromNext", true); - generator.writeEndObject(); + private List<String> columns; - generator.writeFieldName("context"); - generator.writeStartObject(); - generator.writeBooleanField(DRUID_QUERY_FETCH, fetch != null); - generator.writeEndObject(); + private Integer fetchLimit; - generator.writeEndObject(); - break; + ScanQuery(String dataSource, List<Interval> intervals, + DruidJsonFilter jsonFilter, + List<VirtualColumn> virtualColumnList, + List<String> columns, + Integer fetchLimit) { + this.dataSource = dataSource; + this.intervals = intervals; + this.jsonFilter = jsonFilter; + this.virtualColumnList = virtualColumnList; + this.columns = columns; + this.fetchLimit = fetchLimit; + } - case SCAN: + public String toQuery() { + final StringWriter sw = new StringWriter(); + try { + final JsonFactory factory = new JsonFactory(); + final JsonGenerator generator = factory.createGenerator(sw); generator.writeStartObject(); - generator.writeStringField("queryType", "scan"); - generator.writeStringField("dataSource", druidTable.dataSource); + generator.writeStringField("dataSource", dataSource); writeField(generator, "intervals", intervals); writeFieldIf(generator, "filter", jsonFilter); - writeField(generator, "columns", - Lists.transform(fieldNames, new Function<String, String>() { - @Override public String apply(String s) { - return s.equals(druidTable.timestampFieldName) - ? DruidTable.DEFAULT_TIMESTAMP_COLUMN : s; - } - })); - writeField(generator, "granularity", finalGranularity); + writeFieldIf(generator, "virtualColumns", + virtualColumnList.size() > 0 ? virtualColumnList : null); + writeField(generator, "columns", columns); generator.writeStringField("resultFormat", "compactedList"); - if (fetch != null) { - generator.writeNumberField("limit", fetch); + if (fetchLimit != null) { + generator.writeNumberField("limit", fetchLimit); } - generator.writeEndObject(); - break; - - default: - throw new AssertionError("unknown query type " + queryType); + generator.close(); + } catch (IOException e) { + Throwables.propagate(e); } - - generator.close(); - } catch (IOException e) { - e.printStackTrace(); + return sw.toString(); } - - return new QuerySpec(queryType, sw.toString(), fieldNames); } - protected JsonAggregation getJsonAggregation(List<String> fieldNames, - String name, AggregateCall aggCall, List<RexNode> projects, Translator translator) { - final List<String> list = new ArrayList<>(); - for (Integer arg : aggCall.getArgList()) { - list.add(fieldNames.get(arg)); - } - final String only = Iterables.getFirst(list, null); + @Nullable + private static JsonAggregation getJsonAggregation( + String name, AggregateCall aggCall, RexNode filterNode, String fieldName, + String aggExpression, + DruidQuery druidQuery) { final boolean fractional; final RelDataType type = aggCall.getType(); final SqlTypeName sqlTypeName = type.getSqlTypeName(); + final JsonAggregation aggregation; + final CalciteConnectionConfig config = druidQuery.getConnectionConfig(); + if (SqlTypeFamily.APPROXIMATE_NUMERIC.getTypeNames().contains(sqlTypeName)) { fractional = true; } else if (SqlTypeFamily.INTEGER.getTypeNames().contains(sqlTypeName)) { @@ -886,138 +1374,78 @@ public class DruidQuery extends AbstractRelNode implements BindableRel { } } else { // Cannot handle this aggregate function type - throw new AssertionError("unknown aggregate type " + type); + return null; } - JsonAggregation aggregation; - - CalciteConnectionConfig config = getConnectionConfig(); - // Convert from a complex metric - ComplexMetric complexMetric = druidTable.resolveComplexMetric(only, aggCall); + ComplexMetric complexMetric = druidQuery.druidTable.resolveComplexMetric(fieldName, aggCall); switch (aggCall.getAggregation().getKind()) { case COUNT: if (aggCall.isDistinct()) { if (aggCall.isApproximate() || config.approximateDistinctCount()) { if (complexMetric == null) { - aggregation = new JsonCardinalityAggregation("cardinality", name, list); + aggregation = new JsonCardinalityAggregation("cardinality", name, + ImmutableList.of(fieldName)); } else { aggregation = new JsonAggregation(complexMetric.getMetricType(), name, - complexMetric.getMetricName()); + complexMetric.getMetricName(), null); } break; } else { - // Gets thrown if one of the rules allows a count(distinct ...) through // when approximate results were not told be acceptable. - throw new UnsupportedOperationException("Cannot push " + aggCall - + " because an approximate count distinct is not acceptable."); + return null; } } - if (aggCall.getArgList().size() == 1) { + if (aggCall.getArgList().size() == 1 && !aggCall.isDistinct()) { // case we have count(column) push it as count(*) where column is not null - final JsonFilter matchNulls = new JsonSelector(only, null, null); - final JsonFilter filterOutNulls = new JsonCompositeFilter(JsonFilter.Type.NOT, matchNulls); - aggregation = new JsonFilteredAggregation(filterOutNulls, - new JsonAggregation("count", name, only)); + final DruidJsonFilter matchNulls; + if (fieldName == null) { + matchNulls = new DruidJsonFilter.JsonExpressionFilter(aggExpression + " == null"); + } else { + matchNulls = DruidJsonFilter.getSelectorFilter(fieldName, null, null); + } + aggregation = new JsonFilteredAggregation(DruidJsonFilter.toNotDruidFilter(matchNulls), + new JsonAggregation("count", name, fieldName, aggExpression)); + } else if (!aggCall.isDistinct()) { + aggregation = new JsonAggregation("count", name, fieldName, aggExpression); } else { - aggregation = new JsonAggregation("count", name, only); + aggregation = null; } break; case SUM: case SUM0: - aggregation = new JsonAggregation(fractional ? "doubleSum" : "longSum", name, only); + aggregation = new JsonAggregation(fractional ? "doubleSum" : "longSum", name, fieldName, + aggExpression); break; case MIN: - aggregation = new JsonAggregation(fractional ? "doubleMin" : "longMin", name, only); + aggregation = new JsonAggregation(fractional ? "doubleMin" : "longMin", name, fieldName, + aggExpression); break; case MAX: - aggregation = new JsonAggregation(fractional ? "doubleMax" : "longMax", name, only); + aggregation = new JsonAggregation(fractional ? "doubleMax" : "longMax", name, fieldName, + aggExpression); break; default: - throw new AssertionError("unknown aggregate " + aggCall); + return null; } - // Check for filters - if (aggCall.hasFilter()) { - RexCall filterNode = (RexCall) projects.get(aggCall.filterArg); - JsonFilter filter = translator.translateFilter(filterNode.getOperands().get(0)); - aggregation = new JsonFilteredAggregation(filter, aggregation); + if (aggregation == null) { + return null; } - - return aggregation; - } - - public JsonPostAggregation getJsonPostAggregation(String name, RexNode rexNode, RelNode rel) { - if (rexNode instanceof RexCall) { - List<JsonPostAggregation> fields = new ArrayList<>(); - for (RexNode ele : ((RexCall) rexNode).getOperands()) { - JsonPostAggregation field = getJsonPostAggregation("", ele, rel); - if (field == null) { - throw new RuntimeException("Unchecked types that cannot be parsed as Post Aggregator"); - } - fields.add(field); - } - switch (rexNode.getKind()) { - case PLUS: - return new JsonArithmetic(name, "+", fields, null); - case MINUS: - return new JsonArithmetic(name, "-", fields, null); - case DIVIDE: - return new JsonArithmetic(name, "quotient", fields, null); - case TIMES: - return new JsonArithmetic(name, "*", fields, null); - case CAST: - return getJsonPostAggregation(name, ((RexCall) rexNode).getOperands().get(0), - rel); - default: - } - } else if (rexNode instanceof RexInputRef) { - // Subtract only number of grouping columns as offset because for now only Aggregates - // without grouping sets (i.e. indicator columns size is zero) are allowed to pushed - // in Druid Query. - Integer indexSkipGroup = ((RexInputRef) rexNode).getIndex() - - ((Aggregate) rel).getGroupCount(); - AggregateCall aggCall = ((Aggregate) rel).getAggCallList().get(indexSkipGroup); - // Use either the hyper unique estimator, or the theta sketch one. - // Hyper unique is used by default. - if (aggCall.isDistinct() - && aggCall.getAggregation().getKind() == SqlKind.COUNT) { - final String fieldName = rel.getRowType().getFieldNames() - .get(((RexInputRef) rexNode).getIndex()); - - List<String> fieldNames = ((Aggregate) rel).getInput().getRowType().getFieldNames(); - String complexName = fieldNames.get(aggCall.getArgList().get(0)); - ComplexMetric metric = druidTable.resolveComplexMetric(complexName, aggCall); - - if (metric != null) { - switch (metric.getDruidType()) { - case THETA_SKETCH: - return new JsonThetaSketchEstimate("", fieldName); - case HYPER_UNIQUE: - return new JsonHyperUniqueCardinality("", fieldName); - default: - throw new AssertionError("Can not translate complex metric type: " - + metric.getDruidType()); - } - } - // Count distinct on a non-complex column. - return new JsonHyperUniqueCardinality("", fieldName); - } - return new JsonFieldAccessor("", - rel.getRowType().getFieldNames().get(((RexInputRef) rexNode).getIndex())); - } else if (rexNode instanceof RexLiteral) { - // Druid constant post aggregator only supports numeric value for now. - // (http://druid.io/docs/0.10.0/querying/post-aggregations.html) Accordingly, all - // numeric type of RexLiteral can only have BigDecimal value, so filter out unsupported - // constant by checking the type of RexLiteral value. - if (((RexLiteral) rexNode).getValue3() instanceof BigDecimal) { - return new JsonConstant("", - ((BigDecimal) ((RexLiteral) rexNode).getValue3()).doubleValue()); + // translate filters + if (filterNode != null) { + DruidJsonFilter druidFilter = DruidJsonFilter + .toDruidFilters(filterNode, druidQuery.table.getRowType(), druidQuery); + if (druidFilter == null) { + //can not translate filter + return null; } + return new JsonFilteredAggregation(druidFilter, aggregation); } - throw new RuntimeException("Unchecked types that cannot be parsed as Post Aggregator"); + + return aggregation; } protected static void writeField(JsonGenerator generator, String fieldName, @@ -1054,8 +1482,8 @@ public class DruidQuery extends AbstractRelNode implements BindableRel { generator.writeNumber(i); } else if (o instanceof List) { writeArray(generator, (List<?>) o); - } else if (o instanceof Json) { - ((Json) o).write(generator); + } else if (o instanceof DruidJson) { + ((DruidJson) o).write(generator); } else { throw new AssertionError("not a json object: " + o); } @@ -1126,249 +1554,6 @@ public class DruidQuery extends AbstractRelNode implements BindableRel { } } - /** Translates scalar expressions to Druid field references. */ - @VisibleForTesting - protected static class Translator { - final List<String> dimensions = new ArrayList<>(); - final List<String> metrics = new ArrayList<>(); - final DruidTable druidTable; - final RelDataType rowType; - final String timeZone; - final SimpleDateFormat dateFormatter; - - Translator(DruidTable druidTable, RelDataType rowType, String timeZone) { - this.druidTable = druidTable; - this.rowType = rowType; - for (RelDataTypeField f : rowType.getFieldList()) { - final String fieldName = f.getName(); - if (druidTable.isMetric(fieldName)) { - metrics.add(fieldName); - } else if (!druidTable.timestampFieldName.equals(fieldName) - && !DruidTable.DEFAULT_TIMESTAMP_COLUMN.equals(fieldName)) { - dimensions.add(fieldName); - } - } - this.timeZone = timeZone; - this.dateFormatter = new SimpleDateFormat(TimeExtractionFunction.ISO_TIME_FORMAT, - Locale.ROOT); - if (timeZone != null) { - this.dateFormatter.setTimeZone(TimeZone.getTimeZone(timeZone)); - } - } - - protected void clearFieldNameLists() { - dimensions.clear(); - metrics.clear(); - } - - /** Formats timestamp values to druid format using - * {@link DruidQuery.Translator#dateFormatter}. This is needed when pushing - * timestamp comparisons to druid using a TimeFormatExtractionFunction that - * returns a string value. */ - @SuppressWarnings("incomplete-switch") - String translate(RexNode e, boolean set, boolean formatDateString) { - int index = -1; - switch (e.getKind()) { - case INPUT_REF: - final RexInputRef ref = (RexInputRef) e; - index = ref.getIndex(); - break; - case CAST: - return tr(e, 0, set, formatDateString); - case LITERAL: - final RexLiteral rexLiteral = (RexLiteral) e; - if (!formatDateString) { - return Objects.toString(rexLiteral.getValue3()); - } else { - // Case when we are passing to druid as an extractionFunction - // Need to format the timestamp String in druid format. - TimestampString timestampString = DruidDateTimeUtils - .literalValue(e, TimeZone.getTimeZone(timeZone)); - if (timestampString == null) { - throw new AssertionError( - "Cannot translate Literal" + e + " of type " - + rexLiteral.getTypeName() + " to TimestampString"); - } - return dateFormatter.format(timestampString.getMillisSinceEpoch()); - } - case FLOOR: - case EXTRACT: - final RexCall call = (RexCall) e; - assert DruidDateTimeUtils.extractGranularity(call, timeZone) != null; - index = RelOptUtil.InputFinder.bits(e).asList().get(0); - break; - case IS_TRUE: - return ""; // the fieldName for which this is the filter will be added separately - } - if (index == -1) { - throw new AssertionError("invalid expression " + e); - } - final String fieldName = rowType.getFieldList().get(index).getName(); - if (set) { - if (druidTable.metricFieldNames.contains(fieldName)) { - metrics.add(fieldName); - } else if (!druidTable.timestampFieldName.equals(fieldName) - && !DruidTable.DEFAULT_TIMESTAMP_COLUMN.equals(fieldName)) { - dimensions.add(fieldName); - } - } - return fieldName; - } - - private JsonFilter translateFilter(RexNode e) { - final RexCall call; - if (e.isAlwaysTrue()) { - return JsonExpressionFilter.alwaysTrue(); - } - if (e.isAlwaysFalse()) { - return JsonExpressionFilter.alwaysFalse(); - } - switch (e.getKind()) { - case EQUALS: - case NOT_EQUALS: - case GREATER_THAN: - case GREATER_THAN_OR_EQUAL: - case LESS_THAN: - case LESS_THAN_OR_EQUAL: - case IN: - case BETWEEN: - case IS_NULL: - case IS_NOT_NULL: - call = (RexCall) e; - int posRef; - int posConstant; - if (call.getOperands().size() == 1) { // IS NULL and IS NOT NULL - posRef = 0; - posConstant = -1; - } else if (RexUtil.isConstant(call.getOperands().get(1))) { - posRef = 0; - posConstant = 1; - } else if (RexUtil.isConstant(call.getOperands().get(0))) { - posRef = 1; - posConstant = 0; - } else { - throw new AssertionError("it is not a valid comparison: " + e); - } - RexNode posRefNode = call.getOperands().get(posRef); - final boolean numeric = - call.getOperands().get(posRef).getType().getFamily() - == SqlTypeFamily.NUMERIC; - boolean formatDateString = false; - final Granularity granularity = - DruidDateTimeUtils.extractGranularity(posRefNode, timeZone); - // in case no extraction the field will be omitted from the serialization - final ExtractionFunction extractionFunction; - if (granularity != null) { - switch (posRefNode.getKind()) { - case EXTRACT: - extractionFunction = - TimeExtractionFunction.createExtractFromGranularity(granularity, - timeZone); - break; - case FLOOR: - extractionFunction = - TimeExtractionFunction.createFloorFromGranularity(granularity, - timeZone); - formatDateString = true; - break; - default: - extractionFunction = null; - } - } else { - extractionFunction = null; - } - String dimName = tr(e, posRef, formatDateString); - if (dimName.equals(DruidConnectionImpl.DEFAULT_RESPONSE_TIMESTAMP_COLUMN)) { - // We need to use Druid default column name to refer to the time dimension in a filter - dimName = DruidTable.DEFAULT_TIMESTAMP_COLUMN; - } - - switch (e.getKind()) { - case EQUALS: - // extractionFunction should be null because if we are using an extraction function - // we have guarantees about the format of the output and thus we can apply the - // normal selector - if (numeric && extractionFunction == null) { - String constantValue = tr(e, posConstant, formatDateString); - return new JsonBound(dimName, constantValue, false, constantValue, false, - numeric, extractionFunction); - } - return new JsonSelector(dimName, tr(e, posConstant, formatDateString), - extractionFunction); - case NOT_EQUALS: - // extractionFunction should be null because if we are using an extraction function - // we have guarantees about the format of the output and thus we can apply the - // normal selector - if (numeric && extractionFunction == null) { - String constantValue = tr(e, posConstant, formatDateString); - return new JsonCompositeFilter(JsonFilter.Type.OR, - new JsonBound(dimName, constantValue, true, null, false, - numeric, extractionFunction), - new JsonBound(dimName, null, false, constantValue, true, - numeric, extractionFunction)); - } - return new JsonCompositeFilter(JsonFilter.Type.NOT, - new JsonSelector(dimName, tr(e, posConstant, formatDateString), extractionFunction)); - case GREATER_THAN: - return new JsonBound(dimName, tr(e, posConstant, formatDateString), - true, null, false, numeric, extractionFunction); - case GREATER_THAN_OR_EQUAL: - return new JsonBound(dimName, tr(e, posConstant, formatDateString), - false, null, false, numeric, extractionFunction); - case LESS_THAN: - return new JsonBound(dimName, null, false, - tr(e, posConstant, formatDateString), true, numeric, extractionFunction); - case LESS_THAN_OR_EQUAL: - return new JsonBound(dimName, null, false, - tr(e, posConstant, formatDateString), false, numeric, extractionFunction); - case IN: - ImmutableList.Builder<String> listBuilder = ImmutableList.builder(); - for (RexNode rexNode: call.getOperands()) { - if (rexNode.getKind() == SqlKind.LITERAL) { - listBuilder.add(Objects.toString(((RexLiteral) rexNode).getValue3())); - } - } - return new JsonInFilter(dimName, listBuilder.build(), extractionFunction); - case BETWEEN: - return new JsonBound(dimName, tr(e, 2, formatDateString), false, - tr(e, 3, formatDateString), false, numeric, extractionFunction); - case IS_NULL: - return new JsonSelector(dimName, null, extractionFunction); - case IS_NOT_NULL: - return new JsonCompositeFilter(JsonFilter.Type.NOT, - new JsonSelector(dimName, null, extractionFunction)); - default: - throw new AssertionError(); - } - case AND: - case OR: - case NOT: - call = (RexCall) e; - return new JsonCompositeFilter(JsonFilter.Type.valueOf(e.getKind().name()), - translateFilters(call.getOperands())); - default: - throw new AssertionError("cannot translate filter: " + e); - } - } - - private String tr(RexNode call, int index, boolean formatDateString) { - return tr(call, index, false, formatDateString); - } - - private String tr(RexNode call, int index, boolean set, boolean formatDateString) { - return translate(((RexCall) call).getOperands().get(index), set, formatDateString); - } - - private List<JsonFilter> translateFilters(List<RexNode> operands) { - final ImmutableList.Builder<JsonFilter> builder = - ImmutableList.builder(); - for (RexNode operand : operands) { - builder.add(translateFilter(operand)); - } - return builder.build(); - } - } - /** Interpreter node that executes a Druid query and sends the results to a * {@link Sink}. */ private static class DruidQueryNode implements Node { @@ -1411,6 +1596,7 @@ public class DruidQuery extends AbstractRelNode implements BindableRel { private ColumnMetaData.Rep getPrimitive(RelDataTypeField field) { switch (field.getType().getSqlTypeName()) { case TIMESTAMP_WITH_LOCAL_TIME_ZONE: + case TIMESTAMP: return ColumnMetaData.Rep.JAVA_SQL_TIMESTAMP; case BIGINT: return ColumnMetaData.Rep.LONG; @@ -1431,22 +1617,18 @@ public class DruidQuery extends AbstractRelNode implements BindableRel { } } - /** Object that knows how to write itself to a - * {@link com.fasterxml.jackson.core.JsonGenerator}. */ - public in
<TRUNCATED>
