This is an automated email from the ASF dual-hosted git repository.

gian pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new db82adcdfde SCALAR_IN_ARRAY: Optimization and behavioral follow-ups. 
(#16311)
db82adcdfde is described below

commit db82adcdfde8c456cb78daf5394e032e1f8657f5
Author: Gian Merlino <[email protected]>
AuthorDate: Fri Apr 26 16:01:17 2024 -0700

    SCALAR_IN_ARRAY: Optimization and behavioral follow-ups. (#16311)
    
    * Four changes to scalar_in_array as follow-ups to #16306:
    
    1) Align behavior for `null` scalars to the behavior of the native `in` and 
`inType` filters: return `true` if the array itself contains null, else return 
`null`.
    
    2) Rename the class to more closely match the function name.
    
    3) Add a specialization for constant arrays, where we build a `HashSet`.
    
    4) Use `castForEqualityComparison` to properly handle cross-type 
comparisons.
       Additional tests verify comparisons between LONG and DOUBLE are now
       handled properly.
    
    * Fix spelling.
    
    * Adjustments from review.
---
 docs/querying/math-expr.md                         |  2 +-
 docs/querying/sql-array-functions.md               |  6 +-
 docs/querying/sql-functions.md                     |  9 +-
 .../java/org/apache/druid/math/expr/Function.java  | 98 ++++++++++++++++++++--
 .../org/apache/druid/math/expr/FunctionTest.java   | 12 ++-
 5 files changed, 113 insertions(+), 14 deletions(-)

diff --git a/docs/querying/math-expr.md b/docs/querying/math-expr.md
index d5255544a03..38ced649c06 100644
--- a/docs/querying/math-expr.md
+++ b/docs/querying/math-expr.md
@@ -184,7 +184,7 @@ See javadoc of java.lang.Math for detailed explanation for 
each function.
 | array_ordinal(arr,long) | returns the array element at the 1 based index 
supplied, or null for an out of range index |
 | array_contains(arr,expr) | returns 1 if the array contains the element 
specified by expr, or contains all elements specified by expr if expr is an 
array, else 0 |
 | array_overlap(arr1,arr2) | returns 1 if arr1 and arr2 have any elements in 
common, else 0 |
-| scalar_in_array(expr, arr) | returns 1 if the scalar is present in the 
array, else 0 |
+| scalar_in_array(expr, arr) | returns 1 if the scalar is present in the 
array, else 0 if the expr is non-null, or null if the expr is null |
 | array_offset_of(arr,expr) | returns the 0 based index of the first 
occurrence of expr in the array, or `null` or `-1` if 
`druid.generic.useDefaultValueForNull=true` (deprecated legacy mode) if no 
matching elements exist in the array. |
 | array_ordinal_of(arr,expr) | returns the 1 based index of the first 
occurrence of expr in the array, or `null` or `-1` if 
`druid.generic.useDefaultValueForNull=true` (deprecated legacy mode) if no 
matching elements exist in the array. |
 | array_prepend(expr,arr) | adds expr to arr at the beginning, the resulting 
array type determined by the type of the array |
diff --git a/docs/querying/sql-array-functions.md 
b/docs/querying/sql-array-functions.md
index ab84c664dee..7b0f2112b6f 100644
--- a/docs/querying/sql-array-functions.md
+++ b/docs/querying/sql-array-functions.md
@@ -52,9 +52,9 @@ The following table describes array functions. To learn more 
about array aggrega
 |`ARRAY_LENGTH(arr)`|Returns length of the array expression.|
 |`ARRAY_OFFSET(arr, long)`|Returns the array element at the 0-based index 
supplied, or null for an out of range index.|
 |`ARRAY_ORDINAL(arr, long)`|Returns the array element at the 1-based index 
supplied, or null for an out of range index.|
-|`ARRAY_CONTAINS(arr, expr)`|If `expr` is a scalar type, returns 1 if `arr` 
contains `expr`. If `expr` is an array, returns 1 if `arr` contains all 
elements of `expr`. Otherwise returns 0.|
-|`ARRAY_OVERLAP(arr1, arr2)`|Returns 1 if `arr1` and `arr2` have any elements 
in common, else 0.|
-| `SCALAR_IN_ARRAY(expr, arr)`|Returns 1 if the scalar `expr` is present in 
`arr`. else 0.|
+|`ARRAY_CONTAINS(arr, expr)`|If `expr` is a scalar type, returns true if `arr` 
contains `expr`. If `expr` is an array, returns true if `arr` contains all 
elements of `expr`. Otherwise returns false.|
+|`ARRAY_OVERLAP(arr1, arr2)`|Returns true if `arr1` and `arr2` have any 
elements in common, else false.|
+|`SCALAR_IN_ARRAY(expr, arr)`|Returns true if the scalar `expr` is present in 
`arr`. Otherwise, returns false if the scalar `expr` is non-null or `UNKNOWN` 
if the scalar `expr` is `NULL`.|
 |`ARRAY_OFFSET_OF(arr, expr)`|Returns the 0-based index of the first 
