This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new e9a5d4d  [SYSTEMDS-401] Fix spark diagM2V operations (wrong block 
indexes, tests)
e9a5d4d is described below

commit e9a5d4d8577eb341473a406cf2aa74150cac599d
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri May 29 20:17:58 2020 +0200

    [SYSTEMDS-401] Fix spark diagM2V operations (wrong block indexes, tests)
    
    This patch fixes the spark rorg (rdiag) instruction for creating correct
    output block indexes in the case of matrix-to-vector ops, which obtain
    the diagonal of a matrix into a vector representation. So far, for an
    input (r, c), we returned (r, c), which is now fixed to (r, 1).
---
 dev/docs/Tasks.txt                                 |  3 +
 .../sysds/runtime/functionobjects/DiagIndex.java   | 19 +++--
 .../spark/functions/ReorgMapFunction.java          |  6 +-
 .../reorg/{DiagV2MTest.java => FullDiagTest.java}  | 88 +++++++++++-----------
 .../reorg/{DiagV2MTest.dml => DiagM2VTest.R}       | 13 ++--
 .../reorg/{DiagV2MTest.dml => DiagM2VTest.dml}     | 10 +--
 src/test/scripts/functions/reorg/DiagV2MTest.dml   |  4 +-
 7 files changed, 67 insertions(+), 76 deletions(-)

diff --git a/dev/docs/Tasks.txt b/dev/docs/Tasks.txt
index ecaaf00..cd90a66 100644
--- a/dev/docs/Tasks.txt
+++ b/dev/docs/Tasks.txt
@@ -311,5 +311,8 @@ SYSTEMDS-390 New Builtin Functions IV
  * 394 Builtin for one-hot encoding of matrix (not frame), see table  OK
  * 395 SVM rework and utils (confusionMatrix, msvmPredict)            OK
 
+SYSTEMDS-400 Spark Backend Improvements
+ * 401 Fix output block indexes of rdiag (diagM2V)                    OK
+
 Others:
  * Break append instruction to cbind and rbind 
diff --git 
a/src/main/java/org/apache/sysds/runtime/functionobjects/DiagIndex.java 
b/src/main/java/org/apache/sysds/runtime/functionobjects/DiagIndex.java
index a56c415..143e754 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/DiagIndex.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/DiagIndex.java
@@ -26,25 +26,24 @@ import org.apache.sysds.runtime.meta.DataCharacteristics;
 
 public class DiagIndex extends IndexFunction
 {
-
        private static final long serialVersionUID = -5294771266108903886L;
-
-       private static DiagIndex singleObj = null;
+       private final boolean diagV2M;
        
-       private DiagIndex() {
-               // nothing to do here
+       private DiagIndex(boolean v2m) {
+               diagV2M = v2m;
        }
        
        public static DiagIndex getDiagIndexFnObject() {
-               if ( singleObj == null )
-                       singleObj = new DiagIndex();
-               return singleObj;
+               return getDiagIndexFnObject(true);
+       }
+       
+       public static DiagIndex getDiagIndexFnObject(boolean v2m) {
+               return new DiagIndex(v2m);
        }
        
        @Override
        public void execute(MatrixIndexes in, MatrixIndexes out) {
-               //only used for V2M
-               out.setIndexes(in.getRowIndex(), in.getRowIndex());
+               out.setIndexes(in.getRowIndex(), diagV2M ? in.getRowIndex() : 
1);
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/functions/ReorgMapFunction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/functions/ReorgMapFunction.java
index 7fcaa01..14df2a6 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/functions/ReorgMapFunction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/functions/ReorgMapFunction.java
@@ -32,7 +32,6 @@ import scala.Tuple2;
 
 public class ReorgMapFunction implements PairFunction<Tuple2<MatrixIndexes, 
MatrixBlock>, MatrixIndexes, MatrixBlock> 
 {
-       
        private static final long serialVersionUID = 31065772250744103L;
        
        private ReorgOperator _reorgOp = null;
@@ -42,8 +41,8 @@ public class ReorgMapFunction implements 
PairFunction<Tuple2<MatrixIndexes, Matr
                if(opcode.equalsIgnoreCase("r'")) {
                        _indexFnObject = SwapIndex.getSwapIndexFnObject();
                }
-               else if(opcode.equalsIgnoreCase("rdiag")) {
-                       _indexFnObject = DiagIndex.getDiagIndexFnObject();
+               else if(opcode.equalsIgnoreCase("rdiag")) { //diagM2V
+                       _indexFnObject = DiagIndex.getDiagIndexFnObject(false);
                }
                else {
                        throw new DMLRuntimeException("Incorrect opcode for 
RDDReorgMapFunction:" + opcode);
@@ -65,6 +64,5 @@ public class ReorgMapFunction implements 
PairFunction<Tuple2<MatrixIndexes, Matr
                //output new tuple
                return new Tuple2<>(ixOut,blkOut);
        }
-       
 }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/reorg/DiagV2MTest.java 
b/src/test/java/org/apache/sysds/test/functions/reorg/FullDiagTest.java
similarity index 56%
rename from src/test/java/org/apache/sysds/test/functions/reorg/DiagV2MTest.java
rename to src/test/java/org/apache/sysds/test/functions/reorg/FullDiagTest.java
index 4a44e40..c3e8444 100644
--- a/src/test/java/org/apache/sysds/test/functions/reorg/DiagV2MTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/reorg/FullDiagTest.java
@@ -23,17 +23,18 @@ import java.util.HashMap;
 import java.util.Random;
 
 import org.junit.Test;
-import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 
-public class DiagV2MTest extends AutomatedTestBase
+public class FullDiagTest extends AutomatedTestBase
 {
        private final static String TEST_DIR = "functions/reorg/";
-       private static final String TEST_CLASS_DIR = TEST_DIR + 
DiagV2MTest.class.getSimpleName() + "/";
+       private final static String TEST_NAME1 = "DiagV2MTest";
+       private final static String TEST_NAME2 = "DiagM2VTest";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FullDiagTest.class.getSimpleName() + "/";
 
        private final static double epsilon=0.0000000001;
        private final static int rows = 1059;
@@ -42,68 +43,63 @@ public class DiagV2MTest extends AutomatedTestBase
        
        @Override
        public void setUp() {
-               addTestConfiguration("DiagV2MTest", 
-                       new TestConfiguration(TEST_CLASS_DIR, "DiagV2MTest", 
new String[] {"C"}));
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"C"}));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"C"}));
        }
        
