Repository: cassandra Updated Branches: refs/heads/trunk 6b7db8a53 -> d0e203645
Support counter-columns for native aggregates (sum,avg,max,min) patch by Robert Stupp; reviewed by Benjamin Lerer for CASSANDRA-9977 Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/e4eabd90 Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/e4eabd90 Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/e4eabd90 Branch: refs/heads/trunk Commit: e4eabd901522742550074d5c3c5f25b642037891 Parents: 4d0f140 Author: Robert Stupp <[email protected]> Authored: Mon Jan 4 16:34:27 2016 +0100 Committer: Robert Stupp <[email protected]> Committed: Mon Jan 4 16:34:27 2016 +0100 ---------------------------------------------------------------------- .../cassandra/cql3/functions/AggregateFcts.java | 230 ++++++++++++++----- .../cql3/validation/entities/UFTest.java | 26 +++ .../validation/operations/AggregationTest.java | 41 ++++ 3 files changed, 239 insertions(+), 58 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/cassandra/blob/e4eabd90/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java index 7b5bdb8..a1b67e1 100644 --- a/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java +++ b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java @@ -47,6 +47,7 @@ public abstract class AggregateFcts functions.add(sumFunctionForDouble); functions.add(sumFunctionForDecimal); functions.add(sumFunctionForVarint); + functions.add(sumFunctionForCounter); // avg for primitives functions.add(avgFunctionForByte); @@ -57,6 +58,7 @@ public abstract class AggregateFcts functions.add(avgFunctionForDouble); functions.add(avgFunctionForDecimal); functions.add(avgFunctionForVarint); + functions.add(avgFunctionForCounter); // count, max, and min for all standard types for (CQL3Type type : CQL3Type.Native.values()) @@ -64,8 +66,16 @@ public abstract class AggregateFcts if (type != CQL3Type.Native.VARCHAR) // varchar and text both mapping to UTF8Type { functions.add(AggregateFcts.makeCountFunction(type.getType())); - functions.add(AggregateFcts.makeMaxFunction(type.getType())); - functions.add(AggregateFcts.makeMinFunction(type.getType())); + if (type != CQL3Type.Native.COUNTER) + { + functions.add(AggregateFcts.makeMaxFunction(type.getType())); + functions.add(AggregateFcts.makeMinFunction(type.getType())); + } + else + { + functions.add(AggregateFcts.maxFunctionForCounter); + functions.add(AggregateFcts.minFunctionForCounter); + } } } @@ -515,31 +525,7 @@ public abstract class AggregateFcts { public Aggregate newAggregate() { - return new Aggregate() - { - private long sum; - - public void reset() - { - sum = 0; - } - - public ByteBuffer compute(int protocolVersion) - { - return ((LongType) returnType()).decompose(sum); - } - - public void addInput(int protocolVersion, List<ByteBuffer> values) - { - ByteBuffer value = values.get(0); - - if (value == null) - return; - - Number number = ((Number) argTypes().get(0).compose(value)); - sum += number.longValue(); - } - }; + return new LongSumAggregate(); } }; @@ -551,37 +537,7 @@ public abstract class AggregateFcts { public Aggregate newAggregate() { - return new Aggregate() - { - private long sum; - - private int count; - - public void reset() - { - count = 0; - sum = 0; - } - - public ByteBuffer compute(int protocolVersion) - { - long avg = count == 0 ? 0 : sum / count; - - return ((LongType) returnType()).decompose(avg); - } - - public void addInput(int protocolVersion, List<ByteBuffer> values) - { - ByteBuffer value = values.get(0); - - if (value == null) - return; - - count++; - Number number = ((Number) argTypes().get(0).compose(value)); - sum += number.longValue(); - } - }; + return new LongAvgAggregate(); } }; @@ -742,6 +698,106 @@ public abstract class AggregateFcts }; /** + * The SUM function for counter column values. + */ + public static final AggregateFunction sumFunctionForCounter = + new NativeAggregateFunction("sum", CounterColumnType.instance, CounterColumnType.instance) + { + public Aggregate newAggregate() + { + return new LongSumAggregate(); + } + }; + + /** + * AVG function for counter column values. + */ + public static final AggregateFunction avgFunctionForCounter = + new NativeAggregateFunction("avg", CounterColumnType.instance, CounterColumnType.instance) + { + public Aggregate newAggregate() + { + return new LongAvgAggregate(); + } + }; + + /** + * The MIN function for counter column values. + */ + public static final AggregateFunction minFunctionForCounter = + new NativeAggregateFunction("min", CounterColumnType.instance, CounterColumnType.instance) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private Long min; + + public void reset() + { + min = null; + } + + public ByteBuffer compute(int protocolVersion) + { + return min != null ? LongType.instance.decompose(min) : null; + } + + public void addInput(int protocolVersion, List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + long lval = LongType.instance.compose(value); + + if (min == null || lval < min) + min = lval; + } + }; + } + }; + + /** + * AVG function for counter column values. + */ + public static final AggregateFunction maxFunctionForCounter = + new NativeAggregateFunction("max", CounterColumnType.instance, CounterColumnType.instance) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private Long max; + + public void reset() + { + max = null; + } + + public ByteBuffer compute(int protocolVersion) + { + return max != null ? LongType.instance.decompose(max) : null; + } + + public void addInput(int protocolVersion, List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + long lval = LongType.instance.compose(value); + + if (max == null || lval > max) + max = lval; + } + }; + } + }; + + /** * Creates a MAX function for the specified type. * * @param inputType the function input and output type @@ -862,4 +918,62 @@ public abstract class AggregateFcts } }; } + + private static class LongSumAggregate implements AggregateFunction.Aggregate + { + private long sum; + + public void reset() + { + sum = 0; + } + + public ByteBuffer compute(int protocolVersion) + { + return LongType.instance.decompose(sum); + } + + public void addInput(int protocolVersion, List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + Number number = LongType.instance.compose(value); + sum += number.longValue(); + } + } + + private static class LongAvgAggregate implements AggregateFunction.Aggregate + { + private long sum; + + private int count; + + public void reset() + { + count = 0; + sum = 0; + } + + public ByteBuffer compute(int protocolVersion) + { + long avg = count == 0 ? 0 : sum / count; + + return LongType.instance.decompose(avg); + } + + public void addInput(int protocolVersion, List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + count++; + Number number = LongType.instance.compose(value); + sum += number.longValue(); + } + } } http://git-wip-us.apache.org/repos/asf/cassandra/blob/e4eabd90/test/unit/org/apache/cassandra/cql3/validation/entities/UFTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/cql3/validation/entities/UFTest.java b/test/unit/org/apache/cassandra/cql3/validation/entities/UFTest.java index 467a082..704a6c9 100644 --- a/test/unit/org/apache/cassandra/cql3/validation/entities/UFTest.java +++ b/test/unit/org/apache/cassandra/cql3/validation/entities/UFTest.java @@ -707,6 +707,32 @@ public class UFTest extends CQLTester } @Test + public void testJavaFunctionCounter() throws Throwable + { + createTable("CREATE TABLE %s (key int primary key, val counter)"); + + String fName = createFunction(KEYSPACE, "counter", + "CREATE OR REPLACE FUNCTION %s(val counter) " + + "CALLED ON NULL INPUT " + + "RETURNS bigint " + + "LANGUAGE JAVA " + + "AS 'return val + 1;';"); + + execute("UPDATE %s SET val = val + 1 WHERE key = 1"); + assertRows(execute("SELECT key, val, " + fName + "(val) FROM %s"), + row(1, 1L, 2L)); + execute("UPDATE %s SET val = val + 1 WHERE key = 1"); + assertRows(execute("SELECT key, val, " + fName + "(val) FROM %s"), + row(1, 2L, 3L)); + execute("UPDATE %s SET val = val + 2 WHERE key = 1"); + assertRows(execute("SELECT key, val, " + fName + "(val) FROM %s"), + row(1, 4L, 5L)); + execute("UPDATE %s SET val = val - 2 WHERE key = 1"); + assertRows(execute("SELECT key, val, " + fName + "(val) FROM %s"), + row(1, 2L, 3L)); + } + + @Test public void testFunctionInTargetKeyspace() throws Throwable { createTable("CREATE TABLE %s (key int primary key, val double)"); http://git-wip-us.apache.org/repos/asf/cassandra/blob/e4eabd90/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java b/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java index 2713895..221f48e 100644 --- a/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java +++ b/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java @@ -174,6 +174,47 @@ public class AggregationTest extends CQLTester } @Test + public void testAggregateOnCounters() throws Throwable + { + createTable("CREATE TABLE %s (a int, b counter, primary key (a))"); + + // Test with empty table + assertColumnNames(execute("SELECT count(b), max(b) as max, b FROM %s"), + "system.count(b)", "max", "b"); + assertRows(execute("SELECT count(b), max(b) as max, b FROM %s"), + row(0L, null, null)); + + execute("UPDATE %s SET b = b + 1 WHERE a = 1"); + execute("UPDATE %s SET b = b + 1 WHERE a = 1"); + + assertRows(execute("SELECT count(b), max(b) as max, min(b) as min, avg(b) as avg, sum(b) as sum FROM %s"), + row(1L, 2L, 2L, 2L, 2L)); + flush(); + assertRows(execute("SELECT count(b), max(b) as max, min(b) as min, avg(b) as avg, sum(b) as sum FROM %s"), + row(1L, 2L, 2L, 2L, 2L)); + + execute("UPDATE %s SET b = b + 2 WHERE a = 1"); + + assertRows(execute("SELECT count(b), max(b) as max, min(b) as min, avg(b) as avg, sum(b) as sum FROM %s"), + row(1L, 4L, 4L, 4L, 4L)); + + execute("UPDATE %s SET b = b - 2 WHERE a = 1"); + + assertRows(execute("SELECT count(b), max(b) as max, min(b) as min, avg(b) as avg, sum(b) as sum FROM %s"), + row(1L, 2L, 2L, 2L, 2L)); + flush(); + assertRows(execute("SELECT count(b), max(b) as max, min(b) as min, avg(b) as avg, sum(b) as sum FROM %s"), + row(1L, 2L, 2L, 2L, 2L)); + + execute("UPDATE %s SET b = b + 1 WHERE a = 2"); + execute("UPDATE %s SET b = b + 1 WHERE a = 2"); + execute("UPDATE %s SET b = b + 2 WHERE a = 2"); + + assertRows(execute("SELECT count(b), max(b) as max, min(b) as min, avg(b) as avg, sum(b) as sum FROM %s"), + row(2L, 4L, 2L, 3L, 6L)); + } + + @Test public void testAggregateWithUdtFields() throws Throwable { String myType = createType("CREATE TYPE %s (x int)");
