[HOTFIX][SYSTEMML-2219] Fix ultra-sparse/ultra-sparse matrix multiply

This patch fixes the improved ultra-sparse matrix multiply for special
cases of ultra-sparse x ultra-sparse matrix multiply where the rhs has
entirely empty rows.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/2b3aefe7
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/2b3aefe7
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/2b3aefe7

Branch: refs/heads/master
Commit: 2b3aefe79446b3b1ab13640566a37e11f620ee96
Parents: 015b273
Author: Matthias Boehm <[email protected]>
Authored: Sat Mar 31 15:05:59 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Mar 31 15:05:59 2018 -0700

----------------------------------------------------------------------
 .../runtime/matrix/data/LibMatrixMult.java      |  7 +-
 ...FullMatrixMultiplicationUltraSparseTest.java | 75 ++++++++++++--------
 2 files changed, 51 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/2b3aefe7/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
index 8919ab6..ef273f6 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
@@ -1512,17 +1512,18 @@ public class LibMatrixMult
                                if( alen==1 ) { 
                                        //row selection (now aggregation) with 
potential scaling
                                        int aix = aixs[apos];
+                                       int lnnz = 0;
                                        if( rightSparse ) { //sparse right 
matrix (full row copy)
                                                if( 
!m2.sparseBlock.isEmpty(aix) ) {
                                                        ret.rlen=m;
                                                        
ret.allocateSparseRowsBlock(false); //allocation on demand
                                                        boolean ldeep = 
(m2.sparseBlock instanceof SparseBlockMCSR);
                                                        ret.sparseBlock.set(i, 
m2.sparseBlock.get(aix), ldeep);
-                                                       ret.nonZeros += 
ret.sparseBlock.size(i);
+                                                       ret.nonZeros += (lnnz = 
ret.sparseBlock.size(i));
                                                }
                                        }
                                        else { //dense right matrix (append all 
values)
-                                               int lnnz = 
(int)m2.recomputeNonZeros(aix, aix, 0, n-1);
+                                               lnnz = 
(int)m2.recomputeNonZeros(aix, aix, 0, n-1);
                                                if( lnnz > 0 ) {
                                                        c.allocate(i, lnnz); 
//allocate once
                                                        double[] bvals = 
m2.getDenseBlock().values(aix);
@@ -1532,7 +1533,7 @@ public class LibMatrixMult
                                                }
                                        }
                                        //optional scaling if not pure selection
-                                       if( avals[apos] != 1 )
+                                       if( avals[apos] != 1 && lnnz > 0 )
                                                
vectMultiplyInPlace(avals[apos], c.values(i), c.pos(i), c.size(i));
                                }
                                else //GENERAL CASE

http://git-wip-us.apache.org/repos/asf/systemml/blob/2b3aefe7/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix_full_other/FullMatrixMultiplicationUltraSparseTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix_full_other/FullMatrixMultiplicationUltraSparseTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix_full_other/FullMatrixMultiplicationUltraSparseTest.java
index 2a621fb..538ac1b 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix_full_other/FullMatrixMultiplicationUltraSparseTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix_full_other/FullMatrixMultiplicationUltraSparseTest.java
@@ -24,7 +24,7 @@ import java.util.HashMap;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Test;
-
+import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.lops.LopProperties.ExecType;
 import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
@@ -34,7 +34,6 @@ import org.apache.sysml.test.utils.TestUtils;
 
 public class FullMatrixMultiplicationUltraSparseTest extends AutomatedTestBase 
 {
-       
        private final static String TEST_NAME = "FullMatrixMultiplication";
        private final static String TEST_DIR = 
"functions/binary/matrix_full_other/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FullMatrixMultiplicationUltraSparseTest.class.getSimpleName() + "/";
@@ -83,62 +82,77 @@ public class FullMatrixMultiplicationUltraSparseTest 
extends AutomatedTestBase
        }
 
        @Test
