Repository: cassandra
Updated Branches:
  refs/heads/trunk 6b7db8a53 -> d0e203645


Support counter-columns for native aggregates (sum,avg,max,min)

patch by Robert Stupp; reviewed by Benjamin Lerer for CASSANDRA-9977


Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo
Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/e4eabd90
Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/e4eabd90
Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/e4eabd90

Branch: refs/heads/trunk
Commit: e4eabd901522742550074d5c3c5f25b642037891
Parents: 4d0f140
Author: Robert Stupp <[email protected]>
Authored: Mon Jan 4 16:34:27 2016 +0100
Committer: Robert Stupp <[email protected]>
Committed: Mon Jan 4 16:34:27 2016 +0100

----------------------------------------------------------------------
 .../cassandra/cql3/functions/AggregateFcts.java | 230 ++++++++++++++-----
 .../cql3/validation/entities/UFTest.java        |  26 +++
 .../validation/operations/AggregationTest.java  |  41 ++++
 3 files changed, 239 insertions(+), 58 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/cassandra/blob/e4eabd90/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java 
b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java
index 7b5bdb8..a1b67e1 100644
--- a/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java
+++ b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java
@@ -47,6 +47,7 @@ public abstract class AggregateFcts
         functions.add(sumFunctionForDouble);
         functions.add(sumFunctionForDecimal);
         functions.add(sumFunctionForVarint);
+        functions.add(sumFunctionForCounter);
 
         // avg for primitives
         functions.add(avgFunctionForByte);
@@ -57,6 +58,7 @@ public abstract class AggregateFcts
         functions.add(avgFunctionForDouble);
         functions.add(avgFunctionForDecimal);
         functions.add(avgFunctionForVarint);
+        functions.add(avgFunctionForCounter);
 
         // count, max, and min for all standard types
         for (CQL3Type type : CQL3Type.Native.values())
@@ -64,8 +66,16 @@ public abstract class AggregateFcts
             if (type != CQL3Type.Native.VARCHAR) // varchar and text both 
mapping to UTF8Type
             {
                 functions.add(AggregateFcts.makeCountFunction(type.getType()));
-                functions.add(AggregateFcts.makeMaxFunction(type.getType()));
-                functions.add(AggregateFcts.makeMinFunction(type.getType()));
+                if (type != CQL3Type.Native.COUNTER)
+                {
+                    
functions.add(AggregateFcts.makeMaxFunction(type.getType()));
+                    
functions.add(AggregateFcts.makeMinFunction(type.getType()));
+                }
+                else
+                {
+                    functions.add(AggregateFcts.maxFunctionForCounter);
+                    functions.add(AggregateFcts.minFunctionForCounter);
+                }
             }
         }
 
@@ -515,31 +525,7 @@ public abstract class AggregateFcts
             {
                 public Aggregate newAggregate()
                 {
-                    return new Aggregate()
-                    {
-                        private long sum;
-
-                        public void reset()
-                        {
-                            sum = 0;
-                        }
-
-                        public ByteBuffer compute(int protocolVersion)
-                        {
-                            return ((LongType) returnType()).decompose(sum);
-                        }
-
-                        public void addInput(int protocolVersion, 
List<ByteBuffer> values)
-                        {
-                            ByteBuffer value = values.get(0);
-
-                            if (value == null)
-                                return;
-
-                            Number number = ((Number) 
argTypes().get(0).compose(value));
-                            sum += number.longValue();
-                        }
-                    };
+                    return new LongSumAggregate();
                 }
             };
 
@@ -551,37 +537,7 @@ public abstract class AggregateFcts
             {
                 public Aggregate newAggregate()
                 {
-                    return new Aggregate()
-                    {
-                        private long sum;
-
-                        private int count;
-
-                        public void reset()
-                        {
-                            count = 0;
-                            sum = 0;
-                        }
-
-                        public ByteBuffer compute(int protocolVersion)
-                        {
-                            long avg = count == 0 ? 0 : sum / count;
-
-                            return ((LongType) returnType()).decompose(avg);
-                        }
-
-                        public void addInput(int protocolVersion, 
List<ByteBuffer> values)
-                        {
-                            ByteBuffer value = values.get(0);
-
-                            if (value == null)
-                                return;
-
-                            count++;
-                            Number number = ((Number) 
argTypes().get(0).compose(value));
-                            sum += number.longValue();
-                        }
-                    };
+                    return new LongAvgAggregate();
                 }
             };
 
