This is an automated email from the ASF dual-hosted git repository.

janardhan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new d1a1492  [SYSTEMDS-1863] Full MLContext test for LinearReg
d1a1492 is described below

commit d1a1492c2da608f7be0a5458beaadabb44b06c2b
Author: Janardhan Pulivarthi <[email protected]>
AuthorDate: Mon Jul 20 12:18:57 2020 +0530

    [SYSTEMDS-1863] Full MLContext test for LinearReg
    
      * Takes advantage of existing R algorithm scripts used for
        codegen testing.
      * This would improve the testing by allowing us to provide all
        the necessary inputs into the script.
---
 .../org/apache/sysds/test/AutomatedTestBase.java   |  2 ++
 .../functions/mlcontext/MLContextLinregTest.java   | 38 +++++++++++++++++++---
 2 files changed, 35 insertions(+), 5 deletions(-)

diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java 
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index a66ee1e..0183e34 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -1649,6 +1649,8 @@ public abstract class AutomatedTestBase {
        }
 
        protected String getRScript() {
+               if(fullRScriptName != null)
+                       return fullRScriptName;
                return sourceDirectory + selectedTest + ".R";
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextLinregTest.java
 
b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextLinregTest.java
index a5cddb8..0e45cb4 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextLinregTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextLinregTest.java
@@ -22,8 +22,13 @@ package org.apache.sysds.test.functions.mlcontext;
 import static org.apache.sysds.api.mlcontext.ScriptFactory.dmlFromFile;
 
 import org.apache.log4j.Logger;
-import org.junit.Test;
 import org.apache.sysds.api.mlcontext.Script;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.HashMap;
 
 public class MLContextLinregTest extends MLContextTestBase {
        protected static Logger log = 
Logger.getLogger(MLContextLinregTest.class);
@@ -37,6 +42,11 @@ public class MLContextLinregTest extends MLContextTestBase {
                CG, DS,
        }
 
+       private final static double eps = 1e-3;
+
+       private final static int rows = 2468;
+       private final static int cols = 507;
+
        @Test
        public void testLinregCGSparse() {
                runLinregTestMLC(LinregType.CG, true);
@@ -59,24 +69,42 @@ public class MLContextLinregTest extends MLContextTestBase {
 
        private void runLinregTestMLC(LinregType type, boolean sparse) {
 
-               double[][] X = getRandomMatrix(10, 3, 0, 1, sparse ? sparsity2 
: sparsity1, 7);
-               double[][] Y = getRandomMatrix(10, 1, 0, 10, 1.0, 3);
+               double[][] X = getRandomMatrix(rows, cols, 0, 1, sparse ? 
sparsity2 : sparsity1, 7);
+               double[][] Y = getRandomMatrix(rows, 1, 0, 10, 1.0, 3);
+
+               // Hack Alert
+               // overwrite baseDirectory to the place where test data is 
stored.
+               baseDirectory = "target/testTemp/functions/mlcontext/";
+
+               fullRScriptName = 
"src/test/scripts/functions/codegenalg/Algorithm_LinregCG.R";
+
+               writeInputMatrixWithMTD("X", X, true);
+               writeInputMatrixWithMTD("y", Y, true);
+
+               rCmd = getRCmd(inputDir(), "0", "0.000001", "0", "0.001", 
expectedDir());
+               runRScript(true);
+
+               MatrixBlock outmat = new MatrixBlock();
 
                switch (type) {
                case CG:
                        Script lrcg = dmlFromFile(TEST_SCRIPT_CG);
                        lrcg.in("X", X).in("y", Y).in("$icpt", "0").in("$tol", 
"0.000001").in("$maxi", "0").in("$reg", "0.000001")
                                        .out("beta_out");
-                       ml.execute(lrcg);
+                       outmat = 
ml.execute(lrcg).getMatrix("beta_out").toMatrixBlock();
 
                        break;
 
                case DS:
                        Script lrds = dmlFromFile(TEST_SCRIPT_DS);
                        lrds.in("X", X).in("y", Y).in("$icpt", "0").in("$reg", 
"0.000001").out("beta_out");
-                       ml.execute(lrds);
+                       outmat = 
ml.execute(lrds).getMatrix("beta_out").toMatrixBlock();
 
                        break;
                }
+
+               //compare matrices
+               HashMap<MatrixValue.CellIndex, Double> rfile = 
readRMatrixFromFS("w");
+               TestUtils.compareMatrices(rfile, outmat, eps);
        }
 }

Reply via email to