This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new a11c14a4a4 [SYSTEMDS-3683] SparseBlock Non-empty Row Iterator
a11c14a4a4 is described below
commit a11c14a4a42107eb5057789cbd087b55d7400f4a
Author: Rene Enjilian <[email protected]>
AuthorDate: Mon Mar 25 15:04:12 2024 +0100
[SYSTEMDS-3683] SparseBlock Non-empty Row Iterator
Closes #2005.
---
.../org/apache/sysds/runtime/data/SparseBlock.java | 119 +++++++++++++
.../apache/sysds/runtime/data/SparseBlockCOO.java | 37 ++++
.../apache/sysds/runtime/data/SparseBlockCSR.java | 31 ++++
.../apache/sysds/runtime/data/SparseBlockDCSR.java | 33 ++++
.../apache/sysds/runtime/data/SparseBlockMCSR.java | 30 ++++
.../test/component/sparse/SparseBlockIterator.java | 195 ++++++++++++---------
6 files changed, 361 insertions(+), 84 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
index bd3468531d..c2fd193d7c 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
@@ -306,6 +306,33 @@ public abstract class SparseBlock implements Serializable,
Block
* @return starting position of row r
*/
public abstract int pos(int r);
+
+ /**
+ * Get the next non-zero row index in the row array.
+ *
+ * @param r previous row index starting at 0
+ * @param ru exclusive upper row index starting at 0
+ * @return next non-zero row index
+ */
+ public abstract int nextNonZeroRowIndex(int r, int ru);
+
+ /**
+ * Get the starting index in the row array.
+ *
+ * @param r inclusive lower row index starting at 0
+ * @param ru exclusive upper row index starting at 0
+ * @return starting index in row array
+ */
+ public abstract int setSearchIndex(int r, int ru);
+
+ /**
+ * Get the next index in the row array.
+ *
+ * @param r previous row index starting at 0
+ * @param ru exclusive upper row index starting at 0
+ * @return next index in row array
+ */
+ public abstract int updateSearchIndex(int r, int ru);
////////////////////////
@@ -553,6 +580,30 @@ public abstract class SparseBlock implements Serializable,
Block
//default generic iterator, override if necessary
return new SparseBlockIterator(rl, Math.min(ru,numRows()));
}
+
+ /**
+ * Get an iterator over the indices of non-empty rows within the entire
sparse block.
+ * This iterator facilitates traversal over rows that contain at least
one non-zero element,
+ * skipping entirely zero rows. The returned integers represent the
indexes of non-empty rows.
+ *
+ * @return iterator
+ */
+ public Iterator<Integer> getNonEmptyRowIterator() {
+ return new SparseNonEmptyRowIterator(0, numRows());
+ }
+
+ /**
+ * Get an iterator over the indices of non-zero rows within the
sub-block [rl,ru).
+ * This iterator facilitates traversal over rows that contain at least
one non-zero element,
+ * skipping entirely zero rows. The returned integers represent the
indexes of non-empty rows.
+ *
+ * @param rl inclusive lower row index starting at 0
+ * @param ru exclusive upper row index starting at 0
+ * @return Integer iterator
+ */
+ public Iterator<Integer> getNonEmptyRowIterator(int rl, int ru) {
+ return new SparseNonEmptyRowIterator(rl, ru);
+ }
@Override
public abstract String toString();
@@ -717,4 +768,72 @@ public abstract class SparseBlock implements Serializable,
Block
}
}
}
+
+ //TODO: move to individual sparse blocks for performance/separation ->
MB
+ private class SparseNonEmptyRowIterator implements Iterator<Integer> {
+ private int _rlen = 0; //row upper
+ private int _curRow = -1; //current row
+ private boolean _noNext = false; //end indicator
+ private int _searchIndex = 0;
+ private int _previousSearchIndex = -1;
+
+ protected SparseNonEmptyRowIterator(int rl, int ru) {
+ _rlen = ru;
+ _curRow = rl;
+ _searchIndex = setSearchIndex(_curRow, ru);
+ if(_searchIndex == -1) {
+ _noNext = true;
+ }
+ }
+
+ @Override
+ public boolean hasNext() {
+ return !_noNext;
+ }
+
+ @Override
+ public Integer next() {
+ if(SparseBlock.this instanceof SparseBlockDCSR ||
SparseBlock.this instanceof SparseBlockCOO) {
+ _curRow = nextNonZeroRowIndex(_searchIndex,
_rlen);
+ _previousSearchIndex = _searchIndex;
+ _searchIndex =
updateSearchIndex(_previousSearchIndex, _rlen);
+ if(_previousSearchIndex == _searchIndex) {
+ _noNext = true;
+ }
+ return _curRow;
+ }
+ else if(SparseBlock.this instanceof SparseBlockCSR) {
+ _curRow = nextNonZeroRowIndex(_searchIndex,
_rlen);
+ _searchIndex = updateSearchIndex(_curRow,
_rlen);
+ _searchIndex = setSearchIndex(_searchIndex,
_rlen); // special case: single non-zero row
+ if(_curRow == _previousSearchIndex || _curRow
== _searchIndex || _searchIndex == -1) {
+ _noNext = true;
+ _searchIndex = _curRow;
+ }
+ _previousSearchIndex = _curRow;
+ return _curRow;
+ }
+ else { //MCSR
+ _previousSearchIndex =
nextNonZeroRowIndex(_searchIndex, _rlen);
+ _curRow =
updateSearchIndex(_previousSearchIndex, _rlen);
+ if(_previousSearchIndex == _curRow) {
+ _noNext = true;
+ }
+ else {
+ _searchIndex = _curRow;
+ }
+ return _previousSearchIndex;
+ }
+ }
+
+ @Override
+ public void remove() {
+ throw new RuntimeException("SparseBlockIterator is
unsupported!");
+ }
+
+ /**
+ * Moves cursor to next non-zero row or indicates that no more
+ * rows are available.
+ */
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java
index 124c942122..b7028e1e1f 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java
@@ -354,6 +354,43 @@ public class SparseBlockCOO extends SparseBlock
return index;
}
+ @Override
+ public int nextNonZeroRowIndex(int r, int ru) {
+ return _rindexes[r];
+ }
+
+ @Override
+ public int setSearchIndex(int r, int ru) {
+ int insertionPoint = -1;
+ int result = Arrays.binarySearch(_rindexes, r);
+ if(result < 0) {
+ insertionPoint = -result - 1;
+ if(_rindexes[insertionPoint] == ru) {
+ return -1;
+ }
+ return insertionPoint;
+ }
+ else {
+ if(_rindexes[result] == ru) {
+ return -1;
+ }
+ return result;
+ }
+ }
+
+ @Override
+ public int updateSearchIndex(int r, int ru) {
+ int currentRow = _rindexes[r];
+ int i = r;
+ while(i < _rindexes.length && _rindexes[i] < ru) {
+ if(_rindexes[i] != currentRow) {
+ return i;
+ }
+ i++;
+ }
+ return r;
+ }
+
@Override
public boolean set(int r, int c, double v) {
int pos = pos(r);
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
index ed00f15564..18caf806d7 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
@@ -465,6 +465,37 @@ public class SparseBlockCSR extends SparseBlock
return _ptr[r];
}
+ @Override
+ public int nextNonZeroRowIndex(int r, int ru) {
+ for(int i = r; i < ru; i++) {
+ if(_ptr[i] < _ptr[i + 1]) {
+ return i;
+ }
+ }
+ return r - 1;
+ }
+
+ @Override
+ public int setSearchIndex(int r, int ru) {
+ if(_ptr[r] == _ptr[ru]) {
+ return -1; //zero matrix
+ }
+ return r;
+ }
+
+ @Override
+ public int updateSearchIndex(int r, int ru) {
+ if(r + 2 == ru && _ptr[r + 1] == _ptr[r + 2]) {
+ return r;
+ }
+ else if(r + 1 == ru) {
+ return r;
+ }
+ else {
+ return r + 1;
+ }
+ }
+
@Override
public boolean set(int r, int c, double v) {
int pos = pos(r);
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
index c5d4717e11..5daa4a18d1 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
@@ -341,6 +341,39 @@ public class SparseBlockDCSR extends SparseBlock
return _rowptr[idx];
}
+ @Override
+ public int nextNonZeroRowIndex(int r, int ru) {
+ return _rowidx[r];
+ }
+
+ @Override
+ public int setSearchIndex(int r, int ru) {
+ int insertionPoint = -1;
+ int result = Arrays.binarySearch(_rowidx, r);
+ if(result < 0) {
+ insertionPoint = -result - 1;
+ if(_rowidx[insertionPoint] == ru) {
+ return -1;
+ }
+ return insertionPoint;
+ }
+ else {
+ if(_rowidx[result] == ru) {
+ return -1;
+ }
+ return result;
+ }
+ }
+
+ @Override
+ public int updateSearchIndex(int r, int ru) {
+ int nextIndex = r + 1;
+ if(nextIndex >= _rowidx.length || _rowidx[nextIndex] >= ru) {
+ nextIndex = r;
+ }
+ return nextIndex;
+ }
+
@Override
public boolean set(int r, int c, double v) {
int rowIndex = Arrays.binarySearch(_rowidx, 0, _nnzr, r);
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java
index f6fa157448..62f4480ff0 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java
@@ -339,6 +339,36 @@ public class SparseBlockMCSR extends SparseBlock
return 0;
}
+ @Override
+ public int nextNonZeroRowIndex(int r, int ru) {
+ for(int i = r; i < ru; i++) {
+ if(_rows[i] != null) {
+ return i;
+ }
+ }
+ return r;
+ }
+
+ @Override
+ public int setSearchIndex(int r, int ru) {
+ for(int i = r; i < ru; i++) {
+ if(_rows[i] != null) {
+ return i;
+ }
+ }
+ return -1;
+ }
+
+ @Override
+ public int updateSearchIndex(int r, int ru) {
+ for(int i = r; i < ru; i++) {
+ if(_rows[i] != null && i != r) {
+ return i;
+ }
+ }
+ return r;
+ }
+
@Override
public boolean set(int r, int c, double v) {
if( !isAllocated(r) )
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java
index 068bedf78e..523e7d27db 100644
---
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java
+++
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,7 +19,9 @@
package org.apache.sysds.test.component.sparse;
+import java.util.ArrayList;
import java.util.Iterator;
+import java.util.List;
import org.apache.sysds.runtime.data.SparseBlockDCSR;
import org.junit.Assert;
@@ -35,195 +37,220 @@ import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestUtils;
/**
- * This is a sparse matrix block component test for sparse block iterator
- * functionality. In order to achieve broad coverage, we test against
- * full and partial iterators as well as different sparsity values.
- *
+ * This is a component test for sparse matrix block, focusing on the iterator
functionality for both general iteration
+ * over non-zero cells and specific iteration over non-zero rows. To ensure
comprehensive coverage, the tests encompass
+ * full and partial iterators, different sparsity values, and explicitly
verify the correct identification and iteration
+ * over non-zero rows in the matrix.
*/
-public class SparseBlockIterator extends AutomatedTestBase
-{
+public class SparseBlockIterator extends AutomatedTestBase {
private final static int rows = 324;
- private final static int cols = 100;
+ private final static int cols = 100;
private final static int rlPartial = 134;
private final static double sparsity1 = 0.1;
private final static double sparsity2 = 0.2;
private final static double sparsity3 = 0.3;
-
+
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
}
@Test
- public void testSparseBlockMCSR1Full() {
+ public void testSparseBlockMCSR1Full() {
runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity1,
false);
}
-
+
@Test
- public void testSparseBlockMCSR2Full() {
+ public void testSparseBlockMCSR2Full() {
runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity2,
false);
}
-
+
@Test
- public void testSparseBlockMCSR3Full() {
+ public void testSparseBlockMCSR3Full() {
runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity3,
false);
}
-
+
@Test
- public void testSparseBlockMCSR1Partial() {
+ public void testSparseBlockMCSR1Partial() {
runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity1,
true);
}
-
+
@Test
- public void testSparseBlockMCSR2Partial() {
+ public void testSparseBlockMCSR2Partial() {
runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity2,
true);
}
-
+
@Test
- public void testSparseBlockMCSR3Partial() {
+ public void testSparseBlockMCSR3Partial() {
runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity3,
true);
}
-
+
@Test
- public void testSparseBlockCSR1Full() {
+ public void testSparseBlockCSR1Full() {
runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity1,
false);
}
-
+
@Test
- public void testSparseBlockCSR2Full() {
+ public void testSparseBlockCSR2Full() {
runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity2,
false);
}
-
+
@Test
- public void testSparseBlockCSR3Full() {
+ public void testSparseBlockCSR3Full() {
runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity3,
false);
}
-
+
@Test
- public void testSparseBlockCSR1Partial() {
+ public void testSparseBlockCSR1Partial() {
runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity1,
true);
}
-
+
@Test
- public void testSparseBlockCSR2Partial() {
+ public void testSparseBlockCSR2Partial() {
runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity2,
true);
}
-
+
@Test
- public void testSparseBlockCSR3Partial() {
+ public void testSparseBlockCSR3Partial() {
runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity3,
true);
}
-
+
@Test
- public void testSparseBlockCOO1Full() {
+ public void testSparseBlockCOO1Full() {
runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity1,
false);
}
-
+
@Test
- public void testSparseBlockCOO2Full() {
+ public void testSparseBlockCOO2Full() {
runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity2,
false);
}
-
+
@Test
- public void testSparseBlockCOO3Full() {
+ public void testSparseBlockCOO3Full() {
runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity3,
false);
}
-
+
@Test
- public void testSparseBlockCOO1Partial() {
+ public void testSparseBlockCOO1Partial() {
runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity1,
true);
}
-
+
@Test
- public void testSparseBlockCOO2Partial() {
+ public void testSparseBlockCOO2Partial() {
runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity2,
true);
}
-
+
@Test
- public void testSparseBlockCOO3Partial() {
+ public void testSparseBlockCOO3Partial() {
runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity3,
true);
}
@Test
- public void testSparseBlockDCSR1Full() {
+ public void testSparseBlockDCSR1Full() {
runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity1,
false);
}
@Test
- public void testSparseBlockDCSR2Full() {
+ public void testSparseBlockDCSR2Full() {
runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity2,
false);
}
@Test
- public void testSparseBlockDCSR3Full() {
+ public void testSparseBlockDCSR3Full() {
runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity3,
false);
}
@Test
- public void testSparseBlockDCSR1Partial() {
+ public void testSparseBlockDCSR1Partial() {
runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity1,
true);
}
@Test
- public void testSparseBlockDCSR2Partial() {
+ public void testSparseBlockDCSR2Partial() {
runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity2,
true);
}
@Test
- public void testSparseBlockDCSR3Partial() {
+ public void testSparseBlockDCSR3Partial() {
runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity3,
true);
}
-
- private void runSparseBlockIteratorTest( SparseBlock.Type btype, double
sparsity, boolean partial)
- {
- try
- {
+
+ private void runSparseBlockIteratorTest(SparseBlock.Type btype, double
sparsity, boolean partial) {
+ try {
//data generation
- double[][] A = getRandomMatrix(rows, cols, -10, 10,
sparsity, 8765432);
-
+ double[][] A = getRandomMatrix(rows, cols, -10, 10,
sparsity, 8765432);
+
//init sparse block
SparseBlock sblock = null;
MatrixBlock mbtmp =
DataConverter.convertToMatrixBlock(A);
- SparseBlock srtmp = mbtmp.getSparseBlock();
- switch( btype ) {
- case MCSR: sblock = new SparseBlockMCSR(srtmp);
break;
- case CSR: sblock = new SparseBlockCSR(srtmp);
break;
- case COO: sblock = new SparseBlockCOO(srtmp);
break;
- case DCSR: sblock = new SparseBlockDCSR(srtmp);
break;
+ SparseBlock srtmp = mbtmp.getSparseBlock();
+ switch(btype) {
+ case MCSR:
+ sblock = new SparseBlockMCSR(srtmp);
+ break;
+ case CSR:
+ sblock = new SparseBlockCSR(srtmp);
+ break;
+ case COO:
+ sblock = new SparseBlockCOO(srtmp);
+ break;
+ case DCSR:
+ sblock = new SparseBlockDCSR(srtmp);
+ break;
}
-
+
//check for correct number of non-zeros
- int[] rnnz = new int[rows]; int nnz = 0;
+ int[] rnnz = new int[rows];
+ int nnz = 0;
int rl = partial ? rlPartial : 0;
- for( int i=rl; i<rows; i++ ) {
- for( int j=0; j<cols; j++ )
- rnnz[i] += (A[i][j]!=0) ? 1 : 0;
+ for(int i = rl; i < rows; i++) {
+ for(int j = 0; j < cols; j++)
+ rnnz[i] += (A[i][j] != 0) ? 1 : 0;
nnz += rnnz[i];
}
- if( !partial && nnz != sblock.size() )
- Assert.fail("Wrong number of non-zeros:
"+sblock.size()+", expected: "+nnz);
-
+ if(!partial && nnz != sblock.size())
+ Assert.fail("Wrong number of non-zeros: " +
sblock.size() + ", expected: " + nnz);
+
//check correct isEmpty return
- for( int i=rl; i<rows; i++ )
- if( sblock.isEmpty(i) != (rnnz[i]==0) )
- Assert.fail("Wrong isEmpty(row) result
for row nnz: "+rnnz[i]);
-
+ for(int i = rl; i < rows; i++)
+ if(sblock.isEmpty(i) != (rnnz[i] == 0))
+ Assert.fail("Wrong isEmpty(row) result
for row nnz: " + rnnz[i]);
+
//check correct values
- Iterator<IJV> iter = !partial ? sblock.getIterator() :
- sblock.getIterator(rl, rows);
+ Iterator<IJV> iter = !partial ? sblock.getIterator() :
sblock.getIterator(rl, rows);
int count = 0;
- while( iter.hasNext() ) {
+ while(iter.hasNext()) {
IJV cell = iter.next();
- if( cell.getV() != A[cell.getI()][cell.getJ()] )
- Assert.fail("Wrong value returned by
iterator: "+cell.getV()+", expected: "+A[cell.getI()][cell.getJ()]);
+ if(cell.getV() != A[cell.getI()][cell.getJ()])
+ Assert.fail("Wrong value returned by
iterator: " + cell.getV() + ", expected: " +
+ A[cell.getI()][cell.getJ()]);
count++;
}
- if( count != nnz )
- Assert.fail("Wrong number of values returned by
iterator: "+count+", expected: "+nnz);
+ if(count != nnz)
+ Assert.fail("Wrong number of values returned by
iterator: " + count + ", expected: " + nnz);
+
+ // check iterator over non-zero rows
+ List<Integer> manualNonZeroRows = new ArrayList<>();
+ List<Integer> iteratorNonZeroRows = new ArrayList<>();
+ Iterator<Integer> iterRows = !partial ?
+ sblock.getNonEmptyRowIterator() :
+ sblock.getNonEmptyRowIterator(rl, rows);
+
+ for(int i = rl; i < rows; i++)
+ if(!sblock.isEmpty(i))
+ manualNonZeroRows.add(i);
+ while(iterRows.hasNext()) {
+ iteratorNonZeroRows.add(iterRows.next());
+ }
+
+ // Compare the results
+ if(!manualNonZeroRows.equals(iteratorNonZeroRows)) {
+ Assert.fail("Verification of iterator over
non-zero rows failed.");
+ }
}
catch(Exception ex) {
ex.printStackTrace();
throw new RuntimeException(ex);
}
}
-}
\ No newline at end of file
+}