Repository: incubator-systemml Updated Branches: refs/heads/master d3cfcafcf -> 457a97db8
[SYSTEMML-1234] Migrate FrameTest to new MLContext Migrate FrameTest from old MLContext API to new MLContext API. Fix MLContextConversionUtil frameObjectToListStringIJV and frameObjectToListStringCSV to not output null values. Closes #380. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/457a97db Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/457a97db Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/457a97db Branch: refs/heads/master Commit: 457a97db8c7b5483a3750cc143c98c77a3196db5 Parents: d3cfcaf Author: Deron Eriksson <de...@us.ibm.com> Authored: Thu Feb 9 10:05:46 2017 -0800 Committer: Deron Eriksson <de...@us.ibm.com> Committed: Thu Feb 9 10:05:46 2017 -0800 ---------------------------------------------------------------------- .../api/mlcontext/MLContextConversionUtil.java | 20 ++- .../functions/mlcontext/FrameTest.java | 167 ++++++++++++------- 2 files changed, 115 insertions(+), 72 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457a97db/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java index cca9d2c..5414e4d 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java @@ -1100,7 +1100,9 @@ public class MLContextConversionUtil { if (j > 0) { sb.append(delimiter); } - sb.append(fb.get(i, j)); + if (fb.get(i, j) != null) { + sb.append(fb.get(i, j)); + } } list.add(sb.toString()); } @@ -1185,13 +1187,15 @@ public class MLContextConversionUtil { for (int i = 0; i < rows; i++) { sb = new StringBuilder(); for (int j = 0; j < cols; j++) { - sb = new StringBuilder(); - sb.append(i + 1); - sb.append(" "); - sb.append(j + 1); - sb.append(" "); - sb.append(fb.get(i, j)); - list.add(sb.toString()); + if (fb.get(i, j) != null) { + sb = new StringBuilder(); + sb.append(i + 1); + sb.append(" "); + sb.append(j + 1); + sb.append(" "); + sb.append(fb.get(i, j)); + list.add(sb.toString()); + } } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457a97db/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java index 1b29077..e6a947f 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java @@ -27,7 +27,7 @@ import java.util.HashMap; import java.util.List; import org.apache.hadoop.io.LongWritable; -import org.apache.spark.SparkContext; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -38,9 +38,13 @@ import org.apache.spark.sql.types.StructType; import org.apache.sysml.api.DMLException; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; -import org.apache.sysml.api.MLContext; -import org.apache.sysml.api.MLContextProxy; -import org.apache.sysml.api.MLOutput; +import org.apache.sysml.api.mlcontext.FrameFormat; +import org.apache.sysml.api.mlcontext.FrameMetadata; +import org.apache.sysml.api.mlcontext.FrameSchema; +import org.apache.sysml.api.mlcontext.MLContext; +import org.apache.sysml.api.mlcontext.MLResults; +import org.apache.sysml.api.mlcontext.Script; +import org.apache.sysml.api.mlcontext.ScriptFactory; import org.apache.sysml.parser.DataExpression; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.parser.ParseException; @@ -49,7 +53,6 @@ import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils; import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils.LongFrameToLongWritableFrameFunction; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; -import org.apache.sysml.runtime.matrix.data.CSVFileFormatProperties; import org.apache.sysml.runtime.matrix.data.FrameBlock; import org.apache.sysml.runtime.matrix.data.InputInfo; import org.apache.sysml.runtime.matrix.data.OutputInfo; @@ -58,7 +61,10 @@ import org.apache.sysml.runtime.util.UtilFunctions; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; import org.apache.sysml.test.utils.TestUtils; +import org.junit.After; +import org.junit.AfterClass; import org.junit.Assert; +import org.junit.BeforeClass; import org.junit.Test; @@ -93,7 +99,21 @@ public class FrameTest extends AutomatedTestBase schemaMixedLarge = new ValueType[schemaMixedLargeList.size()]; schemaMixedLarge = (ValueType[]) schemaMixedLargeList.toArray(schemaMixedLarge); } - + + private static SparkConf conf; + private static JavaSparkContext sc; + private static MLContext ml; + + @BeforeClass + public static void setUpClass() { + if (conf == null) + conf = SparkExecutionContext.createSystemMLSparkConf() + .setAppName("FrameTest").setMaster("local"); + if (sc == null) + sc = new JavaSparkContext(conf); + ml = new MLContext(sc); + } + @Override public void setUp() { addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, @@ -154,8 +174,6 @@ public class FrameTest extends AutomatedTestBase RUNTIME_PLATFORM oldRT = DMLScript.rtplatform; DMLScript.rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; - this.scriptType = ScriptType.DML; - int rowstart = 234, rowend = 1478, colstart = 125, colend = 568; int bRows = rowend-rowstart+1, bCols = colend-colstart+1; @@ -186,7 +204,6 @@ public class FrameTest extends AutomatedTestBase proArgs.add(Integer.toString(colstartC)); proArgs.add(Integer.toString(colendC)); proArgs.add(output("C")); - programArgs = proArgs.toArray(new String[proArgs.size()]); fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME + ".dml"; @@ -199,71 +216,75 @@ public class FrameTest extends AutomatedTestBase rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + rowstart + " " + rowend + " " + colstart + " " + colend + " " + expectedDir() + " " + rowstartC + " " + rowendC + " " + colstartC + " " + colendC; - - double sparsity=sparsity1;//rand.nextDouble(); - double[][] A = getRandomMatrix(rows, cols, min, max, sparsity, 1111 /*\\System.currentTimeMillis()*/); - writeInputFrameWithMTD("A", A, true, schema, oinfo); - - sparsity=sparsity2;//rand.nextDouble(); - double[][] B = getRandomMatrix((int)(bRows), (int)(bCols), min, max, sparsity, 2345 /*System.currentTimeMillis()*/); - //Following way of creation causes serialization issue in frame processing - //List<ValueType> lschemaB = lschema.subList((int)colstart-1, (int)colend); - ValueType[] schemaB = new ValueType[bCols]; - for (int i = 0; i < bCols; ++i) - schemaB[i] = schema[colstart-1+i]; + + double sparsity = sparsity1; + double[][] A = getRandomMatrix(rows, cols, min, max, sparsity, 1111); + writeInputFrameWithMTD("A", A, true, schema, oinfo); + + sparsity = sparsity2; + double[][] B = getRandomMatrix((int) (bRows), (int) (bCols), min, max, sparsity, 2345); + + ValueType[] schemaB = new ValueType[bCols]; + for (int i = 0; i < bCols; ++i) + schemaB[i] = schema[colstart - 1 + i]; List<ValueType> lschemaB = Arrays.asList(schemaB); - writeInputFrameWithMTD("B", B, true, schemaB, oinfo); + writeInputFrameWithMTD("B", B, true, schemaB, oinfo); + + ValueType[] schemaC = new ValueType[colendC - colstartC + 1]; + for (int i = 0; i < cCols; ++i) + schemaC[i] = schema[colstartC - 1 + i]; - ValueType[] schemaC = new ValueType[colendC-colstartC+1]; - for (int i = 0; i < cCols; ++i) - schemaC[i] = schema[colstartC-1+i]; - - MLContext mlCtx = getMLContextForTesting(); - SparkContext sc = mlCtx.getSparkContext(); - JavaSparkContext jsc = new JavaSparkContext(sc); - Dataset<Row> dfA = null, dfB = null; if(bFromDataFrame) { //Create DataFrame for input A SQLContext sqlContext = new SQLContext(sc); StructType dfSchemaA = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schema, false); - JavaRDD<Row> rowRDDA = FrameRDDConverterUtils.csvToRowRDD(jsc, input("A"), DataExpression.DEFAULT_DELIM_DELIMITER, schema); + JavaRDD<Row> rowRDDA = FrameRDDConverterUtils.csvToRowRDD(sc, input("A"), DataExpression.DEFAULT_DELIM_DELIMITER, schema); dfA = sqlContext.createDataFrame(rowRDDA, dfSchemaA); //Create DataFrame for input B StructType dfSchemaB = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schemaB, false); - JavaRDD<Row> rowRDDB = FrameRDDConverterUtils.csvToRowRDD(jsc, input("B"), DataExpression.DEFAULT_DELIM_DELIMITER, schemaB); + JavaRDD<Row> rowRDDB = FrameRDDConverterUtils.csvToRowRDD(sc, input("B"), DataExpression.DEFAULT_DELIM_DELIMITER, schemaB); dfB = sqlContext.createDataFrame(rowRDDB, dfSchemaB); } try { - mlCtx.reset(true); // Cleanup config to ensure future MLContext testcases have correct 'cp.parallel.matrixmult' + Script script = ScriptFactory.dmlFromFile(fullDMLScriptName); String format = "csv"; if(oinfo == OutputInfo.TextCellOutputInfo) format = "text"; - if(bFromDataFrame) - mlCtx.registerFrameInput("A", dfA, false); - else { - JavaRDD<String> aIn = jsc.textFile(input("A")); - mlCtx.registerInput("A", aIn, format, rows, cols, new CSVFileFormatProperties(), lschema); + if(bFromDataFrame) { + script.in("A", dfA); + } else { + JavaRDD<String> aIn = sc.textFile(input("A")); + FrameSchema fs = new FrameSchema(lschema); + FrameFormat ff = (format.equals("text")) ? FrameFormat.IJV : FrameFormat.CSV; + FrameMetadata fm = new FrameMetadata(ff, fs, rows, cols); + script.in("A", aIn, fm); } - if(bFromDataFrame) - mlCtx.registerFrameInput("B", dfB, false); - else { - JavaRDD<String> bIn = jsc.textFile(input("B")); - mlCtx.registerInput("B", bIn, format, bRows, bCols, new CSVFileFormatProperties(), lschemaB); + if(bFromDataFrame) { + script.in("B", dfB); + } else { + JavaRDD<String> bIn = sc.textFile(input("B")); + FrameSchema fs = new FrameSchema(lschemaB); + FrameFormat ff = (format.equals("text")) ? FrameFormat.IJV : FrameFormat.CSV; + FrameMetadata fm = new FrameMetadata(ff, fs, bRows, bCols); + script.in("B", bIn, fm); } // Output one frame to HDFS and get one as RDD //TODO HDFS input/output to do - mlCtx.registerOutput("A"); - mlCtx.registerOutput("C"); + script.out("A", "C"); - MLOutput out = mlCtx.execute(fullDMLScriptName, programArgs); + // set positional argument values + for (int argNum = 1; argNum <= proArgs.size(); argNum++) { + script.in("$" + argNum, proArgs.get(argNum-1)); + } + MLResults results = ml.execute(script); format = "csv"; if(iinfo == InputInfo.TextCellInputInfo) @@ -278,15 +299,20 @@ public class FrameTest extends AutomatedTestBase if(!bToDataFrame) { - JavaRDD<String> aOut = out.getStringFrameRDD("A", format, new CSVFileFormatProperties()); - aOut.saveAsTextFile(fName); + if (format.equals("text")) { + JavaRDD<String> javaRDDStringIJV = results.getJavaRDDStringIJV("A"); + javaRDDStringIJV.saveAsTextFile(fName); + } else { + JavaRDD<String> javaRDDStringCSV = results.getJavaRDDStringCSV("A"); + javaRDDStringCSV.saveAsTextFile(fName); + } } else { - Dataset<Row> df = out.getDataFrameRDD("A", jsc); + Dataset<Row> df = results.getDataFrame("A"); //Convert back DataFrame to binary block for comparison using original binary to converted DF and back to binary MatrixCharacteristics mc = new MatrixCharacteristics(rows, cols, -1, -1, -1); JavaPairRDD<LongWritable, FrameBlock> rddOut = FrameRDDConverterUtils - .dataFrameToBinaryBlock(jsc, df, mc, bFromDataFrame) + .dataFrameToBinaryBlock(sc, df, mc, bFromDataFrame) .mapToPair(new LongFrameToLongWritableFrameFunction()); rddOut.saveAsHadoopFile(output("AB"), LongWritable.class, FrameBlock.class, OutputInfo.BinaryBlockOutputInfo.outputFormatClass); } @@ -299,15 +325,20 @@ public class FrameTest extends AutomatedTestBase } if(!bToDataFrame) { - JavaRDD<String> aOut = out.getStringFrameRDD("C", format, new CSVFileFormatProperties()); - aOut.saveAsTextFile(fName); + if (format.equals("text")) { + JavaRDD<String> javaRDDStringIJV = results.getJavaRDDStringIJV("C"); + javaRDDStringIJV.saveAsTextFile(fName); + } else { + JavaRDD<String> javaRDDStringCSV = results.getJavaRDDStringCSV("C"); + javaRDDStringCSV.saveAsTextFile(fName); + } } else { - Dataset<Row> df = out.getDataFrameRDD("C", jsc); + Dataset<Row> df = results.getDataFrame("C"); //Convert back DataFrame to binary block for comparison using original binary to converted DF and back to binary MatrixCharacteristics mc = new MatrixCharacteristics(cRows, cCols, -1, -1, -1); JavaPairRDD<LongWritable, FrameBlock> rddOut = FrameRDDConverterUtils - .dataFrameToBinaryBlock(jsc, df, mc, bFromDataFrame) + .dataFrameToBinaryBlock(sc, df, mc, bFromDataFrame) .mapToPair(new LongFrameToLongWritableFrameFunction()); rddOut.saveAsHadoopFile(fName, LongWritable.class, FrameBlock.class, OutputInfo.BinaryBlockOutputInfo.outputFormatClass); } @@ -329,20 +360,11 @@ public class FrameTest extends AutomatedTestBase System.out.println("File " + file + " processed successfully."); } - //cleanup mlcontext (prevent test memory leaks) - mlCtx.reset(); - System.out.println("Frame MLContext test completed successfully."); } finally { DMLScript.rtplatform = oldRT; DMLScript.USE_LOCAL_SPARK_CONFIG = oldConfig; - - if (sc != null) { - sc.stop(); - } - SparkExecutionContext.resetSparkContextStatic(); - MLContextProxy.setActive(false); } } @@ -357,4 +379,21 @@ public class FrameTest extends AutomatedTestBase } } -} \ No newline at end of file + @After + public void tearDown() { + super.tearDown(); + } + + @AfterClass + public static void tearDownClass() { + // stop spark context to allow single jvm tests (otherwise the + // next test that tries to create a SparkContext would fail) + sc.stop(); + sc = null; + conf = null; + + // clear status mlcontext and spark exec context + ml.close(); + ml = null; + } +}