Repository: cassandra Updated Branches: refs/heads/trunk f72668552 -> 0cad81aeb
Add support for aggregation functions patch by blerer; reviewed by slebresne for CASSANDRA-4914 Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/0cad81ae Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/0cad81ae Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/0cad81ae Branch: refs/heads/trunk Commit: 0cad81aeb9ddaf5dad8b2ab9c6ff6955402c9310 Parents: f726685 Author: Benjamin Lerer <[email protected]> Authored: Tue Oct 7 12:18:52 2014 +0200 Committer: Sylvain Lebresne <[email protected]> Committed: Tue Oct 7 12:18:52 2014 +0200 ---------------------------------------------------------------------- CHANGES.txt | 1 + src/java/org/apache/cassandra/cql3/Cql.g | 16 +- .../org/apache/cassandra/cql3/ResultSet.java | 21 - .../cassandra/cql3/functions/AggregateFcts.java | 661 +++++++++++++++++++ .../cql3/functions/AggregateFunction.java | 59 ++ .../cql3/functions/BytesConversionFcts.java | 8 +- .../cassandra/cql3/functions/Function.java | 23 +- .../cassandra/cql3/functions/FunctionCall.java | 25 +- .../cassandra/cql3/functions/Functions.java | 17 + .../cql3/functions/NativeAggregateFunction.java | 36 + .../cql3/functions/NativeScalarFunction.java | 36 + .../cql3/functions/ScalarFunction.java | 38 ++ .../cassandra/cql3/functions/TimeuuidFcts.java | 10 +- .../cassandra/cql3/functions/TokenFct.java | 2 +- .../cassandra/cql3/functions/UDFunction.java | 7 +- .../cassandra/cql3/functions/UuidFcts.java | 2 +- .../cql3/statements/SelectStatement.java | 72 +- .../cassandra/cql3/statements/Selection.java | 301 +++++++-- .../apache/cassandra/cql3/AggregationTest.java | 88 +++ 19 files changed, 1282 insertions(+), 141 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/cassandra/blob/0cad81ae/CHANGES.txt ---------------------------------------------------------------------- diff --git a/CHANGES.txt b/CHANGES.txt index d10747c..132396e 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,4 +1,5 @@ 3.0 + * Support for aggregation functions (CASSANDRA-4914) * Improve query to read paxos table on propose (CASSANDRA-7929) * Remove cassandra-cli (CASSANDRA-7920) * Optimize java source-based UDF invocation (CASSANDRA-7924) http://git-wip-us.apache.org/repos/asf/cassandra/blob/0cad81ae/src/java/org/apache/cassandra/cql3/Cql.g ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/Cql.g b/src/java/org/apache/cassandra/cql3/Cql.g index e4bfd32..2ec9746 100644 --- a/src/java/org/apache/cassandra/cql3/Cql.g +++ b/src/java/org/apache/cassandra/cql3/Cql.g @@ -262,14 +262,12 @@ useStatement returns [UseStatement stmt] selectStatement returns [SelectStatement.RawStatement expr] @init { boolean isDistinct = false; - boolean isCount = false; - ColumnIdentifier countAlias = null; Term.Raw limit = null; Map<ColumnIdentifier, Boolean> orderings = new LinkedHashMap<ColumnIdentifier, Boolean>(); boolean allowFiltering = false; } : K_SELECT ( ( K_DISTINCT { isDistinct = true; } )? sclause=selectClause - | (K_COUNT '(' sclause=selectCountClause ')' { isCount = true; } (K_AS c=cident { countAlias = c; })?) ) + | sclause=selectCountClause ) K_FROM cf=columnFamilyName ( K_WHERE wclause=whereClause )? ( K_ORDER K_BY orderByClause[orderings] ( ',' orderByClause[orderings] )* )? @@ -278,8 +276,6 @@ selectStatement returns [SelectStatement.RawStatement expr] { SelectStatement.Parameters params = new SelectStatement.Parameters(orderings, isDistinct, - isCount, - countAlias, allowFiltering); $expr = new SelectStatement.RawStatement(cf, params, sclause, wclause, limit); } @@ -312,8 +308,13 @@ selectionFunctionArgs returns [List<Selectable> a] ; selectCountClause returns [List<RawSelector> expr] - : '\*' { $expr = Collections.<RawSelector>emptyList();} - | i=INTEGER { if (!i.getText().equals("1")) addRecognitionError("Only COUNT(1) is supported, got COUNT(" + i.getText() + ")"); $expr = Collections.<RawSelector>emptyList();} + @init{ ColumnIdentifier alias = new ColumnIdentifier("count", false); } + : K_COUNT '(' countArgument ')' (K_AS c=cident { alias = c; })? { $expr = new ArrayList<RawSelector>(); $expr.add( new RawSelector(new Selectable.WithFunction(new FunctionName("countRows"), Collections.<Selectable>emptyList()), alias));} + ; + +countArgument + : '\*' + | i=INTEGER { if (!i.getText().equals("1")) addRecognitionError("Only COUNT(1) is supported, got COUNT(" + i.getText() + ")");} ; whereClause returns [List<Relation> clause] @@ -984,6 +985,7 @@ allowedFunctionName returns [String s] : f=IDENT { $s = $f.text; } | u=unreserved_function_keyword { $s = u; } | K_TOKEN { $s = "token"; } + | K_COUNT { $s = "count"; } ; functionArgs returns [List<Term.Raw> a] http://git-wip-us.apache.org/repos/asf/cassandra/blob/0cad81ae/src/java/org/apache/cassandra/cql3/ResultSet.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/ResultSet.java b/src/java/org/apache/cassandra/cql3/ResultSet.java index 30b5c4e..a8a5081 100644 --- a/src/java/org/apache/cassandra/cql3/ResultSet.java +++ b/src/java/org/apache/cassandra/cql3/ResultSet.java @@ -102,27 +102,6 @@ public class ResultSet return new ResultSet(metadata.withPagingState(state), rows); } - public ResultSet makeCountResult(ColumnIdentifier alias) - { - assert metadata.names != null; - String ksName = metadata.names.get(0).ksName; - String cfName = metadata.names.get(0).cfName; - long count = rows.size(); - return makeCountResult(ksName, cfName, count, alias); - } - - public static ResultSet.Metadata makeCountMetadata(String ksName, String cfName, ColumnIdentifier alias) - { - ColumnSpecification spec = new ColumnSpecification(ksName, cfName, alias == null ? COUNT_COLUMN : alias, LongType.instance); - return new Metadata(Collections.singletonList(spec)); - } - - public static ResultSet makeCountResult(String ksName, String cfName, long count, ColumnIdentifier alias) - { - List<List<ByteBuffer>> newRows = Collections.singletonList(Collections.singletonList(ByteBufferUtil.bytes(count))); - return new ResultSet(makeCountMetadata(ksName, cfName, alias), newRows); - } - public CqlResult toThriftResult() { assert metadata.names != null; http://git-wip-us.apache.org/repos/asf/cassandra/blob/0cad81ae/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java new file mode 100644 index 0000000..f72ed44 --- /dev/null +++ b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java @@ -0,0 +1,661 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.cql3.functions; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.util.List; + +import org.apache.cassandra.db.marshal.AbstractType; +import org.apache.cassandra.db.marshal.DecimalType; +import org.apache.cassandra.db.marshal.DoubleType; +import org.apache.cassandra.db.marshal.FloatType; +import org.apache.cassandra.db.marshal.Int32Type; +import org.apache.cassandra.db.marshal.IntegerType; +import org.apache.cassandra.db.marshal.LongType; + +/** + * Factory methods for aggregate functions. + */ +public abstract class AggregateFcts +{ + /** + * The function used to count the number of rows of a result set. This function is called when COUNT(*) or COUNT(1) + * is specified. + */ + public static final AggregateFunction countRowsFunction = + new NativeAggregateFunction("countRows", LongType.instance) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private long count; + + public void reset() + { + count = 0; + } + + public ByteBuffer compute() + { + return ((LongType) returnType()).decompose(Long.valueOf(count)); + } + + public void addInput(List<ByteBuffer> values) + { + count++; + } + }; + } + }; + + /** + * The SUM function for decimal values. + */ + public static final AggregateFunction sumFunctionForDecimal = + new NativeAggregateFunction("sum", DecimalType.instance, DecimalType.instance) + { + @Override + public Aggregate newAggregate() + { + return new Aggregate() + { + private BigDecimal sum = BigDecimal.ZERO; + + public void reset() + { + sum = BigDecimal.ZERO; + } + + public ByteBuffer compute() + { + return ((DecimalType) returnType()).decompose(sum); + } + + public void addInput(List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + BigDecimal number = ((BigDecimal) argTypes().get(0).compose(value)); + sum = sum.add(number); + } + }; + } + }; + + /** + * The AVG function for decimal values. + */ + public static final AggregateFunction avgFunctionForDecimal = + new NativeAggregateFunction("avg", DecimalType.instance, DecimalType.instance) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private BigDecimal sum = BigDecimal.ZERO; + + private int count; + + public void reset() + { + count = 0; + sum = BigDecimal.ZERO; + } + + public ByteBuffer compute() + { + if (count == 0) + return ((DecimalType) returnType()).decompose(BigDecimal.ZERO); + + return ((DecimalType) returnType()).decompose(sum.divide(BigDecimal.valueOf(count))); + } + + public void addInput(List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + count++; + BigDecimal number = ((BigDecimal) argTypes().get(0).compose(value)); + sum = sum.add(number); + } + }; + } + }; + + /** + * The SUM function for varint values. + */ + public static final AggregateFunction sumFunctionForVarint = + new NativeAggregateFunction("sum", IntegerType.instance, IntegerType.instance) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private BigInteger sum = BigInteger.ZERO; + + public void reset() + { + sum = BigInteger.ZERO; + } + + public ByteBuffer compute() + { + return ((IntegerType) returnType()).decompose(sum); + } + + public void addInput(List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + BigInteger number = ((BigInteger) argTypes().get(0).compose(value)); + sum = sum.add(number); + } + }; + } + }; + + /** + * The AVG function for varint values. + */ + public static final AggregateFunction avgFunctionForVarint = + new NativeAggregateFunction("avg", IntegerType.instance, IntegerType.instance) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private BigInteger sum = BigInteger.ZERO; + + private int count; + + public void reset() + { + count = 0; + sum = BigInteger.ZERO; + } + + public ByteBuffer compute() + { + if (count == 0) + return ((IntegerType) returnType()).decompose(BigInteger.ZERO); + + return ((IntegerType) returnType()).decompose(sum.divide(BigInteger.valueOf(count))); + } + + public void addInput(List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + count++; + BigInteger number = ((BigInteger) argTypes().get(0).compose(value)); + sum = sum.add(number); + } + }; + } + }; + + /** + * The SUM function for int32 values. + */ + public static final AggregateFunction sumFunctionForInt32 = + new NativeAggregateFunction("sum", Int32Type.instance, Int32Type.instance) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private int sum; + + public void reset() + { + sum = 0; + } + + public ByteBuffer compute() + { + return ((Int32Type) returnType()).decompose(sum); + } + + public void addInput(List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + Number number = ((Number) argTypes().get(0).compose(value)); + sum += number.intValue(); + } + }; + } + }; + + /** + * AVG function for int32 values. + */ + public static final AggregateFunction avgFunctionForInt32 = + new NativeAggregateFunction("avg", Int32Type.instance, Int32Type.instance) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private int sum; + + private int count; + + public void reset() + { + count = 0; + sum = 0; + } + + public ByteBuffer compute() + { + int avg = count == 0 ? 0 : sum / count; + + return ((Int32Type) returnType()).decompose(avg); + } + + public void addInput(List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + count++; + Number number = ((Number) argTypes().get(0).compose(value)); + sum += number.intValue(); + } + }; + } + }; + + /** + * The SUM function for long values. + */ + public static final AggregateFunction sumFunctionForLong = + new NativeAggregateFunction("sum", LongType.instance, LongType.instance) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private long sum; + + public void reset() + { + sum = 0; + } + + public ByteBuffer compute() + { + return ((LongType) returnType()).decompose(sum); + } + + public void addInput(List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + Number number = ((Number) argTypes().get(0).compose(value)); + sum += number.longValue(); + } + }; + } + }; + + /** + * AVG function for long values. + */ + public static final AggregateFunction avgFunctionForLong = + new NativeAggregateFunction("avg", LongType.instance, LongType.instance) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private long sum; + + private int count; + + public void reset() + { + count = 0; + sum = 0; + } + + public ByteBuffer compute() + { + long avg = count == 0 ? 0 : sum / count; + + return ((LongType) returnType()).decompose(avg); + } + + public void addInput(List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + count++; + Number number = ((Number) argTypes().get(0).compose(value)); + sum += number.longValue(); + } + }; + } + }; + + /** + * The SUM function for float values. + */ + public static final AggregateFunction sumFunctionForFloat = + new NativeAggregateFunction("sum", FloatType.instance, FloatType.instance) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private float sum; + + public void reset() + { + sum = 0; + } + + public ByteBuffer compute() + { + return ((FloatType) returnType()).decompose(sum); + } + + public void addInput(List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + Number number = ((Number) argTypes().get(0).compose(value)); + sum += number.floatValue(); + } + }; + } + }; + + /** + * AVG function for float values. + */ + public static final AggregateFunction avgFunctionForFloat = + new NativeAggregateFunction("avg", FloatType.instance, FloatType.instance) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private float sum; + + private int count; + + public void reset() + { + count = 0; + sum = 0; + } + + public ByteBuffer compute() + { + float avg = count == 0 ? 0 : sum / count; + + return ((FloatType) returnType()).decompose(avg); + } + + public void addInput(List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + count++; + Number number = ((Number) argTypes().get(0).compose(value)); + sum += number.floatValue(); + } + }; + } + }; + + /** + * The SUM function for double values. + */ + public static final AggregateFunction sumFunctionForDouble = + new NativeAggregateFunction("sum", DoubleType.instance, DoubleType.instance) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private double sum; + + public void reset() + { + sum = 0; + } + + public ByteBuffer compute() + { + return ((DoubleType) returnType()).decompose(sum); + } + + public void addInput(List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + Number number = ((Number) argTypes().get(0).compose(value)); + sum += number.doubleValue(); + } + }; + } + }; + + /** + * AVG function for double values. + */ + public static final AggregateFunction avgFunctionForDouble = + new NativeAggregateFunction("avg", DoubleType.instance, DoubleType.instance) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private double sum; + + private int count; + + public void reset() + { + count = 0; + sum = 0; + } + + public ByteBuffer compute() + { + double avg = count == 0 ? 0 : sum / count; + + return ((DoubleType) returnType()).decompose(avg); + } + + public void addInput(List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + count++; + Number number = ((Number) argTypes().get(0).compose(value)); + sum += number.doubleValue(); + } + }; + } + }; + + /** + * Creates a MAX function for the specified type. + * + * @param inputType the function input and output type + * @return a MAX function for the specified type. + */ + public static AggregateFunction makeMaxFunction(final AbstractType<?> inputType) + { + return new NativeAggregateFunction("max", inputType, inputType) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private ByteBuffer max; + + public void reset() + { + max = null; + } + + public ByteBuffer compute() + { + return max; + } + + public void addInput(List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + if (max == null || returnType().compare(max, value) < 0) + max = value; + } + }; + } + }; + } + + /** + * Creates a MIN function for the specified type. + * + * @param inputType the function input and output type + * @return a MIN function for the specified type. + */ + public static AggregateFunction makeMinFunction(final AbstractType<?> inputType) + { + return new NativeAggregateFunction("min", inputType, inputType) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private ByteBuffer min; + + public void reset() + { + min = null; + } + + public ByteBuffer compute() + { + return min; + } + + public void addInput(List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + if (min == null || returnType().compare(min, value) > 0) + min = value; + } + }; + } + }; + } + + /** + * Creates a COUNT function for the specified type. + * + * @param inputType the function input type + * @return a COUNT function for the specified type. + */ + public static AggregateFunction makeCountFunction(AbstractType<?> inputType) + { + return new NativeAggregateFunction("count", LongType.instance, inputType) + { + public Aggregate newAggregate() + { + return new Aggregate() + { + private long count; + + public void reset() + { + count = 0; + } + + public ByteBuffer compute() + { + return ((LongType) returnType()).decompose(count); + } + + public void addInput(List<ByteBuffer> values) + { + ByteBuffer value = values.get(0); + + if (value == null) + return; + + count++; + } + }; + } + }; + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/0cad81ae/src/java/org/apache/cassandra/cql3/functions/AggregateFunction.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/AggregateFunction.java b/src/java/org/apache/cassandra/cql3/functions/AggregateFunction.java new file mode 100644 index 0000000..47eee4b --- /dev/null +++ b/src/java/org/apache/cassandra/cql3/functions/AggregateFunction.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.cql3.functions; + +import java.nio.ByteBuffer; +import java.util.List; + +/** + * Performs a calculation on a set of values and return a single value. + */ +public interface AggregateFunction extends Function +{ + /** + * Creates a new <code>Aggregate</code> instance. + * + * @return a new <code>Aggregate</code> instance. + */ + public Aggregate newAggregate(); + + /** + * An aggregation operation. + */ + interface Aggregate + { + /** + * Adds the specified input to this aggregate. + * + * @param values the values to add to the aggregate. + */ + public void addInput(List<ByteBuffer> values); + + /** + * Computes and returns the aggregate current value. + * + * @return the aggregate current value. + */ + public ByteBuffer compute(); + + /** + * Reset this aggregate. + */ + public void reset(); + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/0cad81ae/src/java/org/apache/cassandra/cql3/functions/BytesConversionFcts.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/BytesConversionFcts.java b/src/java/org/apache/cassandra/cql3/functions/BytesConversionFcts.java index 6ea0a55..1cd1d69 100644 --- a/src/java/org/apache/cassandra/cql3/functions/BytesConversionFcts.java +++ b/src/java/org/apache/cassandra/cql3/functions/BytesConversionFcts.java @@ -34,7 +34,7 @@ public abstract class BytesConversionFcts public static Function makeToBlobFunction(AbstractType<?> fromType) { String name = fromType.asCQL3Type() + "asblob"; - return new NativeFunction(name, BytesType.instance, fromType) + return new NativeScalarFunction(name, BytesType.instance, fromType) { public ByteBuffer execute(List<ByteBuffer> parameters) { @@ -46,7 +46,7 @@ public abstract class BytesConversionFcts public static Function makeFromBlobFunction(final AbstractType<?> toType) { final String name = "blobas" + toType.asCQL3Type(); - return new NativeFunction(name, toType, BytesType.instance) + return new NativeScalarFunction(name, toType, BytesType.instance) { public ByteBuffer execute(List<ByteBuffer> parameters) throws InvalidRequestException { @@ -66,7 +66,7 @@ public abstract class BytesConversionFcts }; } - public static final Function VarcharAsBlobFct = new NativeFunction("varcharasblob", BytesType.instance, UTF8Type.instance) + public static final Function VarcharAsBlobFct = new NativeScalarFunction("varcharasblob", BytesType.instance, UTF8Type.instance) { public ByteBuffer execute(List<ByteBuffer> parameters) { @@ -74,7 +74,7 @@ public abstract class BytesConversionFcts } }; - public static final Function BlobAsVarcharFact = new NativeFunction("blobasvarchar", UTF8Type.instance, BytesType.instance) + public static final Function BlobAsVarcharFact = new NativeScalarFunction("blobasvarchar", UTF8Type.instance, BytesType.instance) { public ByteBuffer execute(List<ByteBuffer> parameters) { http://git-wip-us.apache.org/repos/asf/cassandra/blob/0cad81ae/src/java/org/apache/cassandra/cql3/functions/Function.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/Function.java b/src/java/org/apache/cassandra/cql3/functions/Function.java index dc2a0db..cba9fcf 100644 --- a/src/java/org/apache/cassandra/cql3/functions/Function.java +++ b/src/java/org/apache/cassandra/cql3/functions/Function.java @@ -17,11 +17,9 @@ */ package org.apache.cassandra.cql3.functions; -import java.nio.ByteBuffer; import java.util.List; import org.apache.cassandra.db.marshal.AbstractType; -import org.apache.cassandra.exceptions.InvalidRequestException; public interface Function { @@ -29,11 +27,24 @@ public interface Function public List<AbstractType<?>> argTypes(); public AbstractType<?> returnType(); - public ByteBuffer execute(List<ByteBuffer> parameters) throws InvalidRequestException; - - // Whether the function is a pure function (as in doesn't depend on, nor produce side effects). + /** + * Checks whether the function is a pure function (as in doesn't depend on, nor produce side effects) or not. + * + * @return <code>true</code> if the function is a pure function, <code>false</code> otherwise. + */ public boolean isPure(); - // Whether the function is a native/harcoded one. + /** + * Checks whether the function is a native/hard coded one or not. + * + * @return <code>true</code> if the function is a native/hard coded one, <code>false</code> otherwise. + */ public boolean isNative(); + + /** + * Checks whether the function is an aggregate function or not. + * + * @return <code>true</code> if the function is an aggregate function, <code>false</code> otherwise. + */ + public boolean isAggregate(); } http://git-wip-us.apache.org/repos/asf/cassandra/blob/0cad81ae/src/java/org/apache/cassandra/cql3/functions/FunctionCall.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/FunctionCall.java b/src/java/org/apache/cassandra/cql3/functions/FunctionCall.java index 0a8fe58..3b80fc0 100644 --- a/src/java/org/apache/cassandra/cql3/functions/FunctionCall.java +++ b/src/java/org/apache/cassandra/cql3/functions/FunctionCall.java @@ -22,7 +22,6 @@ import java.util.ArrayList; import java.util.List; import org.apache.cassandra.cql3.*; -import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.marshal.CollectionType; import org.apache.cassandra.db.marshal.ListType; import org.apache.cassandra.db.marshal.MapType; @@ -33,10 +32,10 @@ import org.apache.cassandra.serializers.MarshalException; public class FunctionCall extends Term.NonTerminal { - private final Function fun; + private final ScalarFunction fun; private final List<Term> terms; - private FunctionCall(Function fun, List<Term> terms) + private FunctionCall(ScalarFunction fun, List<Term> terms) { this.fun = fun; this.terms = terms; @@ -68,7 +67,7 @@ public class FunctionCall extends Term.NonTerminal return executeInternal(fun, buffers); } - private static ByteBuffer executeInternal(Function fun, List<ByteBuffer> params) throws InvalidRequestException + private static ByteBuffer executeInternal(ScalarFunction fun, List<ByteBuffer> params) throws InvalidRequestException { ByteBuffer result = fun.execute(params); try @@ -125,12 +124,16 @@ public class FunctionCall extends Term.NonTerminal Function fun = Functions.get(keyspace, name, terms, receiver.ksName, receiver.cfName); if (fun == null) throw new InvalidRequestException(String.format("Unknown function %s called", name)); + if (fun.isAggregate()) + throw new InvalidRequestException("Aggregation function are not supported in the where clause"); + + ScalarFunction scalarFun = (ScalarFunction) fun; // Functions.get() will complain if no function "name" type check with the provided arguments. // We still have to validate that the return type matches however - if (!receiver.type.isValueCompatibleWith(fun.returnType())) + if (!receiver.type.isValueCompatibleWith(scalarFun.returnType())) throw new InvalidRequestException(String.format("Type error: cannot assign result of function %s (type %s) to %s (type %s)", - fun.name(), fun.returnType().asCQL3Type(), + scalarFun.name(), scalarFun.returnType().asCQL3Type(), receiver.name, receiver.type.asCQL3Type())); if (fun.argTypes().size() != terms.size()) @@ -141,7 +144,7 @@ public class FunctionCall extends Term.NonTerminal boolean allTerminal = true; for (int i = 0; i < terms.size(); i++) { - Term t = terms.get(i).prepare(keyspace, Functions.makeArgSpec(receiver.ksName, receiver.cfName, fun, i)); + Term t = terms.get(i).prepare(keyspace, Functions.makeArgSpec(receiver.ksName, receiver.cfName, scalarFun, i)); if (t instanceof NonTerminal) allTerminal = false; parameters.add(t); @@ -149,13 +152,13 @@ public class FunctionCall extends Term.NonTerminal // If all parameters are terminal and the function is pure, we can // evaluate it now, otherwise we'd have to wait execution time - return allTerminal && fun.isPure() - ? makeTerminal(fun, execute(fun, parameters), QueryOptions.DEFAULT.getProtocolVersion()) - : new FunctionCall(fun, parameters); + return allTerminal && scalarFun.isPure() + ? makeTerminal(scalarFun, execute(scalarFun, parameters), QueryOptions.DEFAULT.getProtocolVersion()) + : new FunctionCall(scalarFun, parameters); } // All parameters must be terminal - private static ByteBuffer execute(Function fun, List<Term> parameters) throws InvalidRequestException + private static ByteBuffer execute(ScalarFunction fun, List<Term> parameters) throws InvalidRequestException { List<ByteBuffer> buffers = new ArrayList<ByteBuffer>(parameters.size()); for (Term t : parameters) http://git-wip-us.apache.org/repos/asf/cassandra/blob/0cad81ae/src/java/org/apache/cassandra/cql3/functions/Functions.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/Functions.java b/src/java/org/apache/cassandra/cql3/functions/Functions.java index 18feb36..e8d6181 100644 --- a/src/java/org/apache/cassandra/cql3/functions/Functions.java +++ b/src/java/org/apache/cassandra/cql3/functions/Functions.java @@ -48,6 +48,7 @@ public abstract class Functions static { + declare(AggregateFcts.countRowsFunction); declare(TimeuuidFcts.nowFct); declare(TimeuuidFcts.minTimeuuidFct); declare(TimeuuidFcts.maxTimeuuidFct); @@ -64,9 +65,25 @@ public abstract class Functions declare(BytesConversionFcts.makeToBlobFunction(type.getType())); declare(BytesConversionFcts.makeFromBlobFunction(type.getType())); + + declare(AggregateFcts.makeCountFunction(type.getType())); + declare(AggregateFcts.makeMaxFunction(type.getType())); + declare(AggregateFcts.makeMinFunction(type.getType())); } declare(BytesConversionFcts.VarcharAsBlobFct); declare(BytesConversionFcts.BlobAsVarcharFact); + declare(AggregateFcts.sumFunctionForInt32); + declare(AggregateFcts.sumFunctionForLong); + declare(AggregateFcts.sumFunctionForFloat); + declare(AggregateFcts.sumFunctionForDouble); + declare(AggregateFcts.sumFunctionForDecimal); + declare(AggregateFcts.sumFunctionForVarint); + declare(AggregateFcts.avgFunctionForInt32); + declare(AggregateFcts.avgFunctionForLong); + declare(AggregateFcts.avgFunctionForFloat); + declare(AggregateFcts.avgFunctionForDouble); + declare(AggregateFcts.avgFunctionForVarint); + declare(AggregateFcts.avgFunctionForDecimal); } private static void declare(Function fun) http://git-wip-us.apache.org/repos/asf/cassandra/blob/0cad81ae/src/java/org/apache/cassandra/cql3/functions/NativeAggregateFunction.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/NativeAggregateFunction.java b/src/java/org/apache/cassandra/cql3/functions/NativeAggregateFunction.java new file mode 100644 index 0000000..88aab4b --- /dev/null +++ b/src/java/org/apache/cassandra/cql3/functions/NativeAggregateFunction.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.cql3.functions; + +import org.apache.cassandra.db.marshal.AbstractType; + +/** + * Base class for the <code>AggregateFunction</code> native classes. + */ +public abstract class NativeAggregateFunction extends NativeFunction implements AggregateFunction +{ + protected NativeAggregateFunction(String name, AbstractType<?> returnType, AbstractType<?>... argTypes) + { + super(name, returnType, argTypes); + } + + public final boolean isAggregate() + { + return true; + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/0cad81ae/src/java/org/apache/cassandra/cql3/functions/NativeScalarFunction.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/NativeScalarFunction.java b/src/java/org/apache/cassandra/cql3/functions/NativeScalarFunction.java new file mode 100644 index 0000000..8f7f221 --- /dev/null +++ b/src/java/org/apache/cassandra/cql3/functions/NativeScalarFunction.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.cql3.functions; + +import org.apache.cassandra.db.marshal.AbstractType; + +/** + * Base class for the <code>ScalarFunction</code> native classes. + */ +public abstract class NativeScalarFunction extends NativeFunction implements ScalarFunction +{ + protected NativeScalarFunction(String name, AbstractType<?> returnType, AbstractType<?>... argsType) + { + super(name, returnType, argsType); + } + + public final boolean isAggregate() + { + return false; + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/0cad81ae/src/java/org/apache/cassandra/cql3/functions/ScalarFunction.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/ScalarFunction.java b/src/java/org/apache/cassandra/cql3/functions/ScalarFunction.java new file mode 100644 index 0000000..ba2a374 --- /dev/null +++ b/src/java/org/apache/cassandra/cql3/functions/ScalarFunction.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.cql3.functions; + +import java.nio.ByteBuffer; +import java.util.List; + +import org.apache.cassandra.exceptions.InvalidRequestException; + +/** + * Determines a single output value based on a single input value. + */ +public interface ScalarFunction extends Function +{ + /** + * Applies this function to the specified parameter. + * + * @param parameters the input parameters + * @return the result of applying this function to the parameter + * @throws InvalidRequestException if this function cannot not be applied to the parameter + */ + public ByteBuffer execute(List<ByteBuffer> parameters) throws InvalidRequestException; +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/0cad81ae/src/java/org/apache/cassandra/cql3/functions/TimeuuidFcts.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/TimeuuidFcts.java b/src/java/org/apache/cassandra/cql3/functions/TimeuuidFcts.java index 9b7bbf0..e481cf5 100644 --- a/src/java/org/apache/cassandra/cql3/functions/TimeuuidFcts.java +++ b/src/java/org/apache/cassandra/cql3/functions/TimeuuidFcts.java @@ -29,7 +29,7 @@ import org.apache.cassandra.utils.UUIDGen; public abstract class TimeuuidFcts { - public static final Function nowFct = new NativeFunction("now", TimeUUIDType.instance) + public static final Function nowFct = new NativeScalarFunction("now", TimeUUIDType.instance) { public ByteBuffer execute(List<ByteBuffer> parameters) { @@ -43,7 +43,7 @@ public abstract class TimeuuidFcts } }; - public static final Function minTimeuuidFct = new NativeFunction("mintimeuuid", TimeUUIDType.instance, TimestampType.instance) + public static final Function minTimeuuidFct = new NativeScalarFunction("mintimeuuid", TimeUUIDType.instance, TimestampType.instance) { public ByteBuffer execute(List<ByteBuffer> parameters) { @@ -55,7 +55,7 @@ public abstract class TimeuuidFcts } }; - public static final Function maxTimeuuidFct = new NativeFunction("maxtimeuuid", TimeUUIDType.instance, TimestampType.instance) + public static final Function maxTimeuuidFct = new NativeScalarFunction("maxtimeuuid", TimeUUIDType.instance, TimestampType.instance) { public ByteBuffer execute(List<ByteBuffer> parameters) { @@ -67,7 +67,7 @@ public abstract class TimeuuidFcts } }; - public static final Function dateOfFct = new NativeFunction("dateof", TimestampType.instance, TimeUUIDType.instance) + public static final Function dateOfFct = new NativeScalarFunction("dateof", TimestampType.instance, TimeUUIDType.instance) { public ByteBuffer execute(List<ByteBuffer> parameters) { @@ -79,7 +79,7 @@ public abstract class TimeuuidFcts } }; - public static final Function unixTimestampOfFct = new NativeFunction("unixtimestampof", LongType.instance, TimeUUIDType.instance) + public static final Function unixTimestampOfFct = new NativeScalarFunction("unixtimestampof", LongType.instance, TimeUUIDType.instance) { public ByteBuffer execute(List<ByteBuffer> parameters) { http://git-wip-us.apache.org/repos/asf/cassandra/blob/0cad81ae/src/java/org/apache/cassandra/cql3/functions/TokenFct.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/TokenFct.java b/src/java/org/apache/cassandra/cql3/functions/TokenFct.java index 2504a66..ca4d473 100644 --- a/src/java/org/apache/cassandra/cql3/functions/TokenFct.java +++ b/src/java/org/apache/cassandra/cql3/functions/TokenFct.java @@ -28,7 +28,7 @@ import org.apache.cassandra.dht.IPartitioner; import org.apache.cassandra.exceptions.InvalidRequestException; import org.apache.cassandra.service.StorageService; -public class TokenFct extends NativeFunction +public class TokenFct extends NativeScalarFunction { // The actual token function depends on the partitioner used private static final IPartitioner partitioner = StorageService.getPartitioner(); http://git-wip-us.apache.org/repos/asf/cassandra/blob/0cad81ae/src/java/org/apache/cassandra/cql3/functions/UDFunction.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/UDFunction.java b/src/java/org/apache/cassandra/cql3/functions/UDFunction.java index 3ef5764..264998c 100644 --- a/src/java/org/apache/cassandra/cql3/functions/UDFunction.java +++ b/src/java/org/apache/cassandra/cql3/functions/UDFunction.java @@ -40,7 +40,7 @@ import org.apache.cassandra.utils.FBUtilities; /** * Base class for User Defined Functions. */ -public abstract class UDFunction extends AbstractFunction +public abstract class UDFunction extends AbstractFunction implements ScalarFunction { protected static final Logger logger = LoggerFactory.getLogger(UDFunction.class); @@ -65,6 +65,11 @@ public abstract class UDFunction extends AbstractFunction this.deterministic = deterministic; } + public boolean isAggregate() + { + return false; + } + public static UDFunction create(FunctionName name, List<ColumnIdentifier> argNames, List<AbstractType<?>> argTypes, http://git-wip-us.apache.org/repos/asf/cassandra/blob/0cad81ae/src/java/org/apache/cassandra/cql3/functions/UuidFcts.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/UuidFcts.java b/src/java/org/apache/cassandra/cql3/functions/UuidFcts.java index 1bf4c17..b3cef85 100644 --- a/src/java/org/apache/cassandra/cql3/functions/UuidFcts.java +++ b/src/java/org/apache/cassandra/cql3/functions/UuidFcts.java @@ -26,7 +26,7 @@ import org.apache.cassandra.serializers.UUIDSerializer; public abstract class UuidFcts { - public static final Function uuidFct = new NativeFunction("uuid", UUIDType.instance) + public static final Function uuidFct = new NativeScalarFunction("uuid", UUIDType.instance) { public ByteBuffer execute(List<ByteBuffer> parameters) { http://git-wip-us.apache.org/repos/asf/cassandra/blob/0cad81ae/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java index 6e0929c..d7485c0 100644 --- a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java @@ -97,7 +97,7 @@ public class SelectStatement implements CQLStatement, MeasurableForPreparedCache private boolean selectsOnlyStaticColumns; // Used by forSelection below - private static final Parameters defaultParameters = new Parameters(Collections.<ColumnIdentifier, Boolean>emptyMap(), false, false, null, false); + private static final Parameters defaultParameters = new Parameters(Collections.<ColumnIdentifier, Boolean>emptyMap(), false, false); private static final Predicate<ColumnDefinition> isStaticFilter = new Predicate<ColumnDefinition>() { @@ -156,9 +156,7 @@ public class SelectStatement implements CQLStatement, MeasurableForPreparedCache public ResultSet.Metadata getResultMetadata() { - return parameters.isCount - ? ResultSet.makeCountMetadata(keyspace(), columnFamily(), parameters.countAlias) - : selection.getResultMetadata(); + return selection.getResultMetadata(); } public long measureForPreparedCache(MemoryMeter meter) @@ -203,32 +201,31 @@ public class SelectStatement implements CQLStatement, MeasurableForPreparedCache Pageable command = getPageableCommand(options, limit, now); int pageSize = options.getPageSize(); - // A count query will never be paged for the user, but we always page it internally to avoid OOM. + + // An aggregation query will never be paged for the user, but we always page it internally to avoid OOM. // If we user provided a pageSize we'll use that to page internally (because why not), otherwise we use our default // Note that if there are some nodes in the cluster with a version less than 2.0, we can't use paging (CASSANDRA-6707). - if (parameters.isCount && pageSize <= 0) + if (selection.isAggregate() && pageSize <= 0) pageSize = DEFAULT_COUNT_PAGE_SIZE; if (pageSize <= 0 || command == null || !QueryPagers.mayNeedPaging(command, pageSize)) { return execute(command, options, limit, now); } - else - { - QueryPager pager = QueryPagers.pager(command, cl, options.getPagingState()); - if (parameters.isCount) - return pageCountQuery(pager, options, pageSize, now, limit); - // We can't properly do post-query ordering if we page (see #6722) - if (needsPostQueryOrdering()) - throw new InvalidRequestException("Cannot page queries with both ORDER BY and a IN restriction on the partition key; you must either remove the " - + "ORDER BY or the IN and sort client side, or disable paging for this query"); + QueryPager pager = QueryPagers.pager(command, cl, options.getPagingState()); + if (selection.isAggregate()) + return pageAggregateQuery(pager, options, pageSize, now); - List<Row> page = pager.fetchPage(pageSize); - ResultMessage.Rows msg = processResults(page, options, limit, now); + // We can't properly do post-query ordering if we page (see #6722) + if (needsPostQueryOrdering()) + throw new InvalidRequestException("Cannot page queries with both ORDER BY and a IN restriction on the partition key; you must either remove the " + + "ORDER BY or the IN and sort client side, or disable paging for this query"); - return pager.isExhausted() ? msg : msg.withPagingState(pager.state()); - } + List<Row> page = pager.fetchPage(pageSize); + ResultMessage.Rows msg = processResults(page, options, limit, now); + + return pager.isExhausted() ? msg : msg.withPagingState(pager.state()); } private Pageable getPageableCommand(QueryOptions options, int limit, long now) throws RequestValidationException @@ -263,28 +260,27 @@ public class SelectStatement implements CQLStatement, MeasurableForPreparedCache return processResults(rows, options, limit, now); } - private ResultMessage.Rows pageCountQuery(QueryPager pager, QueryOptions options, int pageSize, long now, int limit) throws RequestValidationException, RequestExecutionException + private ResultMessage.Rows pageAggregateQuery(QueryPager pager, QueryOptions options, int pageSize, long now) + throws RequestValidationException, RequestExecutionException { - int count = 0; + Selection.ResultSetBuilder result = selection.resultSetBuilder(now); while (!pager.isExhausted()) { - int maxLimit = pager.maxRemaining(); - logger.debug("New maxLimit for paged count query is {}", maxLimit); - ResultSet rset = process(pager.fetchPage(pageSize), options, maxLimit, now); - count += rset.rows.size(); - } + for (org.apache.cassandra.db.Row row : pager.fetchPage(pageSize)) + { + // Not columns match the query, skip + if (row.cf == null) + continue; - // We sometimes query one more result than the user limit asks to handle exclusive bounds with compact tables (see updateLimitForQuery). - // So do make sure the count is not greater than what the user asked for. - ResultSet result = ResultSet.makeCountResult(keyspace(), columnFamily(), Math.min(count, limit), parameters.countAlias); - return new ResultMessage.Rows(result); + processColumnFamily(row.key.getKey(), row.cf, options, now, result); + } + } + return new ResultMessage.Rows(result.build()); } public ResultMessage.Rows processResults(List<Row> rows, QueryOptions options, int limit, long now) throws RequestValidationException { - // Even for count, we need to process the result as it'll group some column together in sparse column families ResultSet rset = process(rows, options, limit, now); - rset = parameters.isCount ? rset.makeCountResult(parameters.countAlias) : rset; return new ResultMessage.Rows(rset); } @@ -313,7 +309,6 @@ public class SelectStatement implements CQLStatement, MeasurableForPreparedCache public ResultSet process(List<Row> rows) throws InvalidRequestException { - assert !parameters.isCount; // not yet needed QueryOptions options = QueryOptions.DEFAULT; return process(rows, options, getLimit(options), System.currentTimeMillis()); } @@ -1333,10 +1328,6 @@ public class SelectStatement implements CQLStatement, MeasurableForPreparedCache CFMetaData cfm = ThriftValidation.validateColumnFamily(keyspace(), columnFamily()); VariableSpecifications boundNames = getBoundVariables(); - // Select clause - if (parameters.isCount && !selectClause.isEmpty()) - throw new InvalidRequestException("Only COUNT(*) and COUNT(1) operations are currently supported."); - Selection selection = selectClause.isEmpty() ? Selection.wildcard(cfm) : Selection.fromSelectors(cfm, selectClause); @@ -2070,7 +2061,6 @@ public class SelectStatement implements CQLStatement, MeasurableForPreparedCache .add("selectClause", selectClause) .add("whereClause", whereClause) .add("isDistinct", parameters.isDistinct) - .add("isCount", parameters.isCount) .toString(); } } @@ -2079,20 +2069,14 @@ public class SelectStatement implements CQLStatement, MeasurableForPreparedCache { private final Map<ColumnIdentifier, Boolean> orderings; private final boolean isDistinct; - private final boolean isCount; - private final ColumnIdentifier countAlias; private final boolean allowFiltering; public Parameters(Map<ColumnIdentifier, Boolean> orderings, boolean isDistinct, - boolean isCount, - ColumnIdentifier countAlias, boolean allowFiltering) { this.orderings = orderings; this.isDistinct = isDistinct; - this.isCount = isCount; - this.countAlias = countAlias; this.allowFiltering = allowFiltering; } } http://git-wip-us.apache.org/repos/asf/cassandra/blob/0cad81ae/src/java/org/apache/cassandra/cql3/statements/Selection.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/statements/Selection.java b/src/java/org/apache/cassandra/cql3/statements/Selection.java index 20211b2..88ba265 100644 --- a/src/java/org/apache/cassandra/cql3/statements/Selection.java +++ b/src/java/org/apache/cassandra/cql3/statements/Selection.java @@ -22,13 +22,16 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; -import com.google.common.collect.Iterators; - -import org.apache.cassandra.cql3.*; -import org.apache.cassandra.cql3.functions.Function; -import org.apache.cassandra.cql3.functions.Functions; import org.apache.cassandra.config.CFMetaData; import org.apache.cassandra.config.ColumnDefinition; +import org.apache.cassandra.cql3.AssignmentTestable; +import org.apache.cassandra.cql3.ColumnIdentifier; +import org.apache.cassandra.cql3.ColumnSpecification; +import org.apache.cassandra.cql3.ResultSet; +import org.apache.cassandra.cql3.functions.AggregateFunction; +import org.apache.cassandra.cql3.functions.Function; +import org.apache.cassandra.cql3.functions.Functions; +import org.apache.cassandra.cql3.functions.ScalarFunction; import org.apache.cassandra.db.Cell; import org.apache.cassandra.db.CounterCell; import org.apache.cassandra.db.ExpiringCell; @@ -36,11 +39,13 @@ import org.apache.cassandra.db.context.CounterContext; import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.marshal.Int32Type; import org.apache.cassandra.db.marshal.LongType; -import org.apache.cassandra.db.marshal.UserType; import org.apache.cassandra.db.marshal.UTF8Type; +import org.apache.cassandra.db.marshal.UserType; import org.apache.cassandra.exceptions.InvalidRequestException; import org.apache.cassandra.utils.ByteBufferUtil; +import com.google.common.collect.Iterators; + public abstract class Selection { private final Collection<ColumnDefinition> columns; @@ -166,7 +171,8 @@ public abstract class Selection throw new InvalidRequestException(String.format("Unknown function '%s'", withFun.functionName)); if (metadata != null) metadata.add(makeFunctionSpec(cfm, withFun, fun.returnType(), raw.alias)); - return new FunctionSelector(fun, args); + return fun.isAggregate() ? new AggregateFunctionSelector(fun, args) + : new ScalarFunctionSelector(fun, args); } } @@ -245,7 +251,13 @@ public abstract class Selection } } - protected abstract List<ByteBuffer> handleRow(ResultSetBuilder rs) throws InvalidRequestException; + protected abstract void addInputRow(ResultSetBuilder rs) throws InvalidRequestException; + + protected abstract boolean isAggregate(); + + protected abstract List<ByteBuffer> getOutputRow() throws InvalidRequestException; + + protected abstract void reset(); /** * @return the list of CQL3 columns value this SelectionClause needs. @@ -267,6 +279,26 @@ public abstract class Selection : c.value(); } + /** + * Checks that selectors are either all aggregates or that none of them is. + * + * @param selectors the selectors to test. + * @param msgTemplate the error message template + * @param messageArgs the error message arguments + * @throws InvalidRequestException if some of the selectors are aggregate but not all of them + */ + private static void validateSelectors(List<Selector> selectors, String messageTemplate, Object... messageArgs) + throws InvalidRequestException + { + int aggregates = 0; + for (Selector s : selectors) + if (s.isAggregate()) + ++aggregates; + + if (aggregates != 0 && aggregates != selectors.size()) + throw new InvalidRequestException(String.format(messageTemplate, messageArgs)); + } + public class ResultSetBuilder { private final ResultSet resultSet; @@ -321,7 +353,14 @@ public abstract class Selection public void newRow() throws InvalidRequestException { if (current != null) - resultSet.addRow(handleRow(this)); + { + addInputRow(this); + if (!isAggregate()) + { + resultSet.addRow(getOutputRow()); + reset(); + } + } current = new ArrayList<ByteBuffer>(columns.size()); } @@ -329,7 +368,9 @@ public abstract class Selection { if (current != null) { - resultSet.addRow(handleRow(this)); + addInputRow(this); + resultSet.addRow(getOutputRow()); + reset(); current = null; } return resultSet; @@ -341,6 +382,8 @@ public abstract class Selection { private final boolean isWildcard; + private List<ByteBuffer> current; + public SimpleSelection(Collection<ColumnDefinition> columns, boolean isWildcard) { this(columns, new ArrayList<ColumnSpecification>(columns), isWildcard); @@ -357,23 +400,48 @@ public abstract class Selection this.isWildcard = isWildcard; } - protected List<ByteBuffer> handleRow(ResultSetBuilder rs) - { - return rs.current; - } - @Override public boolean isWildcard() { return isWildcard; } + + protected void addInputRow(ResultSetBuilder rs) throws InvalidRequestException + { + current = rs.current; + } + + protected boolean isAggregate() + { + return false; + } + + protected List<ByteBuffer> getOutputRow() throws InvalidRequestException + { + return current; + } + + protected void reset() + { + current = null; + } } private static abstract class Selector implements AssignmentTestable { - public abstract ByteBuffer compute(ResultSetBuilder rs) throws InvalidRequestException; + public abstract void addInput(ResultSetBuilder rs) throws InvalidRequestException; + + public abstract ByteBuffer getOutput() throws InvalidRequestException; + public abstract AbstractType<?> getType(); + public boolean isAggregate() + { + return false; + } + + public abstract void reset(); + public AssignmentTestable.TestResult testAssignment(String keyspace, ColumnSpecification receiver) { if (receiver.type.equals(getType())) @@ -390,6 +458,7 @@ public abstract class Selection private final String columnName; private final int idx; private final AbstractType<?> type; + private ByteBuffer current; public SimpleSelector(String columnName, int idx, AbstractType<?> type) { @@ -398,9 +467,19 @@ public abstract class Selection this.type = type; } - public ByteBuffer compute(ResultSetBuilder rs) + public void addInput(ResultSetBuilder rs) throws InvalidRequestException { - return rs.current.get(idx); + current = rs.current.get(idx); + } + + public ByteBuffer getOutput() throws InvalidRequestException + { + return current; + } + + public void reset() + { + current = null; } public AbstractType<?> getType() @@ -415,26 +494,17 @@ public abstract class Selection } } - private static class FunctionSelector extends Selector + private static abstract class AbstractFunctionSelector<T extends Function> extends Selector { - private final Function fun; - private final List<Selector> argSelectors; + protected final T fun; + protected final List<Selector> argSelectors; - public FunctionSelector(Function fun, List<Selector> argSelectors) + public AbstractFunctionSelector(T fun, List<Selector> argSelectors) { this.fun = fun; this.argSelectors = argSelectors; } - public ByteBuffer compute(ResultSetBuilder rs) throws InvalidRequestException - { - List<ByteBuffer> args = new ArrayList<ByteBuffer>(argSelectors.size()); - for (Selector s : argSelectors) - args.add(s.compute(rs)); - - return fun.execute(args); - } - public AbstractType<?> getType() { return fun.returnType(); @@ -455,6 +525,102 @@ public abstract class Selection } } + private static class ScalarFunctionSelector extends AbstractFunctionSelector<ScalarFunction> + { + public ScalarFunctionSelector(Function fun, List<Selector> argSelectors) throws InvalidRequestException + { + super((ScalarFunction) fun, argSelectors); + validateSelectors(argSelectors, + "the %s function arguments must be either all aggregates or all none aggregates", + fun.name().name); + } + + public boolean isAggregate() + { + // We cannot just return true as it is possible to have a scalar function wrapping an aggregation function + if (argSelectors.isEmpty()) + return false; + + return argSelectors.get(0).isAggregate(); + } + + public void addInput(ResultSetBuilder rs) throws InvalidRequestException + { + for (Selector s : argSelectors) + s.addInput(rs); + } + + public void reset() + { + } + + public ByteBuffer getOutput() throws InvalidRequestException + { + List<ByteBuffer> args = new ArrayList<ByteBuffer>(argSelectors.size()); + for (Selector s : argSelectors) + { + args.add(s.getOutput()); + s.reset(); + } + return fun.execute(args); + } + } + + private static class AggregateFunctionSelector extends AbstractFunctionSelector<AggregateFunction> + { + private final AggregateFunction.Aggregate aggregate; + + public AggregateFunctionSelector(Function fun, List<Selector> argSelectors) throws InvalidRequestException + { + super((AggregateFunction) fun, argSelectors); + + validateAgruments(argSelectors); + this.aggregate = this.fun.newAggregate(); + } + + public boolean isAggregate() + { + return true; + } + + public void addInput(ResultSetBuilder rs) throws InvalidRequestException + { + List<ByteBuffer> args = new ArrayList<ByteBuffer>(argSelectors.size()); + // Aggregation of aggregation is not supported + for (Selector s : argSelectors) + { + s.addInput(rs); + args.add(s.getOutput()); + s.reset(); + } + this.aggregate.addInput(args); + } + + public ByteBuffer getOutput() throws InvalidRequestException + { + return aggregate.compute(); + } + + public void reset() + { + aggregate.reset(); + } + + /** + * Checks that the arguments are not themselves aggregation functions. + * + * @param argSelectors the selector to check + * @throws InvalidRequestException if on of the arguments is an aggregation function + */ + private static void validateAgruments(List<Selector> argSelectors) throws InvalidRequestException + { + for (Selector selector : argSelectors) + if (selector.isAggregate()) + throw new InvalidRequestException( + "aggregate functions cannot be used as arguments of aggregate functions"); + } + } + private static class FieldSelector extends Selector { private final UserType type; @@ -468,9 +634,19 @@ public abstract class Selection this.selected = selected; } - public ByteBuffer compute(ResultSetBuilder rs) throws InvalidRequestException + public boolean isAggregate() + { + return selected.isAggregate(); + } + + public void addInput(ResultSetBuilder rs) throws InvalidRequestException { - ByteBuffer value = selected.compute(rs); + selected.addInput(rs); + } + + public ByteBuffer getOutput() throws InvalidRequestException + { + ByteBuffer value = selected.getOutput(); if (value == null) return null; ByteBuffer[] buffers = type.split(value); @@ -482,6 +658,11 @@ public abstract class Selection return type.fieldType(field); } + public void reset() + { + selected.reset(); + } + @Override public String toString() { @@ -494,6 +675,7 @@ public abstract class Selection private final String columnName; private final int idx; private final boolean isWritetime; + private ByteBuffer current; public WritetimeOrTTLSelector(String columnName, int idx, boolean isWritetime) { @@ -502,16 +684,28 @@ public abstract class Selection this.isWritetime = isWritetime; } - public ByteBuffer compute(ResultSetBuilder rs) + public void addInput(ResultSetBuilder rs) { if (isWritetime) { long ts = rs.timestamps[idx]; - return ts >= 0 ? ByteBufferUtil.bytes(ts) : null; + current = ts >= 0 ? ByteBufferUtil.bytes(ts) : null; + } + else + { + int ttl = rs.ttls[idx]; + current = ttl > 0 ? ByteBufferUtil.bytes(ttl) : null; } + } - int ttl = rs.ttls[idx]; - return ttl > 0 ? ByteBufferUtil.bytes(ttl) : null; + public ByteBuffer getOutput() + { + return current; + } + + public void reset() + { + current = null; } public AbstractType<?> getType() @@ -530,20 +724,47 @@ public abstract class Selection { private final List<Selector> selectors; - public SelectionWithFunctions(Collection<ColumnDefinition> columns, List<ColumnSpecification> metadata, List<Selector> selectors, boolean collectTimestamps, boolean collectTTLs) + public SelectionWithFunctions(Collection<ColumnDefinition> columns, + List<ColumnSpecification> metadata, + List<Selector> selectors, + boolean collectTimestamps, + boolean collectTTLs) throws InvalidRequestException { super(columns, metadata, collectTimestamps, collectTTLs); this.selectors = selectors; + + validateSelectors(selectors, "the select clause must either contains only aggregates or none"); } - protected List<ByteBuffer> handleRow(ResultSetBuilder rs) throws InvalidRequestException + protected void addInputRow(ResultSetBuilder rs) throws InvalidRequestException + { + for (Selector selector : selectors) + { + selector.addInput(rs); + } + } + + protected List<ByteBuffer> getOutputRow() throws InvalidRequestException { List<ByteBuffer> result = new ArrayList<ByteBuffer>(); for (Selector selector : selectors) { - result.add(selector.compute(rs)); + result.add(selector.getOutput()); } return result; } + + protected void reset() + { + for (Selector selector : selectors) + { + selector.reset(); + } + } + + public boolean isAggregate() + { + return selectors.get(0).isAggregate(); + } } } http://git-wip-us.apache.org/repos/asf/cassandra/blob/0cad81ae/test/unit/org/apache/cassandra/cql3/AggregationTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/cql3/AggregationTest.java b/test/unit/org/apache/cassandra/cql3/AggregationTest.java new file mode 100644 index 0000000..87b7ca7 --- /dev/null +++ b/test/unit/org/apache/cassandra/cql3/AggregationTest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.cql3; + +import java.math.BigDecimal; +import java.text.SimpleDateFormat; +import java.util.Calendar; +import java.util.Date; +import java.util.TimeZone; + +import org.apache.commons.lang3.time.DateUtils; +import org.junit.Test; + +public class AggregationTest extends CQLTester +{ + @Test + public void testFunctions() throws Throwable + { + createTable("CREATE TABLE %s (a int, b int, c double, d decimal, primary key (a, b))"); + execute("INSERT INTO %s (a, b, c, d) VALUES (1, 1, 11.5, 11.5)"); + execute("INSERT INTO %s (a, b, c, d) VALUES (1, 2, 9.5, 1.5)"); + execute("INSERT INTO %s (a, b, c, d) VALUES (1, 3, 9.0, 2.0)"); + + assertRows(execute("SELECT max(b), min(b), sum(b), avg(b) , max(c), sum(c), avg(c), sum(d), avg(d) FROM %s"), + row(3, 1, 6, 2, 11.5, 30.0, 10.0, new BigDecimal("15.0"), new BigDecimal("5.0"))); + + execute("INSERT INTO %s (a, b, d) VALUES (1, 5, 1.0)"); + assertRows(execute("SELECT COUNT(*) FROM %s"), row(4L)); + assertRows(execute("SELECT COUNT(1) FROM %s"), row(4L)); + assertRows(execute("SELECT COUNT(b), count(c) FROM %s"), row(4L, 3L)); + } + + @Test + public void testInvalidCalls() throws Throwable + { + createTable("CREATE TABLE %s (a int, b int, c int, primary key (a, b))"); + execute("INSERT INTO %s (a, b, c) VALUES (1, 1, 10)"); + execute("INSERT INTO %s (a, b, c) VALUES (1, 2, 9)"); + execute("INSERT INTO %s (a, b, c) VALUES (1, 3, 8)"); + + assertInvalidSyntax("SELECT max(b), max(c) FROM %s WHERE max(a) = 1"); + assertInvalid("SELECT max(b), c FROM %s"); + assertInvalid("SELECT b, max(c) FROM %s"); + assertInvalid("SELECT max(sum(c)) FROM %s"); + assertInvalidSyntax("SELECT COUNT(2) FROM %s"); + } + + @Test + public void testNestedFunctions() throws Throwable + { + createTable("CREATE TABLE %s (a int primary key, b timeuuid, c double, d double)"); + + execute("INSERT INTO %s (a, b, c, d) VALUES (1, maxTimeuuid('2011-02-03 04:05:00+0000'), -1.2, 2.1)"); + execute("INSERT INTO %s (a, b, c, d) VALUES (2, maxTimeuuid('2011-02-03 04:06:00+0000'), 1.3, -3.4)"); + execute("INSERT INTO %s (a, b, c, d) VALUES (3, maxTimeuuid('2011-02-03 04:10:00+0000'), 1.4, 1.2)"); + + SimpleDateFormat format = new SimpleDateFormat("yyyy-MM-dd hh:mm:ss"); + format.setTimeZone(TimeZone.getTimeZone("GMT")); + Date date = format.parse("2011-02-03 04:10:00"); + date = DateUtils.truncate(date, Calendar.MILLISECOND); + + assertRows(execute("SELECT max(a), max(unixTimestampOf(b)) FROM %s"), row(3, date.getTime())); + assertRows(execute("SELECT max(a), unixTimestampOf(max(b)) FROM %s"), row(3, date.getTime())); + execute("CREATE OR REPLACE FUNCTION copySign(magnitude double, sign double) RETURNS double LANGUAGE JAVA\n" + + "AS 'return Double.valueOf(Math.copySign(magnitude.doubleValue(), sign.doubleValue()));';"); + + assertRows(execute("SELECT copySign(max(c), min(c)) FROM %s"), row(-1.4)); + assertRows(execute("SELECT copySign(c, d) FROM %s"), row(1.2), row(-1.3), row(1.4)); + assertRows(execute("SELECT max(copySign(c, d)) FROM %s"), row(1.4)); + assertInvalid("SELECT copySign(c, max(c)) FROM %s"); + assertInvalid("SELECT copySign(max(c), c) FROM %s"); + } +}
