[SYSTEMML-561] Generalized cp tostring instruction for frames, cleanup Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/8a0df5b8 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/8a0df5b8 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/8a0df5b8
Branch: refs/heads/master Commit: 8a0df5b856577e53f348e7963822ed8ac15dc2b0 Parents: 29ad143 Author: Matthias Boehm <[email protected]> Authored: Fri Jun 10 23:10:54 2016 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri Jun 10 23:10:54 2016 -0700 ---------------------------------------------------------------------- .../ParameterizedBuiltinFunctionExpression.java | 33 +++++--- .../context/ExecutionContext.java | 18 +++++ .../cp/ParameterizedBuiltinCPInstruction.java | 81 ++++++++++---------- .../sysml/runtime/util/DataConverter.java | 75 +++++++++++++++++- 4 files changed, 152 insertions(+), 55 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/8a0df5b8/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java index eea6f55..679075b 100644 --- a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java @@ -736,23 +736,32 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier return; } - - private void validateCastAsString(DataIdentifier output, boolean conditional) throws LanguageException { - + /** + * + * @param output + * @param conditional + * @throws LanguageException + */ + private void validateCastAsString(DataIdentifier output, boolean conditional) + throws LanguageException + { HashMap<String, Expression> varParams = getVarParams(); - // null is for the matrix argument - String[] validArgsArr = {null, "rows", "cols", "decimal", "sparse", "sep", "linesep"}; + + // replace parameter name for matrix argument + if( varParams.containsKey(null) ) + varParams.put("target", varParams.remove(null)); + + // check validate parameter names + String[] validArgsArr = {"target", "rows", "cols", "decimal", "sparse", "sep", "linesep"}; HashSet<String> validArgs = new HashSet<String>(Arrays.asList(validArgsArr)); - for (String k : varParams.keySet()){ - if (!validArgs.contains(k)){ - String errMsg = "Invalid parameter " + k + " for as.string, valid parameters are " + validArgsArr[0]; - for (int i=1; i<validArgsArr.length; ++i) - errMsg += "," + validArgsArr[i]; - raiseValidateError(errMsg, conditional, LanguageErrorCodes.INVALID_PARAMETERS); + for( String k : varParams.keySet() ) { + if( !validArgs.contains(k) ) { + raiseValidateError("Invalid parameter " + k + " for toString, valid parameters are " + + Arrays.toString(validArgsArr), conditional, LanguageErrorCodes.INVALID_PARAMETERS); } } - // Output is a string + // set output characteristics output.setDataType(DataType.SCALAR); output.setValueType(ValueType.STRING); output.setDimensions(0, 0); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/8a0df5b8/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java index cc4fcf2..6356a76 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java @@ -179,6 +179,12 @@ public class ExecutionContext return (FrameObject) dat; } + /** + * + * @param varname + * @return + * @throws DMLRuntimeException + */ public CacheableData<?> getCacheableData(String varname) throws DMLRuntimeException { @@ -193,6 +199,18 @@ public class ExecutionContext return (CacheableData<?>) dat; } + /** + * + * @param varname + * @throws DMLRuntimeException + */ + public void releaseCacheableData(String varname) + throws DMLRuntimeException + { + CacheableData<?> dat = getCacheableData(varname); + dat.release(); + } + public MatrixCharacteristics getMatrixCharacteristics( String varname ) throws DMLRuntimeException { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/8a0df5b8/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java index 70e7a11..e6a536d 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java @@ -26,6 +26,7 @@ import org.apache.sysml.parser.ParameterizedBuiltinFunctionExpression; import org.apache.sysml.parser.Statement; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.caching.CacheableData; import org.apache.sysml.runtime.controlprogram.caching.FrameObject; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; @@ -48,6 +49,14 @@ import org.apache.sysml.runtime.util.DataConverter; public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction { + private static final int TOSTRING_MAXROWS = 100; + private static final int TOSTRING_MAXCOLS = 100; + private static final int TOSTRING_DECIMAL = 3; + private static final boolean TOSTRING_SPARSE = false; + private static final String TOSTRING_SEPARATOR = " "; + private static final String TOSTRING_LINESEPARATOR = "\n"; + + private int arity; protected HashMap<String,String> params; @@ -62,7 +71,13 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction return arity; } - public HashMap<String,String> getParameterMap() { return params; } + public HashMap<String,String> getParameterMap() { + return params; + } + + public String getParam(String key) { + return getParameterMap().get(key); + } public static HashMap<String, String> constructParameterMap(String[] params) { // process all elements in "params" except first(opcode) and last(output) @@ -294,48 +309,30 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction ec.setFrameOutput(output.getName(), meta); } else if ( opcode.equalsIgnoreCase("toString")) { - // Default Arguments - final int MAXROWS = 100; - final int MAXCOLS = 100; - final int DECIMAL = 3; - final boolean SPARSE = false; - final String SEPARATOR = " "; - final String LINESEPARATOR = "\n"; - - int rows=MAXROWS, cols=MAXCOLS, decimal=DECIMAL; - boolean sparse = SPARSE; - String separator=SEPARATOR, lineseparator=LINESEPARATOR; - - String rowsStr = getParameterMap().get("rows"); - if (rowsStr != null){ rows = Integer.parseInt(rowsStr); } - - String colsStr = getParameterMap().get("cols"); - if (colsStr != null) { cols = Integer.parseInt(rowsStr); } - - String decimalStr = getParameterMap().get("decimal"); - if (decimalStr != null) { decimal = Integer.parseInt(decimalStr); } - - String sparseStr = getParameterMap().get("sparse"); - if (sparseStr != null) { sparse = Boolean.parseBoolean(sparseStr); } - - String separatorStr = getParameterMap().get("sep"); - if (separatorStr != null) { separator = separatorStr; } - - String lineseparatorStr = getParameterMap().get("linesep"); - if (lineseparatorStr != null) { lineseparator = lineseparatorStr; } - - // The matrix argument is "null" - String matrixStr = getParameterMap().get("null"); - Data data = ec.getVariable(matrixStr); - if (!(data instanceof MatrixObject)) - throw new DMLRuntimeException("toString only converts matrix objects to string"); - MatrixBlock matrix = ec.getMatrixInput(matrixStr); - - String outputStr = DataConverter.convertToString(matrix, sparse, separator, lineseparator, rows, cols, decimal); - - ec.releaseMatrixInput(matrixStr); - ec.setScalarOutput(output.getName(), new StringObject(outputStr)); + //handle input parameters + int rows = (getParam("rows")!=null) ? Integer.parseInt(getParam("rows")) : TOSTRING_MAXROWS; + int cols = (getParam("cols") != null) ? Integer.parseInt(getParam("cols")) : TOSTRING_MAXCOLS; + int decimal = (getParam("decimal") != null) ? Integer.parseInt(getParam("decimal")) : TOSTRING_DECIMAL; + boolean sparse = (getParam("sparse") != null) ? Boolean.parseBoolean(getParam("sparse")) : TOSTRING_SPARSE; + String separator = (getParam("sep") != null) ? getParam("sep") : TOSTRING_SEPARATOR; + String lineseparator = (getParam("linesep") != null) ? getParam("linesep") : TOSTRING_LINESEPARATOR; + //get input matrix/frame and convert to string + CacheableData<?> data = ec.getCacheableData(getParam("target")); + String out = null; + if( data instanceof MatrixObject ) { + MatrixBlock matrix = (MatrixBlock) data.acquireRead(); + out = DataConverter.toString(matrix, sparse, separator, lineseparator, rows, cols, decimal); + } + else if( data instanceof FrameObject ) { + FrameBlock frame = (FrameBlock) data.acquireRead(); + out = DataConverter.toString(frame, sparse, separator, lineseparator, rows, cols, decimal); + } + else { + throw new DMLRuntimeException("toString only converts matrix or frames to string"); + } + ec.releaseCacheableData(getParam("target")); + ec.setScalarOutput(output.getName(), new StringObject(out)); } else { throw new DMLRuntimeException("Unknown opcode : " + opcode); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/8a0df5b8/src/main/java/org/apache/sysml/runtime/util/DataConverter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/util/DataConverter.java b/src/main/java/org/apache/sysml/runtime/util/DataConverter.java index 6b46c6c..846b2b2 100644 --- a/src/main/java/org/apache/sysml/runtime/util/DataConverter.java +++ b/src/main/java/org/apache/sysml/runtime/util/DataConverter.java @@ -913,7 +913,7 @@ public class DataConverter * @param decimal number of decimal places to print, -1 for default * @return */ - public static String convertToString(MatrixBlock mb, boolean sparse, String separator, String lineseparator, int rowsToPrint, int colsToPrint, int decimal){ + public static String toString(MatrixBlock mb, boolean sparse, String separator, String lineseparator, int rowsToPrint, int colsToPrint, int decimal){ StringBuffer sb = new StringBuffer(); // Setup number of rows and columns to print @@ -973,4 +973,77 @@ public class DataConverter return sb.toString(); } + + /** + * + * @param fb + * @param sparse + * @param separator + * @param lineseparator + * @param rowsToPrint + * @param colsToPrint + * @param decimal + * @return + */ + public static String toString(FrameBlock fb, boolean sparse, String separator, String lineseparator, int rowsToPrint, int colsToPrint, int decimal) + { + StringBuffer sb = new StringBuffer(); + + // Setup number of rows and columns to print + int rlen = fb.getNumRows(); + int clen = fb.getNumColumns(); + int rowLength = rlen; + int colLength = clen; + if (rowsToPrint >= 0) + rowLength = rowsToPrint < rlen ? rowsToPrint : rlen; + if (colsToPrint >= 0) + colLength = colsToPrint < clen ? colsToPrint : clen; + + //print frame header + sb.append("# FRAME: "); + sb.append("nrow = " + fb.getNumRows() + ", "); + sb.append("ncol = " + fb.getNumColumns() + lineseparator); + + //print column names + sb.append("#"); sb.append(separator); + for( int j=0; j<colLength; j++ ) { + sb.append(fb.getColumnNames().get(j)); + if( j != colLength-1 ) + sb.append(separator); + } + sb.append(lineseparator); + + //print schema + sb.append("#"); sb.append(separator); + for( int j=0; j<colLength; j++ ) { + sb.append(fb.getSchema().get(j)); + if( j != colLength-1 ) + sb.append(separator); + } + sb.append(lineseparator); + + //print data + DecimalFormat df = new DecimalFormat(); + df.setGroupingUsed(false); + if (decimal >= 0) + df.setMinimumFractionDigits(decimal); + + Iterator<Object[]> iter = fb.getObjectRowIterator(0, rowLength); + while( iter.hasNext() ) { + Object[] row = iter.next(); + for( int j=0; j<colLength; j++ ) { + if( row[j]!=null ) { + if( fb.getSchema().get(j) == ValueType.DOUBLE ) + sb.append(df.format(row[j])); + else + sb.append(row[j]); + if( j != colLength-1 ) + sb.append(separator); + } + } + sb.append(lineseparator); + } + + return sb.toString(); + } }
