ArithmeticException in avgFunctionForDecimal patch by Robert Stupp; reviewed by Tyler Hobbs for CASSANDRA-11485
Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/ce445991 Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/ce445991 Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/ce445991 Branch: refs/heads/trunk Commit: ce445991fab05d2ba404f6289796664dd581662a Parents: 9f557ff Author: Robert Stupp <sn...@snazy.de> Authored: Fri Apr 15 20:31:49 2016 +0200 Committer: Robert Stupp <sn...@snazy.de> Committed: Fri Apr 15 20:31:49 2016 +0200 ---------------------------------------------------------------------- CHANGES.txt | 1 + .../cassandra/cql3/functions/AggregateFcts.java | 8 +++++--- .../validation/operations/AggregationTest.java | 19 +++++++++++++++++++ 3 files changed, 25 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/cassandra/blob/ce445991/CHANGES.txt ---------------------------------------------------------------------- diff --git a/CHANGES.txt b/CHANGES.txt index 3b4d473..85660d9 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,4 +1,5 @@ 3.0.6 + * ArithmeticException in avgFunctionForDecimal (CASSANDRA-11485) * Allow only DISTINCT queries with partition keys or static columns restrictions (CASSANDRA-11339) * LogAwareFileLister should only use OLD sstable files in current folder to determine disk consistency (CASSANDRA-11470) * Notify indexers of expired rows during compaction (CASSANDRA-11329) http://git-wip-us.apache.org/repos/asf/cassandra/blob/ce445991/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java index a1b67e1..79a08cd 100644 --- a/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java +++ b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java @@ -19,6 +19,7 @@ package org.apache.cassandra.cql3.functions; import java.math.BigDecimal; import java.math.BigInteger; +import java.math.RoundingMode; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collection; @@ -184,9 +185,9 @@ public abstract class AggregateFcts public ByteBuffer compute(int protocolVersion) { if (count == 0) - return ((DecimalType) returnType()).decompose(BigDecimal.ZERO); + return DecimalType.instance.decompose(BigDecimal.ZERO); - return ((DecimalType) returnType()).decompose(sum.divide(BigDecimal.valueOf(count))); + return DecimalType.instance.decompose(sum.divide(BigDecimal.valueOf(count), BigDecimal.ROUND_HALF_EVEN)); } public void addInput(int protocolVersion, List<ByteBuffer> values) @@ -197,13 +198,14 @@ public abstract class AggregateFcts return; count++; - BigDecimal number = ((BigDecimal) argTypes().get(0).compose(value)); + BigDecimal number = DecimalType.instance.compose(value); sum = sum.add(number); } }; } }; + /** * The SUM function for varint values. */ http://git-wip-us.apache.org/repos/asf/cassandra/blob/ce445991/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java b/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java index e5420c9..411d5ee 100644 --- a/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java +++ b/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java @@ -18,6 +18,8 @@ package org.apache.cassandra.cql3.validation.operations; import java.math.BigDecimal; +import java.math.MathContext; +import java.math.RoundingMode; import java.nio.ByteBuffer; import java.text.SimpleDateFormat; import java.util.Arrays; @@ -1950,4 +1952,21 @@ public class AggregationTest extends CQLTester } } } + + @Test + public void testArithmeticCorrectness() throws Throwable + { + createTable("create table %s (bucket int primary key, val decimal)"); + execute("insert into %s (bucket, val) values (1, 0.25)"); + execute("insert into %s (bucket, val) values (2, 0.25)"); + execute("insert into %s (bucket, val) values (3, 0.5);"); + + BigDecimal a = new BigDecimal("0.25"); + a = a.add(new BigDecimal("0.25")); + a = a.add(new BigDecimal("0.5")); + a = a.divide(new BigDecimal(3), RoundingMode.HALF_EVEN); + + assertRows(execute("select avg(val) from %s where bucket in (1, 2, 3);"), + row(a)); + } }