kgyrtkirk commented on code in PR #16388:
URL: https://github.com/apache/druid/pull/16388#discussion_r1595594980
##########
sql/src/main/java/org/apache/druid/sql/calcite/rule/ReverseLookupRule.java:
##########
@@ -559,8 +572,16 @@ private RexNode makeMatchCondition(
} else {
return SearchOperatorConversion.makeIn(
reverseLookupKey.arg,
- stringsToRexNodes(reversedMatchValues, rexBuilder),
+ reversedMatchValues,
+ rexBuilder.getTypeFactory()
+ .createTypeWithNullability(
+
rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR),
+ true
+ ),
reverseLookupKey.negate,
+
+ // Use regular equals, or SCALAR_IN_ARRAY, depending on
inFunctionThreshold.
+ reversedMatchValues.size() >=
plannerContext.queryContext().getInFunctionThreshold(),
Review Comment:
I wonder if it would look simpler to pass `plannerContext` instead or
`inFunctionThreshold` - and let this logic live inside `makeIn`
##########
sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java:
##########
@@ -5634,6 +5719,245 @@ public void testInIsNotTrueAndLessThanFilter()
);
}
+ @Test
+ public void testInExpression()
+ {
+ // Cannot vectorize scalar_in_array expression.
+ cannotVectorize();
+
+ testQuery(
+ "SELECT dim1 IN ('abc', 'def', 'ghi'), COUNT(*)\n"
+ + "FROM druid.foo\n"
+ + "GROUP BY 1",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setVirtualColumns(
+ expressionVirtualColumn(
+ "v0",
+
"scalar_in_array(\"dim1\",array('abc','def','ghi'))",
+ ColumnType.LONG
+ )
+ )
+ .setDimensions(dimensions(new
DefaultDimensionSpec("v0", "d0", ColumnType.LONG)))
+ .setAggregatorSpecs(new CountAggregatorFactory("a0"))
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{false, 4L},
+ new Object[]{true, 2L}
+ )
+ );
+ }
+
+ @Test
+ public void testInExpressionBelowThreshold()
+ {
+ // Cannot vectorize scalar_in_array expression.
+ cannotVectorize();
+
+ testQuery(
+ "SELECT dim1 IN ('abc', 'def', 'ghi'), COUNT(*)\n"
+ + "FROM druid.foo\n"
+ + "GROUP BY 1",
+ QueryContexts.override(QUERY_CONTEXT_DEFAULT,
QueryContexts.IN_FUNCTION_EXPR_THRESHOLD, 100),
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setVirtualColumns(
+ expressionVirtualColumn(
+ "v0",
+ "((\"dim1\" == 'abc') || (\"dim1\" == 'def')
|| (\"dim1\" == 'ghi'))",
+ ColumnType.LONG
+ )
+ )
+ .setDimensions(dimensions(new
DefaultDimensionSpec("v0", "d0", ColumnType.LONG)))
+ .setAggregatorSpecs(new CountAggregatorFactory("a0"))
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{false, 4L},
+ new Object[]{true, 2L}
+ )
+ );
+ }
+
+ @Test
+ public void testInOrIsNullExpression()
+ {
+ // Cannot vectorize scalar_in_array expression.
+ cannotVectorize();
+
+ testQuery(
+ "SELECT dim1 IN ('abc', 'def', 'ghi') OR dim1 IS NULL, COUNT(*)\n"
+ + "FROM druid.foo\n"
+ + "GROUP BY 1",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setVirtualColumns(
+ expressionVirtualColumn(
+ "v0",
+ "(isnull(\"dim1\") ||
scalar_in_array(\"dim1\",array('abc','def','ghi')))",
+ ColumnType.LONG
+ )
+ )
+ .setDimensions(dimensions(new
DefaultDimensionSpec("v0", "d0", ColumnType.LONG)))
+ .setAggregatorSpecs(new CountAggregatorFactory("a0"))
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{false, NullHandling.sqlCompatible() ? 4L : 3L},
+ new Object[]{true, NullHandling.sqlCompatible() ? 2L : 3L}
+ )
+ );
+ }
+
+ @Test
+ public void testNotInOrIsNullExpression()
+ {
+ // Cannot vectorize scalar_in_array expression.
+ cannotVectorize();
+
+ testQuery(
+ "SELECT NOT (dim1 IN ('abc', 'def', 'ghi') OR dim1 IS NULL),
COUNT(*)\n"
+ + "FROM druid.foo\n"
+ + "GROUP BY 1",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setVirtualColumns(
+ expressionVirtualColumn(
+ "v0",
+ "(notnull(\"dim1\") && (!
scalar_in_array(\"dim1\",array('abc','def','ghi'))))",
+ ColumnType.LONG
+ )
+ )
+ .setDimensions(dimensions(new
DefaultDimensionSpec("v0", "d0", ColumnType.LONG)))
+ .setAggregatorSpecs(new CountAggregatorFactory("a0"))
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{false, NullHandling.sqlCompatible() ? 2L : 3L},
+ new Object[]{true, NullHandling.sqlCompatible() ? 4L : 3L}
+ )
+ );
+ }
+
+ @Test
+ public void testNotInAndIsNotNullExpression()
+ {
+ // Cannot vectorize scalar_in_array expression.
+ cannotVectorize();
+
+ testQuery(
+ "SELECT dim1 NOT IN ('abc', 'def', 'ghi') AND dim1 IS NOT NULL,
COUNT(*)\n"
+ + "FROM druid.foo\n"
+ + "GROUP BY 1",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setVirtualColumns(
+ expressionVirtualColumn(
+ "v0",
+ "(notnull(\"dim1\") && (!
scalar_in_array(\"dim1\",array('abc','def','ghi'))))",
+ ColumnType.LONG
+ )
+ )
+ .setDimensions(dimensions(new
DefaultDimensionSpec("v0", "d0", ColumnType.LONG)))
+ .setAggregatorSpecs(new CountAggregatorFactory("a0"))
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{false, NullHandling.sqlCompatible() ? 2L : 3L},
+ new Object[]{true, NullHandling.sqlCompatible() ? 4L : 3L}
+ )
+ );
+ }
+
+ @Test
+ public void testInOrGreaterThanExpression()
+ {
+ // Cannot vectorize scalar_in_array expression.
+ cannotVectorize();
+
+ testQuery(
+ "SELECT dim1 IN ('abc', 'def', 'ghi') OR dim1 > 'zzz', COUNT(*)\n"
+ + "FROM druid.foo\n"
+ + "GROUP BY 1",
+ ImmutableList.of(
+ GroupByQuery.builder()
+ .setDataSource(CalciteTests.DATASOURCE1)
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setVirtualColumns(
+ expressionVirtualColumn(
+ "v0",
+
"(scalar_in_array(\"dim1\",array('abc','def','ghi')) || (\"dim1\" > 'zzz'))",
+ ColumnType.LONG
+ )
+ )
+ .setDimensions(dimensions(new
DefaultDimensionSpec("v0", "d0", ColumnType.LONG)))
+ .setAggregatorSpecs(new CountAggregatorFactory("a0"))
+ .setContext(QUERY_CONTEXT_DEFAULT)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{false, 4L},
+ new Object[]{true, 2L}
+ )
+ );
+ }
+
+ @Test
+ public void testNotInAndLessThanExpression()
+ {
+ // Cannot vectorize scalar_in_array expression.
+ cannotVectorize();
+
+ testQuery(
+ "SELECT dim1 NOT IN ('abc', 'def', 'ghi') AND dim1 < 'zzz', COUNT(*)\n"
Review Comment:
I think it would be more interesting have these tests apply inequality which
could have filtered out some `IN` literal(s)
##########
sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidSqlValidator.java:
##########
@@ -774,6 +776,59 @@ public void validateCall(SqlCall call, SqlValidatorScope
scope)
super.validateCall(call, scope);
}
+ @Override
+ protected SqlNode performUnconditionalRewrites(SqlNode node, final boolean
underFrom)
+ {
+ if (node != null && (node.getKind() == SqlKind.IN || node.getKind() ==
SqlKind.NOT_IN)) {
+ final SqlNode rewritten = rewriteInToScalarInArrayIfNeeded((SqlCall)
node, underFrom);
+ //noinspection ObjectEquality
+ if (rewritten != node) {
+ return rewritten;
+ }
+ }
+
+ return super.performUnconditionalRewrites(node, underFrom);
+ }
+
+ /**
+ * Rewrites "x IN (values)" to "SCALAR_IN_ARRAY(x, values)", if appropriate.
Checks the form of the IN and checks
+ * the value of {@link QueryContext#getInFunctionThreshold()}.
+ *
+ * @param call call to {@link SqlKind#IN} or {@link SqlKind#NOT_IN}
+ * @param underFrom underFrom arg from {@link
#performUnconditionalRewrites(SqlNode, boolean)}, used for
+ * recursive calls
+ *
+ * @return rewritten call, or the original call if no rewrite was appropriate
+ */
+ private SqlNode rewriteInToScalarInArrayIfNeeded(final SqlCall call, final
boolean underFrom)
+ {
+ if (call.getOperandList().size() == 2 && call.getOperandList().get(1)
instanceof SqlNodeList) {
+ // expr IN (values)
+ final SqlNode exprNode = call.getOperandList().get(0);
+ final SqlNodeList valuesNode = (SqlNodeList)
call.getOperandList().get(1);
+
+ // Confirm valuesNode is big enough to convert to SCALAR_IN_ARRAY, and
references only nonnull literals.
+ // (Can't include NULL literals in the conversion, because
SCALAR_IN_ARRAY matches NULLs as if they were regular
+ // values, whereas IN does not.)
+ if (valuesNode.size() >
plannerContext.queryContext().getInFunctionThreshold()
+ && valuesNode.stream().allMatch(node -> node.getKind() ==
SqlKind.LITERAL && !SqlUtil.isNull(node))) {
Review Comment:
why not handle mixed versions as well? literals could be handled with this -
but leave the other problematic stuff outside in an `OR`
the `NULL` case would be also less problematic - as those will be left
outside as well...
or there is something wrong with:
`x IN (1,2,3,y,null) => DRUID_IN(x,[1,2,3]) OR x = y OR x = null`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]