Support for user-defined aggregate functions Patch by Robert Stupp; reviewed by Tyler Hobbs for CASSANDRA-8053
Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/e2f35c76 Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/e2f35c76 Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/e2f35c76 Branch: refs/heads/trunk Commit: e2f35c767e479da9761628578299b54872d7eea9 Parents: 857de55 Author: Robert Stupp <[email protected]> Authored: Thu Dec 11 11:46:28 2014 -0600 Committer: Tyler Hobbs <[email protected]> Committed: Thu Dec 11 11:46:28 2014 -0600 ---------------------------------------------------------------------- CHANGES.txt | 1 + pylib/cqlshlib/cql3handling.py | 28 +- src/java/org/apache/cassandra/auth/Auth.java | 12 + .../org/apache/cassandra/config/KSMetaData.java | 1 + src/java/org/apache/cassandra/cql3/Cql.g | 61 ++ .../apache/cassandra/cql3/QueryProcessor.java | 15 + .../cql3/functions/AbstractFunction.java | 10 + .../cassandra/cql3/functions/AggregateFcts.java | 64 +- .../cql3/functions/AggregateFunction.java | 10 +- .../cassandra/cql3/functions/Function.java | 4 + .../cassandra/cql3/functions/FunctionCall.java | 2 +- .../cassandra/cql3/functions/Functions.java | 24 +- .../cql3/functions/JavaSourceUDFFactory.java | 6 +- .../cassandra/cql3/functions/UDAggregate.java | 280 ++++++++ .../cassandra/cql3/functions/UDFunction.java | 193 ++---- .../cassandra/cql3/functions/UDHelper.java | 123 ++++ .../selection/AbstractFunctionSelector.java | 4 +- .../selection/AggregateFunctionSelector.java | 6 +- .../cassandra/cql3/selection/FieldSelector.java | 2 +- .../cassandra/cql3/selection/Selection.java | 8 +- .../cassandra/cql3/selection/Selector.java | 2 +- .../cql3/selection/SelectorFactories.java | 2 +- .../statements/CreateAggregateStatement.java | 194 ++++++ .../statements/CreateFunctionStatement.java | 11 +- .../cql3/statements/DropAggregateStatement.java | 136 ++++ .../cql3/statements/DropFunctionStatement.java | 17 +- .../org/apache/cassandra/db/DefsTables.java | 89 ++- .../org/apache/cassandra/db/SystemKeyspace.java | 21 +- .../cassandra/service/IMigrationListener.java | 3 + .../cassandra/service/MigrationManager.java | 45 +- .../org/apache/cassandra/transport/Server.java | 12 + .../apache/cassandra/cql3/AggregationTest.java | 640 ++++++++++++++++++- .../org/apache/cassandra/cql3/CQLTester.java | 14 + test/unit/org/apache/cassandra/cql3/UFTest.java | 8 - 34 files changed, 1795 insertions(+), 253 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/CHANGES.txt ---------------------------------------------------------------------- diff --git a/CHANGES.txt b/CHANGES.txt index 34e740e..6ff61e7 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,4 +1,5 @@ 3.0 + * Support for user-defined aggregation functions (CASSANDRA-8053) * Fix NPE in SelectStatement with empty IN values (CASSANDRA-8419) * Refactor SelectStatement, return IN results in natural order instead of IN value list order (CASSANDRA-7981) http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/pylib/cqlshlib/cql3handling.py ---------------------------------------------------------------------- diff --git a/pylib/cqlshlib/cql3handling.py b/pylib/cqlshlib/cql3handling.py index f8a3069..84af796 100644 --- a/pylib/cqlshlib/cql3handling.py +++ b/pylib/cqlshlib/cql3handling.py @@ -41,7 +41,7 @@ class Cql3ParsingRuleSet(CqlParsingRuleSet): 'select', 'from', 'where', 'and', 'key', 'insert', 'update', 'with', 'limit', 'using', 'use', 'set', 'begin', 'apply', 'batch', 'truncate', 'delete', 'in', 'create', - 'function', 'keyspace', 'schema', 'columnfamily', 'table', 'index', 'on', 'drop', + 'function', 'aggregate', 'keyspace', 'schema', 'columnfamily', 'table', 'index', 'on', 'drop', 'primary', 'into', 'values', 'timestamp', 'ttl', 'alter', 'add', 'type', 'compact', 'storage', 'order', 'by', 'asc', 'desc', 'clustering', 'token', 'writetime', 'map', 'list', 'to', 'custom', 'if', 'not' @@ -209,7 +209,10 @@ JUNK ::= /([ \t\r\f\v]+|(--|[/][/])[^\n\r]*([\n\r]|$)|[/][*].*?[*][/])/ ; <mapLiteral> ::= "{" <term> ":" <term> ( "," <term> ":" <term> )* "}" ; -<functionName> ::= <identifier> ( "." <identifier> )? +<userFunctionName> ::= <identifier> ( "." <identifier> )? + ; + +<functionName> ::= <userFunctionName> | "TOKEN" ; @@ -233,12 +236,14 @@ JUNK ::= /([ \t\r\f\v]+|(--|[/][/])[^\n\r]*([\n\r]|$)|[/][*].*?[*][/])/ ; | <createIndexStatement> | <createUserTypeStatement> | <createFunctionStatement> + | <createAggregateStatement> | <createTriggerStatement> | <dropKeyspaceStatement> | <dropColumnFamilyStatement> | <dropIndexStatement> | <dropUserTypeStatement> | <dropFunctionStatement> + | <dropAggregateStatement> | <dropTriggerStatement> | <alterTableStatement> | <alterKeyspaceStatement> @@ -1010,7 +1015,7 @@ syntax_rules += r''' <createFunctionStatement> ::= "CREATE" ("OR" "REPLACE")? "FUNCTION" ("IF" "NOT" "EXISTS")? ("NON"? "DETERMINISTIC")? - <functionName> + <userFunctionName> ( "(" ( newcol=<cident> <storageType> ( "," [newcolname]=<cident> <storageType> )* )? ")" )? @@ -1018,6 +1023,18 @@ syntax_rules += r''' "LANGUAGE" <cident> "AS" <stringLiteral> ; +<createAggregateStatement> ::= "CREATE" ("OR" "REPLACE")? "AGGREGATE" + ("IF" "NOT" "EXISTS")? + <userFunctionName> + ( "(" + ( <storageType> ( "," <storageType> )* )? + ")" )? + "SFUNC" <identifier> + "STYPE" <storageType> + ( "FINALFUNC" <identifier> )? + ( "INITCOND" <term> )? + ; + ''' explain_completion('createIndexStatement', 'indexname', '<new_index_name>') @@ -1049,7 +1066,10 @@ syntax_rules += r''' <dropUserTypeStatement> ::= "DROP" "TYPE" ut=<userTypeName> ; -<dropFunctionStatement> ::= "DROP" "FUNCTION" ( "IF" "EXISTS" )? <functionName> +<dropFunctionStatement> ::= "DROP" "FUNCTION" ( "IF" "EXISTS" )? <userFunctionName> + ; + +<dropAggregateStatement> ::= "DROP" "AGGREGATE" ( "IF" "EXISTS" )? <userFunctionName> ; ''' http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/auth/Auth.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/auth/Auth.java b/src/java/org/apache/cassandra/auth/Auth.java index 041ce2b..cdcfa0e 100644 --- a/src/java/org/apache/cassandra/auth/Auth.java +++ b/src/java/org/apache/cassandra/auth/Auth.java @@ -340,6 +340,10 @@ public class Auth implements AuthMBean { } + public void onDropAggregate(String ksName, String aggregateName) + { + } + public void onCreateKeyspace(String ksName) { } @@ -356,6 +360,10 @@ public class Auth implements AuthMBean { } + public void onCreateAggregate(String ksName, String aggregateName) + { + } + public void onUpdateKeyspace(String ksName) { } @@ -371,5 +379,9 @@ public class Auth implements AuthMBean public void onUpdateFunction(String ksName, String functionName) { } + + public void onUpdateAggregate(String ksName, String aggregateName) + { + } } } http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/config/KSMetaData.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/config/KSMetaData.java b/src/java/org/apache/cassandra/config/KSMetaData.java index 494f98b..e5576ad 100644 --- a/src/java/org/apache/cassandra/config/KSMetaData.java +++ b/src/java/org/apache/cassandra/config/KSMetaData.java @@ -186,6 +186,7 @@ public final class KSMetaData mutation.delete(SystemKeyspace.SCHEMA_TRIGGERS_TABLE, timestamp); mutation.delete(SystemKeyspace.SCHEMA_USER_TYPES_TABLE, timestamp); mutation.delete(SystemKeyspace.SCHEMA_FUNCTIONS_TABLE, timestamp); + mutation.delete(SystemKeyspace.SCHEMA_AGGREGATES_TABLE, timestamp); mutation.delete(SystemKeyspace.BUILT_INDEXES_TABLE, timestamp); return mutation; http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/Cql.g ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/Cql.g b/src/java/org/apache/cassandra/cql3/Cql.g index 1997544..ed133e7 100644 --- a/src/java/org/apache/cassandra/cql3/Cql.g +++ b/src/java/org/apache/cassandra/cql3/Cql.g @@ -245,6 +245,8 @@ cqlStatement returns [ParsedStatement stmt] | st27=dropTypeStatement { $stmt = st27; } | st28=createFunctionStatement { $stmt = st28; } | st29=dropFunctionStatement { $stmt = st29; } + | st30=createAggregateStatement { $stmt = st30; } + | st31=dropAggregateStatement { $stmt = st31; } ; /* @@ -488,6 +490,55 @@ batchStatementObjective returns [ModificationStatement.Parsed statement] | d=deleteStatement { $statement = d; } ; +createAggregateStatement returns [CreateAggregateStatement expr] + @init { + boolean orReplace = false; + boolean ifNotExists = false; + + List<CQL3Type.Raw> argsTypes = new ArrayList<>(); + } + : K_CREATE (K_OR K_REPLACE { orReplace = true; })? + K_AGGREGATE + (K_IF K_NOT K_EXISTS { ifNotExists = true; })? + fn=functionName + '(' + ( + v=comparatorType { argsTypes.add(v); } + ( ',' v=comparatorType { argsTypes.add(v); } )* + )? + ')' + K_SFUNC sfunc = allowedFunctionName + K_STYPE stype = comparatorType + ( + K_FINALFUNC ffunc = allowedFunctionName + )? + ( + K_INITCOND ival = term + )? + { $expr = new CreateAggregateStatement(fn, argsTypes, sfunc, stype, ffunc, ival, orReplace, ifNotExists); } + ; + +dropAggregateStatement returns [DropAggregateStatement expr] + @init { + boolean ifExists = false; + List<CQL3Type.Raw> argsTypes = new ArrayList<>(); + boolean argsPresent = false; + } + : K_DROP K_AGGREGATE + (K_IF K_EXISTS { ifExists = true; } )? + fn=functionName + ( + '(' + ( + v=comparatorType { argsTypes.add(v); } + ( ',' v=comparatorType { argsTypes.add(v); } )* + )? + ')' + { argsPresent = true; } + )? + { $expr = new DropAggregateStatement(fn, argsTypes, argsPresent, ifExists); } + ; + createFunctionStatement returns [CreateFunctionStatement expr] @init { boolean orReplace = false; @@ -1271,6 +1322,11 @@ basic_unreserved_keyword returns [String str] | K_CONTAINS | K_STATIC | K_FUNCTION + | K_AGGREGATE + | K_SFUNC + | K_STYPE + | K_FINALFUNC + | K_INITCOND | K_RETURNS | K_LANGUAGE | K_NON @@ -1384,6 +1440,11 @@ K_STATIC: S T A T I C; K_FROZEN: F R O Z E N; K_FUNCTION: F U N C T I O N; +K_AGGREGATE: A G G R E G A T E; +K_SFUNC: S F U N C; +K_STYPE: S T Y P E; +K_FINALFUNC: F I N A L F U N C; +K_INITCOND: I N I T C O N D; K_RETURNS: R E T U R N S; K_LANGUAGE: L A N G U A G E; K_NON: N O N; http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/QueryProcessor.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/QueryProcessor.java b/src/java/org/apache/cassandra/cql3/QueryProcessor.java index 82b354e..8bd5daa 100644 --- a/src/java/org/apache/cassandra/cql3/QueryProcessor.java +++ b/src/java/org/apache/cassandra/cql3/QueryProcessor.java @@ -613,11 +613,21 @@ public class QueryProcessor implements QueryHandler removeInvalidPreparedStatementsForFunction(thriftPreparedStatements.values().iterator(), ksName, functionName); } } + public void onCreateAggregate(String ksName, String aggregateName) { + if (Functions.getOverloadCount(new FunctionName(ksName, aggregateName)) > 1) + { + // in case there are other overloads, we have to remove all overloads since argument type + // matching may change (due to type casting) + removeInvalidPreparedStatementsForFunction(preparedStatements.values().iterator(), ksName, aggregateName); + removeInvalidPreparedStatementsForFunction(thriftPreparedStatements.values().iterator(), ksName, aggregateName); + } + } public void onUpdateKeyspace(String ksName) { } public void onUpdateColumnFamily(String ksName, String cfName) { } public void onUpdateUserType(String ksName, String typeName) { } public void onUpdateFunction(String ksName, String functionName) { } + public void onUpdateAggregate(String ksName, String aggregateName) { } public void onDropKeyspace(String ksName) { @@ -634,6 +644,11 @@ public class QueryProcessor implements QueryHandler removeInvalidPreparedStatementsForFunction(preparedStatements.values().iterator(), ksName, functionName); removeInvalidPreparedStatementsForFunction(thriftPreparedStatements.values().iterator(), ksName, functionName); } + public void onDropAggregate(String ksName, String aggregateName) + { + removeInvalidPreparedStatementsForFunction(preparedStatements.values().iterator(), ksName, aggregateName); + removeInvalidPreparedStatementsForFunction(thriftPreparedStatements.values().iterator(), ksName, aggregateName); + } private void removeInvalidPreparedStatementsForFunction(Iterator<ParsedStatement.Prepared> iterator, String ksName, String functionName) http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/functions/AbstractFunction.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/AbstractFunction.java b/src/java/org/apache/cassandra/cql3/functions/AbstractFunction.java index d5a40a0..e2d69b8 100644 --- a/src/java/org/apache/cassandra/cql3/functions/AbstractFunction.java +++ b/src/java/org/apache/cassandra/cql3/functions/AbstractFunction.java @@ -66,6 +66,16 @@ public abstract class AbstractFunction implements Function && Objects.equal(this.returnType, that.returnType); } + public boolean usesFunction(String ksName, String functionName) + { + return name.keyspace.equals(ksName) && name.name.equals(functionName); + } + + public boolean hasReferenceTo(Function function) + { + return false; + } + @Override public int hashCode() { http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java index f72ed44..865dfbf 100644 --- a/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java +++ b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java @@ -53,12 +53,12 @@ public abstract class AggregateFcts count = 0; } - public ByteBuffer compute() + public ByteBuffer compute(int protocolVersion) { return ((LongType) returnType()).decompose(Long.valueOf(count)); } - public void addInput(List<ByteBuffer> values) + public void addInput(int protocolVersion, List<ByteBuffer> values) { count++; } @@ -84,12 +84,12 @@ public abstract class AggregateFcts sum = BigDecimal.ZERO; } - public ByteBuffer compute() + public ByteBuffer compute(int protocolVersion) { return ((DecimalType) returnType()).decompose(sum); } - public void addInput(List<ByteBuffer> values) + public void addInput(int protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); @@ -123,7 +123,7 @@ public abstract class AggregateFcts sum = BigDecimal.ZERO; } - public ByteBuffer compute() + public ByteBuffer compute(int protocolVersion) { if (count == 0) return ((DecimalType) returnType()).decompose(BigDecimal.ZERO); @@ -131,7 +131,7 @@ public abstract class AggregateFcts return ((DecimalType) returnType()).decompose(sum.divide(BigDecimal.valueOf(count))); } - public void addInput(List<ByteBuffer> values) + public void addInput(int protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); @@ -163,12 +163,12 @@ public abstract class AggregateFcts sum = BigInteger.ZERO; } - public ByteBuffer compute() + public ByteBuffer compute(int protocolVersion) { return ((IntegerType) returnType()).decompose(sum); } - public void addInput(List<ByteBuffer> values) + public void addInput(int protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); @@ -202,7 +202,7 @@ public abstract class AggregateFcts sum = BigInteger.ZERO; } - public ByteBuffer compute() + public ByteBuffer compute(int protocolVersion) { if (count == 0) return ((IntegerType) returnType()).decompose(BigInteger.ZERO); @@ -210,7 +210,7 @@ public abstract class AggregateFcts return ((IntegerType) returnType()).decompose(sum.divide(BigInteger.valueOf(count))); } - public void addInput(List<ByteBuffer> values) + public void addInput(int protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); @@ -242,12 +242,12 @@ public abstract class AggregateFcts sum = 0; } - public ByteBuffer compute() + public ByteBuffer compute(int protocolVersion) { return ((Int32Type) returnType()).decompose(sum); } - public void addInput(List<ByteBuffer> values) + public void addInput(int protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); @@ -281,14 +281,14 @@ public abstract class AggregateFcts sum = 0; } - public ByteBuffer compute() + public ByteBuffer compute(int protocolVersion) { int avg = count == 0 ? 0 : sum / count; return ((Int32Type) returnType()).decompose(avg); } - public void addInput(List<ByteBuffer> values) + public void addInput(int protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); @@ -320,12 +320,12 @@ public abstract class AggregateFcts sum = 0; } - public ByteBuffer compute() + public ByteBuffer compute(int protocolVersion) { return ((LongType) returnType()).decompose(sum); } - public void addInput(List<ByteBuffer> values) + public void addInput(int protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); @@ -359,14 +359,14 @@ public abstract class AggregateFcts sum = 0; } - public ByteBuffer compute() + public ByteBuffer compute(int protocolVersion) { long avg = count == 0 ? 0 : sum / count; return ((LongType) returnType()).decompose(avg); } - public void addInput(List<ByteBuffer> values) + public void addInput(int protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); @@ -398,12 +398,12 @@ public abstract class AggregateFcts sum = 0; } - public ByteBuffer compute() + public ByteBuffer compute(int protocolVersion) { return ((FloatType) returnType()).decompose(sum); } - public void addInput(List<ByteBuffer> values) + public void addInput(int protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); @@ -437,14 +437,14 @@ public abstract class AggregateFcts sum = 0; } - public ByteBuffer compute() + public ByteBuffer compute(int protocolVersion) { float avg = count == 0 ? 0 : sum / count; return ((FloatType) returnType()).decompose(avg); } - public void addInput(List<ByteBuffer> values) + public void addInput(int protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); @@ -476,12 +476,12 @@ public abstract class AggregateFcts sum = 0; } - public ByteBuffer compute() + public ByteBuffer compute(int protocolVersion) { return ((DoubleType) returnType()).decompose(sum); } - public void addInput(List<ByteBuffer> values) + public void addInput(int protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); @@ -515,14 +515,14 @@ public abstract class AggregateFcts sum = 0; } - public ByteBuffer compute() + public ByteBuffer compute(int protocolVersion) { double avg = count == 0 ? 0 : sum / count; return ((DoubleType) returnType()).decompose(avg); } - public void addInput(List<ByteBuffer> values) + public void addInput(int protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); @@ -558,12 +558,12 @@ public abstract class AggregateFcts max = null; } - public ByteBuffer compute() + public ByteBuffer compute(int protocolVersion) { return max; } - public void addInput(List<ByteBuffer> values) + public void addInput(int protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); @@ -599,12 +599,12 @@ public abstract class AggregateFcts min = null; } - public ByteBuffer compute() + public ByteBuffer compute(int protocolVersion) { return min; } - public void addInput(List<ByteBuffer> values) + public void addInput(int protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); @@ -640,12 +640,12 @@ public abstract class AggregateFcts count = 0; } - public ByteBuffer compute() + public ByteBuffer compute(int protocolVersion) { return ((LongType) returnType()).decompose(count); } - public void addInput(List<ByteBuffer> values) + public void addInput(int protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/functions/AggregateFunction.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/AggregateFunction.java b/src/java/org/apache/cassandra/cql3/functions/AggregateFunction.java index 47eee4b..ddbc9d1 100644 --- a/src/java/org/apache/cassandra/cql3/functions/AggregateFunction.java +++ b/src/java/org/apache/cassandra/cql3/functions/AggregateFunction.java @@ -20,6 +20,8 @@ package org.apache.cassandra.cql3.functions; import java.nio.ByteBuffer; import java.util.List; +import org.apache.cassandra.exceptions.InvalidRequestException; + /** * Performs a calculation on a set of values and return a single value. */ @@ -30,7 +32,7 @@ public interface AggregateFunction extends Function * * @return a new <code>Aggregate</code> instance. */ - public Aggregate newAggregate(); + public Aggregate newAggregate() throws InvalidRequestException; /** * An aggregation operation. @@ -40,16 +42,18 @@ public interface AggregateFunction extends Function /** * Adds the specified input to this aggregate. * + * @param protocolVersion native protocol version * @param values the values to add to the aggregate. */ - public void addInput(List<ByteBuffer> values); + public void addInput(int protocolVersion, List<ByteBuffer> values) throws InvalidRequestException; /** * Computes and returns the aggregate current value. * + * @param protocolVersion native protocol version * @return the aggregate current value. */ - public ByteBuffer compute(); + public ByteBuffer compute(int protocolVersion) throws InvalidRequestException; /** * Reset this aggregate. http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/functions/Function.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/Function.java b/src/java/org/apache/cassandra/cql3/functions/Function.java index 9e41fe4..4d2b993 100644 --- a/src/java/org/apache/cassandra/cql3/functions/Function.java +++ b/src/java/org/apache/cassandra/cql3/functions/Function.java @@ -51,4 +51,8 @@ public interface Function * @return <code>true</code> if the function is an aggregate function, <code>false</code> otherwise. */ public boolean isAggregate(); + + boolean usesFunction(String ksName, String functionName); + + boolean hasReferenceTo(Function function); } http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/functions/FunctionCall.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/FunctionCall.java b/src/java/org/apache/cassandra/cql3/functions/FunctionCall.java index 01443d2..72ac63e 100644 --- a/src/java/org/apache/cassandra/cql3/functions/FunctionCall.java +++ b/src/java/org/apache/cassandra/cql3/functions/FunctionCall.java @@ -44,7 +44,7 @@ public class FunctionCall extends Term.NonTerminal public boolean usesFunction(String ksName, String functionName) { - return fun.name().keyspace.equals(ksName) && fun.name().name.equals(functionName); + return fun.usesFunction(ksName, functionName); } public void collectMarkerSpecification(VariableSpecifications boundNames) http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/functions/Functions.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/Functions.java b/src/java/org/apache/cassandra/cql3/functions/Functions.java index a8fdf0f..7d94e47 100644 --- a/src/java/org/apache/cassandra/cql3/functions/Functions.java +++ b/src/java/org/apache/cassandra/cql3/functions/Functions.java @@ -42,7 +42,8 @@ public abstract class Functions // to handle it as a special case. private static final FunctionName TOKEN_FUNCTION_NAME = FunctionName.nativeFunction("token"); - private static final String SELECT_UDFS = "SELECT * FROM " + SystemKeyspace.NAME + '.' + SystemKeyspace.SCHEMA_FUNCTIONS_TABLE; + private static final String SELECT_UD_FUNCTION = "SELECT * FROM " + SystemKeyspace.NAME + '.' + SystemKeyspace.SCHEMA_FUNCTIONS_TABLE; + private static final String SELECT_UD_AGGREGATE = "SELECT * FROM " + SystemKeyspace.NAME + '.' + SystemKeyspace.SCHEMA_AGGREGATES_TABLE; private Functions() {} @@ -101,8 +102,10 @@ public abstract class Functions public static void loadUDFFromSchema() { logger.debug("Loading UDFs"); - for (UntypedResultSet.Row row : QueryProcessor.executeOnceInternal(SELECT_UDFS)) + for (UntypedResultSet.Row row : QueryProcessor.executeOnceInternal(SELECT_UD_FUNCTION)) addFunction(UDFunction.fromSchema(row)); + for (UntypedResultSet.Row row : QueryProcessor.executeOnceInternal(SELECT_UD_AGGREGATE)) + addFunction(UDAggregate.fromSchema(row)); } public static ColumnSpecification makeArgSpec(String receiverKs, String receiverCf, Function fun, int i) @@ -268,7 +271,7 @@ public abstract class Functions } // This is *not* thread safe but is only called in DefsTables that is synchronized. - public static void addFunction(UDFunction fun) + public static void addFunction(AbstractFunction fun) { // We shouldn't get there unless that function don't exist assert find(fun.name(), fun.argTypes()) == null; @@ -284,12 +287,21 @@ public abstract class Functions } // Same remarks than for addFunction - public static void replaceFunction(UDFunction fun) + public static void replaceFunction(AbstractFunction fun) { removeFunction(fun.name(), fun.argTypes()); addFunction(fun); } + public static List<Function> getReferencesTo(Function old) + { + List<Function> references = new ArrayList<>(); + for (Function function : declared.values()) + if (function.hasReferenceTo(old)) + references.add(function); + return references; + } + public static Collection<Function> all() { return declared.values(); @@ -316,6 +328,7 @@ public abstract class Functions public void onCreateColumnFamily(String ksName, String cfName) { } public void onCreateUserType(String ksName, String typeName) { } public void onCreateFunction(String ksName, String functionName) { } + public void onCreateAggregate(String ksName, String aggregateName) { } public void onUpdateKeyspace(String ksName) { } public void onUpdateColumnFamily(String ksName, String cfName) { } @@ -325,11 +338,12 @@ public abstract class Functions ((UDFunction)function).userTypeUpdated(ksName, typeName); } public void onUpdateFunction(String ksName, String functionName) { } + public void onUpdateAggregate(String ksName, String aggregateName) { } public void onDropKeyspace(String ksName) { } public void onDropColumnFamily(String ksName, String cfName) { } public void onDropUserType(String ksName, String typeName) { } public void onDropFunction(String ksName, String functionName) { } - + public void onDropAggregate(String ksName, String aggregateName) { } } } http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/functions/JavaSourceUDFFactory.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/JavaSourceUDFFactory.java b/src/java/org/apache/cassandra/cql3/functions/JavaSourceUDFFactory.java index 560f077..5b1f5bd 100644 --- a/src/java/org/apache/cassandra/cql3/functions/JavaSourceUDFFactory.java +++ b/src/java/org/apache/cassandra/cql3/functions/JavaSourceUDFFactory.java @@ -57,11 +57,11 @@ public final class JavaSourceUDFFactory throws InvalidRequestException { // argDataTypes is just the C* internal argTypes converted to the Java Driver DataType - DataType[] argDataTypes = UDFunction.driverTypes(argTypes); + DataType[] argDataTypes = UDHelper.driverTypes(argTypes); // returnDataType is just the C* internal returnType converted to the Java Driver DataType - DataType returnDataType = UDFunction.driverType(returnType); + DataType returnDataType = UDHelper.driverType(returnType); // javaParamTypes is just the Java representation for argTypes resp. argDataTypes - Class<?>[] javaParamTypes = UDFunction.javaTypes(argDataTypes); + Class<?>[] javaParamTypes = UDHelper.javaTypes(argDataTypes); // javaReturnType is just the Java representation for returnType resp. returnDataType Class<?> javaReturnType = returnDataType.asJavaClass(); http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/functions/UDAggregate.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/UDAggregate.java b/src/java/org/apache/cassandra/cql3/functions/UDAggregate.java new file mode 100644 index 0000000..f259265 --- /dev/null +++ b/src/java/org/apache/cassandra/cql3/functions/UDAggregate.java @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.cql3.functions; + +import java.nio.ByteBuffer; +import java.util.*; + +import com.google.common.base.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.cassandra.cql3.*; +import org.apache.cassandra.db.*; +import org.apache.cassandra.db.composites.Composite; +import org.apache.cassandra.db.marshal.AbstractType; +import org.apache.cassandra.db.marshal.TypeParser; +import org.apache.cassandra.db.marshal.UTF8Type; +import org.apache.cassandra.exceptions.*; + +/** + * Base class for user-defined-aggregates. + */ +public class UDAggregate extends AbstractFunction implements AggregateFunction +{ + protected static final Logger logger = LoggerFactory.getLogger(UDAggregate.class); + + protected final AbstractType<?> stateType; + protected final ByteBuffer initcond; + private final ScalarFunction stateFunction; + private final ScalarFunction finalFunction; + + public UDAggregate(FunctionName name, + List<AbstractType<?>> argTypes, + AbstractType<?> returnType, + ScalarFunction stateFunc, + ScalarFunction finalFunc, + ByteBuffer initcond) + { + super(name, argTypes, returnType); + this.stateFunction = stateFunc; + this.finalFunction = finalFunc; + this.stateType = stateFunc != null ? stateFunc.returnType() : null; + this.initcond = initcond; + } + + public boolean hasReferenceTo(Function function) + { + return stateFunction == function || finalFunction == function; + } + + public boolean usesFunction(String ksName, String functionName) + { + return super.usesFunction(ksName, functionName) + || stateFunction != null && stateFunction.name().keyspace.equals(ksName) && stateFunction.name().name.equals(functionName) + || finalFunction != null && finalFunction.name().keyspace.equals(ksName) && finalFunction.name().name.equals(functionName); + } + + public boolean isAggregate() + { + return true; + } + + public boolean isPure() + { + return false; + } + + public boolean isNative() + { + return false; + } + + public Aggregate newAggregate() throws InvalidRequestException + { + return new Aggregate() + { + private ByteBuffer state; + { + reset(); + } + + public void addInput(int protocolVersion, List<ByteBuffer> values) throws InvalidRequestException + { + List<ByteBuffer> copy = new ArrayList<>(values.size() + 1); + copy.add(state); + copy.addAll(values); + state = stateFunction.execute(protocolVersion, copy); + } + + public ByteBuffer compute(int protocolVersion) throws InvalidRequestException + { + if (finalFunction == null) + return state; + return finalFunction.execute(protocolVersion, Collections.singletonList(state)); + } + + public void reset() + { + state = initcond != null ? initcond.duplicate() : null; + } + }; + } + + private static ScalarFunction resolveScalar(FunctionName aName, FunctionName fName, List<AbstractType<?>> argTypes) throws InvalidRequestException + { + Function func = Functions.find(fName, argTypes); + if (func == null) + throw new InvalidRequestException(String.format("Referenced state function '%s %s' for aggregate '%s' does not exist", + fName, Arrays.toString(UDHelper.driverTypes(argTypes)), aName)); + if (!(func instanceof ScalarFunction)) + throw new InvalidRequestException(String.format("Referenced state function '%s %s' for aggregate '%s' is not a scalar function", + fName, Arrays.toString(UDHelper.driverTypes(argTypes)), aName)); + return (ScalarFunction) func; + } + + private static Mutation makeSchemaMutation(FunctionName name) + { + UTF8Type kv = (UTF8Type)SystemKeyspace.SchemaAggregatesTable.getKeyValidator(); + return new Mutation(SystemKeyspace.NAME, kv.decompose(name.keyspace)); + } + + public Mutation toSchemaDrop(long timestamp) + { + Mutation mutation = makeSchemaMutation(name); + ColumnFamily cf = mutation.addOrGet(SystemKeyspace.SCHEMA_AGGREGATES_TABLE); + + Composite prefix = SystemKeyspace.SchemaAggregatesTable.comparator.make(name.name, UDHelper.computeSignature(argTypes)); + int ldt = (int) (System.currentTimeMillis() / 1000); + cf.addAtom(new RangeTombstone(prefix, prefix.end(), timestamp, ldt)); + + return mutation; + } + + public static Map<Composite, UDAggregate> fromSchema(Row row) + { + UntypedResultSet results = QueryProcessor.resultify("SELECT * FROM system." + SystemKeyspace.SCHEMA_AGGREGATES_TABLE, row); + Map<Composite, UDAggregate> udfs = new HashMap<>(results.size()); + for (UntypedResultSet.Row result : results) + udfs.put(SystemKeyspace.SchemaAggregatesTable.comparator.make(result.getString("aggregate_name"), result.getBlob("signature")), + fromSchema(result)); + return udfs; + } + + public Mutation toSchemaUpdate(long timestamp) + { + Mutation mutation = makeSchemaMutation(name); + ColumnFamily cf = mutation.addOrGet(SystemKeyspace.SCHEMA_AGGREGATES_TABLE); + + Composite prefix = SystemKeyspace.SchemaAggregatesTable.comparator.make(name.name, UDHelper.computeSignature(argTypes)); + CFRowAdder adder = new CFRowAdder(cf, prefix, timestamp); + + adder.resetCollection("argument_types"); + adder.add("return_type", returnType.toString()); + adder.add("state_func", stateFunction.name().name); + if (stateType != null) + adder.add("state_type", stateType.toString()); + if (finalFunction != null) + adder.add("final_func", finalFunction.name().name); + if (initcond != null) + adder.add("initcond", initcond); + + for (AbstractType<?> argType : argTypes) + adder.addListEntry("argument_types", argType.toString()); + + return mutation; + } + + public static UDAggregate fromSchema(UntypedResultSet.Row row) + { + String ksName = row.getString("keyspace_name"); + String functionName = row.getString("aggregate_name"); + FunctionName name = new FunctionName(ksName, functionName); + + List<String> types = row.getList("argument_types", UTF8Type.instance); + + List<AbstractType<?>> argTypes; + if (types == null) + { + argTypes = Collections.emptyList(); + } + else + { + argTypes = new ArrayList<>(types.size()); + for (String type : types) + argTypes.add(parseType(type)); + } + + AbstractType<?> returnType = parseType(row.getString("return_type")); + + FunctionName stateFunc = new FunctionName(ksName, row.getString("state_func")); + FunctionName finalFunc = row.has("final_func") ? new FunctionName(ksName, row.getString("final_func")) : null; + AbstractType<?> stateType = row.has("state_type") ? parseType(row.getString("state_type")) : null; + ByteBuffer initcond = row.has("initcond") ? row.getBytes("initcond") : null; + + try + { + return create(name, argTypes, returnType, stateFunc, finalFunc, stateType, initcond); + } + catch (InvalidRequestException reason) + { + return createBroken(name, argTypes, returnType, initcond, reason); + } + } + + private static UDAggregate createBroken(FunctionName name, List<AbstractType<?>> argTypes, AbstractType<?> returnType, + ByteBuffer initcond, final InvalidRequestException reason) + { + return new UDAggregate(name, argTypes, returnType, null, null, initcond) { + public Aggregate newAggregate() throws InvalidRequestException + { + throw new InvalidRequestException(String.format("Aggregate '%s' exists but hasn't been loaded successfully for the following reason: %s. " + + "Please see the server log for more details", this, reason.getMessage())); + } + }; + } + + private static UDAggregate create(FunctionName name, List<AbstractType<?>> argTypes, AbstractType<?> returnType, + FunctionName stateFunc, FunctionName finalFunc, AbstractType<?> stateType, ByteBuffer initcond) + throws InvalidRequestException + { + List<AbstractType<?>> stateTypes = new ArrayList<>(argTypes.size() + 1); + stateTypes.add(stateType); + stateTypes.addAll(argTypes); + List<AbstractType<?>> finalTypes = Collections.<AbstractType<?>>singletonList(stateType); + return new UDAggregate(name, argTypes, returnType, + resolveScalar(name, stateFunc, stateTypes), + finalFunc != null ? resolveScalar(name, finalFunc, finalTypes) : null, + initcond); + } + + private static AbstractType<?> parseType(String str) + { + // We only use this when reading the schema where we shouldn't get an error + try + { + return TypeParser.parse(str); + } + catch (SyntaxException | ConfigurationException e) + { + throw new RuntimeException(e); + } + } + + @Override + public boolean equals(Object o) + { + if (!(o instanceof UDAggregate)) + return false; + + UDAggregate that = (UDAggregate) o; + return Objects.equal(this.name, that.name) + && Functions.typeEquals(this.argTypes, that.argTypes) + && Functions.typeEquals(this.returnType, that.returnType) + && Objects.equal(this.stateFunction, that.stateFunction) + && Objects.equal(this.finalFunction, that.finalFunction) + && Objects.equal(this.stateType, that.stateType) + && Objects.equal(this.initcond, that.initcond); + } + + @Override + public int hashCode() + { + return Objects.hashCode(name, argTypes, returnType, stateFunction, finalFunction, stateType, initcond); + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/functions/UDFunction.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/UDFunction.java b/src/java/org/apache/cassandra/cql3/functions/UDFunction.java index 973c70a..8b42e51 100644 --- a/src/java/org/apache/cassandra/cql3/functions/UDFunction.java +++ b/src/java/org/apache/cassandra/cql3/functions/UDFunction.java @@ -17,12 +17,7 @@ */ package org.apache.cassandra.cql3.functions; -import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodHandles; -import java.lang.reflect.Method; import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.security.MessageDigest; import java.util.*; import com.google.common.base.Objects; @@ -43,7 +38,6 @@ import org.apache.cassandra.db.marshal.TypeParser; import org.apache.cassandra.exceptions.*; import org.apache.cassandra.service.MigrationManager; import org.apache.cassandra.utils.ByteBufferUtil; -import org.apache.cassandra.utils.FBUtilities; /** * Base class for User Defined Functions. @@ -52,80 +46,10 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct { protected static final Logger logger = LoggerFactory.getLogger(UDFunction.class); - // TODO make these c'tors and methods public in Java-Driver - see https://datastax-oss.atlassian.net/browse/JAVA-502 - static final MethodHandle methodParseOne; - static - { - try - { - Class<?> cls = Class.forName("com.datastax.driver.core.CassandraTypeParser"); - Method m = cls.getDeclaredMethod("parseOne", String.class); - m.setAccessible(true); - methodParseOne = MethodHandles.lookup().unreflect(m); - } - catch (Exception e) - { - throw new RuntimeException(e); - } - } - - /** - * Construct an array containing the Java classes for the given Java Driver {@link com.datastax.driver.core.DataType}s. - * - * @param dataTypes array with UDF argument types - * @return array of same size with UDF arguments - */ - public static Class<?>[] javaTypes(DataType[] dataTypes) - { - Class<?> paramTypes[] = new Class[dataTypes.length]; - for (int i = 0; i < paramTypes.length; i++) - paramTypes[i] = dataTypes[i].asJavaClass(); - return paramTypes; - } - - /** - * Construct an array containing the Java Driver {@link com.datastax.driver.core.DataType}s for the - * C* internal types. - * - * @param abstractTypes list with UDF argument types - * @return array with argument types as {@link com.datastax.driver.core.DataType} - */ - public static DataType[] driverTypes(List<AbstractType<?>> abstractTypes) - { - DataType[] argDataTypes = new DataType[abstractTypes.size()]; - for (int i = 0; i < argDataTypes.length; i++) - argDataTypes[i] = driverType(abstractTypes.get(i)); - return argDataTypes; - } - - /** - * Returns the Java Driver {@link com.datastax.driver.core.DataType} for the C* internal type. - */ - public static DataType driverType(AbstractType abstractType) - { - CQL3Type cqlType = abstractType.asCQL3Type(); - try - { - return (DataType) methodParseOne.invoke(cqlType.getType().toString()); - } - catch (RuntimeException | Error e) - { - // immediately rethrow these... - throw e; - } - catch (Throwable e) - { - throw new RuntimeException("cannot parse driver type " + cqlType.getType().toString(), e); - } - } - - // instance vars - protected final List<ColumnIdentifier> argNames; - protected final String language; protected final String body; - protected final boolean deterministic; + private final boolean deterministic; protected final DataType[] argDataTypes; protected final DataType returnDataType; @@ -138,8 +62,8 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct String body, boolean deterministic) { - this(name, argNames, argTypes, driverTypes(argTypes), returnType, - driverType(returnType), language, body, deterministic); + this(name, argNames, argTypes, UDHelper.driverTypes(argTypes), returnType, + UDHelper.driverType(returnType), language, body, deterministic); } protected UDFunction(FunctionName name, @@ -151,7 +75,7 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct String language, String body, boolean deterministic) - { + { super(name, argTypes, returnType); assert new HashSet<>(argNames).size() == argNames.size() : "duplicate argument names"; this.argNames = argNames; @@ -162,36 +86,6 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct this.returnDataType = returnDataType; } - /** - * Used by UDF implementations (both Java code generated by {@link org.apache.cassandra.cql3.functions.JavaSourceUDFFactory} - * and script executor {@link org.apache.cassandra.cql3.functions.ScriptBasedUDF}) to convert the C* - * serialized representation to the Java object representation. - * - * @param protocolVersion the native protocol version used for serialization - * @param argIndex index of the UDF input argument - */ - protected Object compose(int protocolVersion, int argIndex, ByteBuffer value) - { - return value == null ? null : argDataTypes[argIndex].deserialize(value, ProtocolVersion.fromInt(protocolVersion)); - } - - /** - * Used by UDF implementations (both Java code generated by {@link org.apache.cassandra.cql3.functions.JavaSourceUDFFactory} - * and script executor {@link org.apache.cassandra.cql3.functions.ScriptBasedUDF}) to convert the Java - * object representation for the return value to the C* serialized representation. - * - * @param protocolVersion the native protocol version used for serialization - */ - protected ByteBuffer decompose(int protocolVersion, Object value) - { - return value == null ? null : returnDataType.serialize(value, ProtocolVersion.fromInt(protocolVersion)); - } - - public boolean isAggregate() - { - return false; - } - public static UDFunction create(FunctionName name, List<ColumnIdentifier> argNames, List<AbstractType<?>> argTypes, @@ -218,12 +112,12 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct * than saying that the function doesn't exist) */ private static UDFunction createBrokenFunction(FunctionName name, - List<ColumnIdentifier> argNames, - List<AbstractType<?>> argTypes, - AbstractType<?> returnType, - String language, - String body, - final InvalidRequestException reason) + List<ColumnIdentifier> argNames, + List<AbstractType<?>> argTypes, + AbstractType<?> returnType, + String language, + String body, + final InvalidRequestException reason) { return new UDFunction(name, argNames, argTypes, returnType, language, body, true) { @@ -235,18 +129,9 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct }; } - // We allow method overloads, so a function is not uniquely identified by its name only, but - // also by its argument types. To distinguish overloads of given function name in the schema - // we use a "signature" which is just a SHA-1 of it's argument types (we could replace that by - // using a "signature" UDT that would be comprised of the function name and argument types, - // which we could then use as clustering column. But as we haven't yet used UDT in system tables, - // We'll left that decision to #6717). - private static ByteBuffer computeSignature(List<AbstractType<?>> argTypes) + public boolean isAggregate() { - MessageDigest digest = FBUtilities.newMessageDigest("SHA-1"); - for (AbstractType<?> type : argTypes) - digest.update(type.asCQL3Type().toString().getBytes(StandardCharsets.UTF_8)); - return ByteBuffer.wrap(digest.digest()); + return false; } public boolean isPure() @@ -259,6 +144,31 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct return false; } + /** + * Used by UDF implementations (both Java code generated by {@link org.apache.cassandra.cql3.functions.JavaSourceUDFFactory} + * and script executor {@link org.apache.cassandra.cql3.functions.ScriptBasedUDF}) to convert the C* + * serialized representation to the Java object representation. + * + * @param protocolVersion the native protocol version used for serialization + * @param argIndex index of the UDF input argument + */ + protected Object compose(int protocolVersion, int argIndex, ByteBuffer value) + { + return value == null ? null : argDataTypes[argIndex].deserialize(value, ProtocolVersion.fromInt(protocolVersion)); + } + + /** + * Used by UDF implementations (both Java code generated by {@link org.apache.cassandra.cql3.functions.JavaSourceUDFFactory} + * and script executor {@link org.apache.cassandra.cql3.functions.ScriptBasedUDF}) to convert the Java + * object representation for the return value to the C* serialized representation. + * + * @param protocolVersion the native protocol version used for serialization + */ + protected ByteBuffer decompose(int protocolVersion, Object value) + { + return value == null ? null : returnDataType.serialize(value, ProtocolVersion.fromInt(protocolVersion)); + } + private static Mutation makeSchemaMutation(FunctionName name) { UTF8Type kv = (UTF8Type)SystemKeyspace.SchemaFunctionsTable.getKeyValidator(); @@ -270,19 +180,29 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct Mutation mutation = makeSchemaMutation(name); ColumnFamily cf = mutation.addOrGet(SystemKeyspace.SCHEMA_FUNCTIONS_TABLE); - Composite prefix = SystemKeyspace.SchemaFunctionsTable.comparator.make(name.name, computeSignature(argTypes)); + Composite prefix = SystemKeyspace.SchemaFunctionsTable.comparator.make(name.name, UDHelper.computeSignature(argTypes)); int ldt = (int) (System.currentTimeMillis() / 1000); cf.addAtom(new RangeTombstone(prefix, prefix.end(), timestamp, ldt)); return mutation; } + public static Map<Composite, UDFunction> fromSchema(Row row) + { + UntypedResultSet results = QueryProcessor.resultify("SELECT * FROM system." + SystemKeyspace.SCHEMA_FUNCTIONS_TABLE, row); + Map<Composite, UDFunction> udfs = new HashMap<>(results.size()); + for (UntypedResultSet.Row result : results) + udfs.put(SystemKeyspace.SchemaFunctionsTable.comparator.make(result.getString("function_name"), result.getBlob("signature")), + fromSchema(result)); + return udfs; + } + public Mutation toSchemaUpdate(long timestamp) { Mutation mutation = makeSchemaMutation(name); ColumnFamily cf = mutation.addOrGet(SystemKeyspace.SCHEMA_FUNCTIONS_TABLE); - Composite prefix = SystemKeyspace.SchemaFunctionsTable.comparator.make(name.name, computeSignature(argTypes)); + Composite prefix = SystemKeyspace.SchemaFunctionsTable.comparator.make(name.name, UDHelper.computeSignature(argTypes)); CFRowAdder adder = new CFRowAdder(cf, prefix, timestamp); adder.resetCollection("argument_names"); @@ -360,15 +280,6 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct } } - public static Map<Composite, UDFunction> fromSchema(Row row) - { - UntypedResultSet results = QueryProcessor.resultify("SELECT * FROM system." + SystemKeyspace.SCHEMA_FUNCTIONS_TABLE, row); - Map<Composite, UDFunction> udfs = new HashMap<>(results.size()); - for (UntypedResultSet.Row result : results) - udfs.put(SystemKeyspace.SchemaFunctionsTable.comparator.make(result.getString("function_name"), result.getBlob("signature")), fromSchema(result)); - return udfs; - } - @Override public boolean equals(Object o) { @@ -377,9 +288,9 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct UDFunction that = (UDFunction)o; return Objects.equal(this.name, that.name) - && Objects.equal(this.argNames, that.argNames) && Functions.typeEquals(this.argTypes, that.argTypes) && Functions.typeEquals(this.returnType, that.returnType) + && Objects.equal(this.argNames, that.argNames) && Objects.equal(this.language, that.language) && Objects.equal(this.body, that.body) && Objects.equal(this.deterministic, that.deterministic); @@ -388,7 +299,7 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct @Override public int hashCode() { - return Objects.hashCode(name, argNames, argTypes, returnType, language, body, deterministic); + return Objects.hashCode(name, argTypes, returnType, argNames, language, body, deterministic); } public void userTypeUpdated(String ksName, String typeName) @@ -408,7 +319,7 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct org.apache.cassandra.db.marshal.UserType ut = ksm.userTypes.getType(ByteBufferUtil.bytes(typeName)); - DataType newUserType = driverType(ut); + DataType newUserType = UDHelper.driverType(ut); argDataTypes[i] = newUserType; argTypes.set(i, ut); http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/functions/UDHelper.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/UDHelper.java b/src/java/org/apache/cassandra/cql3/functions/UDHelper.java new file mode 100644 index 0000000..2a17c75 --- /dev/null +++ b/src/java/org/apache/cassandra/cql3/functions/UDHelper.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.cql3.functions; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.util.*; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.datastax.driver.core.DataType; +import org.apache.cassandra.cql3.*; +import org.apache.cassandra.db.marshal.AbstractType; +import org.apache.cassandra.utils.FBUtilities; + +/** + * Helper class for User Defined Functions + Aggregates. + */ +final class UDHelper +{ + protected static final Logger logger = LoggerFactory.getLogger(UDHelper.class); + + // TODO make these c'tors and methods public in Java-Driver - see https://datastax-oss.atlassian.net/browse/JAVA-502 + static final MethodHandle methodParseOne; + static + { + try + { + Class<?> cls = Class.forName("com.datastax.driver.core.CassandraTypeParser"); + Method m = cls.getDeclaredMethod("parseOne", String.class); + m.setAccessible(true); + methodParseOne = MethodHandles.lookup().unreflect(m); + } + catch (Exception e) + { + throw new RuntimeException(e); + } + } + + /** + * Construct an array containing the Java classes for the given Java Driver {@link com.datastax.driver.core.DataType}s. + * + * @param dataTypes array with UDF argument types + * @return array of same size with UDF arguments + */ + public static Class<?>[] javaTypes(DataType[] dataTypes) + { + Class<?> paramTypes[] = new Class[dataTypes.length]; + for (int i = 0; i < paramTypes.length; i++) + paramTypes[i] = dataTypes[i].asJavaClass(); + return paramTypes; + } + + /** + * Construct an array containing the Java Driver {@link com.datastax.driver.core.DataType}s for the + * C* internal types. + * + * @param abstractTypes list with UDF argument types + * @return array with argument types as {@link com.datastax.driver.core.DataType} + */ + public static DataType[] driverTypes(List<AbstractType<?>> abstractTypes) + { + DataType[] argDataTypes = new DataType[abstractTypes.size()]; + for (int i = 0; i < argDataTypes.length; i++) + argDataTypes[i] = driverType(abstractTypes.get(i)); + return argDataTypes; + } + + /** + * Returns the Java Driver {@link com.datastax.driver.core.DataType} for the C* internal type. + */ + public static DataType driverType(AbstractType abstractType) + { + CQL3Type cqlType = abstractType.asCQL3Type(); + try + { + return (DataType) methodParseOne.invoke(cqlType.getType().toString()); + } + catch (RuntimeException | Error e) + { + // immediately rethrow these... + throw e; + } + catch (Throwable e) + { + throw new RuntimeException("cannot parse driver type " + cqlType.getType().toString(), e); + } + } + + // We allow method overloads, so a function is not uniquely identified by its name only, but + // also by its argument types. To distinguish overloads of given function name in the schema + // we use a "signature" which is just a SHA-1 of it's argument types (we could replace that by + // using a "signature" UDT that would be comprised of the function name and argument types, + // which we could then use as clustering column. But as we haven't yet used UDT in system tables, + // We'll left that decision to #6717). + protected static ByteBuffer computeSignature(List<AbstractType<?>> argTypes) + { + MessageDigest digest = FBUtilities.newMessageDigest("SHA-1"); + for (AbstractType<?> type : argTypes) + digest.update(type.asCQL3Type().toString().getBytes(StandardCharsets.UTF_8)); + return ByteBuffer.wrap(digest.digest()); + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/selection/AbstractFunctionSelector.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/selection/AbstractFunctionSelector.java b/src/java/org/apache/cassandra/cql3/selection/AbstractFunctionSelector.java index 3778d41..2bf169d 100644 --- a/src/java/org/apache/cassandra/cql3/selection/AbstractFunctionSelector.java +++ b/src/java/org/apache/cassandra/cql3/selection/AbstractFunctionSelector.java @@ -69,10 +69,10 @@ abstract class AbstractFunctionSelector<T extends Function> extends Selector public boolean usesFunction(String ksName, String functionName) { - return fun.name().keyspace.equals(ksName) && fun.name().name.equals(functionName); + return fun.usesFunction(ksName, functionName); } - public Selector newInstance() + public Selector newInstance() throws InvalidRequestException { return fun.isAggregate() ? new AggregateFunctionSelector(fun, factories.newInstances()) : new ScalarFunctionSelector(fun, factories.newInstances()); http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/selection/AggregateFunctionSelector.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/selection/AggregateFunctionSelector.java b/src/java/org/apache/cassandra/cql3/selection/AggregateFunctionSelector.java index 7702796..27a8294 100644 --- a/src/java/org/apache/cassandra/cql3/selection/AggregateFunctionSelector.java +++ b/src/java/org/apache/cassandra/cql3/selection/AggregateFunctionSelector.java @@ -44,12 +44,12 @@ final class AggregateFunctionSelector extends AbstractFunctionSelector<Aggregate args.set(i, s.getOutput(protocolVersion)); s.reset(); } - this.aggregate.addInput(args); + this.aggregate.addInput(protocolVersion, args); } public ByteBuffer getOutput(int protocolVersion) throws InvalidRequestException { - return aggregate.compute(); + return aggregate.compute(protocolVersion); } public void reset() @@ -57,7 +57,7 @@ final class AggregateFunctionSelector extends AbstractFunctionSelector<Aggregate aggregate.reset(); } - AggregateFunctionSelector(Function fun, List<Selector> argSelectors) + AggregateFunctionSelector(Function fun, List<Selector> argSelectors) throws InvalidRequestException { super((AggregateFunction) fun, argSelectors); http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/selection/FieldSelector.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/selection/FieldSelector.java b/src/java/org/apache/cassandra/cql3/selection/FieldSelector.java index d695598..76dbb22 100644 --- a/src/java/org/apache/cassandra/cql3/selection/FieldSelector.java +++ b/src/java/org/apache/cassandra/cql3/selection/FieldSelector.java @@ -47,7 +47,7 @@ final class FieldSelector extends Selector return type.fieldType(field); } - public Selector newInstance() + public Selector newInstance() throws InvalidRequestException { return new FieldSelector(type, field, factory.newInstance()); } http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/selection/Selection.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/selection/Selection.java b/src/java/org/apache/cassandra/cql3/selection/Selection.java index e44a39f..58e994a 100644 --- a/src/java/org/apache/cassandra/cql3/selection/Selection.java +++ b/src/java/org/apache/cassandra/cql3/selection/Selection.java @@ -213,7 +213,7 @@ public abstract class Selection return metadata; } - protected abstract Selectors newSelectors(); + protected abstract Selectors newSelectors() throws InvalidRequestException; /** * @return the list of CQL3 columns value this SelectionClause needs. @@ -223,7 +223,7 @@ public abstract class Selection return columns; } - public ResultSetBuilder resultSetBuilder(long now) + public ResultSetBuilder resultSetBuilder(long now) throws InvalidRequestException { return new ResultSetBuilder(now); } @@ -273,7 +273,7 @@ public abstract class Selection final int[] ttls; final long now; - private ResultSetBuilder(long now) + private ResultSetBuilder(long now) throws InvalidRequestException { this.resultSet = new ResultSet(getResultMetadata().copy(), new ArrayList<List<ByteBuffer>>()); this.selectors = newSelectors(); @@ -468,7 +468,7 @@ public abstract class Selection return factories.containsOnlyAggregateFunctions(); } - protected Selectors newSelectors() + protected Selectors newSelectors() throws InvalidRequestException { return new Selectors() { http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/selection/Selector.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/selection/Selector.java b/src/java/org/apache/cassandra/cql3/selection/Selector.java index 0c1933f..3ed773b 100644 --- a/src/java/org/apache/cassandra/cql3/selection/Selector.java +++ b/src/java/org/apache/cassandra/cql3/selection/Selector.java @@ -65,7 +65,7 @@ public abstract class Selector implements AssignmentTestable * * @return a new <code>Selector</code> instance */ - public abstract Selector newInstance(); + public abstract Selector newInstance() throws InvalidRequestException; /** * Checks if this factory creates selectors instances that creates aggregates. http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/selection/SelectorFactories.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/selection/SelectorFactories.java b/src/java/org/apache/cassandra/cql3/selection/SelectorFactories.java index 9f6025c..3afd1ec 100644 --- a/src/java/org/apache/cassandra/cql3/selection/SelectorFactories.java +++ b/src/java/org/apache/cassandra/cql3/selection/SelectorFactories.java @@ -155,7 +155,7 @@ final class SelectorFactories implements Iterable<Selector.Factory> * Creates a list of new <code>Selector</code> instances. * @return a list of new <code>Selector</code> instances. */ - public List<Selector> newInstances() + public List<Selector> newInstances() throws InvalidRequestException { List<Selector> selectors = new ArrayList<>(factories.size()); for (Selector.Factory factory : factories) http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/statements/CreateAggregateStatement.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/statements/CreateAggregateStatement.java b/src/java/org/apache/cassandra/cql3/statements/CreateAggregateStatement.java new file mode 100644 index 0000000..9816e58 --- /dev/null +++ b/src/java/org/apache/cassandra/cql3/statements/CreateAggregateStatement.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.cql3.statements; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.apache.cassandra.auth.Permission; +import org.apache.cassandra.config.Schema; +import org.apache.cassandra.cql3.CQL3Type; +import org.apache.cassandra.cql3.ColumnIdentifier; +import org.apache.cassandra.cql3.ColumnSpecification; +import org.apache.cassandra.cql3.QueryOptions; +import org.apache.cassandra.cql3.Term; +import org.apache.cassandra.cql3.functions.*; +import org.apache.cassandra.db.marshal.AbstractType; +import org.apache.cassandra.exceptions.InvalidRequestException; +import org.apache.cassandra.exceptions.RequestValidationException; +import org.apache.cassandra.exceptions.UnauthorizedException; +import org.apache.cassandra.service.ClientState; +import org.apache.cassandra.service.MigrationManager; +import org.apache.cassandra.thrift.ThriftValidation; +import org.apache.cassandra.transport.Event; + +/** + * A <code>CREATE AGGREGATE</code> statement parsed from a CQL query. + */ +public final class CreateAggregateStatement extends SchemaAlteringStatement +{ + private final boolean orReplace; + private final boolean ifNotExists; + private FunctionName functionName; + private String stateFunc; + private String finalFunc; + private final CQL3Type.Raw stateTypeRaw; + + private final List<CQL3Type.Raw> argRawTypes; + private final Term.Raw ival; + + public CreateAggregateStatement(FunctionName functionName, + List<CQL3Type.Raw> argRawTypes, + String stateFunc, + CQL3Type.Raw stateType, + String finalFunc, + Term.Raw ival, + boolean orReplace, + boolean ifNotExists) + { + this.functionName = functionName; + this.argRawTypes = argRawTypes; + this.stateFunc = stateFunc; + this.finalFunc = finalFunc; + this.stateTypeRaw = stateType; + this.ival = ival; + this.orReplace = orReplace; + this.ifNotExists = ifNotExists; + } + + public void prepareKeyspace(ClientState state) throws InvalidRequestException + { + if (!functionName.hasKeyspace() && state.getRawKeyspace() != null) + functionName = new FunctionName(state.getKeyspace(), functionName.name); + + if (!functionName.hasKeyspace()) + throw new InvalidRequestException("Functions must be fully qualified with a keyspace name if a keyspace is not set for the session"); + + ThriftValidation.validateKeyspaceNotSystem(functionName.keyspace); + } + + public void checkAccess(ClientState state) throws UnauthorizedException, InvalidRequestException + { + // TODO CASSANDRA-7557 (function DDL permission) + + state.hasKeyspaceAccess(functionName.keyspace, Permission.CREATE); + } + + public void validate(ClientState state) throws InvalidRequestException + { + if (ifNotExists && orReplace) + throw new InvalidRequestException("Cannot use both 'OR REPLACE' and 'IF NOT EXISTS' directives"); + + if (Schema.instance.getKSMetaData(functionName.keyspace) == null) + throw new InvalidRequestException(String.format("Cannot add aggregate '%s' to non existing keyspace '%s'.", functionName.name, functionName.keyspace)); + } + + public Event.SchemaChange changeEvent() + { + return null; + } + + public boolean announceMigration(boolean isLocalOnly) throws RequestValidationException + { + List<AbstractType<?>> argTypes = new ArrayList<>(argRawTypes.size()); + for (CQL3Type.Raw rawType : argRawTypes) + argTypes.add(rawType.prepare(functionName.keyspace).getType()); + + FunctionName stateFuncName = new FunctionName(functionName.keyspace, stateFunc); + FunctionName finalFuncName; + + ScalarFunction fFinal = null; + AbstractType<?> stateType = stateTypeRaw.prepare(functionName.keyspace).getType(); + Function f = Functions.find(stateFuncName, stateArguments(stateType, argTypes)); + if (!(f instanceof ScalarFunction)) + throw new InvalidRequestException("State function " + stateFuncSig(stateFuncName, stateTypeRaw, argRawTypes) + " does not exist or is not a scalar function"); + ScalarFunction fState = (ScalarFunction)f; + + AbstractType<?> returnType; + if (finalFunc != null) + { + finalFuncName = new FunctionName(functionName.keyspace, finalFunc); + f = Functions.find(finalFuncName, Collections.<AbstractType<?>>singletonList(stateType)); + if (!(f instanceof ScalarFunction)) + throw new InvalidRequestException("Final function " + finalFuncName + "(" + stateTypeRaw + ") does not exist"); + fFinal = (ScalarFunction) f; + returnType = fFinal.returnType(); + } + else + { + returnType = fState.returnType(); + if (!returnType.equals(stateType)) + throw new InvalidRequestException("State function " + stateFuncSig(stateFuncName, stateTypeRaw, argRawTypes) + " return type must be the same as the first argument type (if no final function is used)"); + } + + Function old = Functions.find(functionName, argTypes); + if (old != null) + { + if (ifNotExists) + return false; + if (!orReplace) + throw new InvalidRequestException(String.format("Function %s already exists", old)); + if (!(old instanceof AggregateFunction)) + throw new InvalidRequestException(String.format("Aggregate %s can only replace an aggregate", old)); + + // Means we're replacing the function. We still need to validate that 1) it's not a native function and 2) that the return type + // matches (or that could break existing code badly) + if (old.isNative()) + throw new InvalidRequestException(String.format("Cannot replace native aggregate %s", old)); + if (!old.returnType().isValueCompatibleWith(returnType)) + throw new InvalidRequestException(String.format("Cannot replace aggregate %s, the new return type %s is not compatible with the return type %s of existing function", + functionName, returnType.asCQL3Type(), old.returnType().asCQL3Type())); + } + + ByteBuffer initcond = null; + if (ival != null) + { + ColumnSpecification receiver = new ColumnSpecification(functionName.keyspace, "--dummy--", new ColumnIdentifier("(aggregate_initcond)", true), stateType); + initcond = ival.prepare(functionName.keyspace, receiver).bindAndGet(QueryOptions.DEFAULT); + } + + UDAggregate udAggregate = new UDAggregate(functionName, argTypes, returnType, + fState, + fFinal, + initcond); + + MigrationManager.announceNewAggregate(udAggregate, isLocalOnly); + + return true; + } + + private String stateFuncSig(FunctionName stateFuncName, CQL3Type.Raw stateTypeRaw, List<CQL3Type.Raw> argRawTypes) + { + StringBuilder sb = new StringBuilder(); + sb.append(stateFuncName.toString()).append('(').append(stateTypeRaw); + for (CQL3Type.Raw argRawType : argRawTypes) + sb.append(", ").append(argRawType); + sb.append(')'); + return sb.toString(); + } + + private List<AbstractType<?>> stateArguments(AbstractType<?> stateType, List<AbstractType<?>> argTypes) + { + List<AbstractType<?>> r = new ArrayList<>(argTypes.size() + 1); + r.add(stateType); + r.addAll(argTypes); + return r; + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/e2f35c76/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java b/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java index 8d8c27a..dbdecf9 100644 --- a/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java @@ -50,7 +50,6 @@ public final class CreateFunctionStatement extends SchemaAlteringStatement private final List<ColumnIdentifier> argNames; private final List<CQL3Type.Raw> argRawTypes; private final CQL3Type.Raw rawReturnType; - private String currentKeyspace; public CreateFunctionStatement(FunctionName functionName, String language, @@ -75,13 +74,11 @@ public final class CreateFunctionStatement extends SchemaAlteringStatement public void prepareKeyspace(ClientState state) throws InvalidRequestException { - currentKeyspace = state.getRawKeyspace(); - - if (!functionName.hasKeyspace() && currentKeyspace != null) - functionName = new FunctionName(currentKeyspace, functionName.name); + if (!functionName.hasKeyspace() && state.getRawKeyspace() != null) + functionName = new FunctionName(state.getRawKeyspace(), functionName.name); if (!functionName.hasKeyspace()) - throw new InvalidRequestException("You need to be logged in a keyspace or use a fully qualified function name"); + throw new InvalidRequestException("Functions must be fully qualified with a keyspace name if a keyspace is not set for the session"); ThriftValidation.validateKeyspaceNotSystem(functionName.keyspace); } @@ -126,6 +123,8 @@ public final class CreateFunctionStatement extends SchemaAlteringStatement return false; if (!orReplace) throw new InvalidRequestException(String.format("Function %s already exists", old)); + if (!(old instanceof ScalarFunction)) + throw new InvalidRequestException(String.format("Function %s can only replace a function", old)); if (!Functions.typeEquals(old.returnType(), returnType)) throw new InvalidRequestException(String.format("Cannot replace function %s, the new return type %s is not compatible with the return type %s of existing function",
