Repository: systemml Updated Branches: refs/heads/master be465dd65 -> 81419ae6a
[MINOR] Support the list datatype in external UDF - Also added RemoveDuplicates to show the usage. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/81419ae6 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/81419ae6 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/81419ae6 Branch: refs/heads/master Commit: 81419ae6a0abcc13e2e84307b7af38732c1892cd Parents: be465dd Author: Niketan Pansare <[email protected]> Authored: Wed Aug 29 19:40:19 2018 -0700 Committer: Niketan Pansare <[email protected]> Committed: Wed Aug 29 19:40:19 2018 -0700 ---------------------------------------------------------------------- .../ExternalFunctionInvocationInstruction.java | 18 ++- .../org/apache/sysml/udf/FunctionParameter.java | 1 + src/main/java/org/apache/sysml/udf/List.java | 36 +++++ .../apache/sysml/udf/lib/RemoveDuplicates.java | 139 +++++++++++++++++++ 4 files changed, 193 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/81419ae6/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java b/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java index c4e4198..6151370 100644 --- a/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java +++ b/src/main/java/org/apache/sysml/udf/ExternalFunctionInvocationInstruction.java @@ -33,6 +33,7 @@ import org.apache.sysml.runtime.instructions.cp.BooleanObject; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.cp.DoubleObject; import org.apache.sysml.runtime.instructions.cp.IntObject; +import org.apache.sysml.runtime.instructions.cp.ListObject; import org.apache.sysml.runtime.instructions.cp.ScalarObject; import org.apache.sysml.runtime.instructions.cp.StringObject; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; @@ -78,7 +79,6 @@ public class ExternalFunctionInvocationInstruction extends Instruction verifyAndAttachOutputs(ec, fun, outputs); } - @SuppressWarnings("incomplete-switch") private ArrayList<FunctionParameter> getInputObjects(CPOperand[] inputs, LocalVariableMap vars) { ArrayList<FunctionParameter> ret = new ArrayList<>(); for( CPOperand input : inputs ) { @@ -94,6 +94,12 @@ public class ExternalFunctionInvocationInstruction extends Instruction case OBJECT: ret.add(new BinaryObject(vars.get(input.getName()))); break; + case LIST: + ret.add(new List((ListObject) vars.get(input.getName()))); + break; + default: + throw new DMLRuntimeException("Unsupported data type: " + +input.getDataType().name()); } } return ret; @@ -125,11 +131,14 @@ public class ExternalFunctionInvocationInstruction extends Instruction CPOperand output = outputs[i]; switch( fun.getFunctionOutput(i).getType() ) { case Matrix: + { Matrix m = (Matrix) fun.getFunctionOutput(i); MatrixObject newVar = createOutputMatrixObject( m ); ec.setVariable(output.getName(), newVar); break; + } case Scalar: + { Scalar s = (Scalar) fun.getFunctionOutput(i); ScalarObject scalarObject = null; switch( s.getScalarType() ) { @@ -151,6 +160,13 @@ public class ExternalFunctionInvocationInstruction extends Instruction } ec.setVariable(output.getName(), scalarObject); break; + } + case List: + { + List l = (List) fun.getFunctionOutput(i); + ec.setVariable(output.getName(), l.getListObject()); + break; + } default: throw new DMLRuntimeException("Unsupported data type: " +fun.getFunctionOutput(i).getType().name()); http://git-wip-us.apache.org/repos/asf/systemml/blob/81419ae6/src/main/java/org/apache/sysml/udf/FunctionParameter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/udf/FunctionParameter.java b/src/main/java/org/apache/sysml/udf/FunctionParameter.java index a45065c..1b3022d 100644 --- a/src/main/java/org/apache/sysml/udf/FunctionParameter.java +++ b/src/main/java/org/apache/sysml/udf/FunctionParameter.java @@ -39,6 +39,7 @@ public abstract class FunctionParameter implements Serializable Matrix, Scalar, Object, + List } private FunctionParameterType _type; http://git-wip-us.apache.org/repos/asf/systemml/blob/81419ae6/src/main/java/org/apache/sysml/udf/List.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/udf/List.java b/src/main/java/org/apache/sysml/udf/List.java new file mode 100644 index 0000000..b2f39b7 --- /dev/null +++ b/src/main/java/org/apache/sysml/udf/List.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysml.udf; + +import org.apache.sysml.runtime.instructions.cp.ListObject; + +public class List extends FunctionParameter { + private static final long serialVersionUID = -3230908817131624857L; + protected ListObject _lObj; + + public List(ListObject obj) { + super(FunctionParameterType.List); + _lObj = obj; + } + + public ListObject getListObject() { + return _lObj; + } + +} http://git-wip-us.apache.org/repos/asf/systemml/blob/81419ae6/src/main/java/org/apache/sysml/udf/lib/RemoveDuplicates.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/udf/lib/RemoveDuplicates.java b/src/main/java/org/apache/sysml/udf/lib/RemoveDuplicates.java new file mode 100644 index 0000000..5c5e0d5 --- /dev/null +++ b/src/main/java/org/apache/sysml/udf/lib/RemoveDuplicates.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.udf.lib; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Random; + +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.instructions.cp.Data; +import org.apache.sysml.runtime.instructions.cp.ListObject; +import org.apache.sysml.runtime.matrix.data.InputInfo; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.OutputInfo; +import org.apache.sysml.udf.FunctionParameter; +import org.apache.sysml.udf.List; +import org.apache.sysml.udf.Matrix; +import org.apache.sysml.udf.PackageFunction; +import org.apache.sysml.udf.Matrix.ValueType; + +/** + * Use this class to remove duplicate matrices from list of matrices. + * It also returns the indexes which maps the original input list to the output list. + * + * Usage: + * <pre> + * <code> + * distinct = externalFunction(list[unknown] inL) return (list[unknown] outL, matrix[double] idx) implemented in (classname="org.apache.sysml.udf.lib.RemoveDuplicates", exectype="mem"); + * X = rand(rows=10, cols=10) + * Y = X*sum(X); + * Z = sum(X)*X; + * W = X*sum(X); + * inL = list(Y, Z, W) + * [outL, idx] = distinct(inL); + * print(">>\n" + toString(idx)); + * </code> + * </pre> + * + * The above code prints: + * >> + * 1.000 + * 2.000 + * 1.000 + */ +public class RemoveDuplicates extends PackageFunction { + private static final long serialVersionUID = -3905212831582648882L; + + private List outputList; + private Matrix indexes; + private Random rand = new Random(); + + @Override + public int getNumFunctionOutputs() { + return 2; + } + + @Override + public FunctionParameter getFunctionOutput(int pos) { + if(pos == 0) + return outputList; + else if(pos == 1) + return indexes; + throw new RuntimeException("Invalid function output being requested"); + } + + private int indexOf(java.util.List<MatrixBlock> list, MatrixBlock mb) { +// Caused by: java.lang.RuntimeException: equals should never be called for matrix blocks. +// at org.apache.sysml.runtime.matrix.data.MatrixBlock.equals(MatrixBlock.java:5644) +// return list.indexOf(mb); + for(int i = 0; i < list.size(); i++) { + if(list.get(i) == mb) { + return i; + } + } + return -1; + } + + @Override + public void execute() { + java.util.List<Data> inputData = ((List)getFunctionInput(0)).getListObject().getData(); + java.util.List<Data> outputData = new ArrayList<>(); + java.util.List<MatrixBlock> outputMB = new ArrayList<>(); + indexes = new Matrix( "tmp_" + rand.nextLong(), inputData.size(), 1, ValueType.Double ); + MatrixBlock indexesMB = allocateDenseMatrixBlock(indexes); + double [] indexesData = indexesMB.getDenseBlockValues(); + + for(int i = 0; i < inputData.size(); i++) { + Data elem = inputData.get(i); + if(elem instanceof MatrixObject) { + MatrixBlock mb = ((MatrixObject)elem).acquireRead(); + int index = indexOf(outputMB, mb); + if(index >= 0) { + indexesData[i] = indexOf(outputMB, mb) + 1; + } + else { + outputMB.add(mb); + outputData.add(elem); + indexesData[i] = outputMB.size(); + } + ((MatrixObject)elem).release(); + } + else { + throw new RuntimeException("Only list of matrices is supported in RemoveDuplicates"); + } + } + indexesMB.setNonZeros(indexesData.length); + try { + indexes.setMatrixDoubleArray(indexesMB, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo); + } catch (IOException e) { + throw new RuntimeException("Exception while executing RemoveDuplicates", e); + } + outputList = new List(new ListObject(outputData)); + } + + private static MatrixBlock allocateDenseMatrixBlock(Matrix mat) { + int rows = (int) mat.getNumRows(); + int cols = (int) mat.getNumCols(); + MatrixBlock mb = new MatrixBlock(rows, cols, false); + mb.allocateDenseBlock(); + return mb; + } +}
