This is an automated email from the ASF dual-hosted git repository.
gian pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git
The following commit(s) were added to refs/heads/master by this push:
new 72432c2e78c Speed up SQL IN using SCALAR_IN_ARRAY. (#16388)
72432c2e78c is described below
commit 72432c2e78c452f66726013bac5c3c9d81d83c5b
Author: Gian Merlino <[email protected]>
AuthorDate: Tue May 14 08:09:27 2024 -0700
Speed up SQL IN using SCALAR_IN_ARRAY. (#16388)
* Speed up SQL IN using SCALAR_IN_ARRAY.
Main changes:
1) DruidSqlValidator now includes a rewrite of IN to SCALAR_IN_ARRAY, when
the size of
the IN is above inFunctionThreshold. The default value of
inFunctionThreshold
is 100. Users can restore the prior behavior by setting it to
Integer.MAX_VALUE.
2) SearchOperatorConversion now generates SCALAR_IN_ARRAY when converting
to a regular
expression, when the size of the SEARCH is above
inFunctionExprThreshold. The default
value of inFunctionExprThreshold is 2. Users can restore the prior
behavior by setting
it to Integer.MAX_VALUE.
3) ReverseLookupRule generates SCALAR_IN_ARRAY if the set of
reverse-looked-up values is
greater than inFunctionThreshold.
* Revert test.
* Additional coverage.
* Update docs/querying/sql-query-context.md
Co-authored-by: Benedict Jin <[email protected]>
* New test.
---------
Co-authored-by: Benedict Jin <[email protected]>
---
.../lookup/SqlReverseLookupBenchmark.java | 23 +
.../druid/benchmark/query/InPlanningBenchmark.java | 67 ++-
docs/querying/sql-query-context.md | 4 +-
.../java/org/apache/druid/query/QueryContext.java | 32 +-
.../java/org/apache/druid/query/QueryContexts.java | 4 +
.../org/apache/druid/query/QueryContextTest.java | 27 +-
.../builtin/SearchOperatorConversion.java | 92 ++--
.../sql/calcite/planner/DruidSqlValidator.java | 55 +++
.../druid/sql/calcite/rule/ReverseLookupRule.java | 36 +-
.../calcite/CalciteLookupFunctionQueryTest.java | 58 +++
.../apache/druid/sql/calcite/CalciteQueryTest.java | 512 ++++++++++++++++++++-
11 files changed, 857 insertions(+), 53 deletions(-)
diff --git
a/benchmarks/src/test/java/org/apache/druid/benchmark/lookup/SqlReverseLookupBenchmark.java
b/benchmarks/src/test/java/org/apache/druid/benchmark/lookup/SqlReverseLookupBenchmark.java
index bd4dcd4d9ad..9db22a2ec59 100644
---
a/benchmarks/src/test/java/org/apache/druid/benchmark/lookup/SqlReverseLookupBenchmark.java
+++
b/benchmarks/src/test/java/org/apache/druid/benchmark/lookup/SqlReverseLookupBenchmark.java
@@ -157,4 +157,27 @@ public class SqlReverseLookupBenchmark
blackhole.consume(plannerResult);
}
}
+
+ @Benchmark
+ @BenchmarkMode(Mode.AverageTime)
+ @OutputTimeUnit(TimeUnit.MILLISECONDS)
+ public void planEqualsInsideAndOutsideCase(Blackhole blackhole)
+ {
+ final String sql = StringUtils.format(
+ "SELECT COUNT(*) FROM foo\n"
+ + "WHERE\n"
+ + " CASE WHEN LOOKUP(dimZipf, 'benchmark-lookup', 'N/A') = '%s'\n"
+ + " THEN NULL\n"
+ + " ELSE LOOKUP(dimZipf, 'benchmark-lookup', 'N/A')\n"
+ + " END IN ('%s', '%s', '%s')",
+ LookupBenchmarkUtil.makeKeyOrValue(0),
+ LookupBenchmarkUtil.makeKeyOrValue(1),
+ LookupBenchmarkUtil.makeKeyOrValue(2),
+ LookupBenchmarkUtil.makeKeyOrValue(3)
+ );
+ try (final DruidPlanner planner =
plannerFactory.createPlannerForTesting(engine, sql, ImmutableMap.of())) {
+ final PlannerResult plannerResult = planner.plan();
+ blackhole.consume(plannerResult);
+ }
+ }
}
diff --git
a/benchmarks/src/test/java/org/apache/druid/benchmark/query/InPlanningBenchmark.java
b/benchmarks/src/test/java/org/apache/druid/benchmark/query/InPlanningBenchmark.java
index 4322e727388..ce01324116f 100644
---
a/benchmarks/src/test/java/org/apache/druid/benchmark/query/InPlanningBenchmark.java
+++
b/benchmarks/src/test/java/org/apache/druid/benchmark/query/InPlanningBenchmark.java
@@ -26,6 +26,7 @@ import com.google.common.collect.ImmutableSet;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.data.input.impl.DimensionSchema;
import org.apache.druid.data.input.impl.DimensionsSpec;
+import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.io.Closer;
@@ -36,6 +37,7 @@ import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.segment.AutoTypeColumnSchema;
import org.apache.druid.segment.IndexSpec;
import org.apache.druid.segment.QueryableIndex;
+import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.generator.GeneratorBasicSchemas;
import org.apache.druid.segment.generator.GeneratorSchemaInfo;
import org.apache.druid.segment.generator.SegmentGenerator;
@@ -204,7 +206,7 @@ public class InPlanningBenchmark
);
String prefix = ("explain plan for select long1 from foo where long1 in ");
- final String sql = createQuery(prefix, inClauseLiteralsCount);
+ final String sql = createQuery(prefix, inClauseLiteralsCount,
ValueType.LONG);
final Sequence<Object[]> resultSequence = getPlan(sql, null);
final Object[] planResult = resultSequence.toList().get(0);
@@ -222,12 +224,13 @@ public class InPlanningBenchmark
closer.close();
}
+ @Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
public void queryInSql(Blackhole blackhole)
{
String prefix = "explain plan for select long1 from foo where long1 in ";
- final String sql = createQuery(prefix, inClauseLiteralsCount);
+ final String sql = createQuery(prefix, inClauseLiteralsCount,
ValueType.LONG);
getPlan(sql, blackhole);
}
@@ -238,7 +241,7 @@ public class InPlanningBenchmark
{
String prefix =
"explain plan for select long1 from foo where string1 = '7' or long1
in ";
- final String sql = createQuery(prefix, inClauseLiteralsCount);
+ final String sql = createQuery(prefix, inClauseLiteralsCount,
ValueType.LONG);
getPlan(sql, blackhole);
}
@@ -250,28 +253,74 @@ public class InPlanningBenchmark
{
String prefix =
"explain plan for select long1 from foo where string1 = '7' or string1
= '8' or long1 in ";
- final String sql = createQuery(prefix, inClauseLiteralsCount);
+ final String sql = createQuery(prefix, inClauseLiteralsCount,
ValueType.LONG);
getPlan(sql, blackhole);
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
- public void queryJoinEqualOrInSql(Blackhole blackhole)
+ public void queryStringFunctionInSql(Blackhole blackhole)
{
+ String prefix =
+ "explain plan for select count(*) from foo where long1 = 8 or
lower(string1) in ";
+ final String sql = createQuery(prefix, inClauseLiteralsCount,
ValueType.STRING);
+ getPlan(sql, blackhole);
+ }
+ @Benchmark
+ @BenchmarkMode(Mode.AverageTime)
+ @OutputTimeUnit(TimeUnit.MILLISECONDS)
+ public void queryStringFunctionIsNotNullAndNotInSql(Blackhole blackhole)
+ {
+ String prefix =
+ "explain plan for select count(*) from foo where long1 = 8 and
lower(string1) is not null and lower(string1) not in ";
+ final String sql = createQuery(prefix, inClauseLiteralsCount,
ValueType.STRING);
+ getPlan(sql, blackhole);
+ }
+
+ @Benchmark
+ @BenchmarkMode(Mode.AverageTime)
+ @OutputTimeUnit(TimeUnit.MILLISECONDS)
+ public void queryStringFunctionIsNullOrInSql(Blackhole blackhole)
+ {
+ String prefix =
+ "explain plan for select count(*) from foo where long1 = 8 and
(lower(string1) is null or lower(string1) in ";
+ final String sql = createQuery(prefix, inClauseLiteralsCount,
ValueType.STRING) + ')';
+ getPlan(sql, blackhole);
+ }
+
+ @Benchmark
+ @BenchmarkMode(Mode.AverageTime)
+ @OutputTimeUnit(TimeUnit.MILLISECONDS)
+ public void queryJoinEqualOrInSql(Blackhole blackhole)
+ {
String prefix =
"explain plan for select foo.long1, fooright.string1 from foo inner
join foo as fooright on foo.string1 = fooright.string1 where fooright.string1 =
'7' or foo.long1 in ";
- final String sql = createQuery(prefix, inClauseLiteralsCount);
+ final String sql = createQuery(prefix, inClauseLiteralsCount,
ValueType.LONG);
getPlan(sql, blackhole);
}
- private String createQuery(String prefix, int inClauseLiteralsCount)
+ private String createQuery(String prefix, int inClauseLiteralsCount,
ValueType type)
{
StringBuilder sqlBuilder = new StringBuilder();
sqlBuilder.append(prefix).append('(');
- IntStream.range(1, inClauseLiteralsCount - 1).forEach(i ->
sqlBuilder.append(i).append(","));
- sqlBuilder.append(inClauseLiteralsCount).append(")");
+ IntStream.range(1, inClauseLiteralsCount + 1).forEach(
+ i -> {
+ if (i > 1) {
+ sqlBuilder.append(',');
+ }
+
+ if (type == ValueType.LONG) {
+ sqlBuilder.append(i);
+ } else if (type == ValueType.STRING) {
+ sqlBuilder.append("'").append(i).append("'");
+ } else {
+ throw new ISE("Cannot generate IN with type[%s]", type);
+ }
+ }
+ );
+ sqlBuilder.append(")");
return sqlBuilder.toString();
}
diff --git a/docs/querying/sql-query-context.md
b/docs/querying/sql-query-context.md
index f8b1576a913..dc73c9e1ab3 100644
--- a/docs/querying/sql-query-context.md
+++ b/docs/querying/sql-query-context.md
@@ -52,7 +52,9 @@ Configure Druid SQL query planning using the parameters in
the table below.
|`sqlPullUpLookup`|Whether to consider the [pull-up
rewrite](lookups.md#pull-up) of the `LOOKUP` function during SQL planning.|true|
|`enableJoinLeftTableScanDirect`|`false`|This flag applies to queries which
have joins. For joins, where left child is a simple scan with a filter, by
default, druid will run the scan as a query and the join the results to the
right child on broker. Setting this flag to true overrides that behavior and
druid will attempt to push the join to data servers instead. Please note that
the flag could be applicable to queries even if there is no explicit join.
since queries can internally trans [...]
|`maxNumericInFilters`|`-1`|Max limit for the amount of numeric values that
can be compared for a string type dimension when the entire SQL WHERE clause of
a query translates only to an [OR](../querying/filters.md#or) of [Bound
filter](../querying/filters.md#bound-filter). By default, Druid does not
restrict the amount of of numeric Bound Filters on String columns, although
this situation may block other queries from running. Set this parameter to a
smaller value to prevent Druid from ru [...]
-|`inSubQueryThreshold`|`2147483647`| Threshold for minimum number of values in
an IN clause to convert the query to a JOIN operation on an inlined table
rather than a predicate. A threshold of 0 forces usage of an inline table in
all cases; a threshold of [Integer.MAX_VALUE] forces usage of OR in all cases. |
+|`inFunctionThreshold`|`100`| At or beyond this threshold number of values,
SQL `IN` is converted to [`SCALAR_IN_ARRAY`](sql-functions.md#scalar_in_array).
A threshold of 0 forces this conversion in all cases. A threshold of
[Integer.MAX_VALUE] disables this conversion. The converted function is
eligible for fewer planning-time optimizations, which speeds up planning, but
may prevent certain planning-time optimizations.|
+|`inFunctionExprThreshold`|`2`| At or beyond this threshold number of values,
SQL `IN` is eligible for execution using the native function `scalar_in_array`
rather than an <code>||</code> of `==`, even if the number of values
is below `inFunctionThreshold`. This property only affects translation of SQL
`IN` to a [native expression](math-expr.md). It does not affect translation of
SQL `IN` to a [native filter](filters.md). This property is provided for
backwards compatibility pu [...]
+|`inSubQueryThreshold`|`2147483647`| At or beyond this threshold number of
values, SQL `IN` is converted to `JOIN` on an inline table.
`inFunctionThreshold` takes priority over this setting. A threshold of 0 forces
usage of an inline table in all cases where the size of a SQL `IN` is larger
than `inFunctionThreshold`. A threshold of `2147483647` disables the rewrite of
SQL `IN` to `JOIN`. |
## Setting the query context
The query context parameters can be specified as a "context" object in the
[JSON API](../api-reference/sql-api.md) or as a [JDBC connection properties
object](../api-reference/sql-jdbc.md).
diff --git a/processing/src/main/java/org/apache/druid/query/QueryContext.java
b/processing/src/main/java/org/apache/druid/query/QueryContext.java
index 5c08678b888..80a860bb273 100644
--- a/processing/src/main/java/org/apache/druid/query/QueryContext.java
+++ b/processing/src/main/java/org/apache/druid/query/QueryContext.java
@@ -24,10 +24,11 @@ import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.QueryContexts.Vectorize;
+import org.apache.druid.query.filter.InDimFilter;
+import org.apache.druid.query.filter.TypedInFilter;
import org.apache.druid.segment.QueryableIndexStorageAdapter;
import javax.annotation.Nullable;
-
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
@@ -575,6 +576,35 @@ public class QueryContext
);
}
+ /**
+ * At or above this threshold number of values, when planning SQL queries,
use the SQL SCALAR_IN_ARRAY operator rather
+ * than a stack of SQL ORs. This speeds up planning for large sets of points
because it is opaque to various
+ * expensive optimizations. But, because this does bypass certain
optimizations, we only do the transformation above
+ * a certain threshold. The SCALAR_IN_ARRAY operator is still able to
convert to {@link InDimFilter} or
+ * {@link TypedInFilter}.
+ */
+ public int getInFunctionThreshold()
+ {
+ return getInt(
+ QueryContexts.IN_FUNCTION_THRESHOLD,
+ QueryContexts.DEFAULT_IN_FUNCTION_THRESHOLD
+ );
+ }
+
+ /**
+ * At or above this threshold, when converting the SEARCH operator to a
native expression, use the "scalar_in_array"
+ * function rather than a sequence of equals (==) separated by or (||). This
is typically a lower threshold
+ * than {@link #getInFunctionThreshold()}, because it does not prevent any
SQL planning optimizations, and it
+ * speeds up query execution.
+ */
+ public int getInFunctionExprThreshold()
+ {
+ return getInt(
+ QueryContexts.IN_FUNCTION_EXPR_THRESHOLD,
+ QueryContexts.DEFAULT_IN_FUNCTION_EXPR_THRESHOLD
+ );
+ }
+
public boolean isTimeBoundaryPlanningEnabled()
{
return getBoolean(
diff --git a/processing/src/main/java/org/apache/druid/query/QueryContexts.java
b/processing/src/main/java/org/apache/druid/query/QueryContexts.java
index 2ea31d33948..3010b4fa923 100644
--- a/processing/src/main/java/org/apache/druid/query/QueryContexts.java
+++ b/processing/src/main/java/org/apache/druid/query/QueryContexts.java
@@ -77,6 +77,8 @@ public class QueryContexts
public static final String BY_SEGMENT_KEY = "bySegment";
public static final String BROKER_SERVICE_NAME = "brokerService";
public static final String IN_SUB_QUERY_THRESHOLD_KEY =
"inSubQueryThreshold";
+ public static final String IN_FUNCTION_THRESHOLD = "inFunctionThreshold";
+ public static final String IN_FUNCTION_EXPR_THRESHOLD =
"inFunctionExprThreshold";
public static final String TIME_BOUNDARY_PLANNING_KEY =
"enableTimeBoundaryPlanning";
public static final String POPULATE_CACHE_KEY = "populateCache";
public static final String POPULATE_RESULT_LEVEL_CACHE_KEY =
"populateResultLevelCache";
@@ -120,6 +122,8 @@ public class QueryContexts
public static final boolean DEFAULT_SECONDARY_PARTITION_PRUNING = true;
public static final boolean DEFAULT_ENABLE_DEBUG = false;
public static final int DEFAULT_IN_SUB_QUERY_THRESHOLD = Integer.MAX_VALUE;
+ public static final int DEFAULT_IN_FUNCTION_THRESHOLD = 100;
+ public static final int DEFAULT_IN_FUNCTION_EXPR_THRESHOLD = 2;
public static final boolean DEFAULT_ENABLE_TIME_BOUNDARY_PLANNING = false;
public static final boolean DEFAULT_WINDOWING_STRICT_VALIDATION = true;
diff --git
a/processing/src/test/java/org/apache/druid/query/QueryContextTest.java
b/processing/src/test/java/org/apache/druid/query/QueryContextTest.java
index 71b477d16c3..c555c2ed437 100644
--- a/processing/src/test/java/org/apache/druid/query/QueryContextTest.java
+++ b/processing/src/test/java/org/apache/druid/query/QueryContextTest.java
@@ -41,7 +41,6 @@ import org.joda.time.Interval;
import org.junit.Test;
import javax.annotation.Nullable;
-
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
@@ -337,11 +336,35 @@ public class QueryContextTest
ImmutableMap.of(QueryContexts.MAX_SUBQUERY_BYTES_KEY, "auto")
);
assertEquals("auto", context2.getMaxSubqueryMemoryBytes(null));
-
+
final QueryContext context3 = new QueryContext(ImmutableMap.of());
assertEquals("disabled", context3.getMaxSubqueryMemoryBytes("disabled"));
}
+ @Test
+ public void testGetInFunctionThreshold()
+ {
+ final QueryContext context1 = new QueryContext(
+ ImmutableMap.of(QueryContexts.IN_FUNCTION_THRESHOLD, Integer.MAX_VALUE)
+ );
+ assertEquals(Integer.MAX_VALUE, context1.getInFunctionThreshold());
+
+ final QueryContext context2 = QueryContext.empty();
+ assertEquals(QueryContexts.DEFAULT_IN_FUNCTION_THRESHOLD,
context2.getInFunctionThreshold());
+ }
+
+ @Test
+ public void testGetInFunctionExprThreshold()
+ {
+ final QueryContext context1 = new QueryContext(
+ ImmutableMap.of(QueryContexts.IN_FUNCTION_EXPR_THRESHOLD,
Integer.MAX_VALUE)
+ );
+ assertEquals(Integer.MAX_VALUE, context1.getInFunctionExprThreshold());
+
+ final QueryContext context2 = QueryContext.empty();
+ assertEquals(QueryContexts.DEFAULT_IN_FUNCTION_EXPR_THRESHOLD,
context2.getInFunctionExprThreshold());
+ }
+
@Test
public void testDefaultEnableQueryDebugging()
{
diff --git
a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/SearchOperatorConversion.java
b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/SearchOperatorConversion.java
index 967978c8760..20202b1a280 100644
---
a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/SearchOperatorConversion.java
+++
b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/SearchOperatorConversion.java
@@ -24,6 +24,7 @@ import com.google.common.collect.Iterables;
import com.google.common.collect.Range;
import com.google.common.collect.RangeSet;
import com.google.common.collect.TreeRangeSet;
+import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
@@ -48,6 +49,7 @@ import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.ArrayList;
+import java.util.Collection;
import java.util.Collections;
import java.util.List;
@@ -82,7 +84,7 @@ public class SearchOperatorConversion implements
SqlOperatorConversion
plannerContext,
rowSignature,
virtualColumnRegistry,
- expandSearch((RexCall) rexNode, REX_BUILDER)
+ expandSearch((RexCall) rexNode, REX_BUILDER,
plannerContext.queryContext().getInFunctionThreshold())
);
}
@@ -97,7 +99,7 @@ public class SearchOperatorConversion implements
SqlOperatorConversion
return Expressions.toDruidExpression(
plannerContext,
rowSignature,
- expandSearch((RexCall) rexNode, REX_BUILDER)
+ expandSearch((RexCall) rexNode, REX_BUILDER,
plannerContext.queryContext().getInFunctionExprThreshold())
);
}
@@ -111,7 +113,8 @@ public class SearchOperatorConversion implements
SqlOperatorConversion
*/
public static RexNode expandSearch(
final RexCall call,
- final RexBuilder rexBuilder
+ final RexBuilder rexBuilder,
+ final int scalarInArrayThreshold
)
{
final RexNode arg = call.operands.get(0);
@@ -139,13 +142,10 @@ public class SearchOperatorConversion implements
SqlOperatorConversion
notInPoints = getPoints(complement);
notInRexNode = makeIn(
arg,
- ImmutableList.copyOf(
- Iterables.transform(
- notInPoints,
- point -> rexBuilder.makeLiteral(point, sargRex.getType(),
true, true)
- )
- ),
+ notInPoints,
+ sargRex.getType(),
true,
+ notInPoints.size() >= scalarInArrayThreshold,
rexBuilder
);
}
@@ -155,13 +155,10 @@ public class SearchOperatorConversion implements
SqlOperatorConversion
sarg.pointCount == 0 ? Collections.emptyList() : (List<Comparable>)
getPoints(sarg.rangeSet);
final RexNode inRexNode = makeIn(
arg,
- ImmutableList.copyOf(
- Iterables.transform(
- inPoints,
- point -> rexBuilder.makeLiteral(point, sargRex.getType(),
true, true)
- )
- ),
+ inPoints,
+ sargRex.getType(),
false,
+ inPoints.size() >= scalarInArrayThreshold,
rexBuilder
);
if (inRexNode != null) {
@@ -225,14 +222,36 @@ public class SearchOperatorConversion implements
SqlOperatorConversion
return retVal;
}
+ /**
+ * Make an IN condition for an "arg" matching certain "points", as in "arg
IN (points)".
+ *
+ * @param arg lhs of the IN
+ * @param pointObjects rhs of the IN. Must match the "pointType"
+ * @param pointType type of "pointObjects"
+ * @param negate true for NOT IN, false for IN
+ * @param useScalarInArray if true, use {@link
ScalarInArrayOperatorConversion#SQL_FUNCTION} when there is more
+ * than one point; if false, use a stack of ORs
+ * @param rexBuilder rex builder
+ *
+ * @return SQL rex nodes equivalent to the IN filter, or null if
"pointObjects" is empty
+ */
@Nullable
public static RexNode makeIn(
final RexNode arg,
- final List<RexNode> points,
+ final Collection<? extends Comparable> pointObjects,
+ final RelDataType pointType,
final boolean negate,
+ final boolean useScalarInArray,
final RexBuilder rexBuilder
)
{
+ final List<RexNode> points = ImmutableList.copyOf(
+ Iterables.transform(
+ pointObjects,
+ point -> rexBuilder.makeLiteral(point, pointType, false, false)
+ )
+ );
+
if (points.isEmpty()) {
return null;
} else if (points.size() == 1) {
@@ -244,22 +263,33 @@ public class SearchOperatorConversion implements
SqlOperatorConversion
return rexBuilder.makeCall(negate ? SqlStdOperatorTable.NOT_EQUALS :
SqlStdOperatorTable.EQUALS, arg, point);
}
} else {
- // x = a || x = b || x = c ...
- RexNode retVal = rexBuilder.makeCall(
- SqlStdOperatorTable.OR,
- ImmutableList.copyOf(
- Iterables.transform(
- points,
- point -> {
- if (RexUtil.isNullLiteral(point, true)) {
- return rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL,
arg);
- } else {
- return rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
arg, point);
+ RexNode retVal;
+
+ if (useScalarInArray) {
+ // SCALAR_IN_ARRAY(x, ARRAY[a, b, c])
+ retVal = rexBuilder.makeCall(
+ ScalarInArrayOperatorConversion.SQL_FUNCTION,
+ arg,
+ rexBuilder.makeCall(SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
points)
+ );
+ } else {
+ // x = a || x = b || x = c ...
+ retVal = rexBuilder.makeCall(
+ SqlStdOperatorTable.OR,
+ ImmutableList.copyOf(
+ Iterables.transform(
+ points,
+ point -> {
+ if (RexUtil.isNullLiteral(point, true)) {
+ return
rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, arg);
+ } else {
+ return rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
arg, point);
+ }
}
- }
- )
- )
- );
+ )
+ )
+ );
+ }
if (negate) {
retVal = rexBuilder.makeCall(SqlStdOperatorTable.NOT, retVal);
diff --git
a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidSqlValidator.java
b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidSqlValidator.java
index b4f006ce97e..d1a47520a90 100644
---
a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidSqlValidator.java
+++
b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidSqlValidator.java
@@ -60,10 +60,12 @@ import org.apache.druid.common.utils.IdUtils;
import org.apache.druid.error.InvalidSqlInput;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.granularity.Granularity;
+import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.column.ValueType;
+import
org.apache.druid.sql.calcite.expression.builtin.ScalarInArrayOperatorConversion;
import org.apache.druid.sql.calcite.parser.DruidSqlIngest;
import org.apache.druid.sql.calcite.parser.DruidSqlInsert;
import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils;
@@ -774,6 +776,59 @@ public class DruidSqlValidator extends
BaseDruidSqlValidator
super.validateCall(call, scope);
}
+ @Override
+ protected SqlNode performUnconditionalRewrites(SqlNode node, final boolean
underFrom)
+ {
+ if (node != null && (node.getKind() == SqlKind.IN || node.getKind() ==
SqlKind.NOT_IN)) {
+ final SqlNode rewritten = rewriteInToScalarInArrayIfNeeded((SqlCall)
node, underFrom);
+ //noinspection ObjectEquality
+ if (rewritten != node) {
+ return rewritten;
+ }
+ }
+
+ return super.performUnconditionalRewrites(node, underFrom);
+ }
+
+ /**
+ * Rewrites "x IN (values)" to "SCALAR_IN_ARRAY(x, values)", if appropriate.
Checks the form of the IN and checks
+ * the value of {@link QueryContext#getInFunctionThreshold()}.
+ *
+ * @param call call to {@link SqlKind#IN} or {@link SqlKind#NOT_IN}
+ * @param underFrom underFrom arg from {@link
#performUnconditionalRewrites(SqlNode, boolean)}, used for
+ * recursive calls
+ *
+ * @return rewritten call, or the original call if no rewrite was appropriate
+ */
+ private SqlNode rewriteInToScalarInArrayIfNeeded(final SqlCall call, final
boolean underFrom)
+ {
+ if (call.getOperandList().size() == 2 && call.getOperandList().get(1)
instanceof SqlNodeList) {
+ // expr IN (values)
+ final SqlNode exprNode = call.getOperandList().get(0);
+ final SqlNodeList valuesNode = (SqlNodeList)
call.getOperandList().get(1);
+
+ // Confirm valuesNode is big enough to convert to SCALAR_IN_ARRAY, and
references only nonnull literals.
+ // (Can't include NULL literals in the conversion, because
SCALAR_IN_ARRAY matches NULLs as if they were regular
+ // values, whereas IN does not.)
+ if (valuesNode.size() >
plannerContext.queryContext().getInFunctionThreshold()
+ && valuesNode.stream().allMatch(node -> node.getKind() ==
SqlKind.LITERAL && !SqlUtil.isNull(node))) {
+ final SqlCall newCall =
ScalarInArrayOperatorConversion.SQL_FUNCTION.createCall(
+ call.getParserPosition(),
+ performUnconditionalRewrites(exprNode, underFrom),
+ SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR.createCall(valuesNode)
+ );
+
+ if (call.getKind() == SqlKind.NOT_IN) {
+ return SqlStdOperatorTable.NOT.createCall(call.getParserPosition(),
newCall);
+ } else {
+ return newCall;
+ }
+ }
+ }
+
+ return call;
+ }
+
public static CalciteContextException buildCalciteContextException(String
message, SqlNode call)
{
return buildCalciteContextException(new CalciteException(message, null),
message, call);
diff --git
a/sql/src/main/java/org/apache/druid/sql/calcite/rule/ReverseLookupRule.java
b/sql/src/main/java/org/apache/druid/sql/calcite/rule/ReverseLookupRule.java
index 95ad2b11334..30329816f04 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/ReverseLookupRule.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/ReverseLookupRule.java
@@ -46,6 +46,7 @@ import org.apache.druid.query.lookup.LookupExtractionFn;
import org.apache.druid.query.lookup.LookupExtractor;
import
org.apache.druid.sql.calcite.expression.builtin.MultiValueStringOperatorConversions;
import
org.apache.druid.sql.calcite.expression.builtin.QueryLookupOperatorConversion;
+import
org.apache.druid.sql.calcite.expression.builtin.ScalarInArrayOperatorConversion;
import
org.apache.druid.sql.calcite.expression.builtin.SearchOperatorConversion;
import org.apache.druid.sql.calcite.filtration.CollectComparisons;
import org.apache.druid.sql.calcite.planner.Calcites;
@@ -275,12 +276,16 @@ public class ReverseLookupRule extends RelOptRule
implements SubstitutionRule
}
/**
- * When we encounter SEARCH, expand it using {@link
SearchOperatorConversion#expandSearch(RexCall, RexBuilder)}
+ * When we encounter SEARCH, expand it using {@link
SearchOperatorConversion#expandSearch(RexCall, RexBuilder, int)}
* and continue processing what lies beneath.
*/
private RexNode visitSearch(final RexCall call)
{
- final RexNode expanded = SearchOperatorConversion.expandSearch(call,
rexBuilder);
+ final RexNode expanded = SearchOperatorConversion.expandSearch(
+ call,
+ rexBuilder,
+ plannerContext.queryContext().getInFunctionThreshold()
+ );
if (expanded instanceof RexCall) {
final RexNode converted = visitCall((RexCall) expanded);
@@ -300,10 +305,17 @@ public class ReverseLookupRule extends RelOptRule
implements SubstitutionRule
*/
private RexNode visitComparison(final RexCall call)
{
- return CollectionUtils.getOnlyElement(
+ final RexNode retVal = CollectionUtils.getOnlyElement(
new CollectReverseLookups(Collections.singletonList(call),
rexBuilder).collect(),
ret -> new ISE("Expected to collect single node, got[%s]", ret)
);
+
+ //noinspection ObjectEquality
+ if (retVal != call) {
+ return retVal;
+ } else {
+ return super.visitCall(call);
+ }
}
/**
@@ -398,12 +410,13 @@ public class ReverseLookupRule extends RelOptRule
implements SubstitutionRule
return Collections.singleton(null);
} else {
// Compute the set of values that this comparison operator matches.
- // Note that MV_CONTAINS and MV_OVERLAP match nulls, but other
comparison operators do not.
+ // Note that MV_CONTAINS, MV_OVERLAP, and SCALAR_IN_ARRAY match
nulls, but other comparison operators do not.
// See "isBinaryComparison" for the set of operators we might
encounter here.
final RexNode matchLiteral = call.getOperands().get(1);
final boolean matchNulls =
call.getOperator().equals(MultiValueStringOperatorConversions.CONTAINS.calciteOperator())
- ||
call.getOperator().equals(MultiValueStringOperatorConversions.OVERLAP.calciteOperator());
+ ||
call.getOperator().equals(MultiValueStringOperatorConversions.OVERLAP.calciteOperator())
+ ||
call.getOperator().equals(ScalarInArrayOperatorConversion.SQL_FUNCTION);
return toStringSet(matchLiteral, matchNulls);
}
}
@@ -559,8 +572,16 @@ public class ReverseLookupRule extends RelOptRule
implements SubstitutionRule
} else {
return SearchOperatorConversion.makeIn(
reverseLookupKey.arg,
- stringsToRexNodes(reversedMatchValues, rexBuilder),
+ reversedMatchValues,
+ rexBuilder.getTypeFactory()
+ .createTypeWithNullability(
+
rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR),
+ true
+ ),
reverseLookupKey.negate,
+
+ // Use regular equals, or SCALAR_IN_ARRAY, depending on
inFunctionThreshold.
+ reversedMatchValues.size() >=
plannerContext.queryContext().getInFunctionThreshold(),
rexBuilder
);
}
@@ -598,7 +619,8 @@ public class ReverseLookupRule extends RelOptRule
implements SubstitutionRule
return call.getKind() == SqlKind.EQUALS
|| call.getKind() == SqlKind.NOT_EQUALS
||
call.getOperator().equals(MultiValueStringOperatorConversions.CONTAINS.calciteOperator())
- ||
call.getOperator().equals(MultiValueStringOperatorConversions.OVERLAP.calciteOperator());
+ ||
call.getOperator().equals(MultiValueStringOperatorConversions.OVERLAP.calciteOperator())
+ ||
call.getOperator().equals(ScalarInArrayOperatorConversion.SQL_FUNCTION);
} else {
return false;
}
diff --git
a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteLookupFunctionQueryTest.java
b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteLookupFunctionQueryTest.java
index 1aa4c89b416..80d9f1bbf17 100644
---
a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteLookupFunctionQueryTest.java
+++
b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteLookupFunctionQueryTest.java
@@ -169,6 +169,30 @@ public class CalciteLookupFunctionQueryTest extends
BaseCalciteQueryTest
);
}
+ @Test
+ public void testFilterScalarInArrayLookupOfConcat()
+ {
+ cannotVectorize();
+
+ testQuery(
+ buildFilterTestSql("SCALAR_IN_ARRAY(LOOKUP(CONCAT(dim1, 'a', dim2),
'lookyloo'), ARRAY['xa', 'xabc'])"),
+ QUERY_CONTEXT,
+ buildFilterTestExpectedQuery(
+ or(
+ and(
+ equality("dim1", "", ColumnType.STRING),
+ equality("dim2", "", ColumnType.STRING)
+ ),
+ and(
+ equality("dim1", "", ColumnType.STRING),
+ equality("dim2", "bc", ColumnType.STRING)
+ )
+ )
+ ),
+ ImmutableList.of()
+ );
+ }
+
@Test
public void testFilterConcatOfLookup()
{
@@ -378,6 +402,40 @@ public class CalciteLookupFunctionQueryTest extends
BaseCalciteQueryTest
);
}
+ @Test
+ public void testFilterScalarInArray()
+ {
+ cannotVectorize();
+
+ testQuery(
+ buildFilterTestSql("SCALAR_IN_ARRAY(LOOKUP(dim1, 'lookyloo'),
ARRAY['xabc', 'x6', 'nonexistent'])"),
+ QUERY_CONTEXT,
+ buildFilterTestExpectedQuery(in("dim1", Arrays.asList("6", "abc"))),
+ ImmutableList.of(new Object[]{"xabc", 1L})
+ );
+ }
+
+ @Test
+ public void testFilterInOverScalarInArrayThreshold()
+ {
+ cannotVectorize();
+
+ // Set inFunctionThreshold = 1 to cause the IN to be converted to
SCALAR_IN_ARRAY.
+ final ImmutableMap<String, Object> queryContext =
+ ImmutableMap.<String, Object>builder()
+ .putAll(QUERY_CONTEXT_DEFAULT)
+ .put(PlannerContext.CTX_SQL_REVERSE_LOOKUP, true)
+ .put(QueryContexts.IN_FUNCTION_THRESHOLD, 1)
+ .build();
+
+ testQuery(
+ buildFilterTestSql("LOOKUP(dim1, 'lookyloo') IN ('xabc', 'x6',
'nonexistent')"),
+ queryContext,
+ buildFilterTestExpectedQuery(in("dim1", Arrays.asList("6", "abc"))),
+ ImmutableList.of(new Object[]{"xabc", 1L})
+ );
+ }
+
@Test
public void testFilterInOverMaxSize()
{
diff --git
a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index 9b302534ef6..7a3df49eedf 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -23,6 +23,8 @@ import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
import org.apache.calcite.runtime.CalciteContextException;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.error.DruidException;
@@ -4039,6 +4041,49 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
);
}
+ @Test
+ public void testGroupingWithNullPlusNonNullInFilter()
+ {
+ msqIncompatible();
+ testQuery(
+ "SELECT COUNT(*) FROM foo WHERE dim1 IN (NULL, 'abc')",
+ ImmutableList.of(
+ Druids.newTimeseriesQueryBuilder()
+ .dataSource(CalciteTests.DATASOURCE1)
+ .intervals(querySegmentSpec(Filtration.eternity()))
+ .granularity(Granularities.ALL)
+ .filters(equality("dim1", "abc", ColumnType.STRING))
+ .aggregators(aggregators(new CountAggregatorFactory("a0")))
+ .context(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(new Object[]{1L})
+ );
+ }
+
+ @Test
+ public void testGroupingWithNotNullPlusNonNullInFilter()
+ {
+ msqIncompatible();
+ testQuery(
+ "SELECT COUNT(*) FROM foo WHERE dim1 NOT IN (NULL, 'abc')",
+ ImmutableList.of(
+ newScanQueryBuilder()
+ .dataSource(
+ InlineDataSource.fromIterable(
+ ImmutableList.of(new Object[]{0L}),
+ RowSignature.builder().add("EXPR$0",
ColumnType.LONG).build()
+ )
+ )
+ .intervals(querySegmentSpec(Filtration.eternity()))
+ .columns("EXPR$0")
+
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
+ .context(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(new Object[]{0L})
+ );
+ }
@Test
public void testGroupByNothingWithLiterallyFalseFilter()
@@ -5557,6 +5602,46 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
);
}
+ @Test
+ public void testNotInAndIsNotNullFilter()
+ {
+ testQuery(
+ "SELECT dim1, COUNT(*) FROM druid.foo "
+ + "WHERE dim1 NOT IN ('ghi', 'abc', 'def') AND dim1 IS NOT NULL "
+ + "GROUP BY dim1",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setDimensions(dimensions(new
DefaultDimensionSpec("dim1", "d0")))
+ .setDimFilter(and(
+ notNull("dim1"),
+ not(in("dim1", ColumnType.STRING,
ImmutableList.of("abc", "def", "ghi")))
+ ))
+ .setAggregatorSpecs(
+ aggregators(
+ new CountAggregatorFactory("a0")
+ )
+ )
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ NullHandling.sqlCompatible()
+ ? ImmutableList.of(
+ new Object[]{"", 1L},
+ new Object[]{"1", 1L},
+ new Object[]{"10.1", 1L},
+ new Object[]{"2", 1L}
+ )
+ : ImmutableList.of(
+ new Object[]{"1", 1L},
+ new Object[]{"10.1", 1L},
+ new Object[]{"2", 1L}
+ )
+ );
+ }
+
@Test
public void testNotInAndLessThanFilter()
{
@@ -5631,6 +5716,279 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
);
}
+ @Test
+ public void testInExpression()
+ {
+ // Cannot vectorize scalar_in_array expression.
+ cannotVectorize();
+
+ testQuery(
+ "SELECT dim1 IN ('abc', 'def', 'ghi'), COUNT(*)\n"
+ + "FROM druid.foo\n"
+ + "GROUP BY 1",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setVirtualColumns(
+ expressionVirtualColumn(
+ "v0",
+
"scalar_in_array(\"dim1\",array('abc','def','ghi'))",
+ ColumnType.LONG
+ )
+ )
+ .setDimensions(dimensions(new
DefaultDimensionSpec("v0", "d0", ColumnType.LONG)))
+ .setAggregatorSpecs(new CountAggregatorFactory("a0"))
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{false, 4L},
+ new Object[]{true, 2L}
+ )
+ );
+ }
+
+ @Test
+ public void testInExpressionBelowThreshold()
+ {
+ // Cannot vectorize scalar_in_array expression.
+ cannotVectorize();
+
+ testQuery(
+ "SELECT dim1 IN ('abc', 'def', 'ghi'), COUNT(*)\n"
+ + "FROM druid.foo\n"
+ + "GROUP BY 1",
+ QueryContexts.override(QUERY_CONTEXT_DEFAULT,
QueryContexts.IN_FUNCTION_EXPR_THRESHOLD, 100),
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setVirtualColumns(
+ expressionVirtualColumn(
+ "v0",
+ "((\"dim1\" == 'abc') || (\"dim1\" == 'def')
|| (\"dim1\" == 'ghi'))",
+ ColumnType.LONG
+ )
+ )
+ .setDimensions(dimensions(new
DefaultDimensionSpec("v0", "d0", ColumnType.LONG)))
+ .setAggregatorSpecs(new CountAggregatorFactory("a0"))
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{false, 4L},
+ new Object[]{true, 2L}
+ )
+ );
+ }
+
+ @Test
+ public void testInOrIsNullExpression()
+ {
+ // Cannot vectorize scalar_in_array expression.
+ cannotVectorize();
+
+ testQuery(
+ "SELECT dim1 IN ('abc', 'def', 'ghi') OR dim1 IS NULL, COUNT(*)\n"
+ + "FROM druid.foo\n"
+ + "GROUP BY 1",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setVirtualColumns(
+ expressionVirtualColumn(
+ "v0",
+ "(isnull(\"dim1\") ||
scalar_in_array(\"dim1\",array('abc','def','ghi')))",
+ ColumnType.LONG
+ )
+ )
+ .setDimensions(dimensions(new
DefaultDimensionSpec("v0", "d0", ColumnType.LONG)))
+ .setAggregatorSpecs(new CountAggregatorFactory("a0"))
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{false, NullHandling.sqlCompatible() ? 4L : 3L},
+ new Object[]{true, NullHandling.sqlCompatible() ? 2L : 3L}
+ )
+ );
+ }
+
+ @Test
+ public void testNotInOrIsNullExpression()
+ {
+ // Cannot vectorize scalar_in_array expression.
+ cannotVectorize();
+
+ testQuery(
+ "SELECT NOT (dim1 IN ('abc', 'def', 'ghi') OR dim1 IS NULL),
COUNT(*)\n"
+ + "FROM druid.foo\n"
+ + "GROUP BY 1",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setVirtualColumns(
+ expressionVirtualColumn(
+ "v0",
+ "(notnull(\"dim1\") && (!
scalar_in_array(\"dim1\",array('abc','def','ghi'))))",
+ ColumnType.LONG
+ )
+ )
+ .setDimensions(dimensions(new
DefaultDimensionSpec("v0", "d0", ColumnType.LONG)))
+ .setAggregatorSpecs(new CountAggregatorFactory("a0"))
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{false, NullHandling.sqlCompatible() ? 2L : 3L},
+ new Object[]{true, NullHandling.sqlCompatible() ? 4L : 3L}
+ )
+ );
+ }
+
+ @Test
+ public void testNotInAndIsNotNullExpression()
+ {
+ // Cannot vectorize scalar_in_array expression.
+ cannotVectorize();
+
+ testQuery(
+ "SELECT dim1 NOT IN ('abc', 'def', 'ghi') AND dim1 IS NOT NULL,
COUNT(*)\n"
+ + "FROM druid.foo\n"
+ + "GROUP BY 1",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setVirtualColumns(
+ expressionVirtualColumn(
+ "v0",
+ "(notnull(\"dim1\") && (!
scalar_in_array(\"dim1\",array('abc','def','ghi'))))",
+ ColumnType.LONG
+ )
+ )
+ .setDimensions(dimensions(new
DefaultDimensionSpec("v0", "d0", ColumnType.LONG)))
+ .setAggregatorSpecs(new CountAggregatorFactory("a0"))
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{false, NullHandling.sqlCompatible() ? 2L : 3L},
+ new Object[]{true, NullHandling.sqlCompatible() ? 4L : 3L}
+ )
+ );
+ }
+
+ @Test
+ public void testInOrGreaterThanExpression()
+ {
+ // Cannot vectorize scalar_in_array expression.
+ cannotVectorize();
+
+ testQuery(
+ "SELECT dim1 IN ('abc', 'def', 'ghi') OR dim1 > 'zzz', COUNT(*)\n"
+ + "FROM druid.foo\n"
+ + "GROUP BY 1",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setVirtualColumns(
+ expressionVirtualColumn(
+ "v0",
+
"(scalar_in_array(\"dim1\",array('abc','def','ghi')) || (\"dim1\" > 'zzz'))",
+ ColumnType.LONG
+ )
+ )
+ .setDimensions(dimensions(new
DefaultDimensionSpec("v0", "d0", ColumnType.LONG)))
+ .setAggregatorSpecs(new CountAggregatorFactory("a0"))
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{false, 4L},
+ new Object[]{true, 2L}
+ )
+ );
+ }
+
+ @Test
+ public void testNotInAndLessThanExpression()
+ {
+ // Cannot vectorize scalar_in_array expression.
+ cannotVectorize();
+
+ testQuery(
+ "SELECT dim1 NOT IN ('abc', 'def', 'ghi') AND dim1 < 'zzz', COUNT(*)\n"
+ + "FROM druid.foo\n"
+ + "GROUP BY 1",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setVirtualColumns(
+ expressionVirtualColumn(
+ "v0",
+ "((\"dim1\" < 'zzz') && (!
scalar_in_array(\"dim1\",array('abc','def','ghi'))))",
+ ColumnType.LONG
+ )
+ )
+ .setDimensions(dimensions(new
DefaultDimensionSpec("v0", "d0", ColumnType.LONG)))
+ .setAggregatorSpecs(new CountAggregatorFactory("a0"))
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{false, 2L},
+ new Object[]{true, 4L}
+ )
+ );
+ }
+
+ @Test
+ public void testNotInOrEqualToOneOfThemExpression()
+ {
+ // Cannot vectorize scalar_in_array expression.
+ cannotVectorize();
+
+ testQuery(
+ "SELECT dim1 NOT IN ('abc', 'def', 'ghi') OR dim1 = 'def', COUNT(*)\n"
+ + "FROM druid.foo\n"
+ + "GROUP BY 1",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setVirtualColumns(
+ expressionVirtualColumn(
+ "v0",
+ "(!
scalar_in_array(\"dim1\",array('abc','ghi')))",
+ ColumnType.LONG
+ )
+ )
+ .setDimensions(dimensions(new
DefaultDimensionSpec("v0", "d0", ColumnType.LONG)))
+ .setAggregatorSpecs(new CountAggregatorFactory("a0"))
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{false, 1L},
+ new Object[]{true, 5L}
+ )
+ );
+ }
+
@Test
public void testSqlIsNullToInFilter()
{
@@ -5685,14 +6043,91 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
final String elementsString = Joiner.on(",").join(elements.stream().map(s
-> "'" + s + "'").iterator());
testQuery(
- "SELECT dim1, COUNT(*) FROM druid.foo WHERE dim1 IN (" +
elementsString + ") GROUP BY dim1",
+ "SELECT dim1, COUNT(*) FROM druid.foo\n"
+ + "WHERE dim1 IN (" + elementsString + ") OR dim1 = 'xyz' OR dim1 IS
NULL\n"
+ + "GROUP BY dim1",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setDimensions(dimensions(new
DefaultDimensionSpec("dim1", "d0")))
+ .setDimFilter(
+ NullHandling.sqlCompatible()
+ ? or(
+ in("dim1",
ImmutableSet.<String>builder().addAll(elements).add("xyz").build()),
+ isNull("dim1")
+ )
+ : in(
+ "dim1",
+ Lists.newArrayList(
+ Iterables.concat(
+ Collections.singleton(null),
+ elements,
+ Collections.singleton("xyz")
+ )
+ )
+ )
+ )
+ .setAggregatorSpecs(
+ aggregators(
+ new CountAggregatorFactory("a0")
+ )
+ )
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ NullHandling.sqlCompatible()
+ ? ImmutableList.of(
+ new Object[]{"abc", 1L},
+ new Object[]{"def", 1L}
+ )
+ : ImmutableList.of(
+ new Object[]{"", 1L},
+ new Object[]{"abc", 1L},
+ new Object[]{"def", 1L}
+ )
+ );
+ }
+
+ @Test
+ public void testInFilterWith23Elements_overScalarInArrayThreshold()
+ {
+ final List<String> elements = new ArrayList<>();
+ elements.add("abc");
+ elements.add("def");
+ elements.add("ghi");
+ for (int i = 0; i < 20; i++) {
+ elements.add("dummy" + i);
+ }
+
+ final String elementsString = Joiner.on(",").join(elements.stream().map(s
-> "'" + s + "'").iterator());
+
+ testQuery(
+ "SELECT dim1, COUNT(*) FROM druid.foo\n"
+ + "WHERE dim1 IN (" + elementsString + ") OR dim1 = 'xyz' OR dim1 IS
NULL\n"
+ + "GROUP BY dim1",
+ QueryContexts.override(QUERY_CONTEXT_DEFAULT,
QueryContexts.IN_FUNCTION_THRESHOLD, 20),
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(dimensions(new
DefaultDimensionSpec("dim1", "d0")))
- .setDimFilter(in("dim1", elements))
+ .setDimFilter(
+ // [dim1 = xyz] is not combined into the IN
filter, because SCALAR_IN_ARRAY was used,
+ // and it is opaque to most optimizations. (That's
its main purpose.)
+ NullHandling.sqlCompatible()
+ ? or(
+ in("dim1", elements),
+ isNull("dim1"),
+ equality("dim1", "xyz", ColumnType.STRING)
+ )
+ : or(
+ in("dim1", Arrays.asList(null, "xyz")),
+ in("dim1", elements)
+ )
+ )
.setAggregatorSpecs(
aggregators(
new CountAggregatorFactory("a0")
@@ -5701,7 +6136,80 @@ public class CalciteQueryTest extends
BaseCalciteQueryTest
.setContext(QUERY_CONTEXT_DEFAULT)
.build()
),
+ NullHandling.sqlCompatible()
+ ? ImmutableList.of(
+ new Object[]{"abc", 1L},
+ new Object[]{"def", 1L}
+ )
+ : ImmutableList.of(
+ new Object[]{"", 1L},
+ new Object[]{"abc", 1L},
+ new Object[]{"def", 1L}
+ )
+ );
+ }
+
+ @Test
+ public void
testInFilterWith23Elements_overBothScalarInArrayAndInSubQueryThresholds()
+ {
+ // Verify that when an IN filter surpasses both inFunctionThreshold and
inSubQueryThreshold, the
+ // inFunctionThreshold takes priority.
+ final List<String> elements = new ArrayList<>();
+ elements.add("abc");
+ elements.add("def");
+ elements.add("ghi");
+ for (int i = 0; i < 20; i++) {
+ elements.add("dummy" + i);
+ }
+
+ final String elementsString = Joiner.on(",").join(elements.stream().map(s
-> "'" + s + "'").iterator());
+
+ testQuery(
+ "SELECT dim1, COUNT(*) FROM druid.foo\n"
+ + "WHERE dim1 IN (" + elementsString + ") OR dim1 = 'xyz' OR dim1 IS
NULL\n"
+ + "GROUP BY dim1",
+ QueryContexts.override(
+ QUERY_CONTEXT_DEFAULT,
+ ImmutableMap.of(
+ QueryContexts.IN_FUNCTION_THRESHOLD, 20,
+ QueryContexts.IN_SUB_QUERY_THRESHOLD_KEY, 20
+ )
+ ),
ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setDimensions(dimensions(new
DefaultDimensionSpec("dim1", "d0")))
+ .setDimFilter(
+ // [dim1 = xyz] is not combined into the IN
filter, because SCALAR_IN_ARRAY was used,
+ // and it is opaque to most optimizations. (That's
its main purpose.)
+ NullHandling.sqlCompatible()
+ ? or(
+ in("dim1", elements),
+ isNull("dim1"),
+ equality("dim1", "xyz", ColumnType.STRING)
+ )
+ : or(
+ in("dim1", Arrays.asList(null, "xyz")),
+ in("dim1", elements)
+ )
+ )
+ .setAggregatorSpecs(
+ aggregators(
+ new CountAggregatorFactory("a0")
+ )
+ )
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ NullHandling.sqlCompatible()
+ ? ImmutableList.of(
+ new Object[]{"abc", 1L},
+ new Object[]{"def", 1L}
+ )
+ : ImmutableList.of(
+ new Object[]{"", 1L},
new Object[]{"abc", 1L},
new Object[]{"def", 1L}
)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]