Repository: incubator-systemml
Updated Branches:
  refs/heads/master 326c1c00e -> 6158bfaf9


[SYSTEMML-1235] Migrate GNMFTest to new MLContext

Closes #381.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/6158bfaf
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/6158bfaf
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/6158bfaf

Branch: refs/heads/master
Commit: 6158bfaf9079b7a3882e709cbc6d873180c5f373
Parents: 326c1c0
Author: Deron Eriksson <[email protected]>
Authored: Tue Feb 7 16:36:51 2017 -0800
Committer: Deron Eriksson <[email protected]>
Committed: Tue Feb 7 16:36:51 2017 -0800

----------------------------------------------------------------------
 .../functions/mlcontext/GNMFTest.java           | 90 ++++++++++++--------
 1 file changed, 53 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/6158bfaf/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
index 89a4363..99ab53b 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
@@ -26,18 +26,24 @@ import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 
-import org.apache.spark.SparkContext;
+import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
 import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
+import org.apache.spark.rdd.RDD;
 import org.apache.sysml.api.DMLException;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
-import org.apache.sysml.api.MLContext;
-import org.apache.sysml.api.MLContextProxy;
-import org.apache.sysml.api.MLOutput;
+import org.apache.sysml.api.mlcontext.MLContext;
+import org.apache.sysml.api.mlcontext.MLResults;
+import org.apache.sysml.api.mlcontext.Matrix;
+import org.apache.sysml.api.mlcontext.MatrixFormat;
+import org.apache.sysml.api.mlcontext.MatrixMetadata;
+import org.apache.sysml.api.mlcontext.Script;
+import org.apache.sysml.api.mlcontext.ScriptFactory;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.parser.ParseException;
 import org.apache.sysml.runtime.DMLRuntimeException;
@@ -52,6 +58,7 @@ import org.apache.sysml.runtime.util.MapReduceTool;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.utils.TestUtils;
 import org.junit.Assert;
+import org.junit.BeforeClass;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -67,12 +74,26 @@ public class GNMFTest extends AutomatedTestBase
        
        int numRegisteredInputs;
        int numRegisteredOutputs;
-       
+
+       private static SparkConf conf;
+       private static JavaSparkContext sc;
+       private static MLContext ml;
+
        public GNMFTest(int in, int out) {
                numRegisteredInputs = in;
                numRegisteredOutputs = out;
        }
