[SYSTEMML-2362] Fix robustness text-mm read in inlined functions Since the validation and function inlining of statement blocks are interleaved, a given statement might be validated several times. Data expression modify their state on validation by inferring the given input format and dimensions on demand. For the mm format, this constellation (read inside an inlined function) is problematic because the previously inferred dimensions are viewed as invalid specified parameters. This patch makes the related code path more robust by passively checking for consistent parameters in these scenarios.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/233b38c1 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/233b38c1 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/233b38c1 Branch: refs/heads/master Commit: 233b38c1482972cf8038f26de8fb992d37ab3d03 Parents: 68f37af Author: Matthias Boehm <[email protected]> Authored: Tue Jun 5 19:45:11 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Tue Jun 5 19:45:11 2018 -0700 ---------------------------------------------------------------------- .../org/apache/sysml/api/jmlc/Connection.java | 6 +- .../apache/sysml/api/mlcontext/MLContext.java | 2 +- .../org/apache/sysml/parser/DataExpression.java | 118 ++++++++----------- .../sysml/runtime/util/MapReduceTool.java | 2 +- .../applications/arima_box-jenkins/arima.dml | 1 - 5 files changed, 57 insertions(+), 72 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/233b38c1/src/main/java/org/apache/sysml/api/jmlc/Connection.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/jmlc/Connection.java b/src/main/java/org/apache/sysml/api/jmlc/Connection.java index 32655fb..7d631f8 100644 --- a/src/main/java/org/apache/sysml/api/jmlc/Connection.java +++ b/src/main/java/org/apache/sysml/api/jmlc/Connection.java @@ -367,10 +367,10 @@ public class Connection implements Closeable jmtd.getInt(DataExpression.ROWBLOCKCOUNTPARAM) : -1; int bclen = jmtd.containsKey(DataExpression.COLUMNBLOCKCOUNTPARAM)? jmtd.getInt(DataExpression.COLUMNBLOCKCOUNTPARAM) : -1; - long nnz = jmtd.containsKey(DataExpression.READNUMNONZEROPARAM)? - jmtd.getLong(DataExpression.READNUMNONZEROPARAM) : -1; + long nnz = jmtd.containsKey(DataExpression.READNNZPARAM)? + jmtd.getLong(DataExpression.READNNZPARAM) : -1; String format = jmtd.getString(DataExpression.FORMAT_TYPE); - InputInfo iinfo = InputInfo.stringExternalToInputInfo(format); + InputInfo iinfo = InputInfo.stringExternalToInputInfo(format); //read matrix file return readDoubleMatrix(fname, iinfo, rows, cols, brlen, bclen, nnz); http://git-wip-us.apache.org/repos/asf/systemml/blob/233b38c1/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java index acc69a0..3690354 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java @@ -523,7 +523,7 @@ public class MLContext implements ConfigurableAPI if (mo != null) { exp.addVarParam(DataExpression.READROWPARAM, new IntIdentifier(mo.getNumRows(), source)); exp.addVarParam(DataExpression.READCOLPARAM, new IntIdentifier(mo.getNumColumns(), source)); - exp.addVarParam(DataExpression.READNUMNONZEROPARAM, new IntIdentifier(mo.getNnz(), source)); + exp.addVarParam(DataExpression.READNNZPARAM, new IntIdentifier(mo.getNnz(), source)); exp.addVarParam(DataExpression.DATATYPEPARAM, new StringIdentifier("matrix", source)); exp.addVarParam(DataExpression.VALUETYPEPARAM, new StringIdentifier("double", source)); http://git-wip-us.apache.org/repos/asf/systemml/blob/233b38c1/src/main/java/org/apache/sysml/parser/DataExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DataExpression.java b/src/main/java/org/apache/sysml/parser/DataExpression.java index eef9c08..acccdf3 100644 --- a/src/main/java/org/apache/sysml/parser/DataExpression.java +++ b/src/main/java/org/apache/sysml/parser/DataExpression.java @@ -65,7 +65,7 @@ public class DataExpression extends DataIdentifier public static final String IO_FILENAME = "iofilename"; public static final String READROWPARAM = "rows"; public static final String READCOLPARAM = "cols"; - public static final String READNUMNONZEROPARAM = "nnz"; + public static final String READNNZPARAM = "nnz"; public static final String FORMAT_TYPE = "format"; public static final String FORMAT_TYPE_VALUE_TEXT = "text"; @@ -101,7 +101,7 @@ public class DataExpression extends DataIdentifier // Valid parameter names in a metadata file public static final String[] READ_VALID_MTD_PARAM_NAMES = - { IO_FILENAME, READROWPARAM, READCOLPARAM, READNUMNONZEROPARAM, FORMAT_TYPE, + { IO_FILENAME, READROWPARAM, READCOLPARAM, READNNZPARAM, FORMAT_TYPE, ROWBLOCKCOUNTPARAM, COLUMNBLOCKCOUNTPARAM, DATATYPEPARAM, VALUETYPEPARAM, SCHEMAPARAM, DESCRIPTIONPARAM, AUTHORPARAM, CREATEDPARAM, // Parameters related to delimited/csv files. @@ -110,10 +110,10 @@ public class DataExpression extends DataIdentifier public static final String[] READ_VALID_PARAM_NAMES = { IO_FILENAME, READROWPARAM, READCOLPARAM, FORMAT_TYPE, DATATYPEPARAM, VALUETYPEPARAM, SCHEMAPARAM, - ROWBLOCKCOUNTPARAM, COLUMNBLOCKCOUNTPARAM, READNUMNONZEROPARAM, + ROWBLOCKCOUNTPARAM, COLUMNBLOCKCOUNTPARAM, READNNZPARAM, // Parameters related to delimited/csv files. DELIM_FILL_VALUE, DELIM_DELIMITER, DELIM_FILL, DELIM_HAS_HEADER_ROW, DELIM_NA_STRINGS - }; + }; /* Default Values for delimited (CSV) files */ public static final String DEFAULT_DELIM_DELIMITER = ","; @@ -379,13 +379,12 @@ public class DataExpression extends DataIdentifier @Override public Expression rewriteExpression(String prefix) { - HashMap<String,Expression> newVarParams = new HashMap<>(); for( Entry<String, Expression> e : _varParams.entrySet() ){ String key = e.getKey(); Expression newExpr = e.getValue().rewriteExpression(prefix); newVarParams.put(key, newExpr); - } + } DataExpression retVal = new DataExpression(_opcode, newVarParams, this); retVal._strInit = this._strInit; @@ -465,18 +464,22 @@ public class DataExpression extends DataIdentifier // if required, initialize values setFilename(value.getFilename()); - if (getBeginLine() == 0) setBeginLine(value.getBeginLine()); - if (getBeginColumn() == 0) setBeginColumn(value.getBeginColumn()); - if (getEndLine() == 0) setEndLine(value.getEndLine()); - if (getEndColumn() == 0) setEndColumn(value.getEndColumn()); + if (getBeginLine() == 0) setBeginLine(value.getBeginLine()); + if (getBeginColumn() == 0) setBeginColumn(value.getBeginColumn()); + if (getEndLine() == 0) setEndLine(value.getEndLine()); + if (getEndColumn() == 0) setEndColumn(value.getEndColumn()); if (getText() == null) setText(value.getText()); - } public void removeVarParam(String name) { _varParams.remove(name); } + public void removeVarParam(String... names) { + for( String name : names ) + removeVarParam(name); + } + private String getInputFileName(HashMap<String, ConstIdentifier> currConstVars, boolean conditional) { String filename = null; @@ -531,8 +534,8 @@ public class DataExpression extends DataIdentifier inputParamExpr.validateExpression(ids, currConstVars, conditional); if ( getVarParam(s).getOutput().getDataType() != DataType.SCALAR && !s.equals(RAND_DATA)) { raiseValidateError("Non-scalar data types are not supported for data expression.", conditional,LanguageErrorCodes.INVALID_PARAMETERS); - } - } + } + } //general data expression constant propagation performConstantPropagationRand( currConstVars ); @@ -550,7 +553,7 @@ public class DataExpression extends DataIdentifier // replace DataOp MATRIX with RAND -- Rand handles matrix generation for Scalar values // replace data parameter with min / max within Rand case below this.setOpCode(DataOp.RAND); - } + } // IMPORTANT: for each operation, one must handle unnamed parameters @@ -558,7 +561,6 @@ public class DataExpression extends DataIdentifier switch (this.getOpCode()) { case READ: - if (getVarParam(DATATYPEPARAM) != null && !(getVarParam(DATATYPEPARAM) instanceof StringIdentifier)){ raiseValidateError("for read statement, parameter " + DATATYPEPARAM + " can only be a string. " + "Valid values are: " + Statement.MATRIX_DATA_TYPE +", " + Statement.SCALAR_DATA_TYPE, conditional); @@ -581,11 +583,11 @@ public class DataExpression extends DataIdentifier ) { raiseValidateError("Invalid parameters in read statement of a scalar: " + - toString() + ". Only " + VALUETYPEPARAM + " is allowed.", conditional, LanguageErrorCodes.INVALID_PARAMETERS); + toString() + ". Only " + VALUETYPEPARAM + " is allowed.", conditional, LanguageErrorCodes.INVALID_PARAMETERS); } } - JSONObject configObject = null; + JSONObject configObject = null; // Process expressions in input filename String inputFileName = getInputFileName(currConstVars, conditional); @@ -604,7 +606,7 @@ public class DataExpression extends DataIdentifier { String fsext = InfrastructureAnalyzer.isLocalMode() ? "FS (local mode)" : "HDFS"; raiseValidateError("Read input file does not exist on "+fsext+": " + - inputFileName, conditional); + inputFileName, conditional); } // track whether format type has been inferred @@ -615,8 +617,7 @@ public class DataExpression extends DataIdentifier // check if file is matrix market format if (formatTypeString == null && shouldReadMTD){ - boolean isMatrixMarketFormat = checkHasMatrixMarketFormat(inputFileName, mtdFileName, conditional); - if (isMatrixMarketFormat) { + if ( checkHasMatrixMarketFormat(inputFileName, mtdFileName, conditional) ) { formatTypeString = FORMAT_TYPE_VALUE_MATRIXMARKET; addVarParam(FORMAT_TYPE, new StringIdentifier(FORMAT_TYPE_VALUE_MATRIXMARKET, this)); inferredFormatType = true; @@ -626,17 +627,13 @@ public class DataExpression extends DataIdentifier // check if file is delimited format if (formatTypeString == null && shouldReadMTD ) { - boolean isDelimitedFormat = checkHasDelimitedFormat(inputFileName, conditional); - - if (isDelimitedFormat) { + if (checkHasDelimitedFormat(inputFileName, conditional)) { addVarParam(FORMAT_TYPE, new StringIdentifier(FORMAT_TYPE_VALUE_CSV, this)); formatTypeString = FORMAT_TYPE_VALUE_CSV; inferredFormatType = true; - // shouldReadMTD = false; } } - if (formatTypeString != null && formatTypeString.equalsIgnoreCase(FORMAT_TYPE_VALUE_MATRIXMARKET)){ /* * handle MATRIXMARKET_FORMAT_TYPE format @@ -649,14 +646,6 @@ public class DataExpression extends DataIdentifier * C) get size information from sizing info line --- M N L */ - for (String key : _varParams.keySet()){ - if ( !(key.equals(IO_FILENAME) || key.equals(FORMAT_TYPE) ) ){ - raiseValidateError("Invalid parameters in readMM statement: " + - toString() + ". Only " + IO_FILENAME + " is allowed.", conditional, LanguageErrorCodes.INVALID_PARAMETERS); - } - } - - // should NOT attempt to read MTD file for MatrixMarket format shouldReadMTD = false; @@ -670,8 +659,8 @@ public class DataExpression extends DataIdentifier String firstLine = headerLines[0].trim(); if (!firstLine.equals(legalHeaderMM)){ raiseValidateError("Unsupported format in MatrixMarket file: " + - headerLines[0] + ". Only supported format in MatrixMarket file has header line " + legalHeaderMM, - conditional, LanguageErrorCodes.INVALID_PARAMETERS); + headerLines[0] + ". Only supported format in MatrixMarket file has header line " + legalHeaderMM, + conditional, LanguageErrorCodes.INVALID_PARAMETERS); } // process 2nd line of MatrixMarket format -- must have size information @@ -683,39 +672,36 @@ public class DataExpression extends DataIdentifier headerLines[1] + ". Only supported format in MatrixMarket file has size line: <NUM ROWS> <NUM COLS> <NUM NON-ZEROS>, where each value is an integer.", conditional); } - try { - long rowsCount = Long.parseLong(sizeInfo[0]); - if (rowsCount < 0) - throw new Exception("invalid rows count"); - addVarParam(READROWPARAM, new IntIdentifier(rowsCount, this)); - } catch (Exception e) { - raiseValidateError("In MatrixMarket file " + getVarParam(IO_FILENAME) + " invalid row count " - + sizeInfo[0] + " (must be long value >= 0). Sizing info line from file: " + headerLines[1], - conditional, LanguageErrorCodes.INVALID_PARAMETERS); + long rowsCount = Long.parseLong(sizeInfo[0]); + if (rowsCount < 0) + raiseValidateError("MM file: invalid number of rows: "+rowsCount); + else if( getVarParam(READROWPARAM) != null ) { + long rowsCount2 = Long.parseLong(getVarParam(READROWPARAM).toString()); + if( rowsCount2 != rowsCount ) + raiseValidateError("MM file: invalid specified number of rows: "+rowsCount2+" vs "+rowsCount); } + addVarParam(READROWPARAM, new IntIdentifier(rowsCount, this)); + - try { - long colsCount = Long.parseLong(sizeInfo[1]); - if (colsCount < 0) - throw new Exception("invalid cols count"); - addVarParam(READCOLPARAM, new IntIdentifier(colsCount, this)); - } catch (Exception e) { - raiseValidateError("In MatrixMarket file " + getVarParam(IO_FILENAME) + " invalid column count " - + sizeInfo[1] + " (must be long value >= 0). Sizing info line from file: " - + headerLines[1], conditional, LanguageErrorCodes.INVALID_PARAMETERS); + long colsCount = Long.parseLong(sizeInfo[1]); + if (colsCount < 0) + raiseValidateError("MM file: invalid number of columns: "+colsCount); + else if( getVarParam(READCOLPARAM) != null ) { + long colsCount2 = Long.parseLong(getVarParam(READCOLPARAM).toString()); + if( colsCount2 != colsCount ) + raiseValidateError("MM file: invalid specified number of columns: "+colsCount2+" vs "+colsCount); } - - try { - long nnzCount = Long.parseLong(sizeInfo[2]); - if (nnzCount < 0) - throw new Exception("invalid nnz count"); - addVarParam("nnz", new IntIdentifier(nnzCount, this)); - } catch (Exception e) { - raiseValidateError("In MatrixMarket file " + getVarParam(IO_FILENAME) - + " invalid number non-zeros " + sizeInfo[2] - + " (must be long value >= 0). Sizing info line from file: " + headerLines[1], - conditional, LanguageErrorCodes.INVALID_PARAMETERS); + addVarParam(READCOLPARAM, new IntIdentifier(colsCount, this)); + + long nnzCount = Long.parseLong(sizeInfo[2]); + if (nnzCount < 0) + raiseValidateError("MM file: invalid number of non-zeros: "+nnzCount); + else if( getVarParam(READNNZPARAM) != null ) { + long nnzCount2 = Long.parseLong(getVarParam(READNNZPARAM).toString()); + if( nnzCount2 != nnzCount ) + raiseValidateError("MM file: invalid specified number of non-zeros: "+nnzCount2+" vs "+nnzCount); } + addVarParam(READNNZPARAM, new IntIdentifier(nnzCount, this)); } } @@ -755,7 +741,7 @@ public class DataExpression extends DataIdentifier || key.equals(DELIM_HAS_HEADER_ROW) || key.equals(DELIM_DELIMITER) || key.equals(DELIM_FILL) || key.equals(DELIM_FILL_VALUE) || key.equals(READROWPARAM) || key.equals(READCOLPARAM) - || key.equals(READNUMNONZEROPARAM) || key.equals(DATATYPEPARAM) || key.equals(VALUETYPEPARAM) + || key.equals(READNNZPARAM) || key.equals(DATATYPEPARAM) || key.equals(VALUETYPEPARAM) || key.equals(SCHEMAPARAM)) ) { String msg = "Only parameters allowed are: " + IO_FILENAME + "," @@ -1576,7 +1562,7 @@ public class DataExpression extends DataIdentifier private void performConstantPropagationReadWrite( HashMap<String, ConstIdentifier> currConstVars ) { //here, we propagate constants for all read/write parameters that are required during validate. - String[] paramNamesForEval = new String[]{FORMAT_TYPE, IO_FILENAME, READROWPARAM, READCOLPARAM, READNUMNONZEROPARAM}; + String[] paramNamesForEval = new String[]{FORMAT_TYPE, IO_FILENAME, READROWPARAM, READCOLPARAM, READNNZPARAM}; //replace data identifiers with const identifiers performConstantPropagation(currConstVars, paramNamesForEval); http://git-wip-us.apache.org/repos/asf/systemml/blob/233b38c1/src/main/java/org/apache/sysml/runtime/util/MapReduceTool.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/util/MapReduceTool.java b/src/main/java/org/apache/sysml/runtime/util/MapReduceTool.java index c6d9604..239c5e1 100644 --- a/src/main/java/org/apache/sysml/runtime/util/MapReduceTool.java +++ b/src/main/java/org/apache/sysml/runtime/util/MapReduceTool.java @@ -477,7 +477,7 @@ public class MapReduceTool mtd.put(DataExpression.ROWBLOCKCOUNTPARAM, mc.getRowsPerBlock()); mtd.put(DataExpression.COLUMNBLOCKCOUNTPARAM, mc.getColsPerBlock()); } - mtd.put(DataExpression.READNUMNONZEROPARAM, mc.getNonZeros()); + mtd.put(DataExpression.READNNZPARAM, mc.getNonZeros()); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/233b38c1/src/test/scripts/applications/arima_box-jenkins/arima.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/applications/arima_box-jenkins/arima.dml b/src/test/scripts/applications/arima_box-jenkins/arima.dml index 6c39339..ab56d39 100644 --- a/src/test/scripts/applications/arima_box-jenkins/arima.dml +++ b/src/test/scripts/applications/arima_box-jenkins/arima.dml @@ -128,7 +128,6 @@ readParamters = function (String default_solver, Integer default_max_func_invoc, #length of the season s = ifdef($s, default_s) - while(FALSE){} #TODO } [X, solver, max_func_invoc, include_mean, p, d, q, P, D, Q, s, dest, result_format] = readParamters ("jacobi", 1000, FALSE, 0,0,0,0,0,0,1, "arima-results.csv", "csv")
