Repository: cassandra Updated Branches: refs/heads/cassandra-2.2 fc202a756 -> 01f3d0a15 refs/heads/trunk 056115fff -> 7dff15011
sum() and avg() functions missing for smallint and tinyint types patch by Robert Stupp; reviewed by Aleksey Yeschenko for CASSANDRA-9671 Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/01f3d0a1 Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/01f3d0a1 Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/01f3d0a1 Branch: refs/heads/cassandra-2.2 Commit: 01f3d0a15476ccada7cefeb3c4fbbc157404fc8b Parents: fc202a7 Author: Robert Stupp <[email protected]> Authored: Sun Jul 12 21:19:20 2015 +0200 Committer: Robert Stupp <[email protected]> Committed: Sun Jul 12 21:19:20 2015 +0200 ---------------------------------------------------------------------- CHANGES.txt | 1 + .../cassandra/cql3/functions/AggregateFcts.java | 158 +++++++++++++++++++ .../cassandra/cql3/functions/Functions.java | 4 + .../validation/operations/AggregationTest.java | 45 ++++-- 4 files changed, 195 insertions(+), 13 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/cassandra/blob/01f3d0a1/CHANGES.txt ---------------------------------------------------------------------- diff --git a/CHANGES.txt b/CHANGES.txt index 15796c4..03458b2 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,4 +1,5 @@ 2.2.0-rc3 + * sum() and avg() functions missing for smallint and tinyint types (CASSANDRA-9671) * Revert CASSANDRA-9542 (allow native functions in UDA) (CASSANDRA-9771) Merged from 2.1: * Fix clientutil jar and tests (CASSANDRA-9760) http://git-wip-us.apache.org/repos/asf/cassandra/blob/01f3d0a1/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java index 865dfbf..1b22da6 100644 --- a/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java +++ b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java @@ -23,12 +23,14 @@ import java.nio.ByteBuffer; import java.util.List; import org.apache.cassandra.db.marshal.AbstractType; +import org.apache.cassandra.db.marshal.ByteType; import org.apache.cassandra.db.marshal.DecimalType; import org.apache.cassandra.db.marshal.DoubleType; import org.apache.cassandra.db.marshal.FloatType; import org.apache.cassandra.db.marshal.Int32Type; import org.apache.cassandra.db.marshal.IntegerType; import org.apache.cassandra.db.marshal.LongType; +import org.apache.cassandra.db.marshal.ShortType; /** * Factory methods for aggregate functions. @@ -228,6 +230,162 @@ public abstract class AggregateFcts /** * The SUM function for int32 values. */ + public static final AggregateFunction sumFunctionForByte = + new NativeAggregateFunction("sum", ByteType.instance, ByteType.instance) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private byte sum; + + public void reset() + { + sum = 0; + } + + public ByteBuffer compute(int protocolVersion) + { + return ((ByteType) returnType()).decompose(sum); + } + + public void addInput(int protocolVersion, List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + Number number = ((Number) argTypes().get(0).compose(value)); + sum += number.byteValue(); + } + }; + } + }; + + /** + * AVG function for int32 values. + */ + public static final AggregateFunction avgFunctionForByte = + new NativeAggregateFunction("avg", ByteType.instance, ByteType.instance) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private byte sum; + + private int count; + + public void reset() + { + count = 0; + sum = 0; + } + + public ByteBuffer compute(int protocolVersion) + { + int avg = count == 0 ? 0 : sum / count; + + return ((ByteType) returnType()).decompose((byte) avg); + } + + public void addInput(int protocolVersion, List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + count++; + Number number = ((Number) argTypes().get(0).compose(value)); + sum += number.byteValue(); + } + }; + } + }; + + /** + * The SUM function for int32 values. + */ + public static final AggregateFunction sumFunctionForShort = + new NativeAggregateFunction("sum", ShortType.instance, ShortType.instance) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private short sum; + + public void reset() + { + sum = 0; + } + + public ByteBuffer compute(int protocolVersion) + { + return ((ShortType) returnType()).decompose(sum); + } + + public void addInput(int protocolVersion, List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + Number number = ((Number) argTypes().get(0).compose(value)); + sum += number.shortValue(); + } + }; + } + }; + + /** + * AVG function for int32 values. + */ + public static final AggregateFunction avgFunctionForShort = + new NativeAggregateFunction("avg", ShortType.instance, ShortType.instance) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private short sum; + + private int count; + + public void reset() + { + count = 0; + sum = 0; + } + + public ByteBuffer compute(int protocolVersion) + { + int avg = count == 0 ? 0 : sum / count; + + return ((ShortType) returnType()).decompose((short) avg); + } + + public void addInput(int protocolVersion, List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + count++; + Number number = ((Number) argTypes().get(0).compose(value)); + sum += number.shortValue(); + } + }; + } + }; + + /** + * The SUM function for int32 values. + */ public static final AggregateFunction sumFunctionForInt32 = new NativeAggregateFunction("sum", Int32Type.instance, Int32Type.instance) { http://git-wip-us.apache.org/repos/asf/cassandra/blob/01f3d0a1/src/java/org/apache/cassandra/cql3/functions/Functions.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/Functions.java b/src/java/org/apache/cassandra/cql3/functions/Functions.java index 85f2817..e31fc9f 100644 --- a/src/java/org/apache/cassandra/cql3/functions/Functions.java +++ b/src/java/org/apache/cassandra/cql3/functions/Functions.java @@ -83,12 +83,16 @@ public abstract class Functions declare(AggregateFcts.makeMinFunction(type.getType())); } } + declare(AggregateFcts.sumFunctionForByte); + declare(AggregateFcts.sumFunctionForShort); declare(AggregateFcts.sumFunctionForInt32); declare(AggregateFcts.sumFunctionForLong); declare(AggregateFcts.sumFunctionForFloat); declare(AggregateFcts.sumFunctionForDouble); declare(AggregateFcts.sumFunctionForDecimal); declare(AggregateFcts.sumFunctionForVarint); + declare(AggregateFcts.avgFunctionForByte); + declare(AggregateFcts.avgFunctionForShort); declare(AggregateFcts.avgFunctionForInt32); declare(AggregateFcts.avgFunctionForLong); declare(AggregateFcts.avgFunctionForFloat); http://git-wip-us.apache.org/repos/asf/cassandra/blob/01f3d0a1/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java b/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java index 7455dbc..62461b8 100644 --- a/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java +++ b/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java @@ -42,27 +42,46 @@ public class AggregationTest extends CQLTester @Test public void testFunctions() throws Throwable { - createTable("CREATE TABLE %s (a int, b int, c double, d decimal, primary key (a, b))"); + createTable("CREATE TABLE %s (a int, b int, c double, d decimal, e smallint, f tinyint, primary key (a, b))"); // Test with empty table assertColumnNames(execute("SELECT COUNT(*) FROM %s"), "count"); assertRows(execute("SELECT COUNT(*) FROM %s"), row(0L)); - assertColumnNames(execute("SELECT max(b), min(b), sum(b), avg(b) , max(c), sum(c), avg(c), sum(d), avg(d) FROM %s"), - "system.max(b)", "system.min(b)", "system.sum(b)", "system.avg(b)", "system.max(c)", "system.sum(c)", "system.avg(c)", "system.sum(d)", "system.avg(d)"); - assertRows(execute("SELECT max(b), min(b), sum(b), avg(b) , max(c), sum(c), avg(c), sum(d), avg(d) FROM %s"), - row(null, null, 0, 0, null, 0.0, 0.0, new BigDecimal("0"), new BigDecimal("0"))); - - execute("INSERT INTO %s (a, b, c, d) VALUES (1, 1, 11.5, 11.5)"); - execute("INSERT INTO %s (a, b, c, d) VALUES (1, 2, 9.5, 1.5)"); - execute("INSERT INTO %s (a, b, c, d) VALUES (1, 3, 9.0, 2.0)"); - - assertRows(execute("SELECT max(b), min(b), sum(b), avg(b) , max(c), sum(c), avg(c), sum(d), avg(d) FROM %s"), - row(3, 1, 6, 2, 11.5, 30.0, 10.0, new BigDecimal("15.0"), new BigDecimal("5.0"))); + assertColumnNames(execute("SELECT max(b), min(b), sum(b), avg(b)," + + "max(c), sum(c), avg(c)," + + "sum(d), avg(d)," + + "max(e), min(e), sum(e), avg(e)," + + "max(f), min(f), sum(f), avg(f) FROM %s"), + "system.max(b)", "system.min(b)", "system.sum(b)", "system.avg(b)", + "system.max(c)", "system.sum(c)", "system.avg(c)", + "system.sum(d)", "system.avg(d)", + "system.max(e)", "system.min(e)", "system.sum(e)", "system.avg(e)", + "system.max(f)", "system.min(f)", "system.sum(f)", "system.avg(f)"); + assertRows(execute("SELECT max(b), min(b), sum(b), avg(b)," + + "max(c), sum(c), avg(c)," + + "sum(d), avg(d)," + + "max(e), min(e), sum(e), avg(e)," + + "max(f), min(f), sum(f), avg(f) FROM %s"), + row(null, null, 0, 0, null, 0.0, 0.0, new BigDecimal("0"), new BigDecimal("0"), + null, null, (short)0, (short)0, + null, null, (byte)0, (byte)0)); + + execute("INSERT INTO %s (a, b, c, d, e, f) VALUES (1, 1, 11.5, 11.5, 1, 1)"); + execute("INSERT INTO %s (a, b, c, d, e, f) VALUES (1, 2, 9.5, 1.5, 2, 2)"); + execute("INSERT INTO %s (a, b, c, d, e, f) VALUES (1, 3, 9.0, 2.0, 3, 3)"); + + assertRows(execute("SELECT max(b), min(b), sum(b), avg(b) , max(c), sum(c), avg(c), sum(d), avg(d)," + + "max(e), min(e), sum(e), avg(e)," + + "max(f), min(f), sum(f), avg(f)" + + " FROM %s"), + row(3, 1, 6, 2, 11.5, 30.0, 10.0, new BigDecimal("15.0"), new BigDecimal("5.0"), + (short)3, (short)1, (short)6, (short)2, + (byte)3, (byte)1, (byte)6, (byte)2)); execute("INSERT INTO %s (a, b, d) VALUES (1, 5, 1.0)"); assertRows(execute("SELECT COUNT(*) FROM %s"), row(4L)); assertRows(execute("SELECT COUNT(1) FROM %s"), row(4L)); - assertRows(execute("SELECT COUNT(b), count(c) FROM %s"), row(4L, 3L)); + assertRows(execute("SELECT COUNT(b), count(c), count(e), count(f) FROM %s"), row(4L, 3L, 3L, 3L)); } @Test
