This is an automated email from the ASF dual-hosted git repository.

rohangarg pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new 7ae6cc6e60 Fix string first/last aggregator comparator (#12773)
7ae6cc6e60 is described below

commit 7ae6cc6e60b6bb7fd83c6a649861a5ddd982ea72
Author: Rohan Garg <[email protected]>
AuthorDate: Mon Aug 1 20:54:15 2022 +0530

    Fix string first/last aggregator comparator (#12773)
---
 .../SerializablePairLongStringSerde.java           | 11 ++++--
 .../first/StringFirstAggregatorFactory.java        | 42 ++++------------------
 .../first/StringFirstAggregationTest.java          | 27 ++++++++++++++
 .../last/StringLastAggregationTest.java            | 27 ++++++++++++++
 4 files changed, 69 insertions(+), 38 deletions(-)

diff --git 
a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongStringSerde.java
 
b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongStringSerde.java
index ab40ec7a18..49300ff531 100644
--- 
a/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongStringSerde.java
+++ 
b/processing/src/main/java/org/apache/druid/query/aggregation/SerializablePairLongStringSerde.java
@@ -19,9 +19,9 @@
 
 package org.apache.druid.query.aggregation;
 
+import org.apache.druid.collections.SerializablePair;
 import org.apache.druid.data.input.InputRow;
 import org.apache.druid.java.util.common.StringUtils;
-import org.apache.druid.query.aggregation.first.StringFirstAggregatorFactory;
 import org.apache.druid.segment.GenericColumnSerializer;
 import org.apache.druid.segment.column.ColumnBuilder;
 import org.apache.druid.segment.data.GenericIndexed;
@@ -34,6 +34,7 @@ import 
org.apache.druid.segment.writeout.SegmentWriteOutMedium;
 
 import javax.annotation.Nullable;
 import java.nio.ByteBuffer;
+import java.util.Comparator;
 
 /**
  * The SerializablePairLongStringSerde serializes a Long-String pair 
(SerializablePairLongString).
@@ -46,6 +47,12 @@ public class SerializablePairLongStringSerde extends 
ComplexMetricSerde
 {
 
   private static final String TYPE_NAME = "serializablePairLongString";
+  // Null SerializablePairLongString values are put first
+  private static final Comparator<SerializablePairLongString> COMPARATOR = 
Comparator.nullsFirst(
+      // assumes that the LHS of the pair will never be null
+      
Comparator.<SerializablePairLongString>comparingLong(SerializablePair::getLhs)
+                .thenComparing(SerializablePair::getRhs, 
Comparator.nullsFirst(Comparator.naturalOrder()))
+  );
 
   @Override
   public String getTypeName()
@@ -87,7 +94,7 @@ public class SerializablePairLongStringSerde extends 
ComplexMetricSerde
       @Override
       public int compare(@Nullable SerializablePairLongString o1, @Nullable 
SerializablePairLongString o2)
       {
-        return StringFirstAggregatorFactory.VALUE_COMPARATOR.compare(o1, o2);
+        return COMPARATOR.compare(o1, o2);
       }
 
       @Override
diff --git 
a/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstAggregatorFactory.java
 
b/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstAggregatorFactory.java
index f808d1debf..1e0c9bedba 100644
--- 
a/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstAggregatorFactory.java
+++ 
b/processing/src/main/java/org/apache/druid/query/aggregation/first/StringFirstAggregatorFactory.java
@@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
 import com.fasterxml.jackson.annotation.JsonTypeName;
 import com.google.common.base.Preconditions;
 import com.google.common.primitives.Longs;
+import org.apache.druid.collections.SerializablePair;
 import org.apache.druid.java.util.common.IAE;
 import org.apache.druid.query.aggregation.AggregateCombiner;
 import org.apache.druid.query.aggregation.Aggregator;
@@ -87,42 +88,11 @@ public class StringFirstAggregatorFactory extends 
AggregatorFactory
       ((SerializablePairLongString) o2).lhs
   );
 
-  public static final Comparator<SerializablePairLongString> VALUE_COMPARATOR 
= (o1, o2) -> {
-    int comparation;
-
-    // First we check if the objects are null
-    if (o1 == null && o2 == null) {
-      comparation = 0;
-    } else if (o1 == null) {
-      comparation = -1;
-    } else if (o2 == null) {
-      comparation = 1;
-    } else {
-
-      // If the objects are not null, we will try to compare using timestamp
-      comparation = o1.lhs.compareTo(o2.lhs);
-
-      // If both timestamp are the same, we try to compare the Strings
-      if (comparation == 0) {
-
-        // First we check if the strings are null
-        if (o1.rhs == null && o2.rhs == null) {
-          comparation = 0;
-        } else if (o1.rhs == null) {
-          comparation = -1;
-        } else if (o2.rhs == null) {
-          comparation = 1;
-        } else {
-
-          // If the strings are not null, we will compare them
-          // Note: This comparation maybe doesn't make sense to first/last 
aggregators
-          comparation = o1.rhs.compareTo(o2.rhs);
-        }
-      }
-    }
-
-    return comparation;
-  };
+  // used in comparing aggregation results amongst distinct groups. hence the 
comparison is done on the finalized
+  // result which is string/value part of the result pair. Null 
SerializablePairLongString values are put first.
+  public static final Comparator<SerializablePairLongString> VALUE_COMPARATOR 
= Comparator.nullsFirst(
+      Comparator.comparing(SerializablePair::getRhs, 
Comparator.nullsFirst(Comparator.naturalOrder()))
+  );
 
   private final String fieldName;
   private final String name;
diff --git 
a/processing/src/test/java/org/apache/druid/query/aggregation/first/StringFirstAggregationTest.java
 
b/processing/src/test/java/org/apache/druid/query/aggregation/first/StringFirstAggregationTest.java
index 2c8cfb84bd..d2f332f0bb 100644
--- 
a/processing/src/test/java/org/apache/druid/query/aggregation/first/StringFirstAggregationTest.java
+++ 
b/processing/src/test/java/org/apache/druid/query/aggregation/first/StringFirstAggregationTest.java
@@ -39,6 +39,7 @@ import org.junit.Before;
 import org.junit.Test;
 
 import java.nio.ByteBuffer;
+import java.util.Comparator;
 
 public class StringFirstAggregationTest extends InitializedNullHandlingTest
 {
@@ -215,6 +216,32 @@ public class StringFirstAggregationTest extends 
InitializedNullHandlingTest
     Assert.assertEquals(pairs[1], stringFirstAggregateCombiner.getObject());
   }
 
+  @Test
+  @SuppressWarnings("EqualsWithItself")
+  public void testStringLastAggregatorComparator()
+  {
+    Comparator<SerializablePairLongString> comparator =
+        (Comparator<SerializablePairLongString>) 
stringFirstAggFactory.getComparator();
+    SerializablePairLongString pair1 = new SerializablePairLongString(1L, "Z");
+    SerializablePairLongString pair2 = new SerializablePairLongString(2L, "A");
+    SerializablePairLongString pair3 = new SerializablePairLongString(3L, 
null);
+
+    // check non null values
+    Assert.assertEquals(0, comparator.compare(pair1, pair1));
+    Assert.assertTrue(comparator.compare(pair1, pair2) > 0);
+    Assert.assertTrue(comparator.compare(pair2, pair1) < 0);
+
+    // check non null value with null value (null values first comparator)
+    Assert.assertEquals(0, comparator.compare(pair3, pair3));
+    Assert.assertTrue(comparator.compare(pair1, pair3) > 0);
+    Assert.assertTrue(comparator.compare(pair3, pair1) < 0);
+
+    // check non null pair with null pair (null pairs first comparator)
+    Assert.assertEquals(0, comparator.compare(null, null));
+    Assert.assertTrue(comparator.compare(pair1, null) > 0);
+    Assert.assertTrue(comparator.compare(null, pair1) < 0);
+  }
+
   private void aggregate(
       Aggregator agg
   )
diff --git 
a/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastAggregationTest.java
 
b/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastAggregationTest.java
index da402b5535..6a6aa1e650 100644
--- 
a/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastAggregationTest.java
+++ 
b/processing/src/test/java/org/apache/druid/query/aggregation/last/StringLastAggregationTest.java
@@ -38,6 +38,7 @@ import org.junit.Before;
 import org.junit.Test;
 
 import java.nio.ByteBuffer;
+import java.util.Comparator;
 
 public class StringLastAggregationTest
 {
@@ -217,6 +218,32 @@ public class StringLastAggregationTest
     Assert.assertEquals(pairs[1], stringFirstAggregateCombiner.getObject());
   }
 
+  @Test
+  @SuppressWarnings("EqualsWithItself")
+  public void testStringLastAggregatorComparator()
+  {
+    Comparator<SerializablePairLongString> comparator =
+        (Comparator<SerializablePairLongString>) 
stringLastAggFactory.getComparator();
+    SerializablePairLongString pair1 = new SerializablePairLongString(1L, "Z");
+    SerializablePairLongString pair2 = new SerializablePairLongString(2L, "A");
+    SerializablePairLongString pair3 = new SerializablePairLongString(3L, 
null);
+
+    // check non null values
+    Assert.assertEquals(0, comparator.compare(pair1, pair1));
+    Assert.assertTrue(comparator.compare(pair1, pair2) > 0);
+    Assert.assertTrue(comparator.compare(pair2, pair1) < 0);
+
+    // check non null value with null value (null values first comparator)
+    Assert.assertEquals(0, comparator.compare(pair3, pair3));
+    Assert.assertTrue(comparator.compare(pair1, pair3) > 0);
+    Assert.assertTrue(comparator.compare(pair3, pair1) < 0);
+
+    // check non null pair with null pair (null pairs first comparator)
+    Assert.assertEquals(0, comparator.compare(null, null));
+    Assert.assertTrue(comparator.compare(pair1, null) > 0);
+    Assert.assertTrue(comparator.compare(null, pair1) < 0);
+  }
+
   private void aggregate(
       Aggregator agg
   )


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to