Repository: cassandra Updated Branches: refs/heads/trunk 194a72f64 -> 816c905ae
Support scripting languages for UDFs Patch by Robert Stupp; reviewed by Tyler Hobbs for CASSANDRA-7526 Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/816c905a Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/816c905a Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/816c905a Branch: refs/heads/trunk Commit: 816c905ae2a3d98f877babc72c1f8d8650aa5d24 Parents: 194a72f Author: Robert Stupp <sn...@snazy.de> Authored: Wed Oct 8 15:06:55 2014 -0500 Committer: Tyler Hobbs <ty...@datastax.com> Committed: Wed Oct 8 15:06:55 2014 -0500 ---------------------------------------------------------------------- CHANGES.txt | 1 + .../cql3/functions/ScriptBasedUDF.java | 150 +++++++++++++ .../cassandra/cql3/functions/UDFunction.java | 3 +- .../statements/CreateFunctionStatement.java | 5 + .../org/apache/cassandra/cql3/CQLTester.java | 8 + test/unit/org/apache/cassandra/cql3/UFTest.java | 213 +++++++++++++++++++ 6 files changed, 379 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/cassandra/blob/816c905a/CHANGES.txt ---------------------------------------------------------------------- diff --git a/CHANGES.txt b/CHANGES.txt index d3f2046..9d522b7 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,4 +1,5 @@ 3.0 + * Support for scripting languages in user-defined functions (CASSANDRA-7526) * Support for aggregation functions (CASSANDRA-4914) * Improve query to read paxos table on propose (CASSANDRA-7929) * Remove cassandra-cli (CASSANDRA-7920) http://git-wip-us.apache.org/repos/asf/cassandra/blob/816c905a/src/java/org/apache/cassandra/cql3/functions/ScriptBasedUDF.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/ScriptBasedUDF.java b/src/java/org/apache/cassandra/cql3/functions/ScriptBasedUDF.java new file mode 100644 index 0000000..73fc43b --- /dev/null +++ b/src/java/org/apache/cassandra/cql3/functions/ScriptBasedUDF.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.cql3.functions; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.script.Bindings; +import javax.script.Compilable; +import javax.script.CompiledScript; +import javax.script.ScriptEngine; +import javax.script.ScriptEngineFactory; +import javax.script.ScriptEngineManager; +import javax.script.ScriptException; +import javax.script.SimpleBindings; + +import org.apache.cassandra.cql3.ColumnIdentifier; +import org.apache.cassandra.db.marshal.AbstractType; +import org.apache.cassandra.exceptions.InvalidRequestException; + +public class ScriptBasedUDF extends UDFunction +{ + static final Map<String, Compilable> scriptEngines = new HashMap<>(); + + static { + ScriptEngineManager scriptEngineManager = new ScriptEngineManager(); + for (ScriptEngineFactory scriptEngineFactory : scriptEngineManager.getEngineFactories()) + { + ScriptEngine scriptEngine = scriptEngineFactory.getScriptEngine(); + boolean compilable = scriptEngine instanceof Compilable; + if (compilable) + { + logger.info("Found scripting engine {} {} - {} {} - language names: {}", + scriptEngineFactory.getEngineName(), scriptEngineFactory.getEngineVersion(), + scriptEngineFactory.getLanguageName(), scriptEngineFactory.getLanguageVersion(), + scriptEngineFactory.getNames()); + for (String name : scriptEngineFactory.getNames()) + scriptEngines.put(name, (Compilable) scriptEngine); + } + } + } + + private final CompiledScript script; + + ScriptBasedUDF(FunctionName name, + List<ColumnIdentifier> argNames, + List<AbstractType<?>> argTypes, + AbstractType<?> returnType, + String language, + String body, + boolean deterministic) + throws InvalidRequestException + { + super(name, argNames, argTypes, returnType, language, body, deterministic); + + Compilable scriptEngine = scriptEngines.get(language); + if (scriptEngine == null) + throw new InvalidRequestException(String.format("Invalid language '%s' for function '%s'", language, name)); + + try + { + this.script = scriptEngine.compile(body); + } + catch (RuntimeException | ScriptException e) + { + logger.info("Failed to compile function '{}' for language {}: ", name, language, e); + throw new InvalidRequestException( + String.format("Failed to compile function '%s' for language %s: %s", name, language, e)); + } + } + + public ByteBuffer execute(List<ByteBuffer> parameters) throws InvalidRequestException + { + Object[] params = new Object[argTypes.size()]; + for (int i = 0; i < params.length; i++) + { + ByteBuffer bb = parameters.get(i); + if (bb != null) + params[i] = argTypes.get(i).compose(bb); + } + + try + { + Bindings bindings = new SimpleBindings(); + for (int i = 0; i < params.length; i++) + bindings.put(argNames.get(i).toString(), params[i]); + + Object result = script.eval(bindings); + if (result == null) + return null; + + Class<?> javaReturnType = returnType.getSerializer().getType(); + Class<?> resultType = result.getClass(); + if (!javaReturnType.isAssignableFrom(resultType)) + { + if (result instanceof Number) + { + Number rNumber = (Number) result; + if (javaReturnType == Integer.class) + result = rNumber.intValue(); + else if (javaReturnType == Long.class) + result = rNumber.longValue(); + else if (javaReturnType == Float.class) + result = rNumber.floatValue(); + else if (javaReturnType == Double.class) + result = rNumber.doubleValue(); + else if (javaReturnType == BigInteger.class) + { + if (rNumber instanceof BigDecimal) + result = ((BigDecimal)rNumber).toBigInteger(); + else if (rNumber instanceof Double || rNumber instanceof Float) + result = new BigDecimal(rNumber.toString()).toBigInteger(); + else + result = BigInteger.valueOf(rNumber.longValue()); + } + else if (javaReturnType == BigDecimal.class) + // String c'tor of BigDecimal is more accurate than valueOf(double) + result = new BigDecimal(rNumber.toString()); + } + } + + @SuppressWarnings("unchecked") ByteBuffer r = ((AbstractType) returnType).decompose(result); + return r; + } + catch (RuntimeException | ScriptException e) + { + logger.info("Execution of UDF '{}' failed", name, e); + throw new InvalidRequestException("Execution of user-defined function '" + name + "' failed: " + e); + } + } +} http://git-wip-us.apache.org/repos/asf/cassandra/blob/816c905a/src/java/org/apache/cassandra/cql3/functions/UDFunction.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/functions/UDFunction.java b/src/java/org/apache/cassandra/cql3/functions/UDFunction.java index 264998c..8c51b9e 100644 --- a/src/java/org/apache/cassandra/cql3/functions/UDFunction.java +++ b/src/java/org/apache/cassandra/cql3/functions/UDFunction.java @@ -59,6 +59,7 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct boolean deterministic) { super(name, argTypes, returnType); + assert new HashSet<>(argNames).size() == argNames.size() : "duplicate argument names"; this.argNames = argNames; this.language = language; this.body = body; @@ -83,7 +84,7 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct { case "class": return new ReflectionBasedUDF(name, argNames, argTypes, returnType, language, body, deterministic); case "java": return JavaSourceUDFFactory.buildUDF(name, argNames, argTypes, returnType, body, deterministic); - default: throw new InvalidRequestException(String.format("Invalid language %s for '%s'", language, name)); + default: return new ScriptBasedUDF(name, argNames, argTypes, returnType, language, body, deterministic); } } http://git-wip-us.apache.org/repos/asf/cassandra/blob/816c905a/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java b/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java index a54409e..712a474 100644 --- a/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java @@ -18,6 +18,7 @@ package org.apache.cassandra.cql3.statements; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import org.apache.cassandra.auth.Permission; @@ -89,6 +90,10 @@ public final class CreateFunctionStatement extends SchemaAlteringStatement public boolean announceMigration(boolean isLocalOnly) throws RequestValidationException { + if (new HashSet<>(argNames).size() != argNames.size()) + throw new InvalidRequestException(String.format("duplicate argument names for given function %s with argument names %s", + functionName, argNames)); + List<AbstractType<?>> argTypes = new ArrayList<>(argRawTypes.size()); for (CQL3Type.Raw rawType : argRawTypes) // We have no proper keyspace to give, which means that this will break (NPE currently) http://git-wip-us.apache.org/repos/asf/cassandra/blob/816c905a/test/unit/org/apache/cassandra/cql3/CQLTester.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/cql3/CQLTester.java b/test/unit/org/apache/cassandra/cql3/CQLTester.java index 31708aa..a456ea8 100644 --- a/test/unit/org/apache/cassandra/cql3/CQLTester.java +++ b/test/unit/org/apache/cassandra/cql3/CQLTester.java @@ -18,6 +18,8 @@ package org.apache.cassandra.cql3; import java.io.File; +import java.math.BigDecimal; +import java.math.BigInteger; import java.nio.ByteBuffer; import java.util.*; import java.util.concurrent.CountDownLatch; @@ -613,6 +615,12 @@ public abstract class CQLTester if (value instanceof Double) return DoubleType.instance; + if (value instanceof BigInteger) + return IntegerType.instance; + + if (value instanceof BigDecimal) + return DecimalType.instance; + if (value instanceof String) return UTF8Type.instance; http://git-wip-us.apache.org/repos/asf/cassandra/blob/816c905a/test/unit/org/apache/cassandra/cql3/UFTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/cql3/UFTest.java b/test/unit/org/apache/cassandra/cql3/UFTest.java index 5dd77bf..46db578 100644 --- a/test/unit/org/apache/cassandra/cql3/UFTest.java +++ b/test/unit/org/apache/cassandra/cql3/UFTest.java @@ -17,6 +17,9 @@ */ package org.apache.cassandra.cql3; +import java.math.BigDecimal; +import java.math.BigInteger; + import org.junit.Assert; import org.junit.Test; @@ -447,4 +450,214 @@ public class UFTest extends CQLTester assertRows(execute("SELECT language, body FROM system.schema_functions WHERE namespace='foo' AND name='pgfun1'"), row("java", functionBody)); } + + @Test + public void testJavascriptFunction() throws Throwable + { + createTable("CREATE TABLE %s (key int primary key, val double)"); + + String functionBody = "\n" + + " Math.sin(val);\n"; + + String cql = "CREATE OR REPLACE FUNCTION jsft(val double) RETURNS double LANGUAGE javascript\n" + + "AS '" + functionBody + "';"; + + execute(cql); + + assertRows(execute("SELECT language, body FROM system.schema_functions WHERE namespace='' AND name='jsft'"), + row("javascript", functionBody)); + + execute("INSERT INTO %s (key, val) VALUES (?, ?)", 1, 1d); + execute("INSERT INTO %s (key, val) VALUES (?, ?)", 2, 2d); + execute("INSERT INTO %s (key, val) VALUES (?, ?)", 3, 3d); + assertRows(execute("SELECT key, val, jsft(val) FROM %s"), + row(1, 1d, Math.sin(1d)), + row(2, 2d, Math.sin(2d)), + row(3, 3d, Math.sin(3d)) + ); + } + + @Test + public void testJavascriptBadReturnType() throws Throwable + { + createTable("CREATE TABLE %s (key int primary key, val double)"); + + execute("CREATE OR REPLACE FUNCTION jsft(val double) RETURNS double LANGUAGE javascript\n" + + "AS '\"string\";';"); + + execute("INSERT INTO %s (key, val) VALUES (?, ?)", 1, 1d); + // throws IRE with ClassCastException + assertInvalid("SELECT key, val, jsft(val) FROM %s"); + } + + @Test + public void testJavascriptThrow() throws Throwable + { + createTable("CREATE TABLE %s (key int primary key, val double)"); + + execute("CREATE OR REPLACE FUNCTION jsft(val double) RETURNS double LANGUAGE javascript\n" + + "AS 'throw \"fool\";';"); + + execute("INSERT INTO %s (key, val) VALUES (?, ?)", 1, 1d); + // throws IRE with ScriptException + assertInvalid("SELECT key, val, jsft(val) FROM %s"); + } + + @Test + public void testDuplicateArgNames() throws Throwable + { + assertInvalid("CREATE OR REPLACE FUNCTION scrinv(val double, val text) RETURNS text LANGUAGE javascript\n" + + "AS '\"foo bar\";';"); + } + + @Test + public void testJavascriptCompileFailure() throws Throwable + { + assertInvalid("CREATE OR REPLACE FUNCTION scrinv(val double) RETURNS double LANGUAGE javascript\n" + + "AS 'foo bar';"); + } + + @Test + public void testScriptInvalidLanguage() throws Throwable + { + assertInvalid("CREATE OR REPLACE FUNCTION scrinv(val double) RETURNS double LANGUAGE artificial_intelligence\n" + + "AS 'question for 42?';"); + } + + @Test + public void testScriptReturnTypeCasting() throws Throwable + { + createTable("CREATE TABLE %s (key int primary key, val double)"); + execute("INSERT INTO %s (key, val) VALUES (?, ?)", 1, 1d); + + execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS boolean LANGUAGE javascript\n" + + "AS 'true;';"); + assertRows(execute("SELECT key, val, js(val) FROM %s"), + row(1, 1d, true)); + execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS boolean LANGUAGE javascript\n" + + "AS 'false;';"); + assertRows(execute("SELECT key, val, js(val) FROM %s"), + row(1, 1d, false)); + execute("DROP FUNCTION js(double)"); + + // declared rtype = int , return type = int + execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS int LANGUAGE javascript\n" + + "AS '100;';"); + assertRows(execute("SELECT key, val, js(val) FROM %s"), + row(1, 1d, 100)); + execute("DROP FUNCTION js(double)"); + + // declared rtype = int , return type = double + execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS int LANGUAGE javascript\n" + + "AS '100.;';"); + assertRows(execute("SELECT key, val, js(val) FROM %s"), + row(1, 1d, 100)); + execute("DROP FUNCTION js(double)"); + + // declared rtype = double , return type = int + execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS double LANGUAGE javascript\n" + + "AS '100;';"); + assertRows(execute("SELECT key, val, js(val) FROM %s"), + row(1, 1d, 100d)); + execute("DROP FUNCTION js(double)"); + + // declared rtype = double , return type = double + execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS double LANGUAGE javascript\n" + + "AS '100.;';"); + assertRows(execute("SELECT key, val, js(val) FROM %s"), + row(1, 1d, 100d)); + execute("DROP FUNCTION js(double)"); + + // declared rtype = bigint , return type = int + execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS bigint LANGUAGE javascript\n" + + "AS '100;';"); + assertRows(execute("SELECT key, val, js(val) FROM %s"), + row(1, 1d, 100L)); + execute("DROP FUNCTION js(double)"); + + // declared rtype = bigint , return type = double + execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS bigint LANGUAGE javascript\n" + + "AS '100.;';"); + assertRows(execute("SELECT key, val, js(val) FROM %s"), + row(1, 1d, 100L)); + execute("DROP FUNCTION js(double)"); + + // declared rtype = varint , return type = int + execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS varint LANGUAGE javascript\n" + + "AS '100;';"); + assertRows(execute("SELECT key, val, js(val) FROM %s"), + row(1, 1d, BigInteger.valueOf(100L))); + execute("DROP FUNCTION js(double)"); + + // declared rtype = varint , return type = double + execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS varint LANGUAGE javascript\n" + + "AS '100.;';"); + assertRows(execute("SELECT key, val, js(val) FROM %s"), + row(1, 1d, BigInteger.valueOf(100L))); + execute("DROP FUNCTION js(double)"); + + // declared rtype = decimal , return type = int + execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS decimal LANGUAGE javascript\n" + + "AS '100;';"); + assertRows(execute("SELECT key, val, js(val) FROM %s"), + row(1, 1d, BigDecimal.valueOf(100d))); + execute("DROP FUNCTION js(double)"); + + // declared rtype = decimal , return type = double + execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS decimal LANGUAGE javascript\n" + + "AS '100.;';"); + assertRows(execute("SELECT key, val, js(val) FROM %s"), + row(1, 1d, BigDecimal.valueOf(100d))); + execute("DROP FUNCTION js(double)"); + } + + @Test + public void testScriptParamReturnTypes() throws Throwable + { + createTable("CREATE TABLE %s (key int primary key, ival int, lval bigint, fval float, dval double, vval varint, ddval decimal)"); + execute("INSERT INTO %s (key, ival, lval, fval, dval, vval, ddval) VALUES (?, ?, ?, ?, ?, ?, ?)", 1, + 1, 1L, 1f, 1d, BigInteger.valueOf(1L), BigDecimal.valueOf(1d)); + + // type = int + execute("CREATE OR REPLACE FUNCTION jsint(val int) RETURNS int LANGUAGE javascript\n" + + "AS 'val+1;';"); + assertRows(execute("SELECT key, ival, jsint(ival) FROM %s"), + row(1, 1, 2)); + execute("DROP FUNCTION jsint(int)"); + + // bigint + execute("CREATE OR REPLACE FUNCTION jsbigint(val bigint) RETURNS bigint LANGUAGE javascript\n" + + "AS 'val+1;';"); + assertRows(execute("SELECT key, lval, jsbigint(lval) FROM %s"), + row(1, 1L, 2L)); + execute("DROP FUNCTION jsbigint(bigint)"); + + // float + execute("CREATE OR REPLACE FUNCTION jsfloat(val float) RETURNS float LANGUAGE javascript\n" + + "AS 'val+1;';"); + assertRows(execute("SELECT key, fval, jsfloat(fval) FROM %s"), + row(1, 1f, 2f)); + execute("DROP FUNCTION jsfloat(float)"); + + // double + execute("CREATE OR REPLACE FUNCTION jsdouble(val double) RETURNS double LANGUAGE javascript\n" + + "AS 'val+1;';"); + assertRows(execute("SELECT key, dval, jsdouble(dval) FROM %s"), + row(1, 1d, 2d)); + execute("DROP FUNCTION jsdouble(double)"); + + // varint + execute("CREATE OR REPLACE FUNCTION jsvarint(val varint) RETURNS varint LANGUAGE javascript\n" + + "AS 'val+1;';"); + assertRows(execute("SELECT key, vval, jsvarint(vval) FROM %s"), + row(1, BigInteger.valueOf(1L), BigInteger.valueOf(2L))); + execute("DROP FUNCTION jsvarint(varint)"); + + // decimal + execute("CREATE OR REPLACE FUNCTION jsdecimal(val decimal) RETURNS decimal LANGUAGE javascript\n" + + "AS 'val+1;';"); + assertRows(execute("SELECT key, ddval, jsdecimal(ddval) FROM %s"), + row(1, BigDecimal.valueOf(1d), BigDecimal.valueOf(2d))); + execute("DROP FUNCTION jsdecimal(decimal)"); + } }