-       public void commonReorgTest(ExecMode platform)
+       @Test
+       public void testDiagV2MCP() {
+               commonReorgTest(TEST_NAME1, ExecMode.SINGLE_NODE);
+       }
+       
+       @Test
+       public void testDiagV2MSP() {
+               commonReorgTest(TEST_NAME1, ExecMode.SPARK);
+       }
+       
+       @Test
+       public void testDiagM2VCP() {
+               commonReorgTest(TEST_NAME2, ExecMode.SINGLE_NODE);
+       }
+       
+       @Test
+       public void testDiagM2VSP() {
+               commonReorgTest(TEST_NAME2, ExecMode.SPARK);
+       }
+       
+       public void commonReorgTest(String testname, ExecMode platform)
        {
-               TestConfiguration config = getTestConfiguration("DiagV2MTest");
-           
-               ExecMode prevPlfm=rtplatform;
+               TestConfiguration config = getTestConfiguration(testname);
+               ExecMode prevPlfm = setExecMode(platform);
                
-           rtplatform = platform;
-               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-               if( rtplatform == ExecMode.SPARK )
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-
                try {
-               config.addVariable("rows", rows);
+                       config.addVariable("rows", rows);
                        loadTestConfiguration(config);
-                 
-                       /* This is for running the junit test the new way, 
i.e., construct the arguments directly */
+       
                        String RI_HOME = SCRIPT_DIR + TEST_DIR;
-                       fullDMLScriptName = RI_HOME + "DiagV2MTest" + ".dml";
-                       programArgs = new String[]{"-explain", "-args",  
input("A"), Long.toString(rows), output("C") };
-                       
-                       fullRScriptName = RI_HOME + "DiagV2MTest" + ".R";
+                       fullDMLScriptName = RI_HOME + testname + ".dml";
+                       programArgs = new String[]{"-explain", "-args",  
input("A"), output("C") };
+                       fullRScriptName = RI_HOME + testname + ".R";
                        rCmd = "Rscript" + " " + fullRScriptName + " " + 
inputDir() + " " + expectedDir();
        
                        Random rand=new Random(System.currentTimeMillis());
-                       double sparsity=0.599200924665577;//rand.nextDouble();
-                       double[][] A = getRandomMatrix(rows, 1, min, max, 
sparsity, 1397289950533L); // System.currentTimeMillis()
-               writeInputMatrix("A", A, true);
-               sparsity=rand.nextDouble();   
-                       
-               boolean exceptionExpected = false;
-                       int expectedNumberOfJobs = -1;
-                       runTest(true, exceptionExpected, null, 
expectedNumberOfJobs);
+                       int cols = testname.equals(TEST_NAME1) ? 1 : rows;
+                       double sparsity=0.599200924665577;
+                       double[][] A = getRandomMatrix(rows, cols, min, max, 
sparsity, 1397289950533L);
+                       writeInputMatrixWithMTD("A", A, true);
+                       sparsity=rand.nextDouble();
                        
+                       runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
                        runRScript(true);
                
-                       for(String file: config.getOutputFiles())
-                       {
+                       for(String file: config.getOutputFiles()) {
                                HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromHDFS(file);
                                HashMap<CellIndex, Double> rfile = 
readRMatrixFromFS(file);
-                       //      System.out.println(file+"-DML: "+dmlfile);
-                       //      System.out.println(file+"-R: "+rfile);
                                TestUtils.compareMatrices(dmlfile, rfile, 
epsilon, file+"-DML", file+"-R");
                        }
                }
                finally {
-                       rtplatform = prevPlfm;
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+                       resetExecMode(prevPlfm);
                }
        }
-       
-       @Test
-       public void testDiagV2MCP() {
-               commonReorgTest(ExecMode.SINGLE_NODE);
-       }
-       
-       @Test
-       public void testDiagV2MSP() {
-               commonReorgTest(ExecMode.SPARK);
-       }
 }
