Repository: systemml
Updated Branches:
  refs/heads/master be465dd65 -> 81419ae6a


[MINOR] Support the list datatype in external UDF

- Also added RemoveDuplicates to show the usage.

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/81419ae6
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/81419ae6
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/81419ae6

Branch: refs/heads/master
Commit: 81419ae6a0abcc13e2e84307b7af38732c1892cd
Parents: be465dd
Author: Niketan Pansare <[email protected]>
Authored: Wed Aug 29 19:40:19 2018 -0700
Committer: Niketan Pansare <[email protected]>
Committed: Wed Aug 29 19:40:19 2018 -0700

----------------------------------------------------------------------
 .../ExternalFunctionInvocationInstruction.java  |  18 ++-
 .../org/apache/sysml/udf/FunctionParameter.java |   1 +
 src/main/java/org/apache/sysml/udf/List.java    |  36 +++++
 .../apache/sysml/udf/lib/RemoveDuplicates.java  | 139 +++++++++++++++++++
 4 files changed, 193 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/81419ae6/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java 
b/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java
index c4e4198..6151370 100644
--- 
a/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java
+++ 
b/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java
@@ -33,6 +33,7 @@ import org.apache.sysml.runtime.instructions.cp.BooleanObject;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
 import org.apache.sysml.runtime.instructions.cp.DoubleObject;
 import org.apache.sysml.runtime.instructions.cp.IntObject;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
 import org.apache.sysml.runtime.instructions.cp.ScalarObject;
 import org.apache.sysml.runtime.instructions.cp.StringObject;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
@@ -78,7 +79,6 @@ public class ExternalFunctionInvocationInstruction extends 
Instruction
                verifyAndAttachOutputs(ec, fun, outputs);
        }
        
-       @SuppressWarnings("incomplete-switch")
        private ArrayList<FunctionParameter> getInputObjects(CPOperand[] 
inputs, LocalVariableMap vars) {
                ArrayList<FunctionParameter> ret = new ArrayList<>();
                for( CPOperand input : inputs ) {
@@ -94,6 +94,12 @@ public class ExternalFunctionInvocationInstruction extends 
Instruction
                                case OBJECT:
                                        ret.add(new 
BinaryObject(vars.get(input.getName())));
                                        break;
+                               case LIST:
+                                       ret.add(new List((ListObject) 
vars.get(input.getName())));
+                                       break;
+                               default:
+                                       throw new 
DMLRuntimeException("Unsupported data type: "
+                                                       
+input.getDataType().name());
                        }
                }
                return ret;
@@ -125,11 +131,14 @@ public class ExternalFunctionInvocationInstruction 
extends Instruction
                        CPOperand output = outputs[i];
                        switch( fun.getFunctionOutput(i).getType() ) {
                                case Matrix:
+                               {
                                        Matrix m = (Matrix) 
fun.getFunctionOutput(i);
                                        MatrixObject newVar = 
createOutputMatrixObject( m );
                                        ec.setVariable(output.getName(), 
newVar);
                                        break;
+                               }
                                case Scalar:
+                               {
                                        Scalar s = (Scalar) 
fun.getFunctionOutput(i);
                                        ScalarObject scalarObject = null;
                                        switch( s.getScalarType() ) {
@@ -151,6 +160,13 @@ public class ExternalFunctionInvocationInstruction extends 
Instruction
                                        }
                                        ec.setVariable(output.getName(), 
scalarObject);
                                        break;
+                               }
+                               case List:
+                               {
+                                       List l = (List) 
fun.getFunctionOutput(i);
+                                       ec.setVariable(output.getName(), 
l.getListObject());
+                                       break;
+                               }
                                default:
                                        throw new 
DMLRuntimeException("Unsupported data type: "
                                                
+fun.getFunctionOutput(i).getType().name());

http://git-wip-us.apache.org/repos/asf/systemml/blob/81419ae6/src/main/java/org/apache/sysml/udf/FunctionParameter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/udf/FunctionParameter.java 
b/src/main/java/org/apache/sysml/udf/FunctionParameter.java
index a45065c..1b3022d 100644
--- a/src/main/java/org/apache/sysml/udf/FunctionParameter.java
+++ b/src/main/java/org/apache/sysml/udf/FunctionParameter.java
@@ -39,6 +39,7 @@ public abstract class FunctionParameter implements 
Serializable
                Matrix, 
                Scalar, 
                Object,
+               List
        }
        
        private FunctionParameterType _type;