@@ -742,6 +698,106 @@ public abstract class AggregateFcts
             };
 
     /**
+     * The SUM function for counter column values.
+     */
+    public static final AggregateFunction sumFunctionForCounter =
+    new NativeAggregateFunction("sum", CounterColumnType.instance, 
CounterColumnType.instance)
+    {
+        public Aggregate newAggregate()
+        {
+            return new LongSumAggregate();
+        }
+    };
+
+    /**
+     * AVG function for counter column values.
+     */
+    public static final AggregateFunction avgFunctionForCounter =
+    new NativeAggregateFunction("avg", CounterColumnType.instance, 
CounterColumnType.instance)
+    {
+        public Aggregate newAggregate()
+        {
+            return new LongAvgAggregate();
+        }
+    };
+
+    /**
+     * The MIN function for counter column values.
+     */
+    public static final AggregateFunction minFunctionForCounter =
+    new NativeAggregateFunction("min", CounterColumnType.instance, 
CounterColumnType.instance)
+    {
+        public Aggregate newAggregate()
+        {
+            return new Aggregate()
+            {
+                private Long min;
+
+                public void reset()
+                {
+                    min = null;
+                }
+
+                public ByteBuffer compute(int protocolVersion)
+                {
+                    return min != null ? LongType.instance.decompose(min) : 
null;
+                }
+
+                public void addInput(int protocolVersion, List<ByteBuffer> 
values)
+                {
+                    ByteBuffer value = values.get(0);
+
+                    if (value == null)
+                        return;
+
+                    long lval = LongType.instance.compose(value);
+
+                    if (min == null || lval < min)
+                        min = lval;
+                }
+            };
+        }
+    };
+
+    /**
+     * AVG function for counter column values.
+     */
+    public static final AggregateFunction maxFunctionForCounter =
+    new NativeAggregateFunction("max", CounterColumnType.instance, 
CounterColumnType.instance)
+    {
+        public Aggregate newAggregate()
+        {
+            return new Aggregate()
+            {
+                private Long max;
+
+                public void reset()
+                {
+                    max = null;
+                }
+
+                public ByteBuffer compute(int protocolVersion)
+                {
+                    return max != null ? LongType.instance.decompose(max) : 
null;
+                }
+
+                public void addInput(int protocolVersion, List<ByteBuffer> 
values)
+                {
+                    ByteBuffer value = values.get(0);
+
+                    if (value == null)
+                        return;
+
+                    long lval = LongType.instance.compose(value);
+
+                    if (max == null || lval > max)
+                        max = lval;
+                }
+            };
+        }
+    };
+
+    /**
      * Creates a MAX function for the specified type.
      *
      * @param inputType the function input and output type
@@ -862,4 +918,62 @@ public abstract class AggregateFcts
             }
         };
     }
+
+    private static class LongSumAggregate implements 
AggregateFunction.Aggregate
+    {
+        private long sum;
+
+        public void reset()
+        {
+            sum = 0;
+        }
+
+        public ByteBuffer compute(int protocolVersion)
+        {
+            return LongType.instance.decompose(sum);
+        }
+
+        public void addInput(int protocolVersion, List<ByteBuffer> values)
+        {
+            ByteBuffer value = values.get(0);
+
+            if (value == null)
+                return;
+
+            Number number = LongType.instance.compose(value);
+            sum += number.longValue();
+        }
+    }
+
+    private static class LongAvgAggregate implements 
AggregateFunction.Aggregate
+    {
+        private long sum;
+
+        private int count;
+
+        public void reset()
+        {
+            count = 0;
+            sum = 0;
+        }
+
+        public ByteBuffer compute(int protocolVersion)
+        {
+            long avg = count == 0 ? 0 : sum / count;
+
+            return LongType.instance.decompose(avg);
+        }
+
+        public void addInput(int protocolVersion, List<ByteBuffer> values)
+        {
+            ByteBuffer value = values.get(0);
+
+            if (value == null)
+                return;
+
+            count++;
+            Number number = LongType.instance.compose(value);
+            sum += number.longValue();
+        }
+    }
 }

http://git-wip-us.apache.org/repos/asf/cassandra/blob/e4eabd90/test/unit/org/apache/cassandra/cql3/validation/entities/UFTest.java
----------------------------------------------------------------------
diff --git 
a/test/unit/org/apache/cassandra/cql3/validation/entities/UFTest.java 
b/test/unit/org/apache/cassandra/cql3/validation/entities/UFTest.java
index 467a082..704a6c9 100644
--- a/test/unit/org/apache/cassandra/cql3/validation/entities/UFTest.java
+++ b/test/unit/org/apache/cassandra/cql3/validation/entities/UFTest.java
@@ -707,6 +707,32 @@ public class UFTest extends CQLTester
     }
 
     @Test
+    public void testJavaFunctionCounter() throws Throwable
+    {
+        createTable("CREATE TABLE %s (key int primary key, val counter)");
+
+        String fName = createFunction(KEYSPACE, "counter",
+                                      "CREATE OR REPLACE FUNCTION %s(val 
counter) " +
+                                      "CALLED ON NULL INPUT " +
+                                      "RETURNS bigint " +
+                                      "LANGUAGE JAVA " +
+                                      "AS 'return val + 1;';");
+
+        execute("UPDATE %s SET val = val + 1 WHERE key = 1");
+        assertRows(execute("SELECT key, val, " + fName + "(val) FROM %s"),
+                   row(1, 1L, 2L));
+        execute("UPDATE %s SET val = val + 1 WHERE key = 1");
+        assertRows(execute("SELECT key, val, " + fName + "(val) FROM %s"),
+                   row(1, 2L, 3L));
+        execute("UPDATE %s SET val = val + 2 WHERE key = 1");
+        assertRows(execute("SELECT key, val, " + fName + "(val) FROM %s"),
+                   row(1, 4L, 5L));
+        execute("UPDATE %s SET val = val - 2 WHERE key = 1");
+        assertRows(execute("SELECT key, val, " + fName + "(val) FROM %s"),
+                   row(1, 2L, 3L));
+    }
+
+    @Test
     public void testFunctionInTargetKeyspace() throws Throwable
     {
         createTable("CREATE TABLE %s (key int primary key, val double)");

http://git-wip-us.apache.org/repos/asf/cassandra/blob/e4eabd90/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java
----------------------------------------------------------------------
diff --git 
a/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java
 
b/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java
index 2713895..221f48e 100644
--- 
a/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java
+++ 
b/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java
@@ -174,6 +174,47 @@ public class AggregationTest extends CQLTester
     }
 
     @Test
+    public void testAggregateOnCounters() throws Throwable
+    {
+        createTable("CREATE TABLE %s (a int, b counter, primary key (a))");
+
+        // Test with empty table
+        assertColumnNames(execute("SELECT count(b), max(b) as max, b FROM %s"),
+                          "system.count(b)", "max", "b");
+        assertRows(execute("SELECT count(b), max(b) as max, b FROM %s"),
+                   row(0L, null, null));
+
+        execute("UPDATE %s SET b = b + 1 WHERE a = 1");
+        execute("UPDATE %s SET b = b + 1 WHERE a = 1");
+
+        assertRows(execute("SELECT count(b), max(b) as max, min(b) as min, 
avg(b) as avg, sum(b) as sum FROM %s"),
+                   row(1L, 2L, 2L, 2L, 2L));
+        flush();
+        assertRows(execute("SELECT count(b), max(b) as max, min(b) as min, 
avg(b) as avg, sum(b) as sum FROM %s"),
+                   row(1L, 2L, 2L, 2L, 2L));
+
+        execute("UPDATE %s SET b = b + 2 WHERE a = 1");
+
+        assertRows(execute("SELECT count(b), max(b) as max, min(b) as min, 
avg(b) as avg, sum(b) as sum FROM %s"),
+                   row(1L, 4L, 4L, 4L, 4L));
+
+        execute("UPDATE %s SET b = b - 2 WHERE a = 1");
+
+        assertRows(execute("SELECT count(b), max(b) as max, min(b) as min, 
avg(b) as avg, sum(b) as sum FROM %s"),
+                   row(1L, 2L, 2L, 2L, 2L));
+        flush();
+        assertRows(execute("SELECT count(b), max(b) as max, min(b) as min, 
avg(b) as avg, sum(b) as sum FROM %s"),
+                   row(1L, 2L, 2L, 2L, 2L));
+
+        execute("UPDATE %s SET b = b + 1 WHERE a = 2");
+        execute("UPDATE %s SET b = b + 1 WHERE a = 2");
+        execute("UPDATE %s SET b = b + 2 WHERE a = 2");
+
+        assertRows(execute("SELECT count(b), max(b) as max, min(b) as min, 
avg(b) as avg, sum(b) as sum FROM %s"),
+                   row(2L, 4L, 2L, 3L, 6L));
+    }
+
+    @Test
     public void testAggregateWithUdtFields() throws Throwable
     {
         String myType = createType("CREATE TYPE %s (x int)");

Reply via email to