Repository: cassandra
Updated Branches:
  refs/heads/trunk 194a72f64 -> 816c905ae


Support scripting languages for UDFs

Patch by Robert Stupp; reviewed by Tyler Hobbs for CASSANDRA-7526


Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo
Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/816c905a
Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/816c905a
Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/816c905a

Branch: refs/heads/trunk
Commit: 816c905ae2a3d98f877babc72c1f8d8650aa5d24
Parents: 194a72f
Author: Robert Stupp <sn...@snazy.de>
Authored: Wed Oct 8 15:06:55 2014 -0500
Committer: Tyler Hobbs <ty...@datastax.com>
Committed: Wed Oct 8 15:06:55 2014 -0500

----------------------------------------------------------------------
 CHANGES.txt                                     |   1 +
 .../cql3/functions/ScriptBasedUDF.java          | 150 +++++++++++++
 .../cassandra/cql3/functions/UDFunction.java    |   3 +-
 .../statements/CreateFunctionStatement.java     |   5 +
 .../org/apache/cassandra/cql3/CQLTester.java    |   8 +
 test/unit/org/apache/cassandra/cql3/UFTest.java | 213 +++++++++++++++++++
 6 files changed, 379 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/cassandra/blob/816c905a/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index d3f2046..9d522b7 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
 3.0
+ * Support for scripting languages in user-defined functions (CASSANDRA-7526)
  * Support for aggregation functions (CASSANDRA-4914)
  * Improve query to read paxos table on propose (CASSANDRA-7929)
  * Remove cassandra-cli (CASSANDRA-7920)

