Repository: incubator-systemml Updated Branches: refs/heads/master 326c1c00e -> 6158bfaf9
[SYSTEMML-1235] Migrate GNMFTest to new MLContext Closes #381. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/6158bfaf Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/6158bfaf Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/6158bfaf Branch: refs/heads/master Commit: 6158bfaf9079b7a3882e709cbc6d873180c5f373 Parents: 326c1c0 Author: Deron Eriksson <[email protected]> Authored: Tue Feb 7 16:36:51 2017 -0800 Committer: Deron Eriksson <[email protected]> Committed: Tue Feb 7 16:36:51 2017 -0800 ---------------------------------------------------------------------- .../functions/mlcontext/GNMFTest.java | 90 ++++++++++++-------- 1 file changed, 53 insertions(+), 37 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/6158bfaf/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java index 89a4363..99ab53b 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java @@ -26,18 +26,24 @@ import java.util.Collection; import java.util.HashMap; import java.util.List; -import org.apache.spark.SparkContext; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix; import org.apache.spark.mllib.linalg.distributed.MatrixEntry; +import org.apache.spark.rdd.RDD; import org.apache.sysml.api.DMLException; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; -import org.apache.sysml.api.MLContext; -import org.apache.sysml.api.MLContextProxy; -import org.apache.sysml.api.MLOutput; +import org.apache.sysml.api.mlcontext.MLContext; +import org.apache.sysml.api.mlcontext.MLResults; +import org.apache.sysml.api.mlcontext.Matrix; +import org.apache.sysml.api.mlcontext.MatrixFormat; +import org.apache.sysml.api.mlcontext.MatrixMetadata; +import org.apache.sysml.api.mlcontext.Script; +import org.apache.sysml.api.mlcontext.ScriptFactory; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.parser.ParseException; import org.apache.sysml.runtime.DMLRuntimeException; @@ -52,6 +58,7 @@ import org.apache.sysml.runtime.util.MapReduceTool; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.utils.TestUtils; import org.junit.Assert; +import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -67,12 +74,26 @@ public class GNMFTest extends AutomatedTestBase int numRegisteredInputs; int numRegisteredOutputs; - + + private static SparkConf conf; + private static JavaSparkContext sc; + private static MLContext ml; + public GNMFTest(int in, int out) { numRegisteredInputs = in; numRegisteredOutputs = out; } - + + @BeforeClass + public static void setUpClass() { + if (conf == null) + conf = SparkExecutionContext.createSystemMLSparkConf() + .setAppName("GNMFTest").setMaster("local"); + if (sc == null) + sc = new JavaSparkContext(conf); + ml = new MLContext(sc); + } + @Parameters public static Collection<Object[]> data() { Object[][] data = new Object[][] { { 0, 0 }, { 3, 2 }, { 2, 2 }, { 2, 1 }, { 2, 0 }, { 3, 0 }}; @@ -145,43 +166,46 @@ public class GNMFTest extends AutomatedTestBase DMLScript.USE_LOCAL_SPARK_CONFIG = true; RUNTIME_PLATFORM oldRT = DMLScript.rtplatform; - MLContext mlCtx = null; - SparkContext sc = null; try { DMLScript.rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; - - mlCtx = getMLContextForTesting(); - sc = mlCtx.getSparkContext(); - mlCtx.reset(true); // Cleanup config to ensure future MLContext testcases have correct 'cp.parallel.matrixmult' - + + Script script = ScriptFactory.dmlFromFile(fullDMLScriptName); + // set positional argument values + for (int argNum = 1; argNum <= proArgs.size(); argNum++) { + script.in("$" + argNum, proArgs.get(argNum-1)); + } + // Read two matrices through RDD and one through HDFS if(numRegisteredInputs >= 1) { - JavaRDD<String> vIn = sc.textFile(input("v"), 2).toJavaRDD(); - mlCtx.registerInput("V", vIn, "text", m, n); + JavaRDD<String> vIn = sc.sc().textFile(input("v"), 2).toJavaRDD(); + MatrixMetadata mm = new MatrixMetadata(MatrixFormat.IJV, m, n); + script.in("V", vIn, mm); } if(numRegisteredInputs >= 2) { - JavaRDD<String> wIn = sc.textFile(input("w"), 2).toJavaRDD(); - mlCtx.registerInput("W", wIn, "text", m, k); + JavaRDD<String> wIn = sc.sc().textFile(input("w"), 2).toJavaRDD(); + MatrixMetadata mm = new MatrixMetadata(MatrixFormat.IJV, m, k); + script.in("W", wIn, mm); } if(numRegisteredInputs >= 3) { - JavaRDD<String> hIn = sc.textFile(input("h"), 2).toJavaRDD(); - mlCtx.registerInput("H", hIn, "text", k, n); + JavaRDD<String> hIn = sc.sc().textFile(input("h"), 2).toJavaRDD(); + MatrixMetadata mm = new MatrixMetadata(MatrixFormat.IJV, k, n); + script.in("H", hIn, mm); } // Output one matrix to HDFS and get one as RDD if(numRegisteredOutputs >= 1) { - mlCtx.registerOutput("H"); + script.out("H"); } if(numRegisteredOutputs >= 2) { - mlCtx.registerOutput("W"); - mlCtx.setConfig("cp.parallel.matrixmult", "false"); + script.out("W"); + ml.setConfigProperty("cp.parallel.matrixmult", "false"); } - MLOutput out = mlCtx.execute(fullDMLScriptName, programArgs); + MLResults results = ml.execute(script); if(numRegisteredOutputs >= 2) { String configStr = ConfigurationManager.getDMLConfig().getConfigInfo(); @@ -190,7 +214,7 @@ public class GNMFTest extends AutomatedTestBase } if(numRegisteredOutputs >= 1) { - JavaRDD<String> hOut = out.getStringRDD("H", "text"); + RDD<String> hOut = results.getRDDStringIJV("H"); String fName = output("h"); try { MapReduceTool.deleteFileIfExistOnHDFS( fName ); @@ -201,10 +225,11 @@ public class GNMFTest extends AutomatedTestBase } if(numRegisteredOutputs >= 2) { -// Test converter: Text -> CoordinateMatrix -> BinaryBlock -> Text -// JavaRDD<String> wOut = out.getStringRDD("W", "text"); - JavaRDD<MatrixEntry> matRDD = out.getStringRDD("W", "text").map(new StringToMatrixEntry()); - MatrixCharacteristics mcW = out.getMatrixCharacteristics("W"); + JavaRDD<String> javaRDDStringIJV = results.getJavaRDDStringIJV("W"); + JavaRDD<MatrixEntry> matRDD = javaRDDStringIJV.map(new StringToMatrixEntry()); + Matrix matrix = results.getMatrix("W"); + MatrixCharacteristics mcW = matrix.getMatrixMetadata().asMatrixCharacteristics(); + CoordinateMatrix coordinateMatrix = new CoordinateMatrix(matRDD.rdd(), mcW.getRows(), mcW.getCols()); JavaPairRDD<MatrixIndexes, MatrixBlock> binaryRDD = RDDConverterUtilsExt.coordinateMatrixToBinaryBlock(sc, coordinateMatrix, mcW, true); JavaRDD<String> wOut = RDDConverterUtils.binaryBlockToTextCell(binaryRDD, mcW); @@ -227,19 +252,10 @@ public class GNMFTest extends AutomatedTestBase HashMap<CellIndex, Double> hmHR = readRMatrixFromFS("h"); TestUtils.compareMatrices(hmWDML, hmWR, 0.000001, "hmWDML", "hmWR"); TestUtils.compareMatrices(hmHDML, hmHR, 0.000001, "hmHDML", "hmHR"); - - //cleanup mlcontext (prevent test memory leaks) - mlCtx.reset(); } finally { DMLScript.rtplatform = oldRT; DMLScript.USE_LOCAL_SPARK_CONFIG = oldConfig; - - if (sc != null) { - sc.stop(); - } - SparkExecutionContext.resetSparkContextStatic(); - MLContextProxy.setActive(false); } }