-       public void testMMDenseUltraSparseCP() 
-       {
+       public void testMMDenseUltraSparseCP() {
                runMatrixMatrixMultiplicationTest(SparsityType.DENSE, 
SparsityType.ULTRA_SPARSE, ExecType.CP);
        }
        
        @Test
-       public void testMMSparseUltraSparseCP() 
-       {
+       public void testMMSparseUltraSparseCP() {
                runMatrixMatrixMultiplicationTest(SparsityType.SPARSE, 
SparsityType.ULTRA_SPARSE, ExecType.CP);
        }
 
        @Test
-       public void testMMUltraSparseDenseCP() 
-       {
+       public void testMMUltraSparseDenseCP() {
                runMatrixMatrixMultiplicationTest(SparsityType.ULTRA_SPARSE, 
SparsityType.DENSE, ExecType.CP);
        }
        
        @Test
-       public void testMMUltraSparseSparseCP() 
-       {
+       public void testMMUltraSparseSparseCP() {
                runMatrixMatrixMultiplicationTest(SparsityType.ULTRA_SPARSE, 
SparsityType.SPARSE, ExecType.CP);
        }
        
        @Test
-       public void testMMUltraSparseUltraSparseCP() 
-       {
+       public void testMMUltraSparseUltraSparseCP() {
                runMatrixMatrixMultiplicationTest(SparsityType.ULTRA_SPARSE, 
SparsityType.ULTRA_SPARSE, ExecType.CP);
        }
        
        @Test
-       public void testMMDenseUltraSparseMR() 
-       {
+       public void testMMDenseUltraSparseSP() {
+               runMatrixMatrixMultiplicationTest(SparsityType.DENSE, 
SparsityType.ULTRA_SPARSE, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testMMSparseUltraSparseSP() {
+               runMatrixMatrixMultiplicationTest(SparsityType.SPARSE, 
SparsityType.ULTRA_SPARSE, ExecType.SPARK);
+       }
+
+       @Test
+       public void testMMUltraSparseDenseSP() {
+               runMatrixMatrixMultiplicationTest(SparsityType.ULTRA_SPARSE, 
SparsityType.DENSE, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testMMUltraSparseSparseSP() {
+               runMatrixMatrixMultiplicationTest(SparsityType.ULTRA_SPARSE, 
SparsityType.SPARSE, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testMMUltraSparseUltraSparseSP() {
+               runMatrixMatrixMultiplicationTest(SparsityType.ULTRA_SPARSE, 
SparsityType.ULTRA_SPARSE, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testMMDenseUltraSparseMR() {
                runMatrixMatrixMultiplicationTest(SparsityType.DENSE, 
SparsityType.ULTRA_SPARSE, ExecType.MR);
        }
        
        @Test
-       public void testMMSparseUltraSparseMR() 
-       {
+       public void testMMSparseUltraSparseMR() {
                runMatrixMatrixMultiplicationTest(SparsityType.SPARSE, 
SparsityType.ULTRA_SPARSE, ExecType.MR);
        }
 
        @Test
-       public void testMMUltraSparseDenseMR() 
-       {
+       public void testMMUltraSparseDenseMR() {
                runMatrixMatrixMultiplicationTest(SparsityType.ULTRA_SPARSE, 
SparsityType.DENSE, ExecType.MR);
        }
        
        @Test
-       public void testMMUltraSparseSparseMR() 
-       {
+       public void testMMUltraSparseSparseMR() {
                runMatrixMatrixMultiplicationTest(SparsityType.ULTRA_SPARSE, 
SparsityType.SPARSE, ExecType.MR);
        }
        
        @Test
-       public void testMMUltraSparseUltraSparseMR() 
-       {
+       public void testMMUltraSparseUltraSparseMR() {
                runMatrixMatrixMultiplicationTest(SparsityType.ULTRA_SPARSE, 
SparsityType.ULTRA_SPARSE, ExecType.MR);
        }
 
@@ -150,12 +164,17 @@ public class FullMatrixMultiplicationUltraSparseTest 
extends AutomatedTestBase
         */
        private void runMatrixMatrixMultiplicationTest( SparsityType sparseM1, 
SparsityType sparseM2, ExecType instType)
        {
-               //setup exec type, rows, cols
-
-               //rtplatform for MR
                RUNTIME_PLATFORM platformOld = rtplatform;
-               rtplatform = (instType==ExecType.MR) ? RUNTIME_PLATFORM.HADOOP 
: RUNTIME_PLATFORM.HYBRID;
+               switch( instType ){
+                       case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+                       case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+                       default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
+               }
        
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               if( rtplatform == RUNTIME_PLATFORM.SPARK )
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
                try
                {
                        TestConfiguration config = 
getTestConfiguration(TEST_NAME);
@@ -194,8 +213,8 @@ public class FullMatrixMultiplicationUltraSparseTest 
extends AutomatedTestBase
                        HashMap<CellIndex, Double> rfile  = 
readRMatrixFromFS("C");
                        TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
                }
-               finally
-               {
+               finally {
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
                        rtplatform = platformOld;
                }
        }

Reply via email to