occurrence of `expr` in the array. If no matching elements exist in the array, 
returns `null` or `-1` if `druid.generic.useDefaultValueForNull=true` 
(deprecated legacy mode).|
 |`ARRAY_ORDINAL_OF(arr, expr)`|Returns the 1-based index of the first 
occurrence of `expr` in the array. If no matching elements exist in the array, 
returns `null` or `-1` if `druid.generic.useDefaultValueForNull=true` 
(deprecated legacy mode).|
 |`ARRAY_PREPEND(expr, arr)`|Adds `expr` to the beginning of `arr`, the 
resulting array type determined by the type of `arr`.|
diff --git a/docs/querying/sql-functions.md b/docs/querying/sql-functions.md
index 093e7ce60fd..883f3b209ac 100644
--- a/docs/querying/sql-functions.md
+++ b/docs/querying/sql-functions.md
@@ -156,7 +156,7 @@ Concatenates array inputs into a single array.
 
 **Function type:** [Array](./sql-array-functions.md)
 
-If `expr` is a scalar type, returns 1 if `arr` contains `expr`. If `expr` is 
an array, returns 1 if `arr` contains all elements of `expr`. Otherwise returns 
0.
+If `expr` is a scalar type, returns true if `arr` contains `expr`. If `expr` 
is an array, returns 1 if `arr` contains all elements of `expr`. Otherwise 
returns false.
 
 
 ## ARRAY_LENGTH
@@ -204,7 +204,7 @@ Returns the 1-based index of the first occurrence of `expr` 
in the array. If no
 
 **Function type:** [Array](./sql-array-functions.md)
 
-Returns 1 if `arr1` and `arr2` have any elements in common, else 0.|
+Returns true if `arr1` and `arr2` have any elements in common, else false.
 
 ## SCALAR_IN_ARRAY
 
@@ -212,7 +212,10 @@ Returns 1 if `arr1` and `arr2` have any elements in 
common, else 0.|
 
 **Function type:** [Array](./sql-array-functions.md)
 
-Returns 1 if the scalar `expr` is present in `arr`, else 0.|
+Returns true if the scalar `expr` is present in `arr`. Otherwise, returns 
false if the scalar `expr` is non-null or
+`UNKNOWN` if the scalar `expr` is `NULL`.
+
+Returns `UNKNOWN` if `arr` is `NULL`.
 
 ## ARRAY_PREPEND
 
diff --git a/processing/src/main/java/org/apache/druid/math/expr/Function.java 
b/processing/src/main/java/org/apache/druid/math/expr/Function.java
index aa54409e132..48bc0570aaa 100644
--- a/processing/src/main/java/org/apache/druid/math/expr/Function.java
+++ b/processing/src/main/java/org/apache/druid/math/expr/Function.java
@@ -45,6 +45,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.Comparator;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
@@ -3724,8 +3725,11 @@ public interface Function extends NamedFunction
     }
   }
 
