This is an automated email from the ASF dual-hosted git repository.
gian pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git
The following commit(s) were added to refs/heads/master by this push:
new db82adcdfde SCALAR_IN_ARRAY: Optimization and behavioral follow-ups.
(#16311)
db82adcdfde is described below
commit db82adcdfde8c456cb78daf5394e032e1f8657f5
Author: Gian Merlino <[email protected]>
AuthorDate: Fri Apr 26 16:01:17 2024 -0700
SCALAR_IN_ARRAY: Optimization and behavioral follow-ups. (#16311)
* Four changes to scalar_in_array as follow-ups to #16306:
1) Align behavior for `null` scalars to the behavior of the native `in` and
`inType` filters: return `true` if the array itself contains null, else return
`null`.
2) Rename the class to more closely match the function name.
3) Add a specialization for constant arrays, where we build a `HashSet`.
4) Use `castForEqualityComparison` to properly handle cross-type
comparisons.
Additional tests verify comparisons between LONG and DOUBLE are now
handled properly.
* Fix spelling.
* Adjustments from review.
---
docs/querying/math-expr.md | 2 +-
docs/querying/sql-array-functions.md | 6 +-
docs/querying/sql-functions.md | 9 +-
.../java/org/apache/druid/math/expr/Function.java | 98 ++++++++++++++++++++--
.../org/apache/druid/math/expr/FunctionTest.java | 12 ++-
5 files changed, 113 insertions(+), 14 deletions(-)
diff --git a/docs/querying/math-expr.md b/docs/querying/math-expr.md
index d5255544a03..38ced649c06 100644
--- a/docs/querying/math-expr.md
+++ b/docs/querying/math-expr.md
@@ -184,7 +184,7 @@ See javadoc of java.lang.Math for detailed explanation for
each function.
| array_ordinal(arr,long) | returns the array element at the 1 based index
supplied, or null for an out of range index |
| array_contains(arr,expr) | returns 1 if the array contains the element
specified by expr, or contains all elements specified by expr if expr is an
array, else 0 |
| array_overlap(arr1,arr2) | returns 1 if arr1 and arr2 have any elements in
common, else 0 |
-| scalar_in_array(expr, arr) | returns 1 if the scalar is present in the
array, else 0 |
+| scalar_in_array(expr, arr) | returns 1 if the scalar is present in the
array, else 0 if the expr is non-null, or null if the expr is null |
| array_offset_of(arr,expr) | returns the 0 based index of the first
occurrence of expr in the array, or `null` or `-1` if
`druid.generic.useDefaultValueForNull=true` (deprecated legacy mode) if no
matching elements exist in the array. |
| array_ordinal_of(arr,expr) | returns the 1 based index of the first
occurrence of expr in the array, or `null` or `-1` if
`druid.generic.useDefaultValueForNull=true` (deprecated legacy mode) if no
matching elements exist in the array. |
| array_prepend(expr,arr) | adds expr to arr at the beginning, the resulting
array type determined by the type of the array |
diff --git a/docs/querying/sql-array-functions.md
b/docs/querying/sql-array-functions.md
index ab84c664dee..7b0f2112b6f 100644
--- a/docs/querying/sql-array-functions.md
+++ b/docs/querying/sql-array-functions.md
@@ -52,9 +52,9 @@ The following table describes array functions. To learn more
about array aggrega
|`ARRAY_LENGTH(arr)`|Returns length of the array expression.|
|`ARRAY_OFFSET(arr, long)`|Returns the array element at the 0-based index
supplied, or null for an out of range index.|
|`ARRAY_ORDINAL(arr, long)`|Returns the array element at the 1-based index
supplied, or null for an out of range index.|
-|`ARRAY_CONTAINS(arr, expr)`|If `expr` is a scalar type, returns 1 if `arr`
contains `expr`. If `expr` is an array, returns 1 if `arr` contains all
elements of `expr`. Otherwise returns 0.|
-|`ARRAY_OVERLAP(arr1, arr2)`|Returns 1 if `arr1` and `arr2` have any elements
in common, else 0.|
-| `SCALAR_IN_ARRAY(expr, arr)`|Returns 1 if the scalar `expr` is present in
`arr`. else 0.|
+|`ARRAY_CONTAINS(arr, expr)`|If `expr` is a scalar type, returns true if `arr`
contains `expr`. If `expr` is an array, returns true if `arr` contains all
elements of `expr`. Otherwise returns false.|
+|`ARRAY_OVERLAP(arr1, arr2)`|Returns true if `arr1` and `arr2` have any
elements in common, else false.|
+|`SCALAR_IN_ARRAY(expr, arr)`|Returns true if the scalar `expr` is present in
`arr`. Otherwise, returns false if the scalar `expr` is non-null or `UNKNOWN`
if the scalar `expr` is `NULL`.|
|`ARRAY_OFFSET_OF(arr, expr)`|Returns the 0-based index of the first
occurrence of `expr` in the array. If no matching elements exist in the array,
returns `null` or `-1` if `druid.generic.useDefaultValueForNull=true`
(deprecated legacy mode).|
|`ARRAY_ORDINAL_OF(arr, expr)`|Returns the 1-based index of the first
occurrence of `expr` in the array. If no matching elements exist in the array,
returns `null` or `-1` if `druid.generic.useDefaultValueForNull=true`
(deprecated legacy mode).|
|`ARRAY_PREPEND(expr, arr)`|Adds `expr` to the beginning of `arr`, the
resulting array type determined by the type of `arr`.|
diff --git a/docs/querying/sql-functions.md b/docs/querying/sql-functions.md
index 093e7ce60fd..883f3b209ac 100644
--- a/docs/querying/sql-functions.md
+++ b/docs/querying/sql-functions.md
@@ -156,7 +156,7 @@ Concatenates array inputs into a single array.
**Function type:** [Array](./sql-array-functions.md)
-If `expr` is a scalar type, returns 1 if `arr` contains `expr`. If `expr` is
an array, returns 1 if `arr` contains all elements of `expr`. Otherwise returns
0.
+If `expr` is a scalar type, returns true if `arr` contains `expr`. If `expr`
is an array, returns 1 if `arr` contains all elements of `expr`. Otherwise
returns false.
## ARRAY_LENGTH
@@ -204,7 +204,7 @@ Returns the 1-based index of the first occurrence of `expr`
in the array. If no
**Function type:** [Array](./sql-array-functions.md)
-Returns 1 if `arr1` and `arr2` have any elements in common, else 0.|
+Returns true if `arr1` and `arr2` have any elements in common, else false.
## SCALAR_IN_ARRAY
@@ -212,7 +212,10 @@ Returns 1 if `arr1` and `arr2` have any elements in
common, else 0.|
**Function type:** [Array](./sql-array-functions.md)
-Returns 1 if the scalar `expr` is present in `arr`, else 0.|
+Returns true if the scalar `expr` is present in `arr`. Otherwise, returns
false if the scalar `expr` is non-null or
+`UNKNOWN` if the scalar `expr` is `NULL`.
+
+Returns `UNKNOWN` if `arr` is `NULL`.
## ARRAY_PREPEND
diff --git a/processing/src/main/java/org/apache/druid/math/expr/Function.java
b/processing/src/main/java/org/apache/druid/math/expr/Function.java
index aa54409e132..48bc0570aaa 100644
--- a/processing/src/main/java/org/apache/druid/math/expr/Function.java
+++ b/processing/src/main/java/org/apache/druid/math/expr/Function.java
@@ -45,6 +45,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
+import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
@@ -3724,8 +3725,11 @@ public interface Function extends NamedFunction
}
}
- class ArrayScalarInFunction extends ArrayScalarFunction
+ class ScalarInArrayFunction extends ArrayScalarFunction
{
+ private static final int SCALAR_ARG = 0;
+ private static final int ARRAY_ARG = 1;
+
@Override
public String name()
{
@@ -3742,23 +3746,105 @@ public interface Function extends NamedFunction
@Override
Expr getScalarArgument(List<Expr> args)
{
- return args.get(0);
+ return args.get(SCALAR_ARG);
}
@Override
Expr getArrayArgument(List<Expr> args)
{
- return args.get(1);
+ return args.get(ARRAY_ARG);
}
@Override
- ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr)
+ ExprEval doApply(ExprEval arrayEval, ExprEval scalarEval)
{
- final Object[] array =
arrayExpr.castTo(scalarExpr.asArrayType()).asArray();
+ final Object[] array = arrayEval.asArray();
if (array == null) {
return ExprEval.ofLong(null);
}
- return
ExprEval.ofLongBoolean(Arrays.asList(array).contains(scalarExpr.value()));
+
+ if (scalarEval.value() == null) {
+ return Arrays.asList(array).contains(null) ?
ExprEval.ofLongBoolean(true) : ExprEval.ofLong(null);
+ }
+
+ final ExpressionType matchType = arrayEval.elementType();
+ final ExprEval<?> scalarEvalForComparison =
ExprEval.castForEqualityComparison(scalarEval, matchType);
+
+ if (scalarEvalForComparison == null) {
+ return ExprEval.ofLongBoolean(false);
+ } else {
+ return
ExprEval.ofLongBoolean(Arrays.asList(array).contains(scalarEvalForComparison.value()));
+ }
+ }
+
+ @Override
+ public Function asSingleThreaded(List<Expr> args,
Expr.InputBindingInspector inspector)
+ {
+ if (args.get(ARRAY_ARG).isLiteral()) {
+ final ExpressionType lhsType =
args.get(SCALAR_ARG).getOutputType(inspector);
+ if (lhsType == null) {
+ return this;
+ }
+
+ final ExprEval<?> arrayEval =
args.get(ARRAY_ARG).eval(InputBindings.nilBindings());
+ final Object[] arrayValues = arrayEval.asArray();
+
+ if (arrayValues == null) {
+ return WithNullArray.INSTANCE;
+ } else {
+ final Set<Object> matchValues = new
HashSet<>(Arrays.asList(arrayValues));
+ final ExpressionType matchType = arrayEval.elementType();
+ return new WithConstantArray(matchValues, matchType);
+ }
+ }
+ return this;
+ }
+
+ /**
+ * Specialization of {@link ScalarInArrayFunction} for null {@link
#ARRAY_ARG}.
+ */
+ private static final class WithNullArray extends ScalarInArrayFunction
+ {
+ private static final WithNullArray INSTANCE = new WithNullArray();
+
+ @Override
+ public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
+ {
+ return ExprEval.of(null);
+ }
+ }
+
+ /**
+ * Specialization of {@link ScalarInArrayFunction} for constant, non-null
{@link #ARRAY_ARG}.
+ */
+ private static final class WithConstantArray extends ScalarInArrayFunction
+ {
+ private final Set<Object> matchValues;
+ private final ExpressionType matchType;
+
+ public WithConstantArray(Set<Object> matchValues, ExpressionType
matchType)
+ {
+ this.matchValues = Preconditions.checkNotNull(matchValues,
"matchValues");
+ this.matchType = Preconditions.checkNotNull(matchType, "matchType");
+ }
+
+ @Override
+ public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
+ {
+ final ExprEval scalarEval = args.get(SCALAR_ARG).eval(bindings);
+
+ if (scalarEval.value() == null) {
+ return matchValues.contains(null) ? ExprEval.ofLongBoolean(true) :
ExprEval.ofLong(null);
+ }
+
+ final ExprEval<?> scalarEvalForComparison =
ExprEval.castForEqualityComparison(scalarEval, matchType);
+
+ if (scalarEvalForComparison == null) {
+ return ExprEval.ofLongBoolean(false);
+ } else {
+ return
ExprEval.ofLongBoolean(matchValues.contains(scalarEvalForComparison.value()));
+ }
+ }
}
}
diff --git
a/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java
b/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java
index da81a556b0b..d6143fd1fa1 100644
--- a/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java
+++ b/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java
@@ -373,12 +373,15 @@ public class FunctionTest extends
InitializedNullHandlingTest
public void testScalarInArray()
{
assertExpr("scalar_in_array(2, [1, 2, 3])", 1L);
+ assertExpr("scalar_in_array(2.1, [1, 2, 3])", 0L);
+ assertExpr("scalar_in_array(2, [1.1, 2.1, 3.1])", 0L);
+ assertExpr("scalar_in_array(2, [1.1, 2.0, 3.1])", 1L);
assertExpr("scalar_in_array(4, [1, 2, 3])", 0L);
assertExpr("scalar_in_array(b, [3, 4])", 0L);
assertExpr("scalar_in_array(1, null)", null);
assertExpr("scalar_in_array(null, null)", null);
assertExpr("scalar_in_array(null, [1, null, 2])", 1L);
- assertExpr("scalar_in_array(null, [1, 2])", 0L);
+ assertExpr("scalar_in_array(null, [1, 2])", null);
}
@Test
@@ -1290,6 +1293,13 @@ public class FunctionTest extends
InitializedNullHandlingTest
final Expr singleThreaded = Expr.singleThreaded(expr, bindings);
Assert.assertEquals(singleThreaded.stringify(), expectedResult,
singleThreaded.eval(bindings).value());
+ final Expr singleThreadedNoFlatten = Expr.singleThreaded(exprNoFlatten,
bindings);
+ Assert.assertEquals(
+ singleThreadedNoFlatten.stringify(),
+ expectedResult,
+ singleThreadedNoFlatten.eval(bindings).value()
+ );
+
Assert.assertEquals(expr.stringify(), roundTrip.stringify());
Assert.assertEquals(expr.stringify(), roundTripFlatten.stringify());
Assert.assertArrayEquals(expr.getCacheKey(), roundTrip.getCacheKey());
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]