Improve sum aggregate functions

Patch by Alex Petrov; reviewed by Branimir Lambov for CASSANDRA-12417


Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo
Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/04cc3a93
Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/04cc3a93
Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/04cc3a93

Branch: refs/heads/cassandra-3.X
Commit: 04cc3a9309fdc4a8c9ae33ed00d2b681a6bb117a
Parents: f5f44f6
Author: Alex Petrov <oleksandr.pet...@gmail.com>
Authored: Wed Oct 5 10:09:04 2016 +0200
Committer: Aleksey Yeschenko <alek...@apache.org>
Committed: Mon Oct 17 18:28:18 2016 +0100

----------------------------------------------------------------------
 CHANGES.txt                                     |  1 +
 .../cassandra/cql3/functions/AggregateFcts.java | 99 ++++++++++++--------
 .../validation/operations/AggregationTest.java  |  8 ++
 3 files changed, 67 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/cassandra/blob/04cc3a93/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index 4f5bd57..d230462 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
 3.10
+ * Improve sum aggregate functions (CASSANDRA-12417)
  * Make cassandra.yaml docs for batch_size_*_threshold_in_kb reflect changes 
in CASSANDRA-10876 (CASSANDRA-12761)
  * cqlsh fails to format collections when using aliases (CASSANDRA-11534)
  * Check for hash conflicts in prepared statements (CASSANDRA-12733)

http://git-wip-us.apache.org/repos/asf/cassandra/blob/04cc3a93/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java 
b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java
index 4e3b977..530b7ba 100644
--- a/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java
+++ b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java
@@ -480,29 +480,11 @@ public abstract class AggregateFcts
             {
                 public Aggregate newAggregate()
                 {
-                    return new Aggregate()
+                    return new FloatSumAggregate(FloatType.instance)
                     {
-                        private float sum;
-
-                        public void reset()
-                        {
-                            sum = 0;
-                        }
-
-                        public ByteBuffer compute(int protocolVersion)
-                        {
-                            return ((FloatType) returnType()).decompose(sum);
-                        }
-
-                        public void addInput(int protocolVersion, 
List<ByteBuffer> values)
+                        public ByteBuffer compute(int protocolVersion) throws 
InvalidRequestException
                         {
-                            ByteBuffer value = values.get(0);
-
-                            if (value == null)
-                                return;
-
-                            Number number = ((Number) 
argTypes().get(0).compose(value));
-                            sum += number.floatValue();
+                            return FloatType.instance.decompose((float) 
computeInternal());
                         }
                     };
                 }
@@ -534,33 +516,68 @@ public abstract class AggregateFcts
             {
                 public Aggregate newAggregate()
                 {
-                    return new Aggregate()
+                    return new FloatSumAggregate(DoubleType.instance)
                     {
-                        private double sum;
-
-                        public void reset()
+                        public ByteBuffer compute(int protocolVersion) throws 
InvalidRequestException
                         {
-                            sum = 0;
+                            return 
DoubleType.instance.decompose(computeInternal());
                         }
+                    };
+                }
+            };
 
-                        public ByteBuffer compute(int protocolVersion)
-                        {
-                            return ((DoubleType) returnType()).decompose(sum);
-                        }
+    /**
+     * Sum aggregate function for floating point numbers, using double 
arithmetics and
+     * Kahan's algorithm to improve result precision.
+     */
+    private static abstract class FloatSumAggregate implements 
AggregateFunction.Aggregate
+    {
+        private double sum;
+        private double compensation;
+        private double simpleSum;
 
-                        public void addInput(int protocolVersion, 
List<ByteBuffer> values)
-                        {
-                            ByteBuffer value = values.get(0);
+        private final AbstractType numberType;
 
-                            if (value == null)
-                                return;
+        public FloatSumAggregate(AbstractType numberType)
+        {
+            this.numberType = numberType;
+        }
+
+        public void reset()
+        {
+            sum = 0;
+            compensation = 0;
+            simpleSum = 0;
+        }
+
+        public void addInput(int protocolVersion, List<ByteBuffer> values)
+        {
+            ByteBuffer value = values.get(0);
+
+            if (value == null)
+                return;
+
+            double number = ((Number) numberType.compose(value)).doubleValue();
+            simpleSum += number;
+            double tmp = number - compensation;
+            double rounded = sum + tmp;
+            compensation = (rounded - sum) - tmp;
+            sum = rounded;
+        }
+
+        public double computeInternal()
+        {
+            // correctly compute final sum if it's NaN from consequently
+            // adding same-signed infinite values.
+            double tmp = sum + compensation;
+
+            if (Double.isNaN(tmp) && Double.isInfinite(simpleSum))
+                return simpleSum;
+            else
+                return tmp;
+        }
+    }
 
-                            Number number = ((Number) 
argTypes().get(0).compose(value));
-                            sum += number.doubleValue();
-                        }
-                    };
-                }
-            };
     /**
      * Average aggregate for floating point umbers, using double arithmetics 
and Kahan's algorithm
      * to calculate sum by default, switching to BigDecimal on sum overflow. 
Resulting number is

http://git-wip-us.apache.org/repos/asf/cassandra/blob/04cc3a93/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java
----------------------------------------------------------------------
diff --git 
a/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java
 
b/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java
index b01993c..8f03635 100644
--- 
a/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java
+++ 
b/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java
@@ -2043,6 +2043,8 @@ public class AggregationTest extends CQLTester
 
         assertRows(execute("select avg(v1), avg(v2) from %s where bucket in 
(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);"),
                    row(Float.NaN, Double.NaN));
+        assertRows(execute("select sum(v1), sum(v2) from %s where bucket in 
(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);"),
+                   row(Float.NaN, Double.NaN));
     }
 
     @Test
@@ -2062,6 +2064,9 @@ public class AggregationTest extends CQLTester
 
             assertRows(execute("select avg(v1), avg(v2) from %s where bucket 
in (1, 2, 3, 4, 5, 6, 7, 8, 9, 10);"),
                        row(FLOAT_INFINITY, DOUBLE_INFINITY));
+            assertRows(execute("select sum(v1), avg(v2) from %s where bucket 
in (1, 2, 3, 4, 5, 6, 7, 8, 9, 10);"),
+                       row(FLOAT_INFINITY, DOUBLE_INFINITY));
+
             execute("truncate %s");
         }
     }
@@ -2073,5 +2078,8 @@ public class AggregationTest extends CQLTester
 
         for (int i = 1; i <= 17; i++)
             execute("insert into %s (bucket, v1, v2, v3) values (?, ?, ?, ?)", 
i, (float) (i / 10.0), i / 10.0, BigDecimal.valueOf(i / 10.0));
+
+        assertRows(execute("select sum(v1), sum(v2), sum(v3) from %s;"),
+                   row((float) 15.3, 15.3, BigDecimal.valueOf(15.3)));
     }
 }

Reply via email to