-  class ArrayScalarInFunction extends ArrayScalarFunction
+  class ScalarInArrayFunction extends ArrayScalarFunction
   {
+    private static final int SCALAR_ARG = 0;
+    private static final int ARRAY_ARG = 1;
+
     @Override
     public String name()
     {
@@ -3742,23 +3746,105 @@ public interface Function extends NamedFunction
     @Override
     Expr getScalarArgument(List<Expr> args)
     {
-      return args.get(0);
+      return args.get(SCALAR_ARG);
     }
 
     @Override
     Expr getArrayArgument(List<Expr> args)
     {
-      return args.get(1);
+      return args.get(ARRAY_ARG);
     }
 
     @Override
-    ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr)
+    ExprEval doApply(ExprEval arrayEval, ExprEval scalarEval)
     {
-      final Object[] array = 
arrayExpr.castTo(scalarExpr.asArrayType()).asArray();
+      final Object[] array = arrayEval.asArray();
       if (array == null) {
         return ExprEval.ofLong(null);
       }
-      return 
ExprEval.ofLongBoolean(Arrays.asList(array).contains(scalarExpr.value()));
+
+      if (scalarEval.value() == null) {
+        return Arrays.asList(array).contains(null) ? 
ExprEval.ofLongBoolean(true) : ExprEval.ofLong(null);
+      }
+
+      final ExpressionType matchType = arrayEval.elementType();
+      final ExprEval<?> scalarEvalForComparison = 
ExprEval.castForEqualityComparison(scalarEval, matchType);
+
+      if (scalarEvalForComparison == null) {
+        return ExprEval.ofLongBoolean(false);
+      } else {
+        return 
ExprEval.ofLongBoolean(Arrays.asList(array).contains(scalarEvalForComparison.value()));
+      }
+    }
+
+    @Override
+    public Function asSingleThreaded(List<Expr> args, 
Expr.InputBindingInspector inspector)
+    {
+      if (args.get(ARRAY_ARG).isLiteral()) {
+        final ExpressionType lhsType = 
args.get(SCALAR_ARG).getOutputType(inspector);
+        if (lhsType == null) {
+          return this;
+        }
+
+        final ExprEval<?> arrayEval = 
args.get(ARRAY_ARG).eval(InputBindings.nilBindings());
+        final Object[] arrayValues = arrayEval.asArray();
+
+        if (arrayValues == null) {
+          return WithNullArray.INSTANCE;
+        } else {
+          final Set<Object> matchValues = new 
HashSet<>(Arrays.asList(arrayValues));
+          final ExpressionType matchType = arrayEval.elementType();
+          return new WithConstantArray(matchValues, matchType);
+        }
+      }
+      return this;
+    }
+
+    /**
+     * Specialization of {@link ScalarInArrayFunction} for null {@link 
#ARRAY_ARG}.
+     */
+    private static final class WithNullArray extends ScalarInArrayFunction
+    {
+      private static final WithNullArray INSTANCE = new WithNullArray();
+
+      @Override
+      public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
+      {
+        return ExprEval.of(null);
+      }
+    }
+
+    /**
+     * Specialization of {@link ScalarInArrayFunction} for constant, non-null 
{@link #ARRAY_ARG}.
+     */
+    private static final class WithConstantArray extends ScalarInArrayFunction
+    {
+      private final Set<Object> matchValues;
+      private final ExpressionType matchType;
+
+      public WithConstantArray(Set<Object> matchValues, ExpressionType 
matchType)
+      {
+        this.matchValues = Preconditions.checkNotNull(matchValues, 
"matchValues");
+        this.matchType = Preconditions.checkNotNull(matchType, "matchType");
+      }
+
+      @Override
+      public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
+      {
+        final ExprEval scalarEval = args.get(SCALAR_ARG).eval(bindings);
+
+        if (scalarEval.value() == null) {
+          return matchValues.contains(null) ? ExprEval.ofLongBoolean(true) : 
ExprEval.ofLong(null);
+        }
+
+        final ExprEval<?> scalarEvalForComparison = 
ExprEval.castForEqualityComparison(scalarEval, matchType);
+
+        if (scalarEvalForComparison == null) {
+          return ExprEval.ofLongBoolean(false);
+        } else {
+          return 
ExprEval.ofLongBoolean(matchValues.contains(scalarEvalForComparison.value()));
+        }
+      }
     }
   }
 
diff --git 
a/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java 
b/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java
index da81a556b0b..d6143fd1fa1 100644
--- a/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java
+++ b/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java
@@ -373,12 +373,15 @@ public class FunctionTest extends 
InitializedNullHandlingTest
   public void testScalarInArray()
   {
     assertExpr("scalar_in_array(2, [1, 2, 3])", 1L);
+    assertExpr("scalar_in_array(2.1, [1, 2, 3])", 0L);
+    assertExpr("scalar_in_array(2, [1.1, 2.1, 3.1])", 0L);
+    assertExpr("scalar_in_array(2, [1.1, 2.0, 3.1])", 1L);
     assertExpr("scalar_in_array(4, [1, 2, 3])", 0L);
     assertExpr("scalar_in_array(b, [3, 4])", 0L);
     assertExpr("scalar_in_array(1, null)", null);
     assertExpr("scalar_in_array(null, null)", null);
     assertExpr("scalar_in_array(null, [1, null, 2])", 1L);
-    assertExpr("scalar_in_array(null, [1, 2])", 0L);
+    assertExpr("scalar_in_array(null, [1, 2])", null);
   }
 
   @Test
@@ -1290,6 +1293,13 @@ public class FunctionTest extends 
InitializedNullHandlingTest
     final Expr singleThreaded = Expr.singleThreaded(expr, bindings);
     Assert.assertEquals(singleThreaded.stringify(), expectedResult, 
singleThreaded.eval(bindings).value());
 
+    final Expr singleThreadedNoFlatten = Expr.singleThreaded(exprNoFlatten, 
bindings);
+    Assert.assertEquals(
+        singleThreadedNoFlatten.stringify(),
+        expectedResult,
+        singleThreadedNoFlatten.eval(bindings).value()
+    );
+
     Assert.assertEquals(expr.stringify(), roundTrip.stringify());
     Assert.assertEquals(expr.stringify(), roundTripFlatten.stringify());
     Assert.assertArrayEquals(expr.getCacheKey(), roundTrip.getCacheKey());


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to