Improve avg aggregate functions Patch by Alex Petrov; reviewed by Branimir Lambov for CASSANDRA-12417
Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/7872318d Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/7872318d Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/7872318d Branch: refs/heads/trunk Commit: 7872318d63009193415ba1365beedc2303a92386 Parents: 153583b Author: Alex Petrov <oleksandr.pet...@gmail.com> Authored: Wed Oct 5 09:46:05 2016 +0200 Committer: Aleksey Yeschenko <alek...@apache.org> Committed: Mon Oct 17 18:16:36 2016 +0100 ---------------------------------------------------------------------- CHANGES.txt | 1 + .../cassandra/cql3/functions/AggregateFcts.java | 327 ++++++++++--------- .../validation/operations/AggregationTest.java | 110 +++++++ 3 files changed, 282 insertions(+), 156 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/cassandra/blob/7872318d/CHANGES.txt ---------------------------------------------------------------------- diff --git a/CHANGES.txt b/CHANGES.txt index e82eedd..d5419c6 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,4 +1,5 @@ 3.0.10 + * Improve avg aggregate functions (CASSANDRA-12417) * Preserve quoted reserved keyword column names in MV creation (CASSANDRA-11803) * nodetool stopdaemon errors out (CASSANDRA-12646) * Split materialized view mutations on build to prevent OOM (CASSANDRA-12268) http://git-wip-us.apache.org/repos/asf/cassandra/blob/7872318d/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java index b2cae50..441fa58 100644 --- a/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java +++ b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java @@ -27,6 +27,7 @@ import java.util.List; import org.apache.cassandra.cql3.CQL3Type; import org.apache.cassandra.db.marshal.*; +import org.apache.cassandra.exceptions.InvalidRequestException; /** * Factory methods for aggregate functions. @@ -114,7 +115,7 @@ public abstract class AggregateFcts public ByteBuffer compute(int protocolVersion) { - return ((LongType) returnType()).decompose(count); + return LongType.instance.decompose(count); } public void addInput(int protocolVersion, List<ByteBuffer> values) @@ -155,7 +156,7 @@ public abstract class AggregateFcts if (value == null) return; - BigDecimal number = ((BigDecimal) argTypes().get(0).compose(value)); + BigDecimal number = DecimalType.instance.compose(value); sum = sum.add(number); } }; @@ -172,22 +173,19 @@ public abstract class AggregateFcts { return new Aggregate() { - private BigDecimal sum = BigDecimal.ZERO; + private BigDecimal avg = BigDecimal.ZERO; private int count; public void reset() { count = 0; - sum = BigDecimal.ZERO; + avg = BigDecimal.ZERO; } public ByteBuffer compute(int protocolVersion) { - if (count == 0) - return DecimalType.instance.decompose(BigDecimal.ZERO); - - return DecimalType.instance.decompose(sum.divide(BigDecimal.valueOf(count), BigDecimal.ROUND_HALF_EVEN)); + return DecimalType.instance.decompose(avg); } public void addInput(int protocolVersion, List<ByteBuffer> values) @@ -199,7 +197,9 @@ public abstract class AggregateFcts count++; BigDecimal number = DecimalType.instance.compose(value); - sum = sum.add(number); + + // avg = avg + (value - sum) / count. + avg = avg.add(number.subtract(avg).divide(BigDecimal.valueOf(count), RoundingMode.HALF_EVEN)); } }; } @@ -235,7 +235,7 @@ public abstract class AggregateFcts if (value == null) return; - BigInteger number = ((BigInteger) argTypes().get(0).compose(value)); + BigInteger number = IntegerType.instance.compose(value); sum = sum.add(number); } }; @@ -265,9 +265,9 @@ public abstract class AggregateFcts public ByteBuffer compute(int protocolVersion) { if (count == 0) - return ((IntegerType) returnType()).decompose(BigInteger.ZERO); + return IntegerType.instance.decompose(BigInteger.ZERO); - return ((IntegerType) returnType()).decompose(sum.divide(BigInteger.valueOf(count))); + return IntegerType.instance.decompose(sum.divide(BigInteger.valueOf(count))); } public void addInput(int protocolVersion, List<ByteBuffer> values) @@ -329,35 +329,11 @@ public abstract class AggregateFcts { public Aggregate newAggregate() { - return new Aggregate() + return new AvgAggregate(ByteType.instance) { - private byte sum; - - private int count; - - public void reset() + public ByteBuffer compute(int protocolVersion) throws InvalidRequestException { - count = 0; - sum = 0; - } - - public ByteBuffer compute(int protocolVersion) - { - int avg = count == 0 ? 0 : sum / count; - - return ((ByteType) returnType()).decompose((byte) avg); - } - - public void addInput(int protocolVersion, List<ByteBuffer> values) - { - ByteBuffer value = values.get(0); - - if (value == null) - return; - - count++; - Number number = ((Number) argTypes().get(0).compose(value)); - sum += number.byteValue(); + return ByteType.instance.decompose((byte) computeInternal()); } }; } @@ -407,35 +383,11 @@ public abstract class AggregateFcts { public Aggregate newAggregate() { - return new Aggregate() + return new AvgAggregate(ShortType.instance) { - private short sum; - - private int count; - - public void reset() - { - count = 0; - sum = 0; - } - public ByteBuffer compute(int protocolVersion) { - int avg = count == 0 ? 0 : sum / count; - - return ((ShortType) returnType()).decompose((short) avg); - } - - public void addInput(int protocolVersion, List<ByteBuffer> values) - { - ByteBuffer value = values.get(0); - - if (value == null) - return; - - count++; - Number number = ((Number) argTypes().get(0).compose(value)); - sum += number.shortValue(); + return ShortType.instance.decompose((short) computeInternal()); } }; } @@ -485,35 +437,11 @@ public abstract class AggregateFcts { public Aggregate newAggregate() { - return new Aggregate() + return new AvgAggregate(Int32Type.instance) { - private int sum; - - private int count; - - public void reset() - { - count = 0; - sum = 0; - } - public ByteBuffer compute(int protocolVersion) { - int avg = count == 0 ? 0 : sum / count; - - return ((Int32Type) returnType()).decompose(avg); - } - - public void addInput(int protocolVersion, List<ByteBuffer> values) - { - ByteBuffer value = values.get(0); - - if (value == null) - return; - - count++; - Number number = ((Number) argTypes().get(0).compose(value)); - sum += number.intValue(); + return Int32Type.instance.decompose((int) computeInternal()); } }; } @@ -539,7 +467,13 @@ public abstract class AggregateFcts { public Aggregate newAggregate() { - return new LongAvgAggregate(); + return new AvgAggregate(LongType.instance) + { + public ByteBuffer compute(int protocolVersion) + { + return LongType.instance.decompose(computeInternal()); + } + }; } }; @@ -587,35 +521,11 @@ public abstract class AggregateFcts { public Aggregate newAggregate() { - return new Aggregate() + return new FloatAvgAggregate(FloatType.instance) { - private float sum; - - private int count; - - public void reset() + public ByteBuffer compute(int protocolVersion) throws InvalidRequestException { - count = 0; - sum = 0; - } - - public ByteBuffer compute(int protocolVersion) - { - float avg = count == 0 ? 0 : sum / count; - - return ((FloatType) returnType()).decompose(avg); - } - - public void addInput(int protocolVersion, List<ByteBuffer> values) - { - ByteBuffer value = values.get(0); - - if (value == null) - return; - - count++; - Number number = ((Number) argTypes().get(0).compose(value)); - sum += number.floatValue(); + return FloatType.instance.decompose((float) computeInternal()); } }; } @@ -656,6 +566,95 @@ public abstract class AggregateFcts }; } }; + /** + * Average aggregate for floating point umbers, using double arithmetics and Kahan's algorithm + * to calculate sum by default, switching to BigDecimal on sum overflow. Resulting number is + * converted to corresponding representation by concrete implementations. + */ + private static abstract class FloatAvgAggregate implements AggregateFunction.Aggregate + { + private double sum; + private double compensation; + private double simpleSum; + + private int count; + + private BigDecimal bigSum = null; + private boolean overflow = false; + + private final AbstractType numberType; + + public FloatAvgAggregate(AbstractType numberType) + { + this.numberType = numberType; + } + + public void reset() + { + sum = 0; + compensation = 0; + simpleSum = 0; + + count = 0; + bigSum = null; + overflow = false; + } + + public double computeInternal() + { + if (count == 0) + return 0d; + + if (overflow) + { + return bigSum.divide(BigDecimal.valueOf(count), RoundingMode.HALF_EVEN).doubleValue(); + } + else + { + // correctly compute final sum if it's NaN from consequently + // adding same-signed infinite values. + double tmp = sum + compensation; + if (Double.isNaN(tmp) && Double.isInfinite(simpleSum)) + sum = simpleSum; + else + sum = tmp; + + return sum / count; + } + } + + public void addInput(int protocolVersion, List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + count++; + + double number = ((Number) numberType.compose(value)).doubleValue(); + + if (overflow) + { + bigSum = bigSum.add(BigDecimal.valueOf(number)); + } + else + { + simpleSum += number; + double prev = sum; + double tmp = number - compensation; + double rounded = sum + tmp; + compensation = (rounded - sum) - tmp; + sum = rounded; + + if (Double.isInfinite(sum) && !Double.isInfinite(number)) + { + overflow = true; + bigSum = BigDecimal.valueOf(prev).add(BigDecimal.valueOf(number)); + } + } + } + } /** * AVG function for double values. @@ -665,35 +664,11 @@ public abstract class AggregateFcts { public Aggregate newAggregate() { - return new Aggregate() + return new FloatAvgAggregate(DoubleType.instance) { - private double sum; - - private int count; - - public void reset() - { - count = 0; - sum = 0; - } - - public ByteBuffer compute(int protocolVersion) - { - double avg = count == 0 ? 0 : sum / count; - - return ((DoubleType) returnType()).decompose(avg); - } - - public void addInput(int protocolVersion, List<ByteBuffer> values) + public ByteBuffer compute(int protocolVersion) throws InvalidRequestException { - ByteBuffer value = values.get(0); - - if (value == null) - return; - - count++; - Number number = ((Number) argTypes().get(0).compose(value)); - sum += number.doubleValue(); + return DoubleType.instance.decompose(computeInternal()); } }; } @@ -719,7 +694,13 @@ public abstract class AggregateFcts { public Aggregate newAggregate() { - return new LongAvgAggregate(); + return new AvgAggregate(LongType.instance) + { + public ByteBuffer compute(int protocolVersion) throws InvalidRequestException + { + return CounterColumnType.instance.decompose(computeInternal()); + } + }; } }; @@ -947,23 +928,43 @@ public abstract class AggregateFcts } } - private static class LongAvgAggregate implements AggregateFunction.Aggregate + /** + * Average aggregate class, collecting the sum using long arithmetics, falling back + * to BigInteger on long overflow. Resulting number is converted to corresponding + * representation by concrete implementations. + */ + private static abstract class AvgAggregate implements AggregateFunction.Aggregate { private long sum; - private int count; + private BigInteger bigSum = null; + private boolean overflow = false; + + private final AbstractType numberType; + + public AvgAggregate(AbstractType type) + { + this.numberType = type; + } public void reset() { count = 0; - sum = 0; + sum = 0L; + overflow = false; + bigSum = null; } - public ByteBuffer compute(int protocolVersion) + long computeInternal() { - long avg = count == 0 ? 0 : sum / count; - - return LongType.instance.decompose(avg); + if (overflow) + { + return bigSum.divide(BigInteger.valueOf(count)).longValue(); + } + else + { + return count == 0 ? 0 : (sum / count); + } } public void addInput(int protocolVersion, List<ByteBuffer> values) @@ -974,8 +975,22 @@ public abstract class AggregateFcts return; count++; - Number number = LongType.instance.compose(value); - sum += number.longValue(); + long number = ((Number) numberType.compose(value)).longValue(); + if (overflow) + { + bigSum = bigSum.add(BigInteger.valueOf(number)); + } + else + { + long prev = sum; + sum += number; + + if (((prev ^ sum) & (number ^ sum)) < 0) + { + overflow = true; + bigSum = BigInteger.valueOf(prev).add(BigInteger.valueOf(number)); + } + } } } } http://git-wip-us.apache.org/repos/asf/cassandra/blob/7872318d/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java b/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java index 411d5ee..2e7dc1a 100644 --- a/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java +++ b/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java @@ -19,6 +19,7 @@ package org.apache.cassandra.cql3.validation.operations; import java.math.BigDecimal; import java.math.MathContext; +import java.math.BigInteger; import java.math.RoundingMode; import java.nio.ByteBuffer; import java.text.SimpleDateFormat; @@ -28,6 +29,7 @@ import java.util.Date; import java.util.Locale; import java.util.TimeZone; import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.DoubleStream; import org.apache.commons.lang3.time.DateUtils; @@ -1969,4 +1971,112 @@ public class AggregationTest extends CQLTester assertRows(execute("select avg(val) from %s where bucket in (1, 2, 3);"), row(a)); } + + @Test + public void testAggregatesWithoutOverflow() throws Throwable + { + createTable("create table %s (bucket int primary key, v1 tinyint, v2 smallint, v3 int, v4 bigint, v5 varint)"); + for (int i = 1; i <= 3; i++) + execute("insert into %s (bucket, v1, v2, v3, v4, v5) values (?, ?, ?, ?, ?, ?)", i, + (byte) ((Byte.MAX_VALUE / 3) + i), (short) ((Short.MAX_VALUE / 3) + i), (Integer.MAX_VALUE / 3) + i, (Long.MAX_VALUE / 3) + i, + BigInteger.valueOf(Long.MAX_VALUE).add(BigInteger.valueOf(i))); + + assertRows(execute("select avg(v1), avg(v2), avg(v3), avg(v4), avg(v5) from %s where bucket in (1, 2, 3);"), + row((byte) ((Byte.MAX_VALUE / 3) + 2), (short) ((Short.MAX_VALUE / 3) + 2), (Integer.MAX_VALUE / 3) + 2, (Long.MAX_VALUE / 3) + 2, + BigInteger.valueOf(Long.MAX_VALUE).add(BigInteger.valueOf(2)))); + + for (int i = 1; i <= 3; i++) + execute("insert into %s (bucket, v1, v2, v3, v4, v5) values (?, ?, ?, ?, ?, ?)", i + 3, + (byte) (100 + i), (short) (100 + i), 100 + i, 100L + i, BigInteger.valueOf(100 + i)); + + assertRows(execute("select avg(v1), avg(v2), avg(v3), avg(v4), avg(v5) from %s where bucket in (4, 5, 6);"), + row((byte) 102, (short) 102, 102, 102L, BigInteger.valueOf(102))); + } + + @Test + public void testAggregateOverflow() throws Throwable + { + createTable("create table %s (bucket int primary key, v1 tinyint, v2 smallint, v3 int, v4 bigint, v5 varint)"); + for (int i = 1; i <= 3; i++) + execute("insert into %s (bucket, v1, v2, v3, v4, v5) values (?, ?, ?, ?, ?, ?)", i, + Byte.MAX_VALUE, Short.MAX_VALUE, Integer.MAX_VALUE, Long.MAX_VALUE, BigInteger.valueOf(Long.MAX_VALUE).multiply(BigInteger.valueOf(2))); + + assertRows(execute("select avg(v1), avg(v2), avg(v3), avg(v4), avg(v5) from %s where bucket in (1, 2, 3);"), + row(Byte.MAX_VALUE, Short.MAX_VALUE, Integer.MAX_VALUE, Long.MAX_VALUE, BigInteger.valueOf(Long.MAX_VALUE).multiply(BigInteger.valueOf(2)))); + + execute("truncate %s"); + + for (int i = 1; i <= 3; i++) + execute("insert into %s (bucket, v1, v2, v3, v4, v5) values (?, ?, ?, ?, ?, ?)", i, + Byte.MIN_VALUE, Short.MIN_VALUE, Integer.MIN_VALUE, Long.MIN_VALUE, BigInteger.valueOf(Long.MIN_VALUE).multiply(BigInteger.valueOf(2))); + + assertRows(execute("select avg(v1), avg(v2), avg(v3), avg(v4), avg(v5) from %s where bucket in (1, 2, 3);"), + row(Byte.MIN_VALUE, Short.MIN_VALUE, Integer.MIN_VALUE, Long.MIN_VALUE, BigInteger.valueOf(Long.MIN_VALUE).multiply(BigInteger.valueOf(2)))); + + } + + @Test + public void testDoubleAggregatesPrecision() throws Throwable + { + createTable("create table %s (bucket int primary key, v1 float, v2 double, v3 decimal)"); + + for (int i = 1; i <= 3; i++) + execute("insert into %s (bucket, v1, v2, v3) values (?, ?, ?, ?)", i, + Float.MAX_VALUE, Double.MAX_VALUE, BigDecimal.valueOf(Double.MAX_VALUE).add(BigDecimal.valueOf(2))); + + assertRows(execute("select avg(v1), avg(v2), avg(v3) from %s where bucket in (1, 2, 3);"), + row(Float.MAX_VALUE, Double.MAX_VALUE, BigDecimal.valueOf(Double.MAX_VALUE).add(BigDecimal.valueOf(2)))); + + execute("insert into %s (bucket, v1, v2, v3) values (?, ?, ?, ?)", 4, (float) 100.10, 100.10, BigDecimal.valueOf(100.10)); + execute("insert into %s (bucket, v1, v2, v3) values (?, ?, ?, ?)", 5, (float) 110.11, 110.11, BigDecimal.valueOf(110.11)); + execute("insert into %s (bucket, v1, v2, v3) values (?, ?, ?, ?)", 6, (float) 120.12, 120.12, BigDecimal.valueOf(120.12)); + + assertRows(execute("select avg(v1), avg(v2), avg(v3) from %s where bucket in (4, 5, 6);"), + row((float) 110.11, 110.11, BigDecimal.valueOf(110.11))); + } + + @Test + public void testNan() throws Throwable + { + createTable("create table %s (bucket int primary key, v1 float, v2 double)"); + + for (int i = 1; i <= 10; i++) + if (i != 5) + execute("insert into %s (bucket, v1, v2) values (?, ?, ?)", i, (float) i, (double) i); + + execute("insert into %s (bucket, v1, v2) values (?, ?, ?)", 5, Float.NaN, Double.NaN); + + assertRows(execute("select avg(v1), avg(v2) from %s where bucket in (1, 2, 3, 4, 5, 6, 7, 8, 9, 10);"), + row(Float.NaN, Double.NaN)); + } + + @Test + public void testInfinity() throws Throwable + { + createTable("create table %s (bucket int primary key, v1 float, v2 double)"); + for (boolean positive: new boolean[] { true, false}) + { + final float FLOAT_INFINITY = positive ? Float.POSITIVE_INFINITY : Float.NEGATIVE_INFINITY; + final double DOUBLE_INFINITY = positive ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; + + for (int i = 1; i <= 10; i++) + if (i != 5) + execute("insert into %s (bucket, v1, v2) values (?, ?, ?)", i, (float) i, (double) i); + + execute("insert into %s (bucket, v1, v2) values (?, ?, ?)", 5, FLOAT_INFINITY, DOUBLE_INFINITY); + + assertRows(execute("select avg(v1), avg(v2) from %s where bucket in (1, 2, 3, 4, 5, 6, 7, 8, 9, 10);"), + row(FLOAT_INFINITY, DOUBLE_INFINITY)); + execute("truncate %s"); + } + } + + @Test + public void testSumPrecision() throws Throwable + { + createTable("create table %s (bucket int primary key, v1 float, v2 double, v3 decimal)"); + + for (int i = 1; i <= 17; i++) + execute("insert into %s (bucket, v1, v2, v3) values (?, ?, ?, ?)", i, (float) (i / 10.0), i / 10.0, BigDecimal.valueOf(i / 10.0)); + } }