http://git-wip-us.apache.org/repos/asf/cassandra/blob/816c905a/src/java/org/apache/cassandra/cql3/functions/ScriptBasedUDF.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/cql3/functions/ScriptBasedUDF.java 
b/src/java/org/apache/cassandra/cql3/functions/ScriptBasedUDF.java
new file mode 100644
index 0000000..73fc43b
--- /dev/null
+++ b/src/java/org/apache/cassandra/cql3/functions/ScriptBasedUDF.java
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.cassandra.cql3.functions;
+
+import java.math.BigDecimal;
+import java.math.BigInteger;
+import java.nio.ByteBuffer;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import javax.script.Bindings;
+import javax.script.Compilable;
+import javax.script.CompiledScript;
+import javax.script.ScriptEngine;
+import javax.script.ScriptEngineFactory;
+import javax.script.ScriptEngineManager;
+import javax.script.ScriptException;
+import javax.script.SimpleBindings;
+
+import org.apache.cassandra.cql3.ColumnIdentifier;
+import org.apache.cassandra.db.marshal.AbstractType;
+import org.apache.cassandra.exceptions.InvalidRequestException;
+
+public class ScriptBasedUDF extends UDFunction
+{
+    static final Map<String, Compilable> scriptEngines = new HashMap<>();
+
+    static {
+        ScriptEngineManager scriptEngineManager = new ScriptEngineManager();
+        for (ScriptEngineFactory scriptEngineFactory : 
scriptEngineManager.getEngineFactories())
+        {
+            ScriptEngine scriptEngine = scriptEngineFactory.getScriptEngine();
+            boolean compilable = scriptEngine instanceof Compilable;
+            if (compilable)
+            {
+                logger.info("Found scripting engine {} {} - {} {} - language 
names: {}",
+                            scriptEngineFactory.getEngineName(), 
scriptEngineFactory.getEngineVersion(),
+                            scriptEngineFactory.getLanguageName(), 
scriptEngineFactory.getLanguageVersion(),
+                            scriptEngineFactory.getNames());
+                for (String name : scriptEngineFactory.getNames())
+                    scriptEngines.put(name, (Compilable) scriptEngine);
+            }
+        }
+    }
+
+    private final CompiledScript script;
+
+    ScriptBasedUDF(FunctionName name,
+                   List<ColumnIdentifier> argNames,
+                   List<AbstractType<?>> argTypes,
+                   AbstractType<?> returnType,
+                   String language,
+                   String body,
+                   boolean deterministic)
+    throws InvalidRequestException
+    {
+        super(name, argNames, argTypes, returnType, language, body, 
deterministic);
+
+        Compilable scriptEngine = scriptEngines.get(language);
+        if (scriptEngine == null)
+            throw new InvalidRequestException(String.format("Invalid language 
'%s' for function '%s'", language, name));
+
+        try
+        {
+            this.script = scriptEngine.compile(body);
+        }
+        catch (RuntimeException | ScriptException e)
+        {
+            logger.info("Failed to compile function '{}' for language {}: ", 
name, language, e);
+            throw new InvalidRequestException(
+                    String.format("Failed to compile function '%s' for 
language %s: %s", name, language, e));
+        }
+    }
+
+    public ByteBuffer execute(List<ByteBuffer> parameters) throws 
InvalidRequestException
+    {
+        Object[] params = new Object[argTypes.size()];
+        for (int i = 0; i < params.length; i++)
+        {
+            ByteBuffer bb = parameters.get(i);
+            if (bb != null)
+                params[i] = argTypes.get(i).compose(bb);
+        }
+
+        try
+        {
+            Bindings bindings = new SimpleBindings();
+            for (int i = 0; i < params.length; i++)
+                bindings.put(argNames.get(i).toString(), params[i]);
+
+            Object result = script.eval(bindings);
+            if (result == null)
+                return null;
+
+            Class<?> javaReturnType = returnType.getSerializer().getType();
+            Class<?> resultType = result.getClass();
+            if (!javaReturnType.isAssignableFrom(resultType))
+            {
+                if (result instanceof Number)
+                {
+                    Number rNumber = (Number) result;
+                    if (javaReturnType == Integer.class)
+                        result = rNumber.intValue();
+                    else if (javaReturnType == Long.class)
+                        result = rNumber.longValue();
+                    else if (javaReturnType == Float.class)
+                        result = rNumber.floatValue();
+                    else if (javaReturnType == Double.class)
+                        result = rNumber.doubleValue();
+                    else if (javaReturnType == BigInteger.class)
+                    {
+                        if (rNumber instanceof BigDecimal)
+                            result = ((BigDecimal)rNumber).toBigInteger();
+                        else if (rNumber instanceof Double || rNumber 
instanceof Float)
+                            result = new 
BigDecimal(rNumber.toString()).toBigInteger();
+                        else
+                            result = BigInteger.valueOf(rNumber.longValue());
+                    }
+                    else if (javaReturnType == BigDecimal.class)
+                        // String c'tor of BigDecimal is more accurate than 
valueOf(double)
+                        result = new BigDecimal(rNumber.toString());
+                }
+            }
+
+            @SuppressWarnings("unchecked") ByteBuffer r = ((AbstractType) 
returnType).decompose(result);
+            return r;
+        }
+        catch (RuntimeException | ScriptException e)
+        {
+            logger.info("Execution of UDF '{}' failed", name, e);
+            throw new InvalidRequestException("Execution of user-defined 
function '" + name + "' failed: " + e);
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/cassandra/blob/816c905a/src/java/org/apache/cassandra/cql3/functions/UDFunction.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/cql3/functions/UDFunction.java 
b/src/java/org/apache/cassandra/cql3/functions/UDFunction.java
index 264998c..8c51b9e 100644
--- a/src/java/org/apache/cassandra/cql3/functions/UDFunction.java
+++ b/src/java/org/apache/cassandra/cql3/functions/UDFunction.java
@@ -59,6 +59,7 @@ public abstract class UDFunction extends AbstractFunction 
implements ScalarFunct
                          boolean deterministic)
     {
         super(name, argTypes, returnType);
+        assert new HashSet<>(argNames).size() == argNames.size() : "duplicate 
argument names";
         this.argNames = argNames;
         this.language = language;
         this.body = body;
@@ -83,7 +84,7 @@ public abstract class UDFunction extends AbstractFunction 
implements ScalarFunct
         {
             case "class": return new ReflectionBasedUDF(name, argNames, 
argTypes, returnType, language, body, deterministic);
             case "java": return JavaSourceUDFFactory.buildUDF(name, argNames, 
argTypes, returnType, body, deterministic);
-            default: throw new InvalidRequestException(String.format("Invalid 
language %s for '%s'", language, name));
+            default: return new ScriptBasedUDF(name, argNames, argTypes, 
returnType, language, body, deterministic);
         }
     }
 

http://git-wip-us.apache.org/repos/asf/cassandra/blob/816c905a/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java
----------------------------------------------------------------------
diff --git 
a/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java 
b/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java
index a54409e..712a474 100644
--- a/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java
+++ b/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java
@@ -18,6 +18,7 @@
 package org.apache.cassandra.cql3.statements;
 
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.List;
 
 import org.apache.cassandra.auth.Permission;
@@ -89,6 +90,10 @@ public final class CreateFunctionStatement extends 
SchemaAlteringStatement
 
     public boolean announceMigration(boolean isLocalOnly) throws 
RequestValidationException
     {
+        if (new HashSet<>(argNames).size() != argNames.size())
+            throw new InvalidRequestException(String.format("duplicate 
argument names for given function %s with argument names %s",
+                                                            functionName, 
argNames));
+
         List<AbstractType<?>> argTypes = new ArrayList<>(argRawTypes.size());
         for (CQL3Type.Raw rawType : argRawTypes)
             // We have no proper keyspace to give, which means that this will 
break (NPE currently)

http://git-wip-us.apache.org/repos/asf/cassandra/blob/816c905a/test/unit/org/apache/cassandra/cql3/CQLTester.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/cql3/CQLTester.java 
b/test/unit/org/apache/cassandra/cql3/CQLTester.java
index 31708aa..a456ea8 100644
--- a/test/unit/org/apache/cassandra/cql3/CQLTester.java
+++ b/test/unit/org/apache/cassandra/cql3/CQLTester.java
@@ -18,6 +18,8 @@
 package org.apache.cassandra.cql3;
 
 import java.io.File;
+import java.math.BigDecimal;
+import java.math.BigInteger;
 import java.nio.ByteBuffer;
 import java.util.*;
 import java.util.concurrent.CountDownLatch;
@@ -613,6 +615,12 @@ public abstract class CQLTester
         if (value instanceof Double)
             return DoubleType.instance;
 
+        if (value instanceof BigInteger)
+            return IntegerType.instance;
+
+        if (value instanceof BigDecimal)
+            return DecimalType.instance;
+
         if (value instanceof String)
             return UTF8Type.instance;
 

http://git-wip-us.apache.org/repos/asf/cassandra/blob/816c905a/test/unit/org/apache/cassandra/cql3/UFTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/cql3/UFTest.java 
b/test/unit/org/apache/cassandra/cql3/UFTest.java
index 5dd77bf..46db578 100644
--- a/test/unit/org/apache/cassandra/cql3/UFTest.java
+++ b/test/unit/org/apache/cassandra/cql3/UFTest.java
@@ -17,6 +17,9 @@
  */
 package org.apache.cassandra.cql3;
 
+import java.math.BigDecimal;
+import java.math.BigInteger;
+
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -447,4 +450,214 @@ public class UFTest extends CQLTester
         assertRows(execute("SELECT language, body FROM system.schema_functions 
WHERE namespace='foo' AND name='pgfun1'"),
                    row("java", functionBody));
     }
+
+    @Test
+    public void testJavascriptFunction() throws Throwable
+    {
+        createTable("CREATE TABLE %s (key int primary key, val double)");
+
+        String functionBody = "\n" +
+                              "  Math.sin(val);\n";
+
+        String cql = "CREATE OR REPLACE FUNCTION jsft(val double) RETURNS 
double LANGUAGE javascript\n" +
+                     "AS '" + functionBody + "';";
+
+        execute(cql);
+
+        assertRows(execute("SELECT language, body FROM system.schema_functions 
WHERE namespace='' AND name='jsft'"),
+                   row("javascript", functionBody));
+
+        execute("INSERT INTO %s (key, val) VALUES (?, ?)", 1, 1d);
+        execute("INSERT INTO %s (key, val) VALUES (?, ?)", 2, 2d);
+        execute("INSERT INTO %s (key, val) VALUES (?, ?)", 3, 3d);
+        assertRows(execute("SELECT key, val, jsft(val) FROM %s"),
+                   row(1, 1d, Math.sin(1d)),
+                   row(2, 2d, Math.sin(2d)),
+                   row(3, 3d, Math.sin(3d))
+        );
+    }
+
+    @Test
+    public void testJavascriptBadReturnType() throws Throwable
+    {
+        createTable("CREATE TABLE %s (key int primary key, val double)");
+
+        execute("CREATE OR REPLACE FUNCTION jsft(val double) RETURNS double 
LANGUAGE javascript\n" +
+                "AS '\"string\";';");
+
+        execute("INSERT INTO %s (key, val) VALUES (?, ?)", 1, 1d);
+        // throws IRE with ClassCastException
+        assertInvalid("SELECT key, val, jsft(val) FROM %s");
+    }
+
+    @Test
+    public void testJavascriptThrow() throws Throwable
+    {
+        createTable("CREATE TABLE %s (key int primary key, val double)");
+
+        execute("CREATE OR REPLACE FUNCTION jsft(val double) RETURNS double 
LANGUAGE javascript\n" +
+                "AS 'throw \"fool\";';");
+
+        execute("INSERT INTO %s (key, val) VALUES (?, ?)", 1, 1d);
+        // throws IRE with ScriptException
+        assertInvalid("SELECT key, val, jsft(val) FROM %s");
+    }
+
+    @Test
+    public void testDuplicateArgNames() throws Throwable
+    {
+        assertInvalid("CREATE OR REPLACE FUNCTION scrinv(val double, val text) 
RETURNS text LANGUAGE javascript\n" +
+                      "AS '\"foo bar\";';");
+    }
+
+    @Test
+    public void testJavascriptCompileFailure() throws Throwable
+    {
+        assertInvalid("CREATE OR REPLACE FUNCTION scrinv(val double) RETURNS 
double LANGUAGE javascript\n" +
+                      "AS 'foo bar';");
+    }
+
+    @Test
+    public void testScriptInvalidLanguage() throws Throwable
+    {
+        assertInvalid("CREATE OR REPLACE FUNCTION scrinv(val double) RETURNS 
double LANGUAGE artificial_intelligence\n" +
+                      "AS 'question for 42?';");
+    }
+
+    @Test
+    public void testScriptReturnTypeCasting() throws Throwable
+    {
+        createTable("CREATE TABLE %s (key int primary key, val double)");
+        execute("INSERT INTO %s (key, val) VALUES (?, ?)", 1, 1d);
+
+        execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS boolean 
LANGUAGE javascript\n" +
+                "AS 'true;';");
+        assertRows(execute("SELECT key, val, js(val) FROM %s"),
+                   row(1, 1d, true));
+        execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS boolean 
LANGUAGE javascript\n" +
+                "AS 'false;';");
+        assertRows(execute("SELECT key, val, js(val) FROM %s"),
+                   row(1, 1d, false));
+        execute("DROP FUNCTION js(double)");
+
+        // declared rtype = int , return type = int
+        execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS int 
LANGUAGE javascript\n" +
+                "AS '100;';");
+        assertRows(execute("SELECT key, val, js(val) FROM %s"),
+                   row(1, 1d, 100));
+        execute("DROP FUNCTION js(double)");
+
+        // declared rtype = int , return type = double
+        execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS int 
LANGUAGE javascript\n" +
+                "AS '100.;';");
+        assertRows(execute("SELECT key, val, js(val) FROM %s"),
+                   row(1, 1d, 100));
+        execute("DROP FUNCTION js(double)");
+
+        // declared rtype = double , return type = int
+        execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS double 
LANGUAGE javascript\n" +
+                "AS '100;';");
+        assertRows(execute("SELECT key, val, js(val) FROM %s"),
+                   row(1, 1d, 100d));
+        execute("DROP FUNCTION js(double)");
+
+        // declared rtype = double , return type = double
+        execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS double 
LANGUAGE javascript\n" +
+                "AS '100.;';");
+        assertRows(execute("SELECT key, val, js(val) FROM %s"),
+                   row(1, 1d, 100d));
+        execute("DROP FUNCTION js(double)");
+
+        // declared rtype = bigint , return type = int
+        execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS bigint 
LANGUAGE javascript\n" +
+                "AS '100;';");
+        assertRows(execute("SELECT key, val, js(val) FROM %s"),
+                   row(1, 1d, 100L));
+        execute("DROP FUNCTION js(double)");
+
+        // declared rtype = bigint , return type = double
+        execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS bigint 
LANGUAGE javascript\n" +
+                "AS '100.;';");
+        assertRows(execute("SELECT key, val, js(val) FROM %s"),
+                   row(1, 1d, 100L));
+        execute("DROP FUNCTION js(double)");
+
+        // declared rtype = varint , return type = int
+        execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS varint 
LANGUAGE javascript\n" +
+                "AS '100;';");
+        assertRows(execute("SELECT key, val, js(val) FROM %s"),
+                   row(1, 1d, BigInteger.valueOf(100L)));
+        execute("DROP FUNCTION js(double)");
+
+        // declared rtype = varint , return type = double
+        execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS varint 
LANGUAGE javascript\n" +
+                "AS '100.;';");
+        assertRows(execute("SELECT key, val, js(val) FROM %s"),
+                   row(1, 1d, BigInteger.valueOf(100L)));
+        execute("DROP FUNCTION js(double)");
+
+        // declared rtype = decimal , return type = int
+        execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS decimal 
LANGUAGE javascript\n" +
+                "AS '100;';");
+        assertRows(execute("SELECT key, val, js(val) FROM %s"),
+                   row(1, 1d, BigDecimal.valueOf(100d)));
+        execute("DROP FUNCTION js(double)");
+
+        // declared rtype = decimal , return type = double
+        execute("CREATE OR REPLACE FUNCTION js(val double) RETURNS decimal 
LANGUAGE javascript\n" +
+                "AS '100.;';");
+        assertRows(execute("SELECT key, val, js(val) FROM %s"),
+                   row(1, 1d, BigDecimal.valueOf(100d)));
+        execute("DROP FUNCTION js(double)");
+    }
+
+    @Test
+    public void testScriptParamReturnTypes() throws Throwable
+    {
+        createTable("CREATE TABLE %s (key int primary key, ival int, lval 
bigint, fval float, dval double, vval varint, ddval decimal)");
+        execute("INSERT INTO %s (key, ival, lval, fval, dval, vval, ddval) 
VALUES (?, ?, ?, ?, ?, ?, ?)", 1,
+                1, 1L, 1f, 1d, BigInteger.valueOf(1L), BigDecimal.valueOf(1d));
+
+        // type = int
+        execute("CREATE OR REPLACE FUNCTION jsint(val int) RETURNS int 
LANGUAGE javascript\n" +
+                "AS 'val+1;';");
+        assertRows(execute("SELECT key, ival, jsint(ival) FROM %s"),
+                   row(1, 1, 2));
+        execute("DROP FUNCTION jsint(int)");
+
+        // bigint
+        execute("CREATE OR REPLACE FUNCTION jsbigint(val bigint) RETURNS 
bigint LANGUAGE javascript\n" +
+                "AS 'val+1;';");
+        assertRows(execute("SELECT key, lval, jsbigint(lval) FROM %s"),
+                   row(1, 1L, 2L));
+        execute("DROP FUNCTION jsbigint(bigint)");
+
+        // float
+        execute("CREATE OR REPLACE FUNCTION jsfloat(val float) RETURNS float 
LANGUAGE javascript\n" +
+                "AS 'val+1;';");
+        assertRows(execute("SELECT key, fval, jsfloat(fval) FROM %s"),
+                   row(1, 1f, 2f));
+        execute("DROP FUNCTION jsfloat(float)");
+
+        // double
+        execute("CREATE OR REPLACE FUNCTION jsdouble(val double) RETURNS 
double LANGUAGE javascript\n" +
+                "AS 'val+1;';");
+        assertRows(execute("SELECT key, dval, jsdouble(dval) FROM %s"),
+                   row(1, 1d, 2d));
+        execute("DROP FUNCTION jsdouble(double)");
+
+        // varint
+        execute("CREATE OR REPLACE FUNCTION jsvarint(val varint) RETURNS 
varint LANGUAGE javascript\n" +
+                "AS 'val+1;';");
+        assertRows(execute("SELECT key, vval, jsvarint(vval) FROM %s"),
+                   row(1, BigInteger.valueOf(1L), BigInteger.valueOf(2L)));
+        execute("DROP FUNCTION jsvarint(varint)");
+
+        // decimal
+        execute("CREATE OR REPLACE FUNCTION jsdecimal(val decimal) RETURNS 
decimal LANGUAGE javascript\n" +
+                "AS 'val+1;';");
+        assertRows(execute("SELECT key, ddval, jsdecimal(ddval) FROM %s"),
+                   row(1, BigDecimal.valueOf(1d), BigDecimal.valueOf(2d)));
+        execute("DROP FUNCTION jsdecimal(decimal)");
+    }
 }

Reply via email to