This is an automated email from the ASF dual-hosted git repository.

adarshsanjeev pushed a commit to branch 30.0.0
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/30.0.0 by this push:
     new dace56596e1 Use typecasting comparator for numeric "any" aggregations. 
(#16494) (#16499)
dace56596e1 is described below

commit dace56596e18c0cb1156f1bdaa33861604dc1079
Author: Adarsh Sanjeev <[email protected]>
AuthorDate: Mon May 27 14:01:21 2024 +0530

    Use typecasting comparator for numeric "any" aggregations. (#16494) (#16499)
    
    This brings them in line with the behavior of other numeric aggregations.
    It is important because otherwise ClassCastExceptions can arise if comparing
    different numeric types that may arise from deserialization.
    
    Co-authored-by: Gian Merlino <[email protected]>
---
 .../apache/druid/query/aggregation/FloatSumAggregator.java    |  2 +-
 .../org/apache/druid/query/aggregation/LongSumAggregator.java |  2 +-
 .../query/aggregation/any/DoubleAnyAggregatorFactory.java     |  5 ++---
 .../query/aggregation/any/FloatAnyAggregatorFactory.java      |  5 ++---
 .../druid/query/aggregation/any/LongAnyAggregatorFactory.java |  5 ++---
 .../druid/query/aggregation/any/DoubleAnyAggregationTest.java | 11 +++++++++++
 .../druid/query/aggregation/any/FloatAnyAggregationTest.java  | 11 +++++++++++
 .../druid/query/aggregation/any/LongAnyAggregationTest.java   | 11 +++++++++++
 8 files changed, 41 insertions(+), 11 deletions(-)

diff --git 
a/processing/src/main/java/org/apache/druid/query/aggregation/FloatSumAggregator.java
 
b/processing/src/main/java/org/apache/druid/query/aggregation/FloatSumAggregator.java
index 90a2fd4fa76..caa5a26c46e 100644
--- 
a/processing/src/main/java/org/apache/druid/query/aggregation/FloatSumAggregator.java
+++ 
b/processing/src/main/java/org/apache/druid/query/aggregation/FloatSumAggregator.java
@@ -28,7 +28,7 @@ import java.util.Comparator;
  */
 public class FloatSumAggregator implements Aggregator
 {
-  static final Comparator COMPARATOR = new Ordering()
+  public static final Comparator COMPARATOR = new Ordering()
   {
     @Override
     public int compare(Object o, Object o1)
diff --git 
a/processing/src/main/java/org/apache/druid/query/aggregation/LongSumAggregator.java
 
b/processing/src/main/java/org/apache/druid/query/aggregation/LongSumAggregator.java
index 30b339337d1..f9ae93c9d6a 100644
--- 
a/processing/src/main/java/org/apache/druid/query/aggregation/LongSumAggregator.java
+++ 
b/processing/src/main/java/org/apache/druid/query/aggregation/LongSumAggregator.java
@@ -29,7 +29,7 @@ import java.util.Comparator;
  */
 public class LongSumAggregator implements Aggregator
 {
-  static final Comparator COMPARATOR = new Ordering()
+  public static final Comparator COMPARATOR = new Ordering()
   {
     @Override
     public int compare(Object o, Object o1)
diff --git 
a/processing/src/main/java/org/apache/druid/query/aggregation/any/DoubleAnyAggregatorFactory.java
 
b/processing/src/main/java/org/apache/druid/query/aggregation/any/DoubleAnyAggregatorFactory.java
index 0a51e563394..eaebec9da49 100644
--- 
a/processing/src/main/java/org/apache/druid/query/aggregation/any/DoubleAnyAggregatorFactory.java
+++ 
b/processing/src/main/java/org/apache/druid/query/aggregation/any/DoubleAnyAggregatorFactory.java
@@ -28,6 +28,7 @@ import org.apache.druid.query.aggregation.Aggregator;
 import org.apache.druid.query.aggregation.AggregatorFactory;
 import org.apache.druid.query.aggregation.AggregatorUtil;
 import org.apache.druid.query.aggregation.BufferAggregator;
+import org.apache.druid.query.aggregation.DoubleSumAggregator;
 import org.apache.druid.query.aggregation.VectorAggregator;
 import org.apache.druid.query.cache.CacheKeyBuilder;
 import org.apache.druid.segment.BaseDoubleColumnValueSelector;
@@ -48,8 +49,6 @@ import java.util.Objects;
 
 public class DoubleAnyAggregatorFactory extends AggregatorFactory
 {
-  private static final Comparator<Double> VALUE_COMPARATOR = 
Comparator.nullsFirst(Double::compare);
-
   private static final Aggregator NIL_AGGREGATOR = new DoubleAnyAggregator(
       NilColumnValueSelector.instance()
   )
@@ -136,7 +135,7 @@ public class DoubleAnyAggregatorFactory extends 
AggregatorFactory
   @Override
   public Comparator getComparator()
   {
-    return VALUE_COMPARATOR;
+    return DoubleSumAggregator.COMPARATOR;
   }
 
   @Override
diff --git 
a/processing/src/main/java/org/apache/druid/query/aggregation/any/FloatAnyAggregatorFactory.java
 
b/processing/src/main/java/org/apache/druid/query/aggregation/any/FloatAnyAggregatorFactory.java
index a9ee3519b9e..9015a6eda25 100644
--- 
a/processing/src/main/java/org/apache/druid/query/aggregation/any/FloatAnyAggregatorFactory.java
+++ 
b/processing/src/main/java/org/apache/druid/query/aggregation/any/FloatAnyAggregatorFactory.java
@@ -28,6 +28,7 @@ import org.apache.druid.query.aggregation.Aggregator;
 import org.apache.druid.query.aggregation.AggregatorFactory;
 import org.apache.druid.query.aggregation.AggregatorUtil;
 import org.apache.druid.query.aggregation.BufferAggregator;
+import org.apache.druid.query.aggregation.FloatSumAggregator;
 import org.apache.druid.query.aggregation.VectorAggregator;
 import org.apache.druid.query.cache.CacheKeyBuilder;
 import org.apache.druid.segment.BaseFloatColumnValueSelector;
@@ -47,8 +48,6 @@ import java.util.Objects;
 
 public class FloatAnyAggregatorFactory extends AggregatorFactory
 {
-  private static final Comparator<Float> VALUE_COMPARATOR = 
Comparator.nullsFirst(Float::compare);
-
   private static final Aggregator NIL_AGGREGATOR = new FloatAnyAggregator(
       NilColumnValueSelector.instance()
   )
@@ -133,7 +132,7 @@ public class FloatAnyAggregatorFactory extends 
AggregatorFactory
   @Override
   public Comparator getComparator()
   {
-    return VALUE_COMPARATOR;
+    return FloatSumAggregator.COMPARATOR;
   }
 
   @Override
diff --git 
a/processing/src/main/java/org/apache/druid/query/aggregation/any/LongAnyAggregatorFactory.java
 
b/processing/src/main/java/org/apache/druid/query/aggregation/any/LongAnyAggregatorFactory.java
index 9b220337e35..86c2964bf99 100644
--- 
a/processing/src/main/java/org/apache/druid/query/aggregation/any/LongAnyAggregatorFactory.java
+++ 
b/processing/src/main/java/org/apache/druid/query/aggregation/any/LongAnyAggregatorFactory.java
@@ -28,6 +28,7 @@ import org.apache.druid.query.aggregation.Aggregator;
 import org.apache.druid.query.aggregation.AggregatorFactory;
 import org.apache.druid.query.aggregation.AggregatorUtil;
 import org.apache.druid.query.aggregation.BufferAggregator;
+import org.apache.druid.query.aggregation.LongSumAggregator;
 import org.apache.druid.query.aggregation.VectorAggregator;
 import org.apache.druid.query.cache.CacheKeyBuilder;
 import org.apache.druid.segment.BaseLongColumnValueSelector;
@@ -46,8 +47,6 @@ import java.util.List;
 
 public class LongAnyAggregatorFactory extends AggregatorFactory
 {
-  private static final Comparator<Long> VALUE_COMPARATOR = 
Comparator.nullsFirst(Long::compare);
-
   private static final Aggregator NIL_AGGREGATOR = new LongAnyAggregator(
       NilColumnValueSelector.instance()
   )
@@ -132,7 +131,7 @@ public class LongAnyAggregatorFactory extends 
AggregatorFactory
   @Override
   public Comparator getComparator()
   {
-    return VALUE_COMPARATOR;
+    return LongSumAggregator.COMPARATOR;
   }
 
   @Override
diff --git 
a/processing/src/test/java/org/apache/druid/query/aggregation/any/DoubleAnyAggregationTest.java
 
b/processing/src/test/java/org/apache/druid/query/aggregation/any/DoubleAnyAggregationTest.java
index 8866cebaa47..103d5c1d4b9 100644
--- 
a/processing/src/test/java/org/apache/druid/query/aggregation/any/DoubleAnyAggregationTest.java
+++ 
b/processing/src/test/java/org/apache/druid/query/aggregation/any/DoubleAnyAggregationTest.java
@@ -117,6 +117,17 @@ public class DoubleAnyAggregationTest extends 
InitializedNullHandlingTest
     Assert.assertEquals(-1, comparator.compare(d2, d1));
   }
 
+  @Test
+  public void testComparatorWithTypeMismatch()
+  {
+    Long n1 = 3L;
+    Double n2 = 4.0;
+    Comparator comparator = doubleAnyAggFactory.getComparator();
+    Assert.assertEquals(0, comparator.compare(n1, n1));
+    Assert.assertEquals(-1, comparator.compare(n1, n2));
+    Assert.assertEquals(1, comparator.compare(n2, n1));
+  }
+
   @Test
   public void testDoubleAnyCombiningAggregator()
   {
diff --git 
a/processing/src/test/java/org/apache/druid/query/aggregation/any/FloatAnyAggregationTest.java
 
b/processing/src/test/java/org/apache/druid/query/aggregation/any/FloatAnyAggregationTest.java
index d04713d2633..f31e8bbec56 100644
--- 
a/processing/src/test/java/org/apache/druid/query/aggregation/any/FloatAnyAggregationTest.java
+++ 
b/processing/src/test/java/org/apache/druid/query/aggregation/any/FloatAnyAggregationTest.java
@@ -117,6 +117,17 @@ public class FloatAnyAggregationTest extends 
InitializedNullHandlingTest
     Assert.assertEquals(-1, comparator.compare(f2, f1));
   }
 
+  @Test
+  public void testComparatorWithTypeMismatch()
+  {
+    Long n1 = 3L;
+    Float n2 = 4.0f;
+    Comparator comparator = floatAnyAggFactory.getComparator();
+    Assert.assertEquals(0, comparator.compare(n1, n1));
+    Assert.assertEquals(-1, comparator.compare(n1, n2));
+    Assert.assertEquals(1, comparator.compare(n2, n1));
+  }
+
   @Test
   public void testFloatAnyCombiningAggregator()
   {
diff --git 
a/processing/src/test/java/org/apache/druid/query/aggregation/any/LongAnyAggregationTest.java
 
b/processing/src/test/java/org/apache/druid/query/aggregation/any/LongAnyAggregationTest.java
index 89cb8ec8cfa..9525cdfb3f3 100644
--- 
a/processing/src/test/java/org/apache/druid/query/aggregation/any/LongAnyAggregationTest.java
+++ 
b/processing/src/test/java/org/apache/druid/query/aggregation/any/LongAnyAggregationTest.java
@@ -118,6 +118,17 @@ public class LongAnyAggregationTest extends 
InitializedNullHandlingTest
     Assert.assertEquals(-1, comparator.compare(l2, l1));
   }
 
+  @Test
+  public void testComparatorWithTypeMismatch()
+  {
+    Integer n1 = 3;
+    Long n2 = 4L;
+    Comparator comparator = longAnyAggFactory.getComparator();
+    Assert.assertEquals(0, comparator.compare(n1, n1));
+    Assert.assertEquals(-1, comparator.compare(n1, n2));
+    Assert.assertEquals(1, comparator.compare(n2, n1));
+  }
+
   @Test
   public void testLongAnyCombiningAggregator()
   {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to