[MINOR] Cleanup lib commons math (input/output consistency) Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0209edd7 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0209edd7 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0209edd7
Branch: refs/heads/master Commit: 0209edd7c8e84fac9f4d2c9c80b0ea7e664e798e Parents: 053e7ff Author: Matthias Boehm <[email protected]> Authored: Sat Apr 21 22:13:42 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Apr 21 22:18:47 2018 -0700 ---------------------------------------------------------------------- .../cp/BinaryMatrixMatrixCPInstruction.java | 8 +-- .../cp/MultiReturnBuiltinCPInstruction.java | 15 ++---- .../cp/UnaryMatrixCPInstruction.java | 17 ++++--- .../runtime/matrix/data/LibCommonsMath.java | 51 ++++++++++---------- .../sysml/runtime/util/DataConverter.java | 5 +- 5 files changed, 44 insertions(+), 52 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/0209edd7/src/main/java/org/apache/sysml/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java index 5ab853d..f8bb499 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java @@ -19,7 +19,6 @@ package org.apache.sysml.runtime.instructions.cp; -import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.matrix.data.LibCommonsMath; import org.apache.sysml.runtime.matrix.data.MatrixBlock; @@ -35,11 +34,12 @@ public class BinaryMatrixMatrixCPInstruction extends BinaryCPInstruction { @Override public void processInstruction(ExecutionContext ec) { - String opcode = getOpcode(); - if ( LibCommonsMath.isSupportedMatrixMatrixOperation(opcode) ) { + if ( LibCommonsMath.isSupportedMatrixMatrixOperation(getOpcode()) ) { MatrixBlock solution = LibCommonsMath.matrixMatrixOperations( - ec.getMatrixObject(input1.getName()), (MatrixObject)ec.getVariable(input2.getName()), opcode); + ec.getMatrixInput(input1.getName()), ec.getMatrixInput(input2.getName()), getOpcode()); ec.setMatrixOutput(output.getName(), solution, getExtendedOpcode()); + ec.releaseMatrixInput(input1.getName()); + ec.releaseMatrixInput(input2.getName()); return; } http://git-wip-us.apache.org/repos/asf/systemml/blob/0209edd7/src/main/java/org/apache/sysml/runtime/instructions/cp/MultiReturnBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/MultiReturnBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/MultiReturnBuiltinCPInstruction.java index 2aa04d8..acbbcc4 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/MultiReturnBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/MultiReturnBuiltinCPInstruction.java @@ -24,7 +24,6 @@ import java.util.ArrayList; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.matrix.data.LibCommonsMath; @@ -98,16 +97,12 @@ public class MultiReturnBuiltinCPInstruction extends ComputationCPInstruction { @Override public void processInstruction(ExecutionContext ec) { - String opcode = getOpcode(); - MatrixObject mo = ec.getMatrixObject(input1.getName()); - MatrixBlock[] out = null; - - if(LibCommonsMath.isSupportedMultiReturnOperation(opcode)) - out = LibCommonsMath.multiReturnOperations(mo, opcode); - else - throw new DMLRuntimeException("Invalid opcode in MultiReturnBuiltin instruction: " + opcode); - + if(!LibCommonsMath.isSupportedMultiReturnOperation(getOpcode())) + throw new DMLRuntimeException("Invalid opcode in MultiReturnBuiltin instruction: " + getOpcode()); + MatrixBlock in = ec.getMatrixInput(input1.getName()); + MatrixBlock[] out = LibCommonsMath.multiReturnOperations(in, getOpcode()); + ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); for(int i=0; i < _outputs.size(); i++) { ec.setMatrixOutput(_outputs.get(i).getName(), out[i], getExtendedOpcode()); } http://git-wip-us.apache.org/repos/asf/systemml/blob/0209edd7/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryMatrixCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryMatrixCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryMatrixCPInstruction.java index 2855246..22da558 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryMatrixCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryMatrixCPInstruction.java @@ -32,21 +32,22 @@ public class UnaryMatrixCPInstruction extends UnaryCPInstruction { @Override public void processInstruction(ExecutionContext ec) { - String output_name = output.getName(); - String opcode = getOpcode(); - if(LibCommonsMath.isSupportedUnaryOperation(opcode)) { - MatrixBlock retBlock = LibCommonsMath.unaryOperations(ec.getMatrixObject(input1.getName()),getOpcode()); - ec.setMatrixOutput(output_name, retBlock, getExtendedOpcode()); + MatrixBlock inBlock = ec.getMatrixInput(input1.getName(), getExtendedOpcode()); + MatrixBlock retBlock = null; + + if(LibCommonsMath.isSupportedUnaryOperation(getOpcode())) { + retBlock = LibCommonsMath.unaryOperations(inBlock, getOpcode()); + ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); } else { UnaryOperator u_op = (UnaryOperator) _optr; - MatrixBlock inBlock = ec.getMatrixInput(input1.getName(), getExtendedOpcode()); - MatrixBlock retBlock = (MatrixBlock) (inBlock.unaryOperations(u_op, new MatrixBlock())); + retBlock = (MatrixBlock) (inBlock.unaryOperations(u_op, new MatrixBlock())); ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); // Ensure right dense/sparse output representation (guarded by released input memory) if( checkGuardedRepresentationChange(inBlock, retBlock) ) retBlock.examSparsity(); - ec.setMatrixOutput(output_name, retBlock, getExtendedOpcode()); } + + ec.setMatrixOutput(output.getName(), retBlock, getExtendedOpcode()); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/0209edd7/src/main/java/org/apache/sysml/runtime/matrix/data/LibCommonsMath.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibCommonsMath.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibCommonsMath.java index 22572ba..517c7fe 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibCommonsMath.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibCommonsMath.java @@ -28,7 +28,6 @@ import org.apache.commons.math3.linear.QRDecomposition; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.SingularValueDecomposition; import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.util.DataConverter; /** @@ -56,7 +55,7 @@ public class LibCommonsMath return ( opcode.equals("solve") ); } - public static MatrixBlock unaryOperations(MatrixObject inj, String opcode) { + public static MatrixBlock unaryOperations(MatrixBlock inj, String opcode) { Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(inj); if(opcode.equals("inverse")) return computeMatrixInverse(matrixInput); @@ -65,7 +64,7 @@ public class LibCommonsMath return null; } - public static MatrixBlock[] multiReturnOperations(MatrixObject in, String opcode) { + public static MatrixBlock[] multiReturnOperations(MatrixBlock in, String opcode) { if(opcode.equals("qr")) return computeQR(in); else if (opcode.equals("lu")) @@ -77,7 +76,7 @@ public class LibCommonsMath return null; } - public static MatrixBlock matrixMatrixOperations(MatrixObject in1, MatrixObject in2, String opcode) { + public static MatrixBlock matrixMatrixOperations(MatrixBlock in1, MatrixBlock in2, String opcode) { if(opcode.equals("solve")) { if (in1.getNumRows() != in1.getNumColumns()) throw new DMLRuntimeException("The A matrix, in solve(A,b) should have squared dimensions."); @@ -93,7 +92,7 @@ public class LibCommonsMath * @param in2 matrix object 2 * @return matrix block */ - private static MatrixBlock computeSolve(MatrixObject in1, MatrixObject in2) { + private static MatrixBlock computeSolve(MatrixBlock in1, MatrixBlock in2) { Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(in1); Array2DRowRealMatrix vectorInput = DataConverter.convertToArray2DRowRealMatrix(in2); @@ -116,7 +115,7 @@ public class LibCommonsMath * @param in matrix object * @return array of matrix blocks */ - private static MatrixBlock[] computeQR(MatrixObject in) { + private static MatrixBlock[] computeQR(MatrixBlock in) { Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(in); // Perform QR decomposition @@ -137,7 +136,7 @@ public class LibCommonsMath * @param in matrix object * @return array of matrix blocks */ - private static MatrixBlock[] computeLU(MatrixObject in) { + private static MatrixBlock[] computeLU(MatrixBlock in) { if ( in.getNumRows() != in.getNumColumns() ) { throw new DMLRuntimeException("LU Decomposition can only be done on a square matrix. Input matrix is rectangular (rows=" + in.getNumRows() + ", cols="+ in.getNumColumns() +")"); } @@ -165,7 +164,7 @@ public class LibCommonsMath * @param in matrix object * @return array of matrix blocks */ - private static MatrixBlock[] computeEigen(MatrixObject in) { + private static MatrixBlock[] computeEigen(MatrixBlock in) { if ( in.getNumRows() != in.getNumColumns() ) { throw new DMLRuntimeException("Eigen Decomposition can only be done on a square matrix. Input matrix is rectangular (rows=" + in.getNumRows() + ", cols="+ in.getNumColumns() +")"); } @@ -180,23 +179,23 @@ public class LibCommonsMath //Sort the eigen values (and vectors) in increasing order (to be compatible w/ LAPACK.DSYEVR()) int n = eValues.length; for (int i = 0; i < n; i++) { - int k = i; - double p = eValues[i]; - for (int j = i + 1; j < n; j++) { - if (eValues[j] < p) { - k = j; - p = eValues[j]; - } - } - if (k != i) { - eValues[k] = eValues[i]; - eValues[i] = p; - for (int j = 0; j < n; j++) { - p = eVectors[j][i]; - eVectors[j][i] = eVectors[j][k]; - eVectors[j][k] = p; - } - } + int k = i; + double p = eValues[i]; + for (int j = i + 1; j < n; j++) { + if (eValues[j] < p) { + k = j; + p = eValues[j]; + } + } + if (k != i) { + eValues[k] = eValues[i]; + eValues[i] = p; + for (int j = 0; j < n; j++) { + p = eVectors[j][i]; + eVectors[j][i] = eVectors[j][k]; + eVectors[j][k] = p; + } + } } MatrixBlock mbValues = DataConverter.convertToMatrixBlock(eValues, true); @@ -216,7 +215,7 @@ public class LibCommonsMath * @param in Input matrix * @return An array containing U, Sigma & V */ - private static MatrixBlock[] computeSvd(MatrixObject in) { + private static MatrixBlock[] computeSvd(MatrixBlock in) { Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(in); SingularValueDecomposition svd = new SingularValueDecomposition(matrixInput); http://git-wip-us.apache.org/repos/asf/systemml/blob/0209edd7/src/main/java/org/apache/sysml/runtime/util/DataConverter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/util/DataConverter.java b/src/main/java/org/apache/sysml/runtime/util/DataConverter.java index 8748de2..0c834b1 100644 --- a/src/main/java/org/apache/sysml/runtime/util/DataConverter.java +++ b/src/main/java/org/apache/sysml/runtime/util/DataConverter.java @@ -31,7 +31,6 @@ import java.util.Map.Entry; import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.instructions.cp.BooleanObject; import org.apache.sysml.runtime.io.MatrixReader; import org.apache.sysml.runtime.io.MatrixReaderFactory; @@ -780,10 +779,8 @@ public class DataConverter * @param mo matrix object * @return matrix as a commons-math3 Array2DRowRealMatrix */ - public static Array2DRowRealMatrix convertToArray2DRowRealMatrix(MatrixObject mo) { - MatrixBlock mb = mo.acquireRead(); + public static Array2DRowRealMatrix convertToArray2DRowRealMatrix(MatrixBlock mb) { double[][] data = DataConverter.convertToDoubleMatrix(mb); - mo.release(); return new Array2DRowRealMatrix(data, false); }