http://git-wip-us.apache.org/repos/asf/systemml/blob/81419ae6/src/main/java/org/apache/sysml/udf/List.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/udf/List.java 
b/src/main/java/org/apache/sysml/udf/List.java
new file mode 100644
index 0000000..b2f39b7
--- /dev/null
+++ b/src/main/java/org/apache/sysml/udf/List.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.udf;
+
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+public class List extends FunctionParameter {
+       private static final long serialVersionUID = -3230908817131624857L;
+       protected ListObject _lObj;
+
+       public List(ListObject obj) {
+               super(FunctionParameterType.List);
+               _lObj = obj;
+       }
+       
+       public ListObject getListObject() {
+               return _lObj;
+       }
+
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/81419ae6/src/main/java/org/apache/sysml/udf/lib/RemoveDuplicates.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/udf/lib/RemoveDuplicates.java 
b/src/main/java/org/apache/sysml/udf/lib/RemoveDuplicates.java
new file mode 100644
index 0000000..5c5e0d5
--- /dev/null
+++ b/src/main/java/org/apache/sysml/udf/lib/RemoveDuplicates.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.udf.lib;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Random;
+
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.instructions.cp.Data;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.runtime.matrix.data.InputInfo;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.OutputInfo;
+import org.apache.sysml.udf.FunctionParameter;
+import org.apache.sysml.udf.List;
+import org.apache.sysml.udf.Matrix;
+import org.apache.sysml.udf.PackageFunction;
+import org.apache.sysml.udf.Matrix.ValueType;
+
+/**
+ * Use this class to remove duplicate matrices from list of matrices.
+ * It also returns the indexes which maps the original input list to the 
output list.
+ * 
+ * Usage:
+ * <pre>
+ * <code>
+ * distinct = externalFunction(list[unknown] inL) return (list[unknown] outL, 
matrix[double] idx) implemented in 
(classname="org.apache.sysml.udf.lib.RemoveDuplicates", exectype="mem");
+ * X = rand(rows=10, cols=10)
+ * Y = X*sum(X);
+ * Z = sum(X)*X;
+ * W = X*sum(X);
+ * inL = list(Y, Z, W)
+ * [outL, idx] = distinct(inL);
+ * print(">>\n" + toString(idx));
+ * </code>
+ * </pre>
+ * 
+ * The above code prints:
+ * >>
+ * 1.000
+ * 2.000
+ * 1.000
+ */
+public class RemoveDuplicates extends PackageFunction {
+       private static final long serialVersionUID = -3905212831582648882L;
+
+       private List outputList;
+       private Matrix indexes;
+       private Random rand = new Random();
+       
+       @Override
+       public int getNumFunctionOutputs() {
+               return 2;
+       }
+
+       @Override
+       public FunctionParameter getFunctionOutput(int pos) {
+               if(pos == 0)
+                       return outputList;
+               else if(pos == 1)
+                       return indexes;
+               throw new RuntimeException("Invalid function output being 
requested");
+       }
+       
+       private int indexOf(java.util.List<MatrixBlock> list, MatrixBlock mb) {
+//             Caused by: java.lang.RuntimeException: equals should never be 
called for matrix blocks.
+//             at 
org.apache.sysml.runtime.matrix.data.MatrixBlock.equals(MatrixBlock.java:5644)
+//             return list.indexOf(mb);
+               for(int i = 0; i < list.size(); i++) {
+                       if(list.get(i) == mb) {
+                               return i;
+                       }
+               }
+               return -1;
+       }
+
+       @Override
+       public void execute() {
+               java.util.List<Data> inputData = 
((List)getFunctionInput(0)).getListObject().getData();
+               java.util.List<Data> outputData = new ArrayList<>();
+               java.util.List<MatrixBlock> outputMB = new ArrayList<>();
+               indexes = new Matrix( "tmp_" + rand.nextLong(), 
inputData.size(), 1, ValueType.Double );
+               MatrixBlock indexesMB = allocateDenseMatrixBlock(indexes);
+               double [] indexesData = indexesMB.getDenseBlockValues();
+               
+               for(int i = 0; i < inputData.size(); i++) {
+                       Data elem = inputData.get(i);
+                       if(elem instanceof MatrixObject) {
+                               MatrixBlock mb = 
((MatrixObject)elem).acquireRead();
+                               int index = indexOf(outputMB, mb);
+                               if(index >= 0) {
+                                       indexesData[i] = indexOf(outputMB, mb) 
+ 1;
+                               }
+                               else {
+                                       outputMB.add(mb);
+                                       outputData.add(elem);
+                                       indexesData[i] = outputMB.size();
+                               }
+                               ((MatrixObject)elem).release();
+                       }
+                       else {
+                               throw new RuntimeException("Only list of 
matrices is supported in RemoveDuplicates");
+                       }
+               }
+               indexesMB.setNonZeros(indexesData.length);
+               try {
+                       indexes.setMatrixDoubleArray(indexesMB, 
OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
+               } catch (IOException e) {
+                       throw new RuntimeException("Exception while executing 
RemoveDuplicates", e);
+               }
+               outputList = new List(new ListObject(outputData));
+       }
+       
+       private static MatrixBlock allocateDenseMatrixBlock(Matrix mat) {
+               int rows = (int) mat.getNumRows();
+               int cols = (int) mat.getNumCols();
+               MatrixBlock mb = new MatrixBlock(rows, cols, false);
+               mb.allocateDenseBlock();
+               return mb;
+       }
+}

Reply via email to