Repository: incubator-systemml
Updated Branches:
  refs/heads/master 107230055 -> 5d84a095a


[SYSTEMML-1190] Allow Scala UDF to be passed to SystemML via external UDF 
mechanism

The registration mechanism is inspired from Spark SQLContext's UDF. The
key construct is ml.udf.register("fn to be used in DML", scala UDF).

The restrictions for Scala UDF are as follows:
- Only types specified by DML language is supported for parameters and return 
types (i.e. Int, Double, Boolean, String, double[][]).
- At minimum, the function should have 1 argument and 1 return value.
- At max, the function can have 10 arguments and 10 return values.

Closes #349.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/5d84a095
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/5d84a095
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/5d84a095

Branch: refs/heads/master
Commit: 5d84a095aadc1d6f1e2e4c06439589591ca24aec
Parents: 1072300
Author: Niketan Pansare <[email protected]>
Authored: Mon Jan 23 13:27:35 2017 -0800
Committer: Niketan Pansare <[email protected]>
Committed: Mon Jan 23 13:31:07 2017 -0800

----------------------------------------------------------------------
 docs/spark-mlcontext-programming-guide.md       |  39 ++
 .../apache/sysml/api/mlcontext/MLContext.java   |  12 +-
 .../sysml/api/mlcontext/ScriptExecutor.java     |   8 +
 .../ExternalFunctionProgramBlock.java           |   5 +
 .../ExternalFunctionProgramBlockCP.java         |   2 +
 .../controlprogram/FunctionProgramBlock.java    |   3 +-
 .../cp/FunctionCallCPInstruction.java           |   2 +
 .../ExternalFunctionInvocationInstruction.java  |   2 +
 .../apache/sysml/udf/lib/GenericFunction.java   | 102 +++++
 .../sysml/api/ExternalUDFRegistration.scala     | 404 +++++++++++++++++++
 10 files changed, 577 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/5d84a095/docs/spark-mlcontext-programming-guide.md
----------------------------------------------------------------------
diff --git a/docs/spark-mlcontext-programming-guide.md 
b/docs/spark-mlcontext-programming-guide.md
index dcaa125..759d392 100644
--- a/docs/spark-mlcontext-programming-guide.md
+++ b/docs/spark-mlcontext-programming-guide.md
@@ -1636,6 +1636,45 @@ scala> for (i <- 1 to 5) {
 
 </div>
 
+## Passing Scala UDF to SystemML
+
+SystemML allows the users to pass a Scala UDF (with input/output types 
supported by SystemML)
+to the DML script via MLContext. The restrictions for the supported Scala UDFs 
are as follows:
+
+1. Only types specified by DML language is supported for parameters and return 
types (i.e. Int, Double, Boolean, String, double[][]).
+2. At minimum, the function should have 1 argument and 1 return value.
+3. At max, the function can have 10 arguments and 10 return values. 
+
+{% highlight scala %}
+import org.apache.sysml.api.mlcontext._
+import org.apache.sysml.api.mlcontext.ScriptFactory._
+val ml = new MLContext(sc)
+
+// Demonstrates how to pass a simple scala UDF to SystemML
+def addOne(x:Double):Double = x + 1
+ml.udf.register("addOne", addOne _)
+val script1 = dml("v = addOne(2.0); print(v)")
+ml.execute(script1)
+
+// Demonstrates operation on local matrices (double[][])
+def addOneToDiagonal(x:Array[Array[Double]]):Array[Array[Double]] = {  for(i 
<- 0 to x.length-1) x(i)(i) = x(i)(i) + 1; x }
+ml.udf.register("addOneToDiagonal", addOneToDiagonal _)
+val script2 = dml("m1 = matrix(0, rows=3, cols=3); m2 = addOneToDiagonal(m1); 
print(toString(m2));")
+ml.execute(script2)
+
+// Demonstrates multi-return function
+def multiReturnFn(x:Double):(Double, Int) = (x + 1, (x * 2).toInt)
+ml.udf.register("multiReturnFn", multiReturnFn _)
+val script3 = dml("[v1, v2] = multiReturnFn(2.0); print(v1)")
+ml.execute(script3)
+
+// Demonstrates multi-argument multi-return function
+def multiArgReturnFn(x:Double, y:Int):(Double, Int) = (x + 1, (x * y).toInt)
+ml.udf.register("multiArgReturnFn", multiArgReturnFn _)
+val script4 = dml("[v1, v2] = multiArgReturnFn(2.0, 1); print(v2)")
+ml.execute(script4)
+{% endhighlight %}
+
 ---
 
 # Jupyter (PySpark) Notebook Example - Poisson Nonnegative Matrix Factorization

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/5d84a095/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java 
b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
index 641beae..9e3a67b 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
@@ -31,6 +31,7 @@ import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.api.MLContextProxy;
+import org.apache.sysml.api.ExternalUDFRegistration;
 import org.apache.sysml.api.jmlc.JMLCUtils;
 import org.apache.sysml.api.monitoring.SparkMonitoringUtil;
 import org.apache.sysml.conf.ConfigurationManager;
@@ -118,6 +119,12 @@ public class MLContext {
 
        private List<String> scriptHistoryStrings = new ArrayList<String>();
        private Map<String, Script> scripts = new LinkedHashMap<String, 
Script>();
+       
+       /**
+        * Allows users to register external scala UDFs.
+        * The design is explained in ExternalUDFRegistration.scala.
+        */
+       public ExternalUDFRegistration udf = null;
 
        /**
         * The different explain levels supported by SystemML.
@@ -217,6 +224,8 @@ public class MLContext {
                }
 
                this.sc = sc;
+               this.udf = new ExternalUDFRegistration();
+               this.udf.setMLContext(this);
                MLContextUtil.verifySparkVersionSupported(sc);
                // by default, run in hybrid Spark mode for optimal performance
                DMLScript.rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK;
@@ -258,7 +267,7 @@ public class MLContext {
                        throw new MLContextException(e);
                }
        }
-
+       
        /**
         * Execute a DML or PYDML Script.
         *
@@ -296,6 +305,7 @@ public class MLContext {
                                script.setName(time.toString());
                        }
 
+                       scriptExecutor.udf = udf;
                        MLResults results = scriptExecutor.execute(script);
 
                        String history = 
MLContextUtil.createHistoryForScript(script, time);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/5d84a095/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java 
b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
index 7f80267..265a683 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
@@ -25,6 +25,7 @@ import java.util.Set;
 
 import org.apache.commons.lang3.StringUtils;
 import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.ExternalUDFRegistration;
 import org.apache.sysml.api.jmlc.JMLCUtils;
 import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel;
 import org.apache.sysml.api.monitoring.SparkMonitoringUtil;
@@ -119,6 +120,7 @@ public class ScriptExecutor {
        protected boolean statistics = false;
        protected ExplainLevel explainLevel;
        protected int statisticsMaxHeavyHitters = 10;
+       public ExternalUDFRegistration udf;
 
        /**
         * ScriptExecutor constructor.
@@ -450,6 +452,12 @@ public class ScriptExecutor {
                                        inputParameters, 
script.getScriptType());
 
                        String scriptExecutionString = 
script.getScriptExecutionString();
+                       if(udf != null) {
+                               // Append the headers from Scala UDF.
+                               String externalHeaders = 
udf.getExternalHeaders();
+                               if(!externalHeaders.equals(""))
+                                       scriptExecutionString = externalHeaders 
+ scriptExecutionString;
+                       }
                        dmlProgram = parser.parse(null, scriptExecutionString, 
inputParametersStringMaps);
                } catch (ParseException e) {
                        throw new MLContextException("Exception occurred while 
parsing script", e);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/5d84a095/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlock.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlock.java
index 3c72ca9..b8e27ca 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlock.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlock.java
@@ -601,6 +601,11 @@ public class ExternalFunctionProgramBlock extends 
FunctionProgramBlock
                func.setConfiguration(configFile);
                func.setBaseDir(_baseDir);
                
+               
if(className.equals("org.apache.sysml.udf.lib.GenericFunction")) {
+                       
((org.apache.sysml.udf.lib.GenericFunction)func)._functionName = 
this._functionName;
+                       
((org.apache.sysml.udf.lib.GenericFunction)func)._namespace = this._namespace;
+               }
+               
                //executes function
                func.execute();
                

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/5d84a095/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlockCP.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlockCP.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlockCP.java
index 1d7e5dd..d5a9125 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlockCP.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/ExternalFunctionProgramBlockCP.java
@@ -100,6 +100,8 @@ public class ExternalFunctionProgramBlockCP extends 
ExternalFunctionProgramBlock
                {
                        try {
                                inst = 
(ExternalFunctionInvocationInstruction)_inst.get(i);
+                               inst._namespace = _namespace;
+                               inst._functionName = _functionName;
                                executeInstruction( ec, inst );
                        }
                        catch (Exception e){

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/5d84a095/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java
index cf8e73e..b2f6e53 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java
@@ -34,7 +34,8 @@ import org.apache.sysml.utils.Statistics;
 
 public class FunctionProgramBlock extends ProgramBlock 
 {
-       
+       public String _functionName;
+       public String _namespace;
        protected ArrayList<ProgramBlock> _childBlocks;
        protected ArrayList<DataIdentifier> _inputParams;
        protected ArrayList<DataIdentifier> _outputParams;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/5d84a095/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java
index ccb90b6..67afdc8 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -167,6 +167,8 @@ public class FunctionCallCPInstruction extends CPInstruction
                
                // execute the function block
                try {
+                       fpb._functionName = this._functionName;
+                       fpb._namespace = this._namespace;
                        fpb.execute(fn_ec);
                }
                catch (DMLScriptException e) {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/5d84a095/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java 
b/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java
index 85f8738..476970a 100644
--- 
a/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java
+++ 
b/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java
@@ -34,6 +34,8 @@ public class ExternalFunctionInvocationInstruction extends 
Instruction
        
        public static final String ELEMENT_DELIM = ":";
        
+       public String _namespace;
+       public String _functionName;
        protected String className; // name of class that contains the function
        protected String configFile; // optional configuration file parameter
        protected String inputParams; // string representation of input 
parameters

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/5d84a095/src/main/java/org/apache/sysml/udf/lib/GenericFunction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/udf/lib/GenericFunction.java 
b/src/main/java/org/apache/sysml/udf/lib/GenericFunction.java
new file mode 100644
index 0000000..a8110a5
--- /dev/null
+++ b/src/main/java/org/apache/sysml/udf/lib/GenericFunction.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.udf.lib;
+
+import java.io.IOException;
+
+import org.apache.commons.lang.StringUtils;
+import org.apache.sysml.api.ExternalUDFRegistration;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.udf.FunctionParameter;
+import org.apache.sysml.udf.Matrix;
+import org.apache.sysml.udf.PackageFunction;
+import org.apache.sysml.udf.Scalar;
+
+import scala.Function0;
+
+public class GenericFunction extends PackageFunction {
+       private static final long serialVersionUID = -195996547505886575L;
+       String [] fnSignature;
+       FunctionParameter [] returnVals;
+       Function0<FunctionParameter []> scalaUDF;
+       public String _functionName;
+       public String _namespace;
+       
+       public void initialize() {
+               if(_namespace != null && 
!_namespace.equals(DMLProgram.DEFAULT_NAMESPACE)) {
+                       throw new RuntimeException("Expected the function in 
default namespace");
+               }
+               if(_functionName == null) {
+                       throw new RuntimeException("Expected the function name 
to be set");
+               }
+               if(fnSignature == null) {
+                       fnSignature = 
ExternalUDFRegistration.fnSignatureMapping().get(_functionName);
+                       scalaUDF = 
ExternalUDFRegistration.fnMapping().get(_functionName);
+                       ExternalUDFRegistration.udfMapping().put(_functionName, 
this);
+               }
+       }
+       
+       @Override
+       public int getNumFunctionOutputs() {
+               initialize();
+               String retSignature = fnSignature[fnSignature.length -1];
+               if(!retSignature.startsWith("("))
+                       return 1;
+               else {
+                       return StringUtils.countMatches(retSignature, ",") + 1;
+               }
+       }
+
+       @Override
+       public FunctionParameter getFunctionOutput(int pos) {
+               initialize();
+               if(returnVals == null || returnVals.length <= pos)
+                       throw new RuntimeException("Incorrect number of outputs 
or function not executed");
+               return returnVals[pos];
+       }
+
+       @Override
+       public void execute() {
+               initialize();
+               returnVals = scalaUDF.apply();
+       }
+       
+       public Object getInput(String type, int pos) throws 
DMLRuntimeException, IOException {
+               if(type.equals("Int") || type.equals("java.lang.Integer")) {
+                       return 
Integer.parseInt(((Scalar)getFunctionInput(pos)).getValue());
+               }
+               else if(type.equals("Double") || 
type.equals("java.lang.Double")) {
+                       return 
Double.parseDouble(((Scalar)getFunctionInput(pos)).getValue());
+               }
+               else if(type.equals("java.lang.String")) {
+                       return ((Scalar)getFunctionInput(pos)).getValue();
+               }
+               else if(type.equals("boolean") || 
type.equals("java.lang.Boolean")) {
+                       return 
Boolean.parseBoolean(((Scalar)getFunctionInput(pos)).getValue());
+               }
+               else if(type.equals("scala.Array[scala.Array[Double]]")) {
+                       return ((Matrix) 
getFunctionInput(pos)).getMatrixAsDoubleArray();
+               }
+               
+               throw new RuntimeException("Unsupported type: " + type);
+       }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/5d84a095/src/main/scala/org/apache/sysml/api/ExternalUDFRegistration.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ExternalUDFRegistration.scala 
b/src/main/scala/org/apache/sysml/api/ExternalUDFRegistration.scala
new file mode 100644
index 0000000..be94b31
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ExternalUDFRegistration.scala
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api;
+
+import scala.reflect.runtime.universe._
+import java.util.ArrayList
+import org.apache.sysml.udf.FunctionParameter
+import org.apache.sysml.udf.Scalar
+import org.apache.sysml.udf.Matrix
+import org.apache.sysml.udf.Matrix.ValueType
+import org.apache.sysml.api.mlcontext.Script
+import org.apache.sysml.udf.PackageFunction
+import org.apache.sysml.udf.FunctionParameter
+import org.apache.sysml.udf.lib.GenericFunction
+import org.apache.sysml.udf.Scalar.ScalarValueType
+import java.util.HashMap
+
+/*
+ * Design of Scala external UDF functionality:
+ * Two main classes in that enable this functionality are as follows:
+ * 1. ExternalUDFRegistration: We have overloaded the register method to allow 
for registration
+ * of scala UDFs with 10 arguments. Each of these functions examine the input 
types to check
+ * if they are supported (see getType). If input types are supported, then it 
creates a header of format:
+ * 
+ * fnName = externalFunction(input arguments) return (output arguments) 
implemented in 
(classname="org.apache.sysml.udf.lib.GenericFunction",exectype="mem")
+ * 
+ * This header is appended in MLContext before execution of the script.
+ * 
+ * In addition, it populates two global data structures: fnMapping (which 
stores a zero-argument anonymous
+ * function) and fnSignatureMapping (useful for computing the number of return 
values).
+ * These data structures are used by GenericFunction.
+ * 
+ * The secret sauce of this approach is conversion of arbitrary Scala UDF into 
a zero-argument anonymous UDF
+ * stored in ExternalUDFRegistration's fnMapping data structure (similar to 
execute) :)
+ * 
+ * 2. GenericFunction
+ * This generic class is called by SystemML for any registered Scala UDF. This 
class first inserts itself into
+ * ExternalUDFRegistration's udfMapping data structure and then invokes the 
zero-argument anonymous
+ * function corresponding to the user specified Scala UDF.
+ *  
+ * 
+ * The current implementation allows the functions registered with one 
MLContext 
+ * to be visible to other MLContext as well as ExternalUDFRegistration's 
fnMapping, fnSignatureMapping and udfMapping
+ * fields are static. This is necessary to simplify the integration with 
existing external UDF function framework.
+ * 
+ * Usage:
+ * scala> import org.apache.sysml.api.mlcontext._
+ * scala> import org.apache.sysml.api.mlcontext.ScriptFactory._
+ * scala> val ml = new MLContext(sc)
+ * scala> 
+ * scala> // Demonstrates how to pass a simple scala UDF to SystemML
+ * scala> def addOne(x:Double):Double = x + 1
+ * scala> ml.udf.register("addOne", addOne)
+ * scala> val script1 = dml("v = addOne(2.0); print(v)")
+ * scala> ml.execute(script1)
+ * scala> 
+ * scala> // Demonstrates operation on local matrices (double[][])
+ * scala> def addOneToDiagonal(x:Array[Array[Double]]):Array[Array[Double]] = 
{  for(i <- 0 to x.length-1) x(i)(i) = x(i)(i) + 1; x }
+ * scala> ml.udf.register("addOneToDiagonal", addOneToDiagonal)
+ * scala> val script2 = dml("m1 = matrix(0, rows=3, cols=3); m2 = 
addOneToDiagonal(m1); print(toString(m2));")
+ * scala> ml.execute(script2)
+ * scala> 
+ * scala> // Demonstrates multi-return function
+ * scala> def multiReturnFn(x:Double):(Double, Int) = (x + 1, (x * 2).toInt)
+ * scala> ml.udf.register("multiReturnFn", multiReturnFn)
+ * scala> val script3 = dml("[v1, v2] = multiReturnFn(2.0); print(v1)")
+ * scala> ml.execute(script3)
+ * scala> 
+ * scala> // Demonstrates multi-argument multi-return function
+ * scala> def multiArgReturnFn(x:Double, y:Int):(Double, Int) = (x + 1, (x * 
y).toInt)
+ * scala> ml.udf.register("multiArgReturnFn", multiArgReturnFn _)
+ * scala> val script4 = dml("[v1, v2] = multiArgReturnFn(2.0, 1); print(v2)")
+ * scala> ml.execute(script4)
+ */
+
+object ExternalUDFRegistration {
+  val fnMapping: HashMap[String, Function0[Array[FunctionParameter]]] = new 
HashMap[String, Function0[Array[FunctionParameter]]]()
+  val fnSignatureMapping: HashMap[String, Array[String]] = new HashMap[String, 
Array[String]]()
+  val udfMapping:HashMap[String, GenericFunction] = new HashMap[String, 
GenericFunction]();
+}
+
+/**
+ * This class handles the registration of external Scala UDFs via MLContext.
+ */
+class ExternalUDFRegistration {
+  var ml:MLContext = null
+  def setMLContext(ml1:org.apache.sysml.api.mlcontext.MLContext) = { this.ml = 
ml }
+  
+  val scriptHeaders:HashMap[String,StringBuilder] = new 
HashMap[String,StringBuilder]()
+  def getExternalHeaders(): String = {
+    val it = scriptHeaders.entrySet().iterator();
+    val ret = new StringBuilder
+    while (it.hasNext()) {
+      val header = it.next().getValue.toString() 
+      if(!header.equals("")) {
+        ret.append(header + "\n")
+      }
+    }
+    // Useful for debugging:
+    // System.out.println(ret.toString)
+    ret.toString()
+  }
+  
+  def getType(t: String):String = {
+    t match {
+      case "java.lang.String" => "string"
+      case "Double" => "double"
+      case "Int" => "integer"
+      case "Boolean" => "boolean"
+      // Support only pass by value for now.
+      // case "org.apache.sysml.runtime.matrix.data.MatrixBlock" => 
"matrix[double]"
+      // case "scala.Array[Double]" => "matrix[double]"
+      case "scala.Array[scala.Array[Double]]" => "matrix[double]"
+      case _ => throw new RuntimeException("Unsupported type of parameter: " + 
t)
+    }
+  }
+  
+  def getReturnType(t: String):String = {
+    if(t.startsWith("(")) {
+      val t1 = t.substring(1, t.length()-1).split(",").map(_.trim)
+      val ret = new StringBuilder
+      for(i <- 0 until t1.length) {
+        if(i != 0) ret.append(", ")
+        ret.append(getType(t1(i)) + " output" + i)
+      }
+      ret.toString
+    }
+    else
+      getType(t) + " output0"
+  }
+  
+  def createExternalFunctionHeader(name:String, typeInput:Array[String]): Unit 
= {
+    if(scriptHeaders.containsKey(name)) scriptHeaders.remove(name)
+    val header:StringBuilder = new StringBuilder()
+    header.append(name + " = externalFunction(")
+    header.append(getType(typeInput(0)) + " input0")
+    for(i <- 1 until typeInput.length -1) {
+      header.append(", " + getType(typeInput(i)) + " input" + i)
+    }
+    header.append(") return (")
+    header.append(getReturnType( typeInput(typeInput.length -1) ))
+    header.append(") implemented in 
(classname=\"org.apache.sysml.udf.lib.GenericFunction\", exectype=\"mem\");\n")
+    scriptHeaders.put(name, header)
+    ExternalUDFRegistration.fnSignatureMapping.put(name, typeInput)
+  }
+  
+  // 
------------------------------------------------------------------------------------------
+  // Overloaded register function for 1 to 10 inputs:
+  
+   // zero-input function unsupported by SystemML
+//  def register[RT: TypeTag](name: String, func: Function0[RT]): Unit = {
+//    println(getType(typeOf[RT].toString()))
+//  }
+   
+  def unregister(name: String): Unit = {
+    ExternalUDFRegistration.fnSignatureMapping.remove(name)
+    ExternalUDFRegistration.fnMapping.remove(name)
+    ExternalUDFRegistration.udfMapping.remove(name)
+    scriptHeaders.remove(name)
+  }
+  
+   def register[A1: TypeTag, RT: TypeTag](name: String, func: Function1[A1, 
RT]): Unit = {
+    val anonfun0 = new Function0[Array[FunctionParameter]] {
+       def apply(): Array[FunctionParameter] = {
+         val udf = ExternalUDFRegistration.udfMapping.get(name);
+         return 
convertReturnToOutput(func.apply(udf.getInput(typeOf[A1].toString(), 
0).asInstanceOf[A1]))
+       }
+    }
+    createExternalFunctionHeader(name, Array(typeOf[A1].toString(), 
typeOf[RT].toString()))
+    ExternalUDFRegistration.fnMapping.put(name, anonfun0)
+  }
+  
+  def register[A1: TypeTag, A2: TypeTag, RT: TypeTag](name: String, func: 
Function2[A1, A2, RT]): Unit = {
+    val anonfun0 = new Function0[Array[FunctionParameter]] {
+       def apply(): Array[FunctionParameter] = {
+         val udf = ExternalUDFRegistration.udfMapping.get(name);
+         return 
convertReturnToOutput(func.apply(udf.getInput(typeOf[A1].toString(), 
0).asInstanceOf[A1],
+             udf.getInput(typeOf[A2].toString(), 1).asInstanceOf[A2]))
+       }
+    }
+    createExternalFunctionHeader(name, Array(typeOf[A1].toString(), 
typeOf[A2].toString(), typeOf[RT].toString()))
+    ExternalUDFRegistration.fnMapping.put(name, anonfun0)
+  }
+  
+  def register[A1: TypeTag, A2: TypeTag, A3: TypeTag, RT: TypeTag](name: 
String, func: Function3[A1, A2, A3, RT]): Unit = {
+    val anonfun0 = new Function0[Array[FunctionParameter]] {
+       def apply(): Array[FunctionParameter] = {
+         val udf = ExternalUDFRegistration.udfMapping.get(name);
+         return 
convertReturnToOutput(func.apply(udf.getInput(typeOf[A1].toString(), 
0).asInstanceOf[A1],
+             udf.getInput(typeOf[A2].toString(), 1).asInstanceOf[A2], 
+             udf.getInput(typeOf[A3].toString(), 2).asInstanceOf[A3]))
+       }
+    }
+    createExternalFunctionHeader(name, Array(
+        typeOf[A1].toString(), typeOf[A2].toString(), typeOf[A3].toString(), 
+        typeOf[RT].toString()))
+    ExternalUDFRegistration.fnMapping.put(name, anonfun0)
+  }
+  
+  def register[A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, RT: 
TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): Unit = {
+    val anonfun0 = new Function0[Array[FunctionParameter]] {
+       def apply(): Array[FunctionParameter] = {
+         val udf = ExternalUDFRegistration.udfMapping.get(name);
+         return 
convertReturnToOutput(func.apply(udf.getInput(typeOf[A1].toString(), 
0).asInstanceOf[A1],
+             udf.getInput(typeOf[A2].toString(), 1).asInstanceOf[A2], 
+             udf.getInput(typeOf[A3].toString(), 2).asInstanceOf[A3],
+             udf.getInput(typeOf[A4].toString(), 3).asInstanceOf[A4]))
+       }
+    }
+    createExternalFunctionHeader(name, Array(
+        typeOf[A1].toString(), typeOf[A2].toString(), typeOf[A3].toString(), 
typeOf[A4].toString(), 
+        typeOf[RT].toString()))
+    ExternalUDFRegistration.fnMapping.put(name, anonfun0)
+  }
+  
+  def register[A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: 
TypeTag, RT: TypeTag](name: String, 
+      func: Function5[A1, A2, A3, A4, A5, RT]): Unit = {
+    val anonfun0 = new Function0[Array[FunctionParameter]] {
+       def apply(): Array[FunctionParameter] = {
+         val udf = ExternalUDFRegistration.udfMapping.get(name);
+         return 
convertReturnToOutput(func.apply(udf.getInput(typeOf[A1].toString(), 
0).asInstanceOf[A1],
+             udf.getInput(typeOf[A2].toString(), 1).asInstanceOf[A2], 
+             udf.getInput(typeOf[A3].toString(), 2).asInstanceOf[A3],
+             udf.getInput(typeOf[A4].toString(), 3).asInstanceOf[A4], 
+             udf.getInput(typeOf[A5].toString(), 4).asInstanceOf[A5]))
+       }
+    }
+    createExternalFunctionHeader(name, Array(
+        typeOf[A1].toString(), typeOf[A2].toString(), typeOf[A3].toString(), 
typeOf[A4].toString(),
+        typeOf[A5].toString(),
+        typeOf[RT].toString()))
+    ExternalUDFRegistration.fnMapping.put(name, anonfun0)
+  }
+  
+  def register[A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: 
TypeTag, A6: TypeTag, RT: TypeTag](name: String, 
+      func: Function6[A1, A2, A3, A4, A5, A6, RT]): Unit = {
+    val anonfun0 = new Function0[Array[FunctionParameter]] {
+       def apply(): Array[FunctionParameter] = {
+         val udf = ExternalUDFRegistration.udfMapping.get(name);
+         return 
convertReturnToOutput(func.apply(udf.getInput(typeOf[A1].toString(), 
0).asInstanceOf[A1],
+             udf.getInput(typeOf[A2].toString(), 1).asInstanceOf[A2], 
+             udf.getInput(typeOf[A3].toString(), 2).asInstanceOf[A3],
+             udf.getInput(typeOf[A4].toString(), 3).asInstanceOf[A4], 
+             udf.getInput(typeOf[A5].toString(), 4).asInstanceOf[A5], 
+             udf.getInput(typeOf[A6].toString(), 5).asInstanceOf[A6]))
+       }
+    }
+    createExternalFunctionHeader(name, Array(
+        typeOf[A1].toString(), typeOf[A2].toString(), typeOf[A3].toString(), 
typeOf[A4].toString(),
+        typeOf[A5].toString(), typeOf[A6].toString(),
+        typeOf[RT].toString()))
+    ExternalUDFRegistration.fnMapping.put(name, anonfun0)
+  }
+  
+  def register[A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: 
TypeTag, A6: TypeTag, A7: TypeTag, RT: TypeTag](name: String, 
+      func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): Unit = {
+    val anonfun0 = new Function0[Array[FunctionParameter]] {
+       def apply(): Array[FunctionParameter] = {
+         val udf = ExternalUDFRegistration.udfMapping.get(name);
+         return 
convertReturnToOutput(func.apply(udf.getInput(typeOf[A1].toString(), 
0).asInstanceOf[A1],
+             udf.getInput(typeOf[A2].toString(), 1).asInstanceOf[A2], 
+             udf.getInput(typeOf[A3].toString(), 2).asInstanceOf[A3],
+             udf.getInput(typeOf[A4].toString(), 3).asInstanceOf[A4], 
+             udf.getInput(typeOf[A5].toString(), 4).asInstanceOf[A5], 
+             udf.getInput(typeOf[A6].toString(), 5).asInstanceOf[A6],
+             udf.getInput(typeOf[A7].toString(), 6).asInstanceOf[A7]))
+       }
+    }
+    createExternalFunctionHeader(name, Array(
+        typeOf[A1].toString(), typeOf[A2].toString(), typeOf[A3].toString(), 
typeOf[A4].toString(),
+        typeOf[A5].toString(), typeOf[A6].toString(), typeOf[A7].toString(),
+        typeOf[RT].toString()))
+    ExternalUDFRegistration.fnMapping.put(name, anonfun0)
+  }
+  
+  def register[A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: 
TypeTag, A6: TypeTag, A7: TypeTag, 
+    A8: TypeTag, RT: TypeTag](name: String, 
+      func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): Unit = {
+    val anonfun0 = new Function0[Array[FunctionParameter]] {
+       def apply(): Array[FunctionParameter] = {
+         val udf = ExternalUDFRegistration.udfMapping.get(name);
+         return 
convertReturnToOutput(func.apply(udf.getInput(typeOf[A1].toString(), 
0).asInstanceOf[A1],
+             udf.getInput(typeOf[A2].toString(), 1).asInstanceOf[A2], 
+             udf.getInput(typeOf[A3].toString(), 2).asInstanceOf[A3],
+             udf.getInput(typeOf[A4].toString(), 3).asInstanceOf[A4], 
+             udf.getInput(typeOf[A5].toString(), 4).asInstanceOf[A5], 
+             udf.getInput(typeOf[A6].toString(), 5).asInstanceOf[A6],
+             udf.getInput(typeOf[A7].toString(), 6).asInstanceOf[A7], 
+             udf.getInput(typeOf[A8].toString(), 7).asInstanceOf[A8]))
+       }
+    }
+    createExternalFunctionHeader(name, Array(
+        typeOf[A1].toString(), typeOf[A2].toString(), typeOf[A3].toString(), 
typeOf[A4].toString(),
+        typeOf[A5].toString(), typeOf[A6].toString(), typeOf[A7].toString(), 
typeOf[A8].toString(),
+        typeOf[RT].toString()))
+    ExternalUDFRegistration.fnMapping.put(name, anonfun0)
+  }
+  
+  def register[A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: 
TypeTag, A6: TypeTag, A7: TypeTag, 
+    A8: TypeTag, A9: TypeTag, RT: TypeTag](name: String, 
+      func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): Unit = {
+    val anonfun0 = new Function0[Array[FunctionParameter]] {
+       def apply(): Array[FunctionParameter] = {
+         val udf = ExternalUDFRegistration.udfMapping.get(name);
+         return 
convertReturnToOutput(func.apply(udf.getInput(typeOf[A1].toString(), 
0).asInstanceOf[A1],
+             udf.getInput(typeOf[A2].toString(), 1).asInstanceOf[A2], 
+             udf.getInput(typeOf[A3].toString(), 2).asInstanceOf[A3],
+             udf.getInput(typeOf[A4].toString(), 3).asInstanceOf[A4], 
+             udf.getInput(typeOf[A5].toString(), 4).asInstanceOf[A5], 
+             udf.getInput(typeOf[A6].toString(), 5).asInstanceOf[A6],
+             udf.getInput(typeOf[A7].toString(), 6).asInstanceOf[A7], 
+             udf.getInput(typeOf[A8].toString(), 7).asInstanceOf[A8], 
+             udf.getInput(typeOf[A9].toString(), 8).asInstanceOf[A9]))
+       }
+    }
+    createExternalFunctionHeader(name, Array(
+        typeOf[A1].toString(), typeOf[A2].toString(), typeOf[A3].toString(), 
typeOf[A4].toString(),
+        typeOf[A5].toString(), typeOf[A6].toString(), typeOf[A7].toString(), 
typeOf[A8].toString(),
+        typeOf[A9].toString(),
+        typeOf[RT].toString()))
+    ExternalUDFRegistration.fnMapping.put(name, anonfun0)
+  }
+  
+  def register[A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: 
TypeTag, A6: TypeTag, A7: TypeTag, 
+    A8: TypeTag, A9: TypeTag, A10: TypeTag, RT: TypeTag](name: String, 
+      func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): Unit = {
+    val anonfun0 = new Function0[Array[FunctionParameter]] {
+       def apply(): Array[FunctionParameter] = {
+         val udf = ExternalUDFRegistration.udfMapping.get(name);
+         return 
convertReturnToOutput(func.apply(udf.getInput(typeOf[A1].toString(), 
0).asInstanceOf[A1],
+             udf.getInput(typeOf[A2].toString(), 1).asInstanceOf[A2], 
+             udf.getInput(typeOf[A3].toString(), 2).asInstanceOf[A3],
+             udf.getInput(typeOf[A4].toString(), 3).asInstanceOf[A4], 
+             udf.getInput(typeOf[A5].toString(), 4).asInstanceOf[A5], 
+             udf.getInput(typeOf[A6].toString(), 5).asInstanceOf[A6],
+             udf.getInput(typeOf[A7].toString(), 6).asInstanceOf[A7], 
+             udf.getInput(typeOf[A8].toString(), 7).asInstanceOf[A8], 
+             udf.getInput(typeOf[A9].toString(), 8).asInstanceOf[A9],
+             udf.getInput(typeOf[A10].toString(), 9).asInstanceOf[A10]))
+       }
+    }
+    createExternalFunctionHeader(name, Array(
+        typeOf[A1].toString(), typeOf[A2].toString(), typeOf[A3].toString(), 
typeOf[A4].toString(),
+        typeOf[A5].toString(), typeOf[A6].toString(), typeOf[A7].toString(), 
typeOf[A8].toString(),
+        typeOf[A9].toString(), typeOf[A10].toString(),
+        typeOf[RT].toString()))
+    ExternalUDFRegistration.fnMapping.put(name, anonfun0)
+  }
+  
+  // 
------------------------------------------------------------------------------------------
+  
+  def convertReturnToOutput(ret:Any): Array[FunctionParameter] = {
+    ret match {
+       case x:Tuple1[Any] => Array(convertToOutput(x._1))
+       case x:Tuple2[Any, Any] => Array(convertToOutput(x._1), 
convertToOutput(x._2))
+       case x:Tuple3[Any, Any, Any] => Array(convertToOutput(x._1), 
convertToOutput(x._2), convertToOutput(x._3))
+       case x:Tuple4[Any, Any, Any, Any] => Array(convertToOutput(x._1), 
convertToOutput(x._2), convertToOutput(x._3), convertToOutput(x._4))
+       case x:Tuple5[Any, Any, Any, Any, Any] => Array(convertToOutput(x._1), 
convertToOutput(x._2), convertToOutput(x._3), convertToOutput(x._4), 
convertToOutput(x._5))
+       case x:Tuple6[Any, Any, Any, Any, Any, Any] => 
Array(convertToOutput(x._1), convertToOutput(x._2), convertToOutput(x._3), 
convertToOutput(x._4), convertToOutput(x._5), convertToOutput(x._6))
+       case x:Tuple7[Any, Any, Any, Any, Any, Any, Any] => 
Array(convertToOutput(x._1), convertToOutput(x._2), convertToOutput(x._3), 
convertToOutput(x._4), convertToOutput(x._5), convertToOutput(x._6), 
convertToOutput(x._7))
+       case x:Tuple8[Any, Any, Any, Any, Any, Any, Any, Any] => 
Array(convertToOutput(x._1), convertToOutput(x._2), convertToOutput(x._3), 
convertToOutput(x._4), convertToOutput(x._5), convertToOutput(x._6), 
convertToOutput(x._7), convertToOutput(x._8))
+       case x:Tuple9[Any, Any, Any, Any, Any, Any, Any, Any, Any] => 
Array(convertToOutput(x._1), convertToOutput(x._2), convertToOutput(x._3), 
convertToOutput(x._4), convertToOutput(x._5), convertToOutput(x._6), 
convertToOutput(x._7), 
+                                                                 
convertToOutput(x._8), convertToOutput(x._9))
+       case x:Tuple10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] => 
Array(convertToOutput(x._1), convertToOutput(x._2), convertToOutput(x._3), 
convertToOutput(x._4), convertToOutput(x._5), convertToOutput(x._6), 
convertToOutput(x._7), 
+                                                                 
convertToOutput(x._8), convertToOutput(x._9), convertToOutput(x._10))           
                                               
+       case _ => Array(convertToOutput(ret))
+     }
+  }
+   val rand = new java.util.Random()
+   def convertToOutput(x:Any): FunctionParameter = {
+     x match {
+       case x1:Int => return new Scalar(ScalarValueType.Integer, 
String.valueOf(x))
+       case x1:java.lang.Integer => return new Scalar(ScalarValueType.Integer, 
String.valueOf(x))
+       case x1:Double => return new Scalar(ScalarValueType.Double, 
String.valueOf(x))
+       case x1:java.lang.Double => return new Scalar(ScalarValueType.Double, 
String.valueOf(x))
+       case x1:java.lang.String => return new Scalar(ScalarValueType.Text, 
String.valueOf(x))
+       case x1:java.lang.Boolean => return new Scalar(ScalarValueType.Boolean, 
String.valueOf(x))
+       case x1:Boolean => return new Scalar(ScalarValueType.Boolean, 
String.valueOf(x))
+       case x1:scala.Array[scala.Array[Double]] => {
+         val mat = new Matrix( "temp" + rand.nextLong, x1.length, 
x1(0).length, ValueType.Double );
+                          mat.setMatrixDoubleArray(x1)
+                          return mat
+       }
+       case _ => throw new RuntimeException("Unsupported output type:" + 
x.getClass().getName)
+     }
+   }
+}
\ No newline at end of file


Reply via email to