-       
+
+       @BeforeClass
+       public static void setUpClass() {
+               if (conf == null)
+                       conf = SparkExecutionContext.createSystemMLSparkConf()
+                               .setAppName("GNMFTest").setMaster("local");
+               if (sc == null)
+                       sc = new JavaSparkContext(conf);
+               ml = new MLContext(sc);
+       }
+
        @Parameters
         public static Collection<Object[]> data() {
           Object[][] data = new Object[][] { { 0, 0 }, { 3, 2 }, { 2, 2 }, { 
2, 1 }, { 2, 0 }, { 3, 0 }};
@@ -145,43 +166,46 @@ public class GNMFTest extends AutomatedTestBase
                DMLScript.USE_LOCAL_SPARK_CONFIG = true;
                RUNTIME_PLATFORM oldRT = DMLScript.rtplatform;
                
-               MLContext mlCtx = null;
-               SparkContext sc = null;
                try 
                {
                        DMLScript.rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK;
-               
-                       mlCtx = getMLContextForTesting();
-                       sc = mlCtx.getSparkContext();
-                       mlCtx.reset(true); // Cleanup config to ensure future 
MLContext testcases have correct 'cp.parallel.matrixmult'
-                       
+
+                       Script script = 
ScriptFactory.dmlFromFile(fullDMLScriptName);
+                       // set positional argument values
+                       for (int argNum = 1; argNum <= proArgs.size(); 
argNum++) {
+                               script.in("$" + argNum, proArgs.get(argNum-1));
+                       }
+
                        // Read two matrices through RDD and one through HDFS
                        if(numRegisteredInputs >= 1) {
-                               JavaRDD<String> vIn = sc.textFile(input("v"), 
2).toJavaRDD();
-                               mlCtx.registerInput("V", vIn, "text", m, n);
+                               JavaRDD<String> vIn = 
sc.sc().textFile(input("v"), 2).toJavaRDD();
+                               MatrixMetadata mm = new 
MatrixMetadata(MatrixFormat.IJV, m, n);
+                               script.in("V", vIn, mm);
                        }
                        
                        if(numRegisteredInputs >= 2) {
-                               JavaRDD<String> wIn = sc.textFile(input("w"), 
2).toJavaRDD();
-                               mlCtx.registerInput("W", wIn, "text", m, k);
+                               JavaRDD<String> wIn = 
sc.sc().textFile(input("w"), 2).toJavaRDD();
+                               MatrixMetadata mm = new 
MatrixMetadata(MatrixFormat.IJV, m, k);
+                               script.in("W", wIn, mm);
                        }
                        
                        if(numRegisteredInputs >= 3) {
-                               JavaRDD<String> hIn = sc.textFile(input("h"), 
2).toJavaRDD();
-                               mlCtx.registerInput("H", hIn, "text", k, n);
+                               JavaRDD<String> hIn = 
sc.sc().textFile(input("h"), 2).toJavaRDD();
+                               MatrixMetadata mm = new 
MatrixMetadata(MatrixFormat.IJV, k, n);
+                               script.in("H", hIn, mm);
                        }
                        
                        // Output one matrix to HDFS and get one as RDD
                        if(numRegisteredOutputs >= 1) {
-                               mlCtx.registerOutput("H");
+                               script.out("H");
                        }
                        
                        if(numRegisteredOutputs >= 2) {
-                               mlCtx.registerOutput("W");
-                               mlCtx.setConfig("cp.parallel.matrixmult", 
"false");
+                               script.out("W");
+                               ml.setConfigProperty("cp.parallel.matrixmult", 
"false");
                        }
                        
-                       MLOutput out = mlCtx.execute(fullDMLScriptName, 
programArgs);
+                       MLResults results = ml.execute(script);
                        
                        if(numRegisteredOutputs >= 2) {
                                String configStr = 
ConfigurationManager.getDMLConfig().getConfigInfo();
@@ -190,7 +214,7 @@ public class GNMFTest extends AutomatedTestBase
                        }
                        
                        if(numRegisteredOutputs >= 1) {
-                               JavaRDD<String> hOut = out.getStringRDD("H", 
"text");
+                               RDD<String> hOut = results.getRDDStringIJV("H");
                                String fName = output("h");
                                try {
                                        MapReduceTool.deleteFileIfExistOnHDFS( 
fName );
@@ -201,10 +225,11 @@ public class GNMFTest extends AutomatedTestBase
                        }
                        
                        if(numRegisteredOutputs >= 2) {
-//                             Test converter: Text -> CoordinateMatrix -> 
BinaryBlock -> Text
-//                             JavaRDD<String> wOut = out.getStringRDD("W", 
"text");
-                               JavaRDD<MatrixEntry> matRDD = 
out.getStringRDD("W", "text").map(new StringToMatrixEntry());
-                               MatrixCharacteristics mcW = 
out.getMatrixCharacteristics("W");
+                               JavaRDD<String> javaRDDStringIJV = 
results.getJavaRDDStringIJV("W");
+                               JavaRDD<MatrixEntry> matRDD = 
javaRDDStringIJV.map(new StringToMatrixEntry());
+                               Matrix matrix = results.getMatrix("W");
+                               MatrixCharacteristics mcW = 
matrix.getMatrixMetadata().asMatrixCharacteristics();
+
                                CoordinateMatrix coordinateMatrix = new 
CoordinateMatrix(matRDD.rdd(), mcW.getRows(), mcW.getCols());
                                JavaPairRDD<MatrixIndexes, MatrixBlock> 
binaryRDD = RDDConverterUtilsExt.coordinateMatrixToBinaryBlock(sc, 
coordinateMatrix, mcW, true);
                                JavaRDD<String> wOut = 
RDDConverterUtils.binaryBlockToTextCell(binaryRDD, mcW);
@@ -227,19 +252,10 @@ public class GNMFTest extends AutomatedTestBase
                        HashMap<CellIndex, Double> hmHR = 
readRMatrixFromFS("h");
                        TestUtils.compareMatrices(hmWDML, hmWR, 0.000001, 
"hmWDML", "hmWR");
                        TestUtils.compareMatrices(hmHDML, hmHR, 0.000001, 
"hmHDML", "hmHR");
-                       
-                       //cleanup mlcontext (prevent test memory leaks)
-                       mlCtx.reset();
                }
                finally {
                        DMLScript.rtplatform = oldRT;
                        DMLScript.USE_LOCAL_SPARK_CONFIG = oldConfig;
-
-                       if (sc != null) {
-                               sc.stop();
-                       }
-                       SparkExecutionContext.resetSparkContextStatic();
-                       MLContextProxy.setActive(false);
                }
        }
        

Reply via email to