This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push:
new e9a5d4d [SYSTEMDS-401] Fix spark diagM2V operations (wrong block
indexes, tests)
e9a5d4d is described below
commit e9a5d4d8577eb341473a406cf2aa74150cac599d
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri May 29 20:17:58 2020 +0200
[SYSTEMDS-401] Fix spark diagM2V operations (wrong block indexes, tests)
This patch fixes the spark rorg (rdiag) instruction for creating correct
output block indexes in the case of matrix-to-vector ops, which obtain
the diagonal of a matrix into a vector representation. So far, for an
input (r, c), we returned (r, c), which is now fixed to (r, 1).
---
dev/docs/Tasks.txt | 3 +
.../sysds/runtime/functionobjects/DiagIndex.java | 19 +++--
.../spark/functions/ReorgMapFunction.java | 6 +-
.../reorg/{DiagV2MTest.java => FullDiagTest.java} | 88 +++++++++++-----------
.../reorg/{DiagV2MTest.dml => DiagM2VTest.R} | 13 ++--
.../reorg/{DiagV2MTest.dml => DiagM2VTest.dml} | 10 +--
src/test/scripts/functions/reorg/DiagV2MTest.dml | 4 +-
7 files changed, 67 insertions(+), 76 deletions(-)
diff --git a/dev/docs/Tasks.txt b/dev/docs/Tasks.txt
index ecaaf00..cd90a66 100644
--- a/dev/docs/Tasks.txt
+++ b/dev/docs/Tasks.txt
@@ -311,5 +311,8 @@ SYSTEMDS-390 New Builtin Functions IV
* 394 Builtin for one-hot encoding of matrix (not frame), see table OK
* 395 SVM rework and utils (confusionMatrix, msvmPredict) OK
+SYSTEMDS-400 Spark Backend Improvements
+ * 401 Fix output block indexes of rdiag (diagM2V) OK
+
Others:
* Break append instruction to cbind and rbind
diff --git
a/src/main/java/org/apache/sysds/runtime/functionobjects/DiagIndex.java
b/src/main/java/org/apache/sysds/runtime/functionobjects/DiagIndex.java
index a56c415..143e754 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/DiagIndex.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/DiagIndex.java
@@ -26,25 +26,24 @@ import org.apache.sysds.runtime.meta.DataCharacteristics;
public class DiagIndex extends IndexFunction
{
-
private static final long serialVersionUID = -5294771266108903886L;
-
- private static DiagIndex singleObj = null;
+ private final boolean diagV2M;
- private DiagIndex() {
- // nothing to do here
+ private DiagIndex(boolean v2m) {
+ diagV2M = v2m;
}
public static DiagIndex getDiagIndexFnObject() {
- if ( singleObj == null )
- singleObj = new DiagIndex();
- return singleObj;
+ return getDiagIndexFnObject(true);
+ }
+
+ public static DiagIndex getDiagIndexFnObject(boolean v2m) {
+ return new DiagIndex(v2m);
}
@Override
public void execute(MatrixIndexes in, MatrixIndexes out) {
- //only used for V2M
- out.setIndexes(in.getRowIndex(), in.getRowIndex());
+ out.setIndexes(in.getRowIndex(), diagV2M ? in.getRowIndex() :
1);
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/functions/ReorgMapFunction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/functions/ReorgMapFunction.java
index 7fcaa01..14df2a6 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/functions/ReorgMapFunction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/functions/ReorgMapFunction.java
@@ -32,7 +32,6 @@ import scala.Tuple2;
public class ReorgMapFunction implements PairFunction<Tuple2<MatrixIndexes,
MatrixBlock>, MatrixIndexes, MatrixBlock>
{
-
private static final long serialVersionUID = 31065772250744103L;
private ReorgOperator _reorgOp = null;
@@ -42,8 +41,8 @@ public class ReorgMapFunction implements
PairFunction<Tuple2<MatrixIndexes, Matr
if(opcode.equalsIgnoreCase("r'")) {
_indexFnObject = SwapIndex.getSwapIndexFnObject();
}
- else if(opcode.equalsIgnoreCase("rdiag")) {
- _indexFnObject = DiagIndex.getDiagIndexFnObject();
+ else if(opcode.equalsIgnoreCase("rdiag")) { //diagM2V
+ _indexFnObject = DiagIndex.getDiagIndexFnObject(false);
}
else {
throw new DMLRuntimeException("Incorrect opcode for
RDDReorgMapFunction:" + opcode);
@@ -65,6 +64,5 @@ public class ReorgMapFunction implements
PairFunction<Tuple2<MatrixIndexes, Matr
//output new tuple
return new Tuple2<>(ixOut,blkOut);
}
-
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/reorg/DiagV2MTest.java
b/src/test/java/org/apache/sysds/test/functions/reorg/FullDiagTest.java
similarity index 56%
rename from src/test/java/org/apache/sysds/test/functions/reorg/DiagV2MTest.java
rename to src/test/java/org/apache/sysds/test/functions/reorg/FullDiagTest.java
index 4a44e40..c3e8444 100644
--- a/src/test/java/org/apache/sysds/test/functions/reorg/DiagV2MTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/reorg/FullDiagTest.java
@@ -23,17 +23,18 @@ import java.util.HashMap;
import java.util.Random;
import org.junit.Test;
-import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-public class DiagV2MTest extends AutomatedTestBase
+public class FullDiagTest extends AutomatedTestBase
{
private final static String TEST_DIR = "functions/reorg/";
- private static final String TEST_CLASS_DIR = TEST_DIR +
DiagV2MTest.class.getSimpleName() + "/";
+ private final static String TEST_NAME1 = "DiagV2MTest";
+ private final static String TEST_NAME2 = "DiagM2VTest";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
FullDiagTest.class.getSimpleName() + "/";
private final static double epsilon=0.0000000001;
private final static int rows = 1059;
@@ -42,68 +43,63 @@ public class DiagV2MTest extends AutomatedTestBase
@Override
public void setUp() {
- addTestConfiguration("DiagV2MTest",
- new TestConfiguration(TEST_CLASS_DIR, "DiagV2MTest",
new String[] {"C"}));
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"C"}));
+ addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"C"}));
}
- public void commonReorgTest(ExecMode platform)
+ @Test
+ public void testDiagV2MCP() {
+ commonReorgTest(TEST_NAME1, ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void testDiagV2MSP() {
+ commonReorgTest(TEST_NAME1, ExecMode.SPARK);
+ }
+
+ @Test
+ public void testDiagM2VCP() {
+ commonReorgTest(TEST_NAME2, ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void testDiagM2VSP() {
+ commonReorgTest(TEST_NAME2, ExecMode.SPARK);
+ }
+
+ public void commonReorgTest(String testname, ExecMode platform)
{
- TestConfiguration config = getTestConfiguration("DiagV2MTest");
-
- ExecMode prevPlfm=rtplatform;
+ TestConfiguration config = getTestConfiguration(testname);
+ ExecMode prevPlfm = setExecMode(platform);
- rtplatform = platform;
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- if( rtplatform == ExecMode.SPARK )
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-
try {
- config.addVariable("rows", rows);
+ config.addVariable("rows", rows);
loadTestConfiguration(config);
-
- /* This is for running the junit test the new way,
i.e., construct the arguments directly */
+
String RI_HOME = SCRIPT_DIR + TEST_DIR;
- fullDMLScriptName = RI_HOME + "DiagV2MTest" + ".dml";
- programArgs = new String[]{"-explain", "-args",
input("A"), Long.toString(rows), output("C") };
-
- fullRScriptName = RI_HOME + "DiagV2MTest" + ".R";
+ fullDMLScriptName = RI_HOME + testname + ".dml";
+ programArgs = new String[]{"-explain", "-args",
input("A"), output("C") };
+ fullRScriptName = RI_HOME + testname + ".R";
rCmd = "Rscript" + " " + fullRScriptName + " " +
inputDir() + " " + expectedDir();
Random rand=new Random(System.currentTimeMillis());
- double sparsity=0.599200924665577;//rand.nextDouble();
- double[][] A = getRandomMatrix(rows, 1, min, max,
sparsity, 1397289950533L); // System.currentTimeMillis()
- writeInputMatrix("A", A, true);
- sparsity=rand.nextDouble();
-
- boolean exceptionExpected = false;
- int expectedNumberOfJobs = -1;
- runTest(true, exceptionExpected, null,
expectedNumberOfJobs);
+ int cols = testname.equals(TEST_NAME1) ? 1 : rows;
+ double sparsity=0.599200924665577;
+ double[][] A = getRandomMatrix(rows, cols, min, max,
sparsity, 1397289950533L);
+ writeInputMatrixWithMTD("A", A, true);
+ sparsity=rand.nextDouble();
+ runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
runRScript(true);
- for(String file: config.getOutputFiles())
- {
+ for(String file: config.getOutputFiles()) {
HashMap<CellIndex, Double> dmlfile =
readDMLMatrixFromHDFS(file);
HashMap<CellIndex, Double> rfile =
readRMatrixFromFS(file);
- // System.out.println(file+"-DML: "+dmlfile);
- // System.out.println(file+"-R: "+rfile);
TestUtils.compareMatrices(dmlfile, rfile,
epsilon, file+"-DML", file+"-R");
}
}
finally {
- rtplatform = prevPlfm;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ resetExecMode(prevPlfm);
}
}
-
- @Test
- public void testDiagV2MCP() {
- commonReorgTest(ExecMode.SINGLE_NODE);
- }
-
- @Test
- public void testDiagV2MSP() {
- commonReorgTest(ExecMode.SPARK);
- }
}
-
diff --git a/src/test/scripts/functions/reorg/DiagV2MTest.dml
b/src/test/scripts/functions/reorg/DiagM2VTest.R
similarity index 83%
copy from src/test/scripts/functions/reorg/DiagV2MTest.dml
copy to src/test/scripts/functions/reorg/DiagM2VTest.R
index 2fdf5c9..fe7554b 100644
--- a/src/test/scripts/functions/reorg/DiagV2MTest.dml
+++ b/src/test/scripts/functions/reorg/DiagM2VTest.R
@@ -19,10 +19,9 @@
#
#-------------------------------------------------------------
-
-A=read($1, rows=$2, cols=1, format="text")
-B=diag(A)
-C=matrix(1, rows=nrow(B), cols=ncol(B));
-D=B%*%C
-C=B+D
-write(C, $3, format="text")
\ No newline at end of file
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+A = readMM(paste(args[1], "A.mtx", sep=""))
+C = diag(A)
+writeMM(as(C,"CsparseMatrix"), paste(args[2], "C", sep=""))
diff --git a/src/test/scripts/functions/reorg/DiagV2MTest.dml
b/src/test/scripts/functions/reorg/DiagM2VTest.dml
similarity index 87%
copy from src/test/scripts/functions/reorg/DiagV2MTest.dml
copy to src/test/scripts/functions/reorg/DiagM2VTest.dml
index 2fdf5c9..fb93478 100644
--- a/src/test/scripts/functions/reorg/DiagV2MTest.dml
+++ b/src/test/scripts/functions/reorg/DiagM2VTest.dml
@@ -19,10 +19,6 @@
#
#-------------------------------------------------------------
-
-A=read($1, rows=$2, cols=1, format="text")
-B=diag(A)
-C=matrix(1, rows=nrow(B), cols=ncol(B));
-D=B%*%C
-C=B+D
-write(C, $3, format="text")
\ No newline at end of file
+A = read($1)
+C = diag(A)
+write(C, $2, format="text")
diff --git a/src/test/scripts/functions/reorg/DiagV2MTest.dml
b/src/test/scripts/functions/reorg/DiagV2MTest.dml
index 2fdf5c9..6bf206b 100644
--- a/src/test/scripts/functions/reorg/DiagV2MTest.dml
+++ b/src/test/scripts/functions/reorg/DiagV2MTest.dml
@@ -20,9 +20,9 @@
#-------------------------------------------------------------
-A=read($1, rows=$2, cols=1, format="text")
+A=read($1)
B=diag(A)
C=matrix(1, rows=nrow(B), cols=ncol(B));
D=B%*%C
C=B+D
-write(C, $3, format="text")
\ No newline at end of file
+write(C, $2, format="text")
\ No newline at end of file