Repository: incubator-systemml Updated Branches: refs/heads/master 4cbb02819 -> 02a9f2770
[SYSTEMML-568] Frame MLContext support Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/02a9f277 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/02a9f277 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/02a9f277 Branch: refs/heads/master Commit: 02a9f277000bd144c729311dac6c04bcb520180f Parents: 4cbb028 Author: Arvind Surve <[email protected]> Authored: Sun Aug 28 23:01:20 2016 -0700 Committer: Arvind Surve <[email protected]> Committed: Sun Aug 28 23:01:20 2016 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/api/MLContext.java | 218 ++++++++++-- .../java/org/apache/sysml/api/MLOutput.java | 49 ++- .../api/mlcontext/MLContextConversionUtil.java | 34 ++ .../sysml/api/mlcontext/MLContextUtil.java | 6 + .../sysml/runtime/util/UtilFunctions.java | 40 +++ .../functions/frame/FrameConverterTest.java | 33 +- .../functions/mlcontext/FrameTest.java | 351 +++++++++++++++++++ src/test/scripts/functions/frame/FrameGeneral.R | 35 ++ .../scripts/functions/frame/FrameGeneral.dml | 30 ++ 9 files changed, 732 insertions(+), 64 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/02a9f277/src/main/java/org/apache/sysml/api/MLContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/MLContext.java b/src/main/java/org/apache/sysml/api/MLContext.java index 405478f..8f6e95f 100644 --- a/src/main/java/org/apache/sysml/api/MLContext.java +++ b/src/main/java/org/apache/sysml/api/MLContext.java @@ -23,6 +23,7 @@ package org.apache.sysml.api; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Scanner; @@ -62,6 +63,7 @@ import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.LocalVariableMap; import org.apache.sysml.runtime.controlprogram.Program; import org.apache.sysml.runtime.controlprogram.caching.CacheableData; +import org.apache.sysml.runtime.controlprogram.caching.FrameObject; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; @@ -73,11 +75,13 @@ import org.apache.sysml.runtime.instructions.spark.functions.ConvertStringToLong import org.apache.sysml.runtime.instructions.spark.functions.CopyBlockPairFunction; import org.apache.sysml.runtime.instructions.spark.functions.CopyTextInputFunction; import org.apache.sysml.runtime.instructions.spark.functions.SparkListener; +import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils; import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.MatrixFormatMetaData; import org.apache.sysml.runtime.matrix.data.CSVFileFormatProperties; import org.apache.sysml.runtime.matrix.data.FileFormatProperties; +import org.apache.sysml.runtime.matrix.data.FrameBlock; import org.apache.sysml.runtime.matrix.data.InputInfo; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; @@ -263,6 +267,21 @@ public class MLContext { } /** + * Register DataFrame as input. DataFrame is assumed to be in row format and each cell can be converted into + * SystemML frame row. Each column could be of type, Double, Float, Long, Integer, String or Boolean. + * <p> + * Marks the variable in the DML script as input variable. + * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation + * would have been created by reading a HDFS file. + * @param varName + * @param df + * @throws DMLRuntimeException + */ + public void registerFrameInput(String varName, DataFrame df) throws DMLRuntimeException { + registerFrameInput(varName, df, false); + } + + /** * Register DataFrame as input. * Marks the variable in the DML script as input variable. * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation @@ -279,6 +298,21 @@ public class MLContext { } /** + * Register DataFrame as input. DataFrame is assumed to be in row format and each cell can be converted into + * SystemML frame row. Each column could be of type, Double, Float, Long, Integer, String or Boolean. + * <p> + * @param varName + * @param df + * @param containsID false if the DataFrame has an column ID which denotes the row ID. + * @throws DMLRuntimeException + */ + public void registerFrameInput(String varName, DataFrame df, boolean containsID) throws DMLRuntimeException { + MatrixCharacteristics mcOut = new MatrixCharacteristics(); + JavaPairRDD<Long, FrameBlock> rdd = FrameRDDConverterUtils.dataFrameToBinaryBlock(new JavaSparkContext(_sc), df, mcOut, containsID); + registerInput(varName, rdd, mcOut.getRows(), mcOut.getCols(), null); + } + + /** * Experimental API. Not supported in Python MLContext API. * @param varName * @param df @@ -520,6 +554,87 @@ public class MLContext { checkIfRegisteringInputAllowed(); } + /** + * Register Frame with CSV/Text as inputs: with dimensions. + * File properties (example: delim, fill, ..) can be specified through props else defaults will be used. + * <p> + * Marks the variable in the DML script as input variable. + * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation + * would have been created by reading a HDFS file. + * @param varName + * @param rdd + * @param format + * @param rlen + * @param clen + * @param props + * @schema schema + * List of column types. + * @throws DMLRuntimeException + */ + public void registerInput(String varName, JavaRDD<String> rddIn, String format, long rlen, long clen, FileFormatProperties props, + List<ValueType> schema) throws DMLRuntimeException { + if(!(DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK)) { + throw new DMLRuntimeException("The registerInput functionality only supported for spark runtime. Please use MLContext(sc) instead of default constructor."); + } + + long nnz = -1; + if(_variables == null) + _variables = new LocalVariableMap(); + if(_inVarnames == null) + _inVarnames = new ArrayList<String>(); + + JavaPairRDD<LongWritable, Text> rddText = rddIn.mapToPair(new ConvertStringToLongTextPair()); + + MatrixCharacteristics mc = new MatrixCharacteristics(rlen, clen, OptimizerUtils.DEFAULT_BLOCKSIZE, OptimizerUtils.DEFAULT_BLOCKSIZE, nnz); + FrameObject fo = new FrameObject(null, new MatrixFormatMetaData(mc, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo)); + JavaPairRDD<Long, FrameBlock> rdd = null; + if( format.equals("csv") ) { + //TODO replace default block size + + rdd = FrameRDDConverterUtils.csvToBinaryBlock(new JavaSparkContext(getSparkContext()), rddText, mc, false, ",", false, -1); + } + else if( format.equals("text") ) { + if(rlen == -1 || clen == -1) { + throw new DMLRuntimeException("The metadata is required in registerInput for format:" + format); + } + //TODO replace default block size + rdd = FrameRDDConverterUtils.textCellToBinaryBlock(new JavaSparkContext(getSparkContext()), rddText, mc, schema); + } + else { + + throw new DMLRuntimeException("Incorrect format in registerInput: " + format); + } + if(props != null) + fo.setFileFormatProperties(props); + + fo.setRDDHandle(new RDDObject(rdd, varName)); + _variables.put(varName, fo); + _inVarnames.add(varName); + checkIfRegisteringInputAllowed(); + } + + private void registerInput(String varName, JavaPairRDD<Long, FrameBlock> rdd, long rlen, long clen, FileFormatProperties props) throws DMLRuntimeException { + if(!(DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK)) { + throw new DMLRuntimeException("The registerInput functionality only supported for spark runtime. Please use MLContext(sc) instead of default constructor."); + } + + if(_variables == null) + _variables = new LocalVariableMap(); + if(_inVarnames == null) + _inVarnames = new ArrayList<String>(); + + MatrixCharacteristics mc = new MatrixCharacteristics(rlen, clen, OptimizerUtils.DEFAULT_BLOCKSIZE, OptimizerUtils.DEFAULT_BLOCKSIZE, -1); + FrameObject fo = new FrameObject(null, new MatrixFormatMetaData(mc, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo)); + + if(props != null) + fo.setFileFormatProperties(props); + + fo.setRDDHandle(new RDDObject(rdd, varName)); + _variables.put(varName, fo); + _inVarnames.add(varName); + checkIfRegisteringInputAllowed(); + } + // ------------------------------------------------------------------------------------ // 3. Binary blocked RDD: Support JavaPairRDD<MatrixIndexes,MatrixBlock> @@ -1008,37 +1123,70 @@ public class MLContext { // Do not check metadata file for registered reads ((DataExpression) source).setCheckMetadata(false); - MatrixObject mo = null; - try { - mo = getMatrixObject(target); - int blp = source.getBeginLine(); int bcp = source.getBeginColumn(); - int elp = source.getEndLine(); int ecp = source.getEndColumn(); - ((DataExpression) source).addVarParam(DataExpression.READROWPARAM, new IntIdentifier(mo.getNumRows(), source.getFilename(), blp, bcp, elp, ecp)); - ((DataExpression) source).addVarParam(DataExpression.READCOLPARAM, new IntIdentifier(mo.getNumColumns(), source.getFilename(), blp, bcp, elp, ecp)); - ((DataExpression) source).addVarParam(DataExpression.READNUMNONZEROPARAM, new IntIdentifier(mo.getNnz(), source.getFilename(), blp, bcp, elp, ecp)); - ((DataExpression) source).addVarParam(DataExpression.DATATYPEPARAM, new StringIdentifier("matrix", source.getFilename(), blp, bcp, elp, ecp)); - ((DataExpression) source).addVarParam(DataExpression.VALUETYPEPARAM, new StringIdentifier("double", source.getFilename(), blp, bcp, elp, ecp)); + if (((DataExpression)source).getDataType() == Expression.DataType.MATRIX) { + + MatrixObject mo = null; - if(mo.getMetaData() instanceof MatrixFormatMetaData) { - MatrixFormatMetaData metaData = (MatrixFormatMetaData) mo.getMetaData(); - if(metaData.getOutputInfo() == OutputInfo.CSVOutputInfo) { - ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_CSV, source.getFilename(), blp, bcp, elp, ecp)); - } - else if(metaData.getOutputInfo() == OutputInfo.TextCellOutputInfo) { - ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_TEXT, source.getFilename(), blp, bcp, elp, ecp)); - } - else if(metaData.getOutputInfo() == OutputInfo.BinaryBlockOutputInfo) { - ((DataExpression) source).addVarParam(DataExpression.ROWBLOCKCOUNTPARAM, new IntIdentifier(mo.getNumRowsPerBlock(), source.getFilename(), blp, bcp, elp, ecp)); - ((DataExpression) source).addVarParam(DataExpression.COLUMNBLOCKCOUNTPARAM, new IntIdentifier(mo.getNumColumnsPerBlock(), source.getFilename(), blp, bcp, elp, ecp)); - ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_BINARY, source.getFilename(), blp, bcp, elp, ecp)); + try { + mo = getMatrixObject(target); + int blp = source.getBeginLine(); int bcp = source.getBeginColumn(); + int elp = source.getEndLine(); int ecp = source.getEndColumn(); + ((DataExpression) source).addVarParam(DataExpression.READROWPARAM, new IntIdentifier(mo.getNumRows(), source.getFilename(), blp, bcp, elp, ecp)); + ((DataExpression) source).addVarParam(DataExpression.READCOLPARAM, new IntIdentifier(mo.getNumColumns(), source.getFilename(), blp, bcp, elp, ecp)); + ((DataExpression) source).addVarParam(DataExpression.READNUMNONZEROPARAM, new IntIdentifier(mo.getNnz(), source.getFilename(), blp, bcp, elp, ecp)); + ((DataExpression) source).addVarParam(DataExpression.DATATYPEPARAM, new StringIdentifier("matrix", source.getFilename(), blp, bcp, elp, ecp)); + ((DataExpression) source).addVarParam(DataExpression.VALUETYPEPARAM, new StringIdentifier("double", source.getFilename(), blp, bcp, elp, ecp)); + + if(mo.getMetaData() instanceof MatrixFormatMetaData) { + MatrixFormatMetaData metaData = (MatrixFormatMetaData) mo.getMetaData(); + if(metaData.getOutputInfo() == OutputInfo.CSVOutputInfo) { + ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_CSV, source.getFilename(), blp, bcp, elp, ecp)); + } + else if(metaData.getOutputInfo() == OutputInfo.TextCellOutputInfo) { + ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_TEXT, source.getFilename(), blp, bcp, elp, ecp)); + } + else if(metaData.getOutputInfo() == OutputInfo.BinaryBlockOutputInfo) { + ((DataExpression) source).addVarParam(DataExpression.ROWBLOCKCOUNTPARAM, new IntIdentifier(mo.getNumRowsPerBlock(), source.getFilename(), blp, bcp, elp, ecp)); + ((DataExpression) source).addVarParam(DataExpression.COLUMNBLOCKCOUNTPARAM, new IntIdentifier(mo.getNumColumnsPerBlock(), source.getFilename(), blp, bcp, elp, ecp)); + ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_BINARY, source.getFilename(), blp, bcp, elp, ecp)); + } + else { + throw new LanguageException("Unsupported format through MLContext"); + } } - else { - throw new LanguageException("Unsupported format through MLContext"); + } catch (DMLRuntimeException e) { + throw new LanguageException(e); + } + } else if (((DataExpression)source).getDataType() == Expression.DataType.FRAME) { + FrameObject mo = null; + try { + mo = getFrameObject(target); + int blp = source.getBeginLine(); int bcp = source.getBeginColumn(); + int elp = source.getEndLine(); int ecp = source.getEndColumn(); + ((DataExpression) source).addVarParam(DataExpression.READROWPARAM, new IntIdentifier(mo.getNumRows(), source.getFilename(), blp, bcp, elp, ecp)); + ((DataExpression) source).addVarParam(DataExpression.READCOLPARAM, new IntIdentifier(mo.getNumColumns(), source.getFilename(), blp, bcp, elp, ecp)); + ((DataExpression) source).addVarParam(DataExpression.DATATYPEPARAM, new StringIdentifier("frame", source.getFilename(), blp, bcp, elp, ecp)); + ((DataExpression) source).addVarParam(DataExpression.VALUETYPEPARAM, new StringIdentifier("double", source.getFilename(), blp, bcp, elp, ecp)); //TODO change to schema + + if(mo.getMetaData() instanceof MatrixFormatMetaData) { + MatrixFormatMetaData metaData = (MatrixFormatMetaData) mo.getMetaData(); + if(metaData.getOutputInfo() == OutputInfo.CSVOutputInfo) { + ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_CSV, source.getFilename(), blp, bcp, elp, ecp)); + } + else if(metaData.getOutputInfo() == OutputInfo.TextCellOutputInfo) { + ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_TEXT, source.getFilename(), blp, bcp, elp, ecp)); + } + else if(metaData.getOutputInfo() == OutputInfo.BinaryBlockOutputInfo) { + ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_BINARY, source.getFilename(), blp, bcp, elp, ecp)); + } + else { + throw new LanguageException("Unsupported format through MLContext"); + } } + } catch (DMLRuntimeException e) { + throw new LanguageException(e); } - } catch (DMLRuntimeException e) { - throw new LanguageException(e); - } + } } } @@ -1129,6 +1277,18 @@ public class MLContext { throw new DMLRuntimeException("ERROR: getMatrixObject not set for variable:" + varName); } + private FrameObject getFrameObject(String varName) throws DMLRuntimeException { + if(_variables != null) { + Data mo = _variables.get(varName); + if(mo instanceof FrameObject) { + return (FrameObject) mo; + } + else { + throw new DMLRuntimeException("ERROR: Incorrect type"); + } + } + throw new DMLRuntimeException("ERROR: getMatrixObject not set for variable:" + varName); + } private int compareVersion(String versionStr1, String versionStr2) { Scanner s1 = null; @@ -1329,7 +1489,7 @@ public class MLContext { if(DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK) { - Map<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> retVal = null; + Map<String, JavaPairRDD<?,?>> retVal = null; // Depending on whether registerInput/registerOutput was called initialize the variables String[] inputs; String[] outputs; @@ -1361,7 +1521,7 @@ public class MLContext { for( String ovar : _outVarnames ) { if( _variables.keySet().contains(ovar) ) { if(retVal == null) { - retVal = new HashMap<String, JavaPairRDD<MatrixIndexes,MatrixBlock>>(); + retVal = new HashMap<String, JavaPairRDD<?,?>>(); } retVal.put(ovar, ((SparkExecutionContext) ec).getBinaryBlockRDDHandleForVariable(ovar)); outMetadata.put(ovar, ec.getMatrixCharacteristics(ovar)); // For converting output to dataframe http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/02a9f277/src/main/java/org/apache/sysml/api/MLOutput.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/MLOutput.java b/src/main/java/org/apache/sysml/api/MLOutput.java index 55daf17..916a652 100644 --- a/src/main/java/org/apache/sysml/api/MLOutput.java +++ b/src/main/java/org/apache/sysml/api/MLOutput.java @@ -27,6 +27,7 @@ import java.util.Map.Entry; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.mllib.linalg.DenseVector; @@ -41,8 +42,11 @@ import org.apache.spark.sql.types.StructType; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.instructions.spark.functions.GetMLBlock; +import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils; import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; +import org.apache.sysml.runtime.matrix.data.CSVFileFormatProperties; +import org.apache.sysml.runtime.matrix.data.FrameBlock; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.util.UtilFunctions; @@ -55,7 +59,7 @@ import scala.Tuple2; */ public class MLOutput { - Map<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> _outputs; + Map<String, JavaPairRDD<?,?>> _outputs; private Map<String, MatrixCharacteristics> _outMetadata = null; public MatrixBlock getMatrixBlock(String varName) throws DMLRuntimeException { @@ -66,14 +70,32 @@ public class MLOutput { mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros()); return mb; } - public MLOutput(Map<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> outputs, Map<String, MatrixCharacteristics> outMetadata) { + + public MLOutput(Map<String, JavaPairRDD<?,?>> outputs, Map<String, MatrixCharacteristics> outMetadata) { this._outputs = outputs; this._outMetadata = outMetadata; } + @SuppressWarnings("unchecked") public JavaPairRDD<MatrixIndexes,MatrixBlock> getBinaryBlockedRDD(String varName) throws DMLRuntimeException { if(_outputs.containsKey(varName)) { - return _outputs.get(varName); + JavaPairRDD<?,?> tmp = _outputs.get(varName); + if (tmp.first()._2() instanceof MatrixBlock) + return (JavaPairRDD<MatrixIndexes,MatrixBlock>)tmp; + else + return null; + } + throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table."); + } + + @SuppressWarnings("unchecked") + public JavaPairRDD<Long,FrameBlock> getFrameBinaryBlockedRDD(String varName) throws DMLRuntimeException { + if(_outputs.containsKey(varName)) { + JavaPairRDD<?,?> tmp = _outputs.get(varName); + if (tmp.first()._2() instanceof FrameBlock) + return (JavaPairRDD<Long,FrameBlock>)tmp; + else + return null; } throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table."); } @@ -197,6 +219,27 @@ public class MLOutput { } + public JavaRDD<String> getStringFrameRDD(String varName, String format, CSVFileFormatProperties fprop ) throws DMLRuntimeException { + JavaPairRDD<Long, FrameBlock> binaryRDD = getFrameBinaryBlockedRDD(varName); + MatrixCharacteristics mcIn = getMatrixCharacteristics(varName); + if(format.equals("csv")) { + return FrameRDDConverterUtils.binaryBlockToCsv(binaryRDD, mcIn, fprop, false); + } + else if(format.equals("text")) { + return FrameRDDConverterUtils.binaryBlockToTextCell(binaryRDD, mcIn); + } + else { + throw new DMLRuntimeException("The output format:" + format + " is not implemented yet."); + } + + } + + public DataFrame getDataFrameRDD(String varName, JavaSparkContext jsc) throws DMLRuntimeException { + JavaPairRDD<Long, FrameBlock> binaryRDD = getFrameBinaryBlockedRDD(varName); + MatrixCharacteristics mcIn = getMatrixCharacteristics(varName); + return FrameRDDConverterUtils.binaryBlockToDataFrame(binaryRDD, mcIn, jsc); + } + public MLMatrix getMLMatrix(MLContext ml, SQLContext sqlContext, String varName) throws DMLRuntimeException { if(sqlContext == null) { throw new DMLRuntimeException("SQLContext is not created."); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/02a9f277/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java index 3a482ef..0c98dea 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java @@ -42,6 +42,7 @@ import org.apache.sysml.api.MLContextProxy; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.caching.CacheException; +import org.apache.sysml.runtime.controlprogram.caching.FrameObject; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.instructions.spark.data.RDDObject; @@ -54,6 +55,7 @@ import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt.Da import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt.DataFrameToBinaryBlockFunction; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.MatrixFormatMetaData; +import org.apache.sysml.runtime.matrix.data.FrameBlock; import org.apache.sysml.runtime.matrix.data.IJV; import org.apache.sysml.runtime.matrix.data.InputInfo; import org.apache.sysml.runtime.matrix.data.MatrixBlock; @@ -187,6 +189,38 @@ public class MLContextConversionUtil { } /** + * Convert a {@code FrameBlock} to a {@code FrameObject}. + * + * @param variableName + * name of the variable associated with the frame + * @param frameBlock + * frame as a FrameBlock + * @param matrixMetadata + * the matrix metadata + * @return the {@code FrameBlock} converted to a {@code FrameObject} + */ + public static FrameObject frameBlockToframeObject(String variableName, FrameBlock frameBlock, + MatrixMetadata matrixMetadata) { + try { + MatrixCharacteristics matrixCharacteristics; + if (matrixMetadata != null) { + matrixCharacteristics = matrixMetadata.asMatrixCharacteristics(); + } else { + matrixCharacteristics = new MatrixCharacteristics(); + } + MatrixFormatMetaData mtd = new MatrixFormatMetaData(matrixCharacteristics, + OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo); + FrameObject frameObject = new FrameObject(MLContextUtil.scratchSpace() + "/" + + variableName, mtd); + frameObject.acquireModify(frameBlock); + frameObject.release(); + return frameObject; + } catch (CacheException e) { + throw new MLContextException("Exception converting MatrixBlock to MatrixObject", e); + } + } + + /** * Convert a {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} to a * {@code MatrixObject}. * http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/02a9f277/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java index ea7857e..120df32 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java @@ -53,6 +53,7 @@ import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.runtime.instructions.cp.DoubleObject; import org.apache.sysml.runtime.instructions.cp.IntObject; import org.apache.sysml.runtime.instructions.cp.StringObject; +import org.apache.sysml.runtime.matrix.data.FrameBlock; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; @@ -456,6 +457,11 @@ public final class MLContextUtil { MatrixObject matrixObject = MLContextConversionUtil.matrixBlockToMatrixObject(name, matrixBlock, matrixMetadata); return matrixObject; + } else if (value instanceof FrameBlock) { + FrameBlock frameBlock = (FrameBlock) value; + FrameObject frameObject = MLContextConversionUtil.frameBlockToframeObject(name, frameBlock, + matrixMetadata); + return frameObject; } else if (value instanceof DataFrame) { DataFrame dataFrame = (DataFrame) value; MatrixObject matrixObject = MLContextConversionUtil http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/02a9f277/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java index 88221f2..4b98f88 100644 --- a/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java +++ b/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java @@ -23,6 +23,11 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -659,4 +664,39 @@ public class UtilFunctions return DataTypes.createStructType(fields); } + /* + * It will return JavaRDD<Row> based on csv data input file. + */ + public static JavaRDD<Row> getRowRDD(JavaSparkContext sc, String fnameIn, String separator, List<ValueType> schema) + { + // Load a text file and convert each line to a java rdd. + JavaRDD<String> dataRdd = sc.textFile(fnameIn); + return dataRdd.map(new RowGenerator(schema)); + } + + /* + * Row Generator class based on individual line in CSV file. + */ + private static class RowGenerator implements Function<String,Row> + { + private static final long serialVersionUID = -6736256507697511070L; + + List<ValueType> _schema = null; + + public RowGenerator(List<ValueType> schema) + { + _schema = schema; + } + + @Override + public Row call(String record) throws Exception { + String[] fields = record.split(","); + Object[] objects = new Object[fields.length]; + for (int i=0; i<fields.length; i++) { + objects[i] = UtilFunctions.stringToObject(_schema.get(i), fields[i]); + } + return RowFactory.create(objects); + } + } + } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/02a9f277/src/test/java/org/apache/sysml/test/integration/functions/frame/FrameConverterTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/frame/FrameConverterTest.java b/src/test/java/org/apache/sysml/test/integration/functions/frame/FrameConverterTest.java index 441a63b..fb076bd 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/frame/FrameConverterTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/frame/FrameConverterTest.java @@ -27,13 +27,11 @@ import java.util.List; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; -import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.StructType; import org.apache.sysml.api.DMLScript; @@ -522,7 +520,7 @@ public class FrameConverterTest extends AutomatedTestBase //Create DataFrame SQLContext sqlContext = new SQLContext(sc); StructType dfSchema = UtilFunctions.convertFrameSchemaToDFSchema(schema); - JavaRDD<Row> rowRDD = getRowRDD(sc, fnameIn, separator); + JavaRDD<Row> rowRDD = UtilFunctions.getRowRDD(sc, fnameIn, separator, schema); DataFrame df = sqlContext.createDataFrame(rowRDD, dfSchema); JavaPairRDD<LongWritable, FrameBlock> rddOut = FrameRDDConverterUtils @@ -552,33 +550,4 @@ public class FrameConverterTest extends AutomatedTestBase sec.close(); } - - /* - * It will return JavaRDD<Row> based on csv data input file. - */ - JavaRDD<Row> getRowRDD(JavaSparkContext sc, String fnameIn, String separator) - { - // Load a text file and convert each line to a java rdd. - JavaRDD<String> dataRdd = sc.textFile(fnameIn); - return dataRdd.map(new RowGenerator()); - } - - /* - * Row Generator class based on individual line in CSV file. - */ - private static class RowGenerator implements Function<String,Row> - { - private static final long serialVersionUID = -6736256507697511070L; - - @Override - public Row call(String record) throws Exception { - String[] fields = record.split(","); - Object[] objects = new Object[fields.length]; - for (int i=0; i<fields.length; i++) { - objects[i] = fields[i]; - } - return RowFactory.create(objects); - } - } - } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/02a9f277/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java new file mode 100644 index 0000000..b6184cf --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.mlcontext; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; + +import org.apache.hadoop.io.LongWritable; +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.StructType; +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysml.api.DMLException; +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.api.MLContext; +import org.apache.sysml.api.MLOutput; +import org.apache.sysml.parser.Expression.ValueType; +import org.apache.sysml.parser.DataExpression; +import org.apache.sysml.parser.ParseException; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils; +import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils.LongFrameToLongWritableFrameFunction; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; +import org.apache.sysml.runtime.matrix.data.CSVFileFormatProperties; +import org.apache.sysml.runtime.matrix.data.FrameBlock; +import org.apache.sysml.runtime.matrix.data.InputInfo; +import org.apache.sysml.runtime.matrix.data.OutputInfo; +import org.apache.sysml.runtime.util.MapReduceTool; +import org.apache.sysml.runtime.util.UtilFunctions; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; + + +public class FrameTest extends AutomatedTestBase +{ + private final static String TEST_DIR = "functions/frame/"; + private final static String TEST_NAME = "FrameGeneral"; + private final static String TEST_CLASS_DIR = TEST_DIR + FrameTest.class.getSimpleName() + "/"; + + private final static int min=0; + private final static int max=100; + private final static int rows = 2245; + private final static int cols = 1264; + + private final static double sparsity1 = 1.0; + private final static double sparsity2 = 0.35; + + private final static double epsilon=0.0000000001; + + + private final static List<ValueType> schemaMixedLargeListStr = Collections.nCopies(cols/4, ValueType.STRING); + private final static List<ValueType> schemaMixedLargeListDble = Collections.nCopies(cols/4, ValueType.DOUBLE); + private final static List<ValueType> schemaMixedLargeListInt = Collections.nCopies(cols/4, ValueType.INT); + private final static List<ValueType> schemaMixedLargeListBool = Collections.nCopies(cols/4, ValueType.BOOLEAN); + private static ValueType[] schemaMixedLarge = null; + static { + final List<ValueType> schemaMixedLargeList = new ArrayList<ValueType>(schemaMixedLargeListStr); + schemaMixedLargeList.addAll(schemaMixedLargeListDble); + schemaMixedLargeList.addAll(schemaMixedLargeListInt); + schemaMixedLargeList.addAll(schemaMixedLargeListBool); + schemaMixedLarge = new ValueType[schemaMixedLargeList.size()]; + schemaMixedLarge = (ValueType[]) schemaMixedLargeList.toArray(schemaMixedLarge); + } + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, + new String[] {"AB", "C"})); + } + + @Test + public void testCSVInCSVOut() throws IOException, DMLException, ParseException { + testFrameGeneral(InputInfo.CSVInputInfo, OutputInfo.CSVOutputInfo); + } + + @Test + public void testCSVInTextOut() throws IOException, DMLException, ParseException { + testFrameGeneral(InputInfo.TextCellInputInfo, OutputInfo.CSVOutputInfo); + } + + @Test + public void testTextInCSVOut() throws IOException, DMLException, ParseException { + testFrameGeneral(InputInfo.CSVInputInfo, OutputInfo.TextCellOutputInfo); + } + + @Test + public void testTextInTextOut() throws IOException, DMLException, ParseException { + testFrameGeneral(InputInfo.TextCellInputInfo, OutputInfo.TextCellOutputInfo); + } + + @Test + public void testDataFrameInCSVOut() throws IOException, DMLException, ParseException { + testFrameGeneral(InputInfo.CSVInputInfo, true, false); + } + + @Test + public void testDataFrameInTextOut() throws IOException, DMLException, ParseException { + testFrameGeneral(InputInfo.TextCellInputInfo, true, false); + } + + @Test + public void testDataFrameInDataFrameOut() throws IOException, DMLException, ParseException { + testFrameGeneral(true, true); + } + + private void testFrameGeneral(InputInfo iinfo, OutputInfo oinfo) throws IOException, DMLException, ParseException { + testFrameGeneral(iinfo, oinfo, false, false); + } + + private void testFrameGeneral(InputInfo iinfo, boolean bFromDataFrame, boolean bToDataFrame) throws IOException, DMLException, ParseException { + testFrameGeneral(iinfo, OutputInfo.CSVOutputInfo, bFromDataFrame, bToDataFrame); + } + + private void testFrameGeneral(boolean bFromDataFrame, boolean bToDataFrame) throws IOException, DMLException, ParseException { + testFrameGeneral(InputInfo.BinaryBlockInputInfo, OutputInfo.CSVOutputInfo, bFromDataFrame, bToDataFrame); + } + + private void testFrameGeneral(InputInfo iinfo, OutputInfo oinfo, boolean bFromDataFrame, boolean bToDataFrame) throws IOException, DMLException, ParseException { + + boolean oldConfig = DMLScript.USE_LOCAL_SPARK_CONFIG; + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + RUNTIME_PLATFORM oldRT = DMLScript.rtplatform; + DMLScript.rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; + + this.scriptType = ScriptType.DML; + + int rowstart = 234, rowend = 1478, colstart = 125, colend = 568; + int bRows = rowend-rowstart+1, bCols = colend-colstart+1; + + int rowstartC = 124, rowendC = 1178, colstartC = 143, colendC = 368; + int cRows = rowendC-rowstartC+1, cCols = colendC-colstartC+1; + + HashMap<String, ValueType[]> outputSchema = new HashMap<String, ValueType[]>(); + HashMap<String, MatrixCharacteristics> outputMC = new HashMap<String, MatrixCharacteristics>(); + + TestConfiguration config = getTestConfiguration(TEST_NAME); + + loadTestConfiguration(config); + + List<String> proArgs = new ArrayList<String>(); + proArgs.add(input("A")); + proArgs.add(Integer.toString(rows)); + proArgs.add(Integer.toString(cols)); + proArgs.add(input("B")); + proArgs.add(Integer.toString(bRows)); + proArgs.add(Integer.toString(bCols)); + proArgs.add(Integer.toString(rowstart)); + proArgs.add(Integer.toString(rowend)); + proArgs.add(Integer.toString(colstart)); + proArgs.add(Integer.toString(colend)); + proArgs.add(output("A")); + proArgs.add(Integer.toString(rowstartC)); + proArgs.add(Integer.toString(rowendC)); + proArgs.add(Integer.toString(colstartC)); + proArgs.add(Integer.toString(colendC)); + proArgs.add(output("C")); + programArgs = proArgs.toArray(new String[proArgs.size()]); + + fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME + ".dml"; + + ValueType[] schema = schemaMixedLarge; + + //initialize the frame data. + List<ValueType> lschema = Arrays.asList(schema); + + fullRScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME + ".R"; + rCmd = "Rscript" + " " + fullRScriptName + " " + + inputDir() + " " + rowstart + " " + rowend + " " + colstart + " " + colend + " " + expectedDir() + + " " + rowstartC + " " + rowendC + " " + colstartC + " " + colendC; + + double sparsity=sparsity1;//rand.nextDouble(); + double[][] A = getRandomMatrix(rows, cols, min, max, sparsity, 1111 /*\\System.currentTimeMillis()*/); + writeInputFrameWithMTD("A", A, true, lschema, oinfo); + + sparsity=sparsity2;//rand.nextDouble(); + double[][] B = getRandomMatrix((int)(bRows), (int)(bCols), min, max, sparsity, 2345 /*System.currentTimeMillis()*/); + //Following way of creation causes serialization issue in frame processing + //List<ValueType> lschemaB = lschema.subList((int)colstart-1, (int)colend); + ValueType[] schemaB = new ValueType[bCols]; + for (int i = 0; i < bCols; ++i) + schemaB[i] = schema[colstart-1+i]; + List<ValueType> lschemaB = Arrays.asList(schemaB); + writeInputFrameWithMTD("B", B, true, lschemaB, oinfo); + + ValueType[] schemaC = new ValueType[colendC-colstartC+1]; + for (int i = 0; i < cCols; ++i) + schemaC[i] = schema[colstartC-1+i]; + + MLContext mlCtx = getMLContextForTesting(); + SparkContext sc = mlCtx.getSparkContext(); + JavaSparkContext jsc = new JavaSparkContext(sc); + + DataFrame dfA = null, dfB = null; + if(bFromDataFrame) + { + //Create DataFrame for input A + SQLContext sqlContext = new SQLContext(sc); + StructType dfSchemaA = UtilFunctions.convertFrameSchemaToDFSchema(lschema); + JavaRDD<Row> rowRDDA = UtilFunctions.getRowRDD(jsc, input("A"), DataExpression.DEFAULT_DELIM_DELIMITER, lschema); + dfA = sqlContext.createDataFrame(rowRDDA, dfSchemaA); + + //Create DataFrame for input B + StructType dfSchemaB = UtilFunctions.convertFrameSchemaToDFSchema(lschemaB); + JavaRDD<Row> rowRDDB = UtilFunctions.getRowRDD(jsc, input("B"), DataExpression.DEFAULT_DELIM_DELIMITER, lschemaB); + dfB = sqlContext.createDataFrame(rowRDDB, dfSchemaB); + } + + try + { + mlCtx.reset(true); // Cleanup config to ensure future MLContext testcases have correct 'cp.parallel.matrixmult' + + String format = "csv"; + if(oinfo == OutputInfo.TextCellOutputInfo) + format = "text"; + + if(bFromDataFrame) + mlCtx.registerFrameInput("A", dfA, false); + else { + JavaRDD<String> aIn = jsc.textFile(input("A")); + mlCtx.registerInput("A", aIn, format, rows, cols, new CSVFileFormatProperties(), lschema); + } + + if(bFromDataFrame) + mlCtx.registerFrameInput("B", dfB, false); + else { + JavaRDD<String> bIn = jsc.textFile(input("B")); + mlCtx.registerInput("B", bIn, format, bRows, bCols, new CSVFileFormatProperties(), lschemaB); + } + + // Output one frame to HDFS and get one as RDD //TODO HDFS input/output to do + mlCtx.registerOutput("A"); + mlCtx.registerOutput("C"); + + MLOutput out = mlCtx.execute(fullDMLScriptName, programArgs); + + format = "csv"; + if(iinfo == InputInfo.TextCellInputInfo) + format = "text"; + + String fName = output("AB"); + try { + MapReduceTool.deleteFileIfExistOnHDFS( fName ); + } catch (IOException e) { + throw new DMLRuntimeException("Error: While deleting file on HDFS"); + } + + if(!bToDataFrame) + { + JavaRDD<String> aOut = out.getStringFrameRDD("A", format, new CSVFileFormatProperties()); + aOut.saveAsTextFile(fName); + } else { + DataFrame df = out.getDataFrameRDD("A", jsc); + + //Convert back DataFrame to binary block for comparison using original binary to converted DF and back to binary + MatrixCharacteristics mc = new MatrixCharacteristics(rows, cols, -1, -1, -1); + JavaPairRDD<LongWritable, FrameBlock> rddOut = FrameRDDConverterUtils + .dataFrameToBinaryBlock(jsc, df, mc, false) + .mapToPair(new LongFrameToLongWritableFrameFunction()); + rddOut.saveAsHadoopFile(output("AB"), LongWritable.class, FrameBlock.class, OutputInfo.BinaryBlockOutputInfo.outputFormatClass); + } + + fName = output("C"); + try { + MapReduceTool.deleteFileIfExistOnHDFS( fName ); + } catch (IOException e) { + throw new DMLRuntimeException("Error: While deleting file on HDFS"); + } + if(!bToDataFrame) + { + JavaRDD<String> aOut = out.getStringFrameRDD("C", format, new CSVFileFormatProperties()); + aOut.saveAsTextFile(fName); + } else { + DataFrame df = out.getDataFrameRDD("C", jsc); + + //Convert back DataFrame to binary block for comparison using original binary to converted DF and back to binary + MatrixCharacteristics mc = new MatrixCharacteristics(cRows, cCols, -1, -1, -1); + JavaPairRDD<LongWritable, FrameBlock> rddOut = FrameRDDConverterUtils + .dataFrameToBinaryBlock(jsc, df, mc, false) + .mapToPair(new LongFrameToLongWritableFrameFunction()); + rddOut.saveAsHadoopFile(fName, LongWritable.class, FrameBlock.class, OutputInfo.BinaryBlockOutputInfo.outputFormatClass); + } + + runRScript(true); + + outputSchema.put("AB", schema); + outputMC.put("AB", new MatrixCharacteristics(rows, cols, -1, -1)); + outputSchema.put("C", schemaC); + outputMC.put("C", new MatrixCharacteristics(cRows, cCols, -1, -1)); + + for(String file: config.getOutputFiles()) + { + MatrixCharacteristics md = outputMC.get(file); + FrameBlock frameBlock = readDMLFrameFromHDFS(file, iinfo, md); + FrameBlock frameRBlock = readRFrameFromHDFS(file+".csv", InputInfo.CSVInputInfo, md); + ValueType[] schemaOut = outputSchema.get(file); + verifyFrameData(frameBlock, frameRBlock, schemaOut); + System.out.println("File " + file + " processed successfully."); + } + + //cleanup mlcontext (prevent test memory leaks) + mlCtx.reset(); + + System.out.println("Frame MLContext test completed successfully."); + } + finally { + DMLScript.rtplatform = oldRT; + DMLScript.USE_LOCAL_SPARK_CONFIG = oldConfig; + } + } + + private void verifyFrameData(FrameBlock frame1, FrameBlock frame2, ValueType[] schema) { + for ( int i=0; i<frame1.getNumRows(); i++ ) + for( int j=0; j<frame1.getNumColumns(); j++ ) { + Object val1 = UtilFunctions.stringToObject(schema[j], UtilFunctions.objectToString(frame1.get(i, j))); + Object val2 = UtilFunctions.stringToObject(schema[j], UtilFunctions.objectToString(frame2.get(i, j))); + if( TestUtils.compareToR(schema[j], val1, val2, epsilon) != 0) + Assert.fail("The DML data for cell ("+ i + "," + j + ") is " + val1 + + ", not same as the R value " + val2); + } + } + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/02a9f277/src/test/scripts/functions/frame/FrameGeneral.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/frame/FrameGeneral.R b/src/test/scripts/functions/frame/FrameGeneral.R new file mode 100644 index 0000000..079c74c --- /dev/null +++ b/src/test/scripts/functions/frame/FrameGeneral.R @@ -0,0 +1,35 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) +options(digits=22) +library("Matrix") + +A=read.csv(paste(args[1], "A.csv", sep=""), header = FALSE, stringsAsFactors=FALSE) +B=read.csv(paste(args[1], "B.csv", sep=""), header = FALSE, stringsAsFactors=FALSE) + +A[args[2]:args[3],args[4]:args[5]]=0 +A[args[2]:args[3],args[4]:args[5]]=B +write.csv(A, paste(args[6], "AB.csv", sep=""), row.names = FALSE, quote = FALSE) + +C=A[args[7]:args[8],args[9]:args[10]] +write.csv(C, paste(args[6], "C.csv", sep=""), row.names = FALSE, quote = FALSE) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/02a9f277/src/test/scripts/functions/frame/FrameGeneral.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/frame/FrameGeneral.dml b/src/test/scripts/functions/frame/FrameGeneral.dml new file mode 100644 index 0000000..9d9a2f7 --- /dev/null +++ b/src/test/scripts/functions/frame/FrameGeneral.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +# +# Left Indexing test +A=read($1, data_type="frame", rows=$2, cols=$3) +B=read($4, data_type="frame", rows=$5, cols=$6) +A[$7:$8,$9:$10]=B +write(A, $11) + +# Right Indexing test +C=A[$12:$13,$14:$15] +write(C, $16)
