[HOTFIX][SYSTEMML-2219] Fix ultra-sparse/ultra-sparse matrix multiply This patch fixes the improved ultra-sparse matrix multiply for special cases of ultra-sparse x ultra-sparse matrix multiply where the rhs has entirely empty rows.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/2b3aefe7 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/2b3aefe7 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/2b3aefe7 Branch: refs/heads/master Commit: 2b3aefe79446b3b1ab13640566a37e11f620ee96 Parents: 015b273 Author: Matthias Boehm <[email protected]> Authored: Sat Mar 31 15:05:59 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Mar 31 15:05:59 2018 -0700 ---------------------------------------------------------------------- .../runtime/matrix/data/LibMatrixMult.java | 7 +- ...FullMatrixMultiplicationUltraSparseTest.java | 75 ++++++++++++-------- 2 files changed, 51 insertions(+), 31 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/2b3aefe7/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java index 8919ab6..ef273f6 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java @@ -1512,17 +1512,18 @@ public class LibMatrixMult if( alen==1 ) { //row selection (now aggregation) with potential scaling int aix = aixs[apos]; + int lnnz = 0; if( rightSparse ) { //sparse right matrix (full row copy) if( !m2.sparseBlock.isEmpty(aix) ) { ret.rlen=m; ret.allocateSparseRowsBlock(false); //allocation on demand boolean ldeep = (m2.sparseBlock instanceof SparseBlockMCSR); ret.sparseBlock.set(i, m2.sparseBlock.get(aix), ldeep); - ret.nonZeros += ret.sparseBlock.size(i); + ret.nonZeros += (lnnz = ret.sparseBlock.size(i)); } } else { //dense right matrix (append all values) - int lnnz = (int)m2.recomputeNonZeros(aix, aix, 0, n-1); + lnnz = (int)m2.recomputeNonZeros(aix, aix, 0, n-1); if( lnnz > 0 ) { c.allocate(i, lnnz); //allocate once double[] bvals = m2.getDenseBlock().values(aix); @@ -1532,7 +1533,7 @@ public class LibMatrixMult } } //optional scaling if not pure selection - if( avals[apos] != 1 ) + if( avals[apos] != 1 && lnnz > 0 ) vectMultiplyInPlace(avals[apos], c.values(i), c.pos(i), c.size(i)); } else //GENERAL CASE http://git-wip-us.apache.org/repos/asf/systemml/blob/2b3aefe7/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix_full_other/FullMatrixMultiplicationUltraSparseTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix_full_other/FullMatrixMultiplicationUltraSparseTest.java b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix_full_other/FullMatrixMultiplicationUltraSparseTest.java index 2a621fb..538ac1b 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix_full_other/FullMatrixMultiplicationUltraSparseTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix_full_other/FullMatrixMultiplicationUltraSparseTest.java @@ -24,7 +24,7 @@ import java.util.HashMap; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; - +import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.lops.LopProperties.ExecType; import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; @@ -34,7 +34,6 @@ import org.apache.sysml.test.utils.TestUtils; public class FullMatrixMultiplicationUltraSparseTest extends AutomatedTestBase { - private final static String TEST_NAME = "FullMatrixMultiplication"; private final static String TEST_DIR = "functions/binary/matrix_full_other/"; private final static String TEST_CLASS_DIR = TEST_DIR + FullMatrixMultiplicationUltraSparseTest.class.getSimpleName() + "/"; @@ -83,62 +82,77 @@ public class FullMatrixMultiplicationUltraSparseTest extends AutomatedTestBase } @Test - public void testMMDenseUltraSparseCP() - { + public void testMMDenseUltraSparseCP() { runMatrixMatrixMultiplicationTest(SparsityType.DENSE, SparsityType.ULTRA_SPARSE, ExecType.CP); } @Test - public void testMMSparseUltraSparseCP() - { + public void testMMSparseUltraSparseCP() { runMatrixMatrixMultiplicationTest(SparsityType.SPARSE, SparsityType.ULTRA_SPARSE, ExecType.CP); } @Test - public void testMMUltraSparseDenseCP() - { + public void testMMUltraSparseDenseCP() { runMatrixMatrixMultiplicationTest(SparsityType.ULTRA_SPARSE, SparsityType.DENSE, ExecType.CP); } @Test - public void testMMUltraSparseSparseCP() - { + public void testMMUltraSparseSparseCP() { runMatrixMatrixMultiplicationTest(SparsityType.ULTRA_SPARSE, SparsityType.SPARSE, ExecType.CP); } @Test - public void testMMUltraSparseUltraSparseCP() - { + public void testMMUltraSparseUltraSparseCP() { runMatrixMatrixMultiplicationTest(SparsityType.ULTRA_SPARSE, SparsityType.ULTRA_SPARSE, ExecType.CP); } @Test - public void testMMDenseUltraSparseMR() - { + public void testMMDenseUltraSparseSP() { + runMatrixMatrixMultiplicationTest(SparsityType.DENSE, SparsityType.ULTRA_SPARSE, ExecType.SPARK); + } + + @Test + public void testMMSparseUltraSparseSP() { + runMatrixMatrixMultiplicationTest(SparsityType.SPARSE, SparsityType.ULTRA_SPARSE, ExecType.SPARK); + } + + @Test + public void testMMUltraSparseDenseSP() { + runMatrixMatrixMultiplicationTest(SparsityType.ULTRA_SPARSE, SparsityType.DENSE, ExecType.SPARK); + } + + @Test + public void testMMUltraSparseSparseSP() { + runMatrixMatrixMultiplicationTest(SparsityType.ULTRA_SPARSE, SparsityType.SPARSE, ExecType.SPARK); + } + + @Test + public void testMMUltraSparseUltraSparseSP() { + runMatrixMatrixMultiplicationTest(SparsityType.ULTRA_SPARSE, SparsityType.ULTRA_SPARSE, ExecType.SPARK); + } + + @Test + public void testMMDenseUltraSparseMR() { runMatrixMatrixMultiplicationTest(SparsityType.DENSE, SparsityType.ULTRA_SPARSE, ExecType.MR); } @Test - public void testMMSparseUltraSparseMR() - { + public void testMMSparseUltraSparseMR() { runMatrixMatrixMultiplicationTest(SparsityType.SPARSE, SparsityType.ULTRA_SPARSE, ExecType.MR); } @Test - public void testMMUltraSparseDenseMR() - { + public void testMMUltraSparseDenseMR() { runMatrixMatrixMultiplicationTest(SparsityType.ULTRA_SPARSE, SparsityType.DENSE, ExecType.MR); } @Test - public void testMMUltraSparseSparseMR() - { + public void testMMUltraSparseSparseMR() { runMatrixMatrixMultiplicationTest(SparsityType.ULTRA_SPARSE, SparsityType.SPARSE, ExecType.MR); } @Test - public void testMMUltraSparseUltraSparseMR() - { + public void testMMUltraSparseUltraSparseMR() { runMatrixMatrixMultiplicationTest(SparsityType.ULTRA_SPARSE, SparsityType.ULTRA_SPARSE, ExecType.MR); } @@ -150,12 +164,17 @@ public class FullMatrixMultiplicationUltraSparseTest extends AutomatedTestBase */ private void runMatrixMatrixMultiplicationTest( SparsityType sparseM1, SparsityType sparseM2, ExecType instType) { - //setup exec type, rows, cols - - //rtplatform for MR RUNTIME_PLATFORM platformOld = rtplatform; - rtplatform = (instType==ExecType.MR) ? RUNTIME_PLATFORM.HADOOP : RUNTIME_PLATFORM.HYBRID; + switch( instType ){ + case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; + case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; + default: rtplatform = RUNTIME_PLATFORM.HYBRID; break; + } + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if( rtplatform == RUNTIME_PLATFORM.SPARK ) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + try { TestConfiguration config = getTestConfiguration(TEST_NAME); @@ -194,8 +213,8 @@ public class FullMatrixMultiplicationUltraSparseTest extends AutomatedTestBase HashMap<CellIndex, Double> rfile = readRMatrixFromFS("C"); TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); } - finally - { + finally { + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; rtplatform = platformOld; } }