-
diff --git a/src/test/scripts/functions/reorg/DiagV2MTest.dml 
b/src/test/scripts/functions/reorg/DiagM2VTest.R
similarity index 83%
copy from src/test/scripts/functions/reorg/DiagV2MTest.dml
copy to src/test/scripts/functions/reorg/DiagM2VTest.R
index 2fdf5c9..fe7554b 100644
--- a/src/test/scripts/functions/reorg/DiagV2MTest.dml
+++ b/src/test/scripts/functions/reorg/DiagM2VTest.R
@@ -19,10 +19,9 @@
 #
 #-------------------------------------------------------------
 
-
-A=read($1, rows=$2, cols=1, format="text")
-B=diag(A)
-C=matrix(1, rows=nrow(B), cols=ncol(B));
-D=B%*%C
-C=B+D
-write(C, $3, format="text")
\ No newline at end of file
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+A = readMM(paste(args[1], "A.mtx", sep=""))
+C = diag(A)
+writeMM(as(C,"CsparseMatrix"), paste(args[2], "C", sep=""))
diff --git a/src/test/scripts/functions/reorg/DiagV2MTest.dml 
b/src/test/scripts/functions/reorg/DiagM2VTest.dml
similarity index 87%
copy from src/test/scripts/functions/reorg/DiagV2MTest.dml
copy to src/test/scripts/functions/reorg/DiagM2VTest.dml
index 2fdf5c9..fb93478 100644
--- a/src/test/scripts/functions/reorg/DiagV2MTest.dml
+++ b/src/test/scripts/functions/reorg/DiagM2VTest.dml
@@ -19,10 +19,6 @@
 #
 #-------------------------------------------------------------
 
-
-A=read($1, rows=$2, cols=1, format="text")
-B=diag(A)
-C=matrix(1, rows=nrow(B), cols=ncol(B));
-D=B%*%C
-C=B+D
-write(C, $3, format="text")
\ No newline at end of file
+A = read($1)
+C = diag(A)
+write(C, $2, format="text")
diff --git a/src/test/scripts/functions/reorg/DiagV2MTest.dml 
b/src/test/scripts/functions/reorg/DiagV2MTest.dml
index 2fdf5c9..6bf206b 100644
--- a/src/test/scripts/functions/reorg/DiagV2MTest.dml
+++ b/src/test/scripts/functions/reorg/DiagV2MTest.dml
@@ -20,9 +20,9 @@
 #-------------------------------------------------------------
 
 
-A=read($1, rows=$2, cols=1, format="text")
+A=read($1)
 B=diag(A)
 C=matrix(1, rows=nrow(B), cols=ncol(B));
 D=B%*%C
 C=B+D
-write(C, $3, format="text")
\ No newline at end of file
+write(C, $2, format="text")
\ No newline at end of file

Reply via email to