Improve sum aggregate functions Patch by Alex Petrov; reviewed by Branimir Lambov for CASSANDRA-12417
Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/04cc3a93 Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/04cc3a93 Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/04cc3a93 Branch: refs/heads/cassandra-3.X Commit: 04cc3a9309fdc4a8c9ae33ed00d2b681a6bb117a Parents: f5f44f6 Author: Alex Petrov <oleksandr.pet...@gmail.com> Authored: Wed Oct 5 10:09:04 2016 +0200 Committer: Aleksey Yeschenko <alek...@apache.org> Committed: Mon Oct 17 18:28:18 2016 +0100 ---------------------------------------------------------------------- CHANGES.txt | 1 + .../cassandra/cql3/functions/AggregateFcts.java | 99 ++++++++++++-------- .../validation/operations/AggregationTest.java | 8 ++ 3 files changed, 67 insertions(+), 41 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/cassandra/blob/04cc3a93/CHANGES.txt ---------------------------------------------------------------------- diff --git a/CHANGES.txt b/CHANGES.txt index 4f5bd57..d230462 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,4 +1,5 @@ 3.10 + * Improve sum aggregate functions (CASSANDRA-12417) * Make cassandra.yaml docs for batch_size_*_threshold_in_kb reflect changes in CASSANDRA-10876 (CASSANDRA-12761) * cqlsh fails to format collections when using aliases (CASSANDRA-11534) * Check for hash conflicts in prepared statements (CASSANDRA-12733) http://git-wip-us.apache.org/repos/asf/cassandra/blob/04cc3a93/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java index 4e3b977..530b7ba 100644 --- a/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java +++ b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java @@ -480,29 +480,11 @@ public abstract class AggregateFcts { public Aggregate newAggregate() { - return new Aggregate() + return new FloatSumAggregate(FloatType.instance) { - private float sum; - - public void reset() - { - sum = 0; - } - - public ByteBuffer compute(int protocolVersion) - { - return ((FloatType) returnType()).decompose(sum); - } - - public void addInput(int protocolVersion, List<ByteBuffer> values) + public ByteBuffer compute(int protocolVersion) throws InvalidRequestException { - ByteBuffer value = values.get(0); - - if (value == null) - return; - - Number number = ((Number) argTypes().get(0).compose(value)); - sum += number.floatValue(); + return FloatType.instance.decompose((float) computeInternal()); } }; } @@ -534,33 +516,68 @@ public abstract class AggregateFcts { public Aggregate newAggregate() { - return new Aggregate() + return new FloatSumAggregate(DoubleType.instance) { - private double sum; - - public void reset() + public ByteBuffer compute(int protocolVersion) throws InvalidRequestException { - sum = 0; + return DoubleType.instance.decompose(computeInternal()); } + }; + } + }; - public ByteBuffer compute(int protocolVersion) - { - return ((DoubleType) returnType()).decompose(sum); - } + /** + * Sum aggregate function for floating point numbers, using double arithmetics and + * Kahan's algorithm to improve result precision. + */ + private static abstract class FloatSumAggregate implements AggregateFunction.Aggregate + { + private double sum; + private double compensation; + private double simpleSum; - public void addInput(int protocolVersion, List<ByteBuffer> values) - { - ByteBuffer value = values.get(0); + private final AbstractType numberType; - if (value == null) - return; + public FloatSumAggregate(AbstractType numberType) + { + this.numberType = numberType; + } + + public void reset() + { + sum = 0; + compensation = 0; + simpleSum = 0; + } + + public void addInput(int protocolVersion, List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + double number = ((Number) numberType.compose(value)).doubleValue(); + simpleSum += number; + double tmp = number - compensation; + double rounded = sum + tmp; + compensation = (rounded - sum) - tmp; + sum = rounded; + } + + public double computeInternal() + { + // correctly compute final sum if it's NaN from consequently + // adding same-signed infinite values. + double tmp = sum + compensation; + + if (Double.isNaN(tmp) && Double.isInfinite(simpleSum)) + return simpleSum; + else + return tmp; + } + } - Number number = ((Number) argTypes().get(0).compose(value)); - sum += number.doubleValue(); - } - }; - } - }; /** * Average aggregate for floating point umbers, using double arithmetics and Kahan's algorithm * to calculate sum by default, switching to BigDecimal on sum overflow. Resulting number is http://git-wip-us.apache.org/repos/asf/cassandra/blob/04cc3a93/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java b/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java index b01993c..8f03635 100644 --- a/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java +++ b/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java @@ -2043,6 +2043,8 @@ public class AggregationTest extends CQLTester assertRows(execute("select avg(v1), avg(v2) from %s where bucket in (1, 2, 3, 4, 5, 6, 7, 8, 9, 10);"), row(Float.NaN, Double.NaN)); + assertRows(execute("select sum(v1), sum(v2) from %s where bucket in (1, 2, 3, 4, 5, 6, 7, 8, 9, 10);"), + row(Float.NaN, Double.NaN)); } @Test @@ -2062,6 +2064,9 @@ public class AggregationTest extends CQLTester assertRows(execute("select avg(v1), avg(v2) from %s where bucket in (1, 2, 3, 4, 5, 6, 7, 8, 9, 10);"), row(FLOAT_INFINITY, DOUBLE_INFINITY)); + assertRows(execute("select sum(v1), avg(v2) from %s where bucket in (1, 2, 3, 4, 5, 6, 7, 8, 9, 10);"), + row(FLOAT_INFINITY, DOUBLE_INFINITY)); + execute("truncate %s"); } } @@ -2073,5 +2078,8 @@ public class AggregationTest extends CQLTester for (int i = 1; i <= 17; i++) execute("insert into %s (bucket, v1, v2, v3) values (?, ?, ?, ?)", i, (float) (i / 10.0), i / 10.0, BigDecimal.valueOf(i / 10.0)); + + assertRows(execute("select sum(v1), sum(v2), sum(v3) from %s;"), + row((float) 15.3, 15.3, BigDecimal.valueOf(15.3))); } }