This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 743c9022f6 [SYSTEMDS-3775] Improved test coverage sparse blocks and
various fixes
743c9022f6 is described below
commit 743c9022f6a2a7c6701e7431f09d53210f8853ea
Author: Jessica Priebe <[email protected]>
AuthorDate: Sun Mar 1 17:26:34 2026 +0100
[SYSTEMDS-3775] Improved test coverage sparse blocks and various fixes
Closes #2406.
---
.../org/apache/sysds/runtime/data/SparseBlock.java | 60 ++-
.../apache/sysds/runtime/data/SparseBlockCOO.java | 35 +-
.../apache/sysds/runtime/data/SparseBlockCSC.java | 83 ++--
.../apache/sysds/runtime/data/SparseBlockCSR.java | 20 +-
.../apache/sysds/runtime/data/SparseBlockDCSR.java | 64 ++-
.../sysds/runtime/data/SparseBlockFactory.java | 10 +-
.../apache/sysds/runtime/data/SparseBlockMCSC.java | 87 +++-
.../apache/sysds/runtime/data/SparseBlockMCSR.java | 42 +-
.../apache/sysds/runtime/data/SparseRowVector.java | 6 +-
.../sysds/runtime/matrix/data/MatrixBlock.java | 2 +-
.../component/sparse/SparseBlockAlignment.java | 6 +
.../sparse/SparseBlockCheckValidityTest.java | 544 +++++++++++++++++++++
.../test/component/sparse/SparseBlockColTest.java | 267 ++++++++++
.../component/sparse/SparseBlockCompactTest.java | 134 +++++
.../component/sparse/SparseBlockContainsTest.java | 316 ++++++++++++
.../component/sparse/SparseBlockEqualsTest.java | 222 +++++++++
.../sparse/SparseBlockInitializationTest.java | 484 ++++++++++++++++++
.../test/component/sparse/SparseBlockIterator.java | 300 ++++++++++--
.../sysds/test/component/sparse/SparseRowTest.java | 216 ++++++++
19 files changed, 2713 insertions(+), 185 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
index 864569358f..ceb0b15adf 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
@@ -26,6 +26,7 @@ import java.util.Iterator;
import java.util.List;
import org.apache.sysds.runtime.matrix.data.IJV;
+import org.apache.sysds.runtime.util.UtilFunctions;
/**
* This SparseBlock is an abstraction for different sparse matrix formats.
@@ -94,10 +95,23 @@ public abstract class SparseBlock implements Serializable,
Block
* @param r row index
*/
public abstract void compact(int r);
+
+ /**
+ * In-place compaction of non-zero-entries; removes zero entries
+ * and shifts non-zero entries to the left if necessary.
+ */
+ public abstract void compact();
////////////////////////
//obtain basic meta data
-
+
+ /**
+ * Get the type of the sparse block.
+ *
+ * @return sparse block type
+ */
+ public abstract SparseBlock.Type getSparseBlockType();
+
/**
* Get the number of rows in the sparse block.
*
@@ -501,18 +515,26 @@ public abstract class SparseBlock implements
Serializable, Block
}
public List<Integer> contains(double[] pattern, boolean earlyAbort) {
+ int pNnz = UtilFunctions.computeNnz(pattern, 0, pattern.length);
List<Integer> ret = new ArrayList<>();
int rlen = numRows();
+
for( int i=0; i<rlen; i++ ) {
int apos = pos(i);
int alen = size(i);
+ if(pNnz > alen) continue;
+
int[] aix = indexes(i);
double[] avals = values(i);
boolean lret = true;
+ int rNnz = 0;
+
//safe comparison on long representations, incl NaN
- for(int k=apos; k<apos+alen & !lret; k++)
+ for(int k=apos; k<apos+alen && lret; k++) {
lret &= Double.compare(avals[k],
pattern[aix[k]]) == 0;
- if( lret )
+ if(avals[k] != 0) rNnz++;
+ }
+ if(lret && rNnz == pNnz)
ret.add(i);
if(earlyAbort && ret.size()>0)
return ret;
@@ -764,17 +786,31 @@ public abstract class SparseBlock implements
Serializable, Block
* values are available.
*/
private void findNextNonZeroRow(int cl) {
- while( _curRow<_rlen && (isEmpty(_curRow)
- || (cl>0 && posFIndexGTE(_curRow, cl) < 0)) )
+ while(_curRow < _rlen){
+ if(isEmpty(_curRow)){
+ _curRow++;
+ continue;
+ }
+
+ int pos = (cl == 0)? 0 : posFIndexGTE(_curRow,
cl);
+ if(pos < 0){
+ _curRow++;
+ continue;
+ }
+
+ int sizeRow = size(_curRow);
+ int endPos = (_cu == Integer.MAX_VALUE)?
sizeRow : posFIndexGTE(_curRow, _cu);
+ if(endPos < 0) endPos = sizeRow;
+
+ if(pos < endPos){
+ _curColIx = pos(_curRow)+pos;
+ _curIndexes = indexes(_curRow);
+ _curValues = values(_curRow);
+ return;
+ }
_curRow++;
- if(_curRow >= _rlen)
- _noNext = true;
- else {
- _curColIx = (cl==0) ?
- pos(_curRow) : posFIndexGTE(_curRow,
cl);
- _curIndexes = indexes(_curRow);
- _curValues = values(_curRow);
}
+ _noNext = true;
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java
index c4e60c10cf..91a309931d 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCOO.java
@@ -193,6 +193,25 @@ public class SparseBlockCOO extends SparseBlock
//do nothing everything preallocated
}
+ @Override
+ public void compact() {
+ int pos = 0;
+ for(int i=0; i< _values.length; i++) {
+ if(_values[i] != 0){
+ _values[pos] = _values[i];
+ _rindexes[pos] = _rindexes[i];
+ _cindexes[pos] = _cindexes[i];
+ pos++;
+ }
+ }
+ _size = pos;
+ }
+
+ @Override
+ public SparseBlock.Type getSparseBlockType() {
+ return Type.COO;
+ }
+
@Override
public int numRows() {
return _rlen;
@@ -221,12 +240,12 @@ public class SparseBlockCOO extends SparseBlock
}
//2. correct array lengths
- if(_size != nnz && _cindexes.length < nnz && _rindexes.length <
nnz && _values.length < nnz) {
+ if(_size != nnz || _cindexes.length < nnz || _rindexes.length <
nnz || _values.length < nnz) {
throw new RuntimeException("Incorrect array lengths.");
}
//3.1. sort order of row indices
- for( int i=1; i<=nnz; i++ ) {
+ for( int i=1; i<nnz; i++ ) {
if(_rindexes[i] < _rindexes[i-1])
throw new RuntimeException("Wrong sorted order
of row indices");
}
@@ -235,14 +254,10 @@ public class SparseBlockCOO extends SparseBlock
for( int i=0; i<rlen; i++ ) {
int apos = pos(i);
int alen = size(i);
- for(int k=apos+i; k<apos+alen; k++)
- if( _cindexes[k+1] >= _cindexes[k] )
+ for(int k=apos+1; k<apos+alen; k++)
+ if(_cindexes[k-1] > _cindexes[k])
throw new RuntimeException("Wrong
sparse row ordering: "
+ k + "
"+_cindexes[k-1]+" "+_cindexes[k]);
- for( int k=apos; k<apos+alen; k++ )
- if(_values[k] == 0)
- throw new RuntimeException("Wrong
sparse row: zero at "
- + k + " at col index "
+ _cindexes[k]);
}
//4. non-existing zero values
@@ -250,11 +265,13 @@ public class SparseBlockCOO extends SparseBlock
if( _values[i] == 0)
throw new RuntimeException("The values array
should not contain zeros."
+ " The " + i + "th value is
"+_values[i]);
+ if(_cindexes[i] < 0 || _rindexes[i] < 0)
+ throw new RuntimeException("Invalid index at
pos=" + i);
}
//5. a capacity that is no larger than nnz times the resize
factor
int capacity = _values.length;
- if( capacity > nnz*RESIZE_FACTOR1 ) {
+ if( capacity > INIT_CAPACITY && capacity > nnz*RESIZE_FACTOR1 )
{
throw new RuntimeException("Capacity is larger than the
nnz times a resize factor."
+ " Current size: "+capacity+ ", while
Expected size:"+nnz*RESIZE_FACTOR1);
}
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java
index b38c3525c9..9674c276fa 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSC.java
@@ -64,15 +64,16 @@ public class SparseBlockCSC extends SparseBlock {
_size = 0;
}
- public SparseBlockCSC(int clen, int capacity, int size) {
+ public SparseBlockCSC(int rlen, int clen, int capacity) {
+ _rlen = rlen;
_ptr = new int[clen + 1]; //ix0=0
_indexes = new int[capacity];
_values = new double[capacity];
- _size = size;
+ _size = 0;
}
- public SparseBlockCSC(int[] rowPtr, int[] rowInd, double[] values, int
nnz) {
- _ptr = rowPtr;
+ public SparseBlockCSC(int[] colPtr, int[] rowInd, double[] values, int
nnz) {
+ _ptr = colPtr;
_indexes = rowInd;
_values = values;
_size = nnz;
@@ -94,8 +95,9 @@ public class SparseBlockCSC extends SparseBlock {
private void initialize(SparseBlock sblock) {
- if(_size > Integer.MAX_VALUE)
- throw new RuntimeException("SparseBlockCSC supports
nnz<=Integer.MAX_VALUE but got " + _size);
+ long size = sblock.size();
+ if(size > Integer.MAX_VALUE)
+ throw new RuntimeException("SparseBlockCSC supports
nnz<=Integer.MAX_VALUE but got " + size);
//special case SparseBlockCSC
if(sblock instanceof SparseBlockCSC) {
@@ -223,27 +225,6 @@ public class SparseBlockCSC extends SparseBlock {
}
- public SparseBlockCSC(int cols, int nnz, int[] rowInd) {
-
- _clenInferred = cols;
- _ptr = new int[cols + 1];
- _indexes = Arrays.copyOf(rowInd, nnz);
- _values = new double[nnz];
- Arrays.fill(_values, 1);
- _size = nnz;
-
- //single-pass construction of col pointers
- //and copy of row indexes if necessary
- for(int i = 0, pos = 0; i < cols; i++) {
- if(rowInd[i] >= 0) {
- if(cols > nnz)
- _indexes[pos] = rowInd[i];
- pos++;
- }
- _ptr[i + 1] = pos;
- }
- }
-
/**
* Initializes the CSC sparse block from an ordered input stream of
ultra-sparse ijv triples.
*
@@ -288,7 +269,6 @@ public class SparseBlockCSC extends SparseBlock {
// Allocate space if necessary
if(_values.length < nnz) {
resize(newCapacity(nnz));
- System.out.println("hallo");
}
// Read sparse columns, append and update pointers
_ptr[0] = 0;
@@ -377,12 +357,36 @@ public class SparseBlockCSC extends SparseBlock {
//do nothing everything preallocated
}
+ @Override
+ public void compact() {
+ int pos = 0;
+ for(int i=0; i<numCols(); i++) {
+ int apos = posCol(i);
+ int alen = sizeCol(i);
+ _ptr[i] = pos;
+ for(int j=apos; j<apos+alen; j++) {
+ if(_values[j] != 0){
+ _values[pos] = _values[j];
+ _indexes[pos] = _indexes[j];
+ pos++;
+ }
+ }
+ }
+ _ptr[numCols()] = pos;
+ _size = pos;
+ }
+
+ @Override
+ public SparseBlock.Type getSparseBlockType() {
+ return Type.CSC;
+ }
+
@Override
public int numRows() {
if(_rlen > -1)
return _rlen;
else {
- int rlen = Arrays.stream(_indexes).max().getAsInt();
+ int rlen = Arrays.stream(_indexes).max().getAsInt()+1;
_rlen = rlen;
return rlen;
}
@@ -550,12 +554,12 @@ public class SparseBlockCSC extends SparseBlock {
}
//2. correct array lengths
- if(_size != nnz && _ptr.length < clen + 1 && _values.length <
nnz && _indexes.length < nnz) {
+ if(_size != nnz || _ptr.length < clen + 1 || _values.length <
nnz || _indexes.length < nnz) {
throw new RuntimeException("Incorrect array lengths.");
}
- //3. non-decreasing row pointers
- for(int i = 1; i < clen; i++) {
+ //3. non-decreasing col pointers
+ for(int i = 1; i <= clen; i++) {
if(_ptr[i - 1] > _ptr[i] && strict)
throw new RuntimeException(
"Column pointers are decreasing at
column: " + i + ", with pointers " + _ptr[i - 1] + " > " +
@@ -569,10 +573,9 @@ public class SparseBlockCSC extends SparseBlock {
for(int k = apos + 1; k < apos + alen; k++)
if(_indexes[k - 1] >= _indexes[k])
throw new RuntimeException(
- "Wrong sparse column ordering:
" + k + " " + _indexes[k - 1] + " " + _indexes[k]);
- for(int k = apos; k < apos + alen; k++)
- if(_values[k] == 0)
- throw new RuntimeException("Wrong
sparse column: zero at " + k + " at row index " + _indexes[k]);
+ "Wrong sparse column ordering,
at column=" + i + ", pos=" + k + " with row indexes " +
+ _indexes[k - 1] + ">="
+ _indexes[k]
+ );
}
//5. non-existing zero values
@@ -581,11 +584,13 @@ public class SparseBlockCSC extends SparseBlock {
throw new RuntimeException(
"The values array should not contain
zeros." + " The " + i + "th value is " + _values[i]);
}
+ if(_indexes[i] < 0)
+ throw new RuntimeException("Invalid index at
pos=" + i);
}
//6. a capacity that is no larger than nnz times resize factor.
int capacity = _values.length;
- if(capacity > nnz * RESIZE_FACTOR1) {
+ if(capacity > INIT_CAPACITY && capacity > nnz * RESIZE_FACTOR1)
{
throw new RuntimeException(
"Capacity is larger than the nnz times a resize
factor." + " Current size: " + capacity +
", while Expected size:" + nnz *
RESIZE_FACTOR1);
@@ -938,7 +943,7 @@ public class SparseBlockCSC extends SparseBlock {
int len = sizeCol(c);
int end = internPosFIndexGTECol(ru, c);
if(end < 0) //delete all remaining
- end = start + len;
+ end = posCol(c) + len;
//overlapping array copy (shift rhs values left)
System.arraycopy(_indexes, end, _indexes, start, _size - end);
@@ -1059,7 +1064,7 @@ public class SparseBlockCSC extends SparseBlock {
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
- sb.append("SparseBlockCSR: clen=");
+ sb.append("SparseBlockCSC: clen=");
sb.append(numCols());
sb.append(", nnz=");
sb.append(size());
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
index a40c567dfb..a1f25cb37e 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
@@ -350,7 +350,8 @@ public class SparseBlockCSR extends SparseBlock
public void compact(int r) {
//do nothing everything preallocated
}
-
+
+ @Override
public void compact() {
int pos = 0;
for(int i=0; i<numRows(); i++) {
@@ -369,6 +370,11 @@ public class SparseBlockCSR extends SparseBlock
_size = pos; //adjust logical size
}
+ @Override
+ public SparseBlock.Type getSparseBlockType() {
+ return Type.CSR;
+ }
+
@Override
public int numRows() {
return _ptr.length-1;
@@ -937,12 +943,12 @@ public class SparseBlockCSR extends SparseBlock
}
//2. correct array lengths
- if(_size != nnz && _ptr.length < rlen+1 && _values.length < nnz
&& _indexes.length < nnz ) {
+ if( _size != nnz || _ptr.length < rlen+1 || _values.length <
nnz || _indexes.length < nnz ) {
throw new RuntimeException("Incorrect array lengths.");
}
//3. non-decreasing row pointers
- for( int i=1; i<rlen; i++ ) {
+ for( int i=1; i<=rlen; i++ ) {
if(_ptr[i-1] > _ptr[i] && strict)
throw new RuntimeException("Row pointers are
decreasing at row: "+i
+ ", with pointers "+_ptr[i-1]+" >
"+_ptr[i]);
@@ -956,10 +962,6 @@ public class SparseBlockCSR extends SparseBlock
if( _indexes[k-1] >= _indexes[k] )
throw new RuntimeException("Wrong
sparse row ordering: "
+ k + " "+_indexes[k-1]+"
"+_indexes[k]);
- for( int k=apos; k<apos+alen; k++ )
- if( _values[k] == 0 )
- throw new RuntimeException("Wrong
sparse row: zero at "
- + k + " at col index " +
_indexes[k]);
}
//5. non-existing zero values
@@ -968,11 +970,13 @@ public class SparseBlockCSR extends SparseBlock
throw new RuntimeException("The values array
should not contain zeros."
+ " The " + i + "th value is
"+_values[i]);
}
+ if(_indexes[i] < 0)
+ throw new RuntimeException("Invalid index at
pos=" + i);
}
//6. a capacity that is no larger than nnz times resize factor.
int capacity = _values.length;
- if(capacity > nnz*RESIZE_FACTOR1 ) {
+ if(capacity > INIT_CAPACITY && capacity > nnz*RESIZE_FACTOR1 ) {
throw new RuntimeException("Capacity is larger than the
nnz times a resize factor."
+ " Current size: "+capacity+ ", while Expected
size:"+nnz*RESIZE_FACTOR1);
}
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
index b369992efa..6cad53dac6 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
@@ -63,17 +63,6 @@ public class SparseBlockDCSR extends SparseBlock
_nnzr = 0;
}
- public SparseBlockDCSR(int rlen, int capacity, int size, int nnzr){
- LOG.warn("Allocating a DCSR-block using row-length. This will
lead to significant overhead!");
- _rowidx = new int[rlen];
- _rowptr = new int[rlen + 1];
- _colidx = new int[capacity];
- _values = new double[capacity];
- _rlen = rlen;
- _size = size;
- _nnzr = nnzr;
- }
-
public SparseBlockDCSR(int[] rowIdx, int[] rowPtr, int[] colIdx,
double[] values, int rlen, int nnz, int nnzr){
LOG.warn("Allocating a DCSR-block using row-length. This will
lead to significant overhead!");
_rowidx = rowIdx;
@@ -210,6 +199,36 @@ public class SparseBlockDCSR extends SparseBlock
//do nothing everything preallocated
}
+ @Override
+ public void compact() {
+ int idx = 0;
+ int pos = 0;
+ for(int i=0; i<_nnzr; i++) {
+ int apos = pos(_rowidx[i]);
+ int alen = size(_rowidx[i]);
+ _rowptr[idx] = pos;
+ for(int j=apos; j<apos+alen; j++) {
+ if(_values[j] != 0){
+ _values[pos] = _values[j];
+ _colidx[pos] = _colidx[j];
+ pos++;
+ }
+ }
+ if(_rowptr[idx]<pos){
+ _rowidx[idx] = _rowidx[i];
+ idx++;
+ }
+ }
+ _size = pos;
+ _nnzr = idx;
+ _rowptr[_nnzr] = pos;
+ }
+
+ @Override
+ public SparseBlock.Type getSparseBlockType() {
+ return Type.DCSR;
+ }
+
@Override
public int numRows() {
return _rlen;
@@ -670,7 +689,7 @@ public class SparseBlockDCSR extends SparseBlock
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
- sb.append("SparseBlockCSR: rlen=");
+ sb.append("SparseBlockDCSR: rlen=");
sb.append(numRows());
sb.append(", nnz=");
sb.append(size());
@@ -705,18 +724,18 @@ public class SparseBlockDCSR extends SparseBlock
}
//2. correct array lengths
- if (_size != nnz && _rowptr.length != _rowidx.length + 1 &&
_values.length < nnz && _colidx.length < nnz ) {
+ if ( _size != nnz || _rowptr.length != _rowidx.length + 1 ||
_values.length < nnz || _colidx.length < nnz ) {
throw new RuntimeException("Incorrect array lengths.");
}
//3. non-decreasing row pointers
- for ( int i=1; i <_rowidx.length; i++ ) {
+ for ( int i=1; i < _nnzr; i++ ) {
if (_rowidx[i-1] > _rowidx[i])
throw new RuntimeException("Row indices are
decreasing at row: " + i
+ ", with indices " +
_rowidx[i-1] + " > " +_rowidx[i]);
}
- for (int i = 1; i < _rowptr.length; i++ ) {
+ for (int i = 1; i < _nnzr+1; i++ ) {
if (_rowptr[i - 1] > _rowptr[i]) {
throw new RuntimeException("Row pointers are
decreasing at row: " + i
+ ", with pointers " +
_rowptr[i-1] + " > " +_rowptr[i]);
@@ -724,19 +743,14 @@ public class SparseBlockDCSR extends SparseBlock
}
//4. sorted column indexes per row
- for ( int rowIdx = 0; rowIdx < _rowidx.length; rowIdx++ ) {
- int apos = _rowidx[rowIdx];
- int alen = _rowidx[rowIdx+1] - apos;
+ for (int i = 0; i < _nnzr; i++) {
+ int apos = _rowptr[i];
+ int alen = _rowptr[i+1] - apos;
for( int k = apos + 1; k < apos + alen; k++)
if( _colidx[k-1] >= _colidx[k] )
throw new RuntimeException("Wrong
sparse row ordering: "
+ k + " " +
_colidx[k-1] + " " + _colidx[k]);
-
- for( int k=apos; k<apos+alen; k++ )
- if( _values[k] == 0 )
- throw new RuntimeException("Wrong
sparse row: zero at "
- + k + " at col index "
+ _colidx[k]);
}
//5. non-existing zero values
@@ -745,11 +759,13 @@ public class SparseBlockDCSR extends SparseBlock
throw new RuntimeException("The values array
should not contain zeros."
+ " The " + i + "th value is
"+_values[i]);
}
+ if(_colidx[i] < 0)
+ throw new RuntimeException("Invalid index at
pos=" + i);
}
//6. a capacity that is no larger than nnz times resize factor.
int capacity = _values.length;
- if(capacity > nnz*RESIZE_FACTOR1 ) {
+ if(capacity > INIT_CAPACITY && capacity > nnz*RESIZE_FACTOR1 ) {
throw new RuntimeException("Capacity is larger than the
nnz times a resize factor."
+ " Current size: "+capacity+ ", while
Expected size:"+nnz*RESIZE_FACTOR1);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java
index 22dfe5417e..3f939fbcc0 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java
@@ -37,6 +37,8 @@ public abstract class SparseBlockFactory{
case CSR: return new SparseBlockCSR(rlen);
case COO: return new SparseBlockCOO(rlen);
case DCSR: return new SparseBlockDCSR(rlen);
+ case MCSC: return new SparseBlockMCSC(rlen);
+ case CSC: return new SparseBlockCSC(rlen, 0);
default:
throw new RuntimeException("Unexpected sparse
block type: "+type.toString());
}
@@ -78,13 +80,7 @@ public abstract class SparseBlockFactory{
}
public static boolean isSparseBlockType(SparseBlock sblock,
SparseBlock.Type type) {
- return (getSparseBlockType(sblock) == type);
- }
-
- public static SparseBlock.Type getSparseBlockType(SparseBlock sblock) {
- return (sblock instanceof SparseBlockMCSR) ?
SparseBlock.Type.MCSR :
- (sblock instanceof SparseBlockCSR) ?
SparseBlock.Type.CSR :
- (sblock instanceof SparseBlockCOO) ?
SparseBlock.Type.COO : null;
+ return (sblock.getSparseBlockType() == type);
}
public static long estimateSizeSparseInMemory(SparseBlock.Type type,
long nrows, long ncols, double sparsity) {
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java
index fd0b3906bc..a1eecff371 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSC.java
@@ -69,15 +69,15 @@ public class SparseBlockMCSC extends SparseBlock {
else if(sblock instanceof SparseBlockMCSR) {
SparseRow[] originalRows = ((SparseBlockMCSR)
sblock).getRows();
Map<Integer, Integer> columnSizes = new HashMap<>();
- if(_clenInferred == -1) {
- for(SparseRow row : originalRows) {
- if(row != null && !row.isEmpty()) {
- for(int i = 0; i < row.size();
i++) {
- int rowIndex =
row.indexes()[i];
-
columnSizes.put(rowIndex, columnSizes.getOrDefault(rowIndex, 0) + 1);
- }
+ for(SparseRow row : originalRows) {
+ if(row != null && !row.isEmpty()) {
+ for(int i = 0; i < row.size(); i++) {
+ int rowIndex = row.indexes()[i];
+ columnSizes.put(rowIndex,
columnSizes.getOrDefault(rowIndex, 0) + 1);
}
}
+ }
+ if(_clenInferred == -1) {
clen =
columnSizes.keySet().stream().max(Integer::compare).orElseThrow(NoSuchElementException::new);
_columns = new SparseRow[clen + 1];
}
@@ -113,7 +113,14 @@ public class SparseBlockMCSC extends SparseBlock {
}
rowPosition++;
}
-
+ }
+ else if(sblock instanceof SparseBlockCSC) {
+ clen = ((SparseBlockCSC) sblock).numCols();
+ _columns = new SparseRow[clen];
+ for(int i = 0; i < clen; i++) {
+ if(!((SparseBlockCSC) sblock).isEmptyCol(i))
+ _columns[i] = ((SparseBlockCSC)
sblock).getCol(i);
+ }
}
// general case SparseBlock
else {
@@ -259,7 +266,7 @@ public class SparseBlockMCSC extends SparseBlock {
}
public void allocateCol(int c, int nnz) {
- if(!isAllocated(c)) {
+ if(!isAllocatedCol(c)) {
_columns[c] = (nnz == 1) ? new SparseRowScalar() : new
SparseRowVector(nnz);
}
}
@@ -270,7 +277,7 @@ public class SparseBlockMCSC extends SparseBlock {
}
public void allocateCol(int c, int ennz, int maxnnz) {
- if(!isAllocated(c)) {
+ if(!isAllocatedCol(c)) {
_columns[c] = (ennz == 1) ? new SparseRowScalar() : new
SparseRowVector(ennz, maxnnz);
}
}
@@ -283,7 +290,7 @@ public class SparseBlockMCSC extends SparseBlock {
}
public void compactCol(int c) {
- if(isAllocated(c)) {
+ if(isAllocatedCol(c)) {
if(_columns[c] instanceof SparseRowVector &&
_columns[c].size() > SparseBlock.INIT_CAPACITY &&
_columns[c].size() * SparseBlock.RESIZE_FACTOR1
< ((SparseRowVector) _columns[c]).capacity()) {
((SparseRowVector) _columns[c]).compact();
@@ -296,6 +303,29 @@ public class SparseBlockMCSC extends SparseBlock {
}
}
+ @Override
+ public void compact() {
+ for(int i = 0; i < numCols(); i++) {
+ if(isAllocatedCol(i)) {
+ if(_columns[i] instanceof SparseRowVector) {
+ _columns[i].compact();
+ if(_columns[i].isEmpty())
+ _columns[i] = null;
+ }
+ else if(_columns[i] instanceof SparseRowScalar)
{
+ SparseRowScalar s = (SparseRowScalar)
_columns[i];
+ if(s.getValue() == 0)
+ _columns[i] = null;
+ }
+ }
+ }
+ }
+
+ @Override
+ public SparseBlock.Type getSparseBlockType() {
+ return Type.MCSC;
+ }
+
@Override
public int numRows() {
return _rlen;
@@ -386,7 +416,7 @@ public class SparseBlockMCSC extends SparseBlock {
public int sizeCol(int c) {
//prior check with isEmpty(r) expected
- return isAllocated(c) ? _columns[c].size() : 0;
+ return isAllocatedCol(c) ? _columns[c].size() : 0;
}
@Override
@@ -404,7 +434,7 @@ public class SparseBlockMCSC extends SparseBlock {
public long sizeCol(int cl, int cu) {
long nnz = 0;
for(int i = cl; i < cu; i++) {
- nnz += isAllocated(i) ? _columns[i].size() : 0;
+ nnz += isAllocatedCol(i) ? _columns[i].size() : 0;
}
return nnz;
}
@@ -449,31 +479,34 @@ public class SparseBlockMCSC extends SparseBlock {
//3. Sorted column indices per row
for(int i = 0; i < clen; i++) {
- if(isEmpty(i))
- continue;
+ if(isEmptyCol(i)) continue;
int apos = pos(i);
- int alen = size(i);
- int[] aix = indexes(i);
- double[] avals = values(i);
- for(int k = apos + 1; k < apos + alen; k++) {
- if(aix[k - 1] >= aix[k] | aix[k - 1] < 0) {
+ int alen = sizeCol(i);
+ int[] aix = indexesCol(i);
+ double[] avals = valuesCol(i);
+
+ int prevRow = -1;
+ for(int k = apos; k < apos + alen; k++) {
+ if(aix[k] < 0)
+ throw new RuntimeException("Invalid
index, at column=" + i + ", pos=" + k);
+ if(aix[k] <= prevRow)
throw new RuntimeException(
"Wrong sparse column ordering,
at column=" + i + ", pos=" + k + " with row indexes " +
- aix[k - 1] + ">=" +
aix[k]);
- }
- if(avals[k] == 0) {
+ prevRow + ">=" +
aix[k]);
+ if(avals[k] == 0)
throw new RuntimeException(
- "The values are expected to be
non zeros " + "but zero at column: " + i + ", row pos: " + k);
- }
+ "The values array should not
contain zeros " + "but zero at column: " + i + ", row pos: " + k);
+ prevRow = aix[k];
}
}
+
//4. A capacity that is no larger than nnz times resize factor
for(int i = 0; i < clen; i++) {
long max_size = (long) Math.max(nnz * RESIZE_FACTOR1,
INIT_CAPACITY);
- if(!isEmpty(i) && values(i).length > max_size) {
+ if(!isEmptyCol(i) && valuesCol(i).length > max_size) {
throw new RuntimeException(
"The capacity is larger than nnz times
a resize factor(=2). " + "Actual length = " +
- values(i).length + ", should
not exceed " + max_size);
+ valuesCol(i).length + ", should
not exceed " + max_size);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java
index f94b6bf7f4..be7e6638db 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java
@@ -199,6 +199,27 @@ public class SparseBlockMCSR extends SparseBlock
}
}
}
+
+ @Override
+ public void compact() {
+ for(int i = 0; i < numRows(); i++) {
+ if(isAllocated(i)) {
+ if(_rows[i] instanceof SparseRowVector) {
+ _rows[i].compact();
+ if(_rows[i].isEmpty()) _rows[i] = null;
+ }
+ else if(_rows[i] instanceof SparseRowScalar) {
+ SparseRowScalar s = (SparseRowScalar)
_rows[i];
+ if(s.getValue() == 0) _rows[i] = null;
+ }
+ }
+ }
+ }
+
+ @Override
+ public SparseBlock.Type getSparseBlockType() {
+ return Type.MCSR;
+ }
@Override
public int numRows() {
@@ -238,13 +259,20 @@ public class SparseBlockMCSR extends SparseBlock
int alen = size(i);
int[] aix = indexes(i);
double[] avals = values(i);
- for (int k = apos + 1; k < apos + alen; k++) {
- if (aix[k-1] >= aix[k] | aix[k-1] < 0 )
- throw new RuntimeException("Wrong
sparse row ordering, at row="+i+", pos="+k
- + " with column indexes " +
aix[k-1] + ">=" + aix[k]);
- if (avals[k] == 0)
- throw new RuntimeException("The values
are expected to be non zeros "
- + "but zero at row: "+ i + ",
col pos: " + k);
+
+ int prevCol = -1;
+ for (int k = apos; k < apos + alen; k++) {
+ if(aix[k] < 0)
+ throw new RuntimeException(
+ "Invalid index, at row=" + i +
", pos=" + k);
+ if(aix[k] <= prevCol)
+ throw new RuntimeException(
+ "Wrong sparse row ordering, at
row=" + i + ", pos=" + k + " with column indexes " +
+ prevCol + ">=" +
aix[k]);
+ if(avals[k] == 0)
+ throw new RuntimeException(
+ "The values array should not
contain zeros " + "but zero at row: " + i + ", column pos: " + k);
+ prevCol = aix[k];
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java
b/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java
index 50229e15df..e59bf2402e 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java
@@ -190,6 +190,10 @@ public final class SparseRowVector extends SparseRow {
estimatedNzs = estnnz;
}
+ public int getEstimatedNzs(){
+ return estimatedNzs;
+ }
+
private void recap(int newCap) {
if( newCap<=values.length )
return;
@@ -314,7 +318,7 @@ public final class SparseRowVector extends SparseRow {
//search lt col index (see binary search)
index = Math.abs( index+1 );
- return (index-1 < size) ? index-1 : -1;
+ return (index-1 >= 0) ? index-1 : -1;
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index f19fe075c9..b1c06cdd51 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -2935,7 +2935,7 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
//in-memory size of dense/sparse representation
return !sparse ? estimateSizeDenseInMemory(rlen, clen) :
estimateSizeSparseInMemory(rlen, clen, getSparsity(),
- SparseBlockFactory.getSparseBlockType(sparseBlock));
+ sparseBlock.getSparseBlockType());
}
@Override
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java
index 3c2ed30adc..006fac8f03 100644
---
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java
+++
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockAlignment.java
@@ -270,6 +270,12 @@ public class SparseBlockAlignment extends AutomatedTestBase
Assert.fail("Wrong row alignment indicated:
"+rowsAligned37+", expected: "+positive);
if( !rowsAlignedRest )
Assert.fail("Wrong row alignment rest
indicated: false.");
+
+ //init third sparse block with different number of rows
+ SparseBlock sblock3
=SparseBlockFactory.createSparseBlock(btype, rows+1);
+ if (sblock.isAligned(sblock3)) {
+ Assert.fail("Wrong alignment different rows
indicated: true.");
+ }
}
catch(Exception ex) {
ex.printStackTrace();
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java
new file mode 100644
index 0000000000..d07729316f
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCheckValidityTest.java
@@ -0,0 +1,544 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.sparse;
+
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.SparseBlockCOO;
+import org.apache.sysds.runtime.data.SparseBlockCSC;
+import org.apache.sysds.runtime.data.SparseBlockCSR;
+import org.apache.sysds.runtime.data.SparseBlockDCSR;
+import org.apache.sysds.runtime.data.SparseBlockFactory;
+import org.apache.sysds.runtime.data.SparseBlockMCSC;
+import org.apache.sysds.runtime.data.SparseBlockMCSR;
+
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.lang.reflect.Field;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+
+public class SparseBlockCheckValidityTest extends AutomatedTestBase
+{
+ private final static int _rows = 123;
+ private final static int _cols = 97;
+ private final static double _sparsity = 0.22;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ }
+
+ @Test
+ public void testSparseBlockCOOValid() {
+ runSparseBlockValidTest(SparseBlock.Type.COO);
+ }
+
+ @Test
+ public void testSparseBlockCSCValid() {
+ runSparseBlockValidTest(SparseBlock.Type.CSC);
+ }
+
+ @Test
+ public void testSparseBlockCSRValid() {
+ runSparseBlockValidTest(SparseBlock.Type.CSR);
+ }
+
+ @Test
+ public void testSparseBlockDCSRValid() {
+ runSparseBlockValidTest(SparseBlock.Type.DCSR);
+ }
+
+ @Test
+ public void testSparseBlockMCSCValid() {
+ runSparseBlockValidTest(SparseBlock.Type.MCSC);
+ }
+
+ @Test
+ public void testSparseBlockMCSRValid() {
+ runSparseBlockValidTest(SparseBlock.Type.MCSR);
+ }
+
+ @Test
+ public void testSparseBlockCOOInvalidDimensions() {
+ runSparseBlockInvalidDimensionsTest(new SparseBlockCOO(-1, 0));
+ }
+
+ @Test
+ public void testSparseBlockCSCInvalidDimensions() {
+ runSparseBlockInvalidDimensionsTest(new SparseBlockCSC(-1, 0));
+ }
+
+ @Test
+ public void testSparseBlockCSRInvalidDimensions() {
+ runSparseBlockInvalidDimensionsTest(new SparseBlockCSR(-1, 0));
+ }
+
+ @Test
+ public void testSparseBlockDCSRInvalidDimensions() {
+ runSparseBlockInvalidDimensionsTest(new SparseBlockDCSR(0, 0));
+ }
+
+ @Test
+ public void testSparseBlockMCSCInvalidDimensions() {
+ runSparseBlockInvalidDimensionsTest(new SparseBlockMCSC(-1, 0));
+ }
+
+ @Test
+ public void testSparseBlockMCSRInvalidDimensions() {
+ runSparseBlockInvalidDimensionsTest(new SparseBlockMCSR(0, -1));
+ }
+
+ @Test
+ public void testSparseBlockCOOIncorrectArrayLengths() {
+ SparseBlockCOO sblock = new
SparseBlockCOO(getFixedSparseBlock());
+
+ int size = (int) sblock.size();
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, size+2, false));
+ assertEquals("Incorrect array lengths.", ex.getMessage());
+
+ checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock,
"_cindexes", new int[size-1]);
+ checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock,
"_rindexes", new int[size-1]);
+ checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock,
"_values", new double[size-1]);
+ }
+
+ @Test
+ public void testSparseBlockCSCIncorrectArrayLengths() {
+ SparseBlockCSC sblock = new
SparseBlockCSC(getFixedSparseBlock());
+
+ int size = (int) sblock.size();
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, size+2, false));
+ assertEquals("Incorrect array lengths.", ex.getMessage());
+
+ int clen = 4;
+ checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock,
"_ptr", new int[clen]); // should be clen+1
+ checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock,
"_values", new double[size-1]);
+ checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock,
"_indexes", new int[size-1]);
+ }
+
+ @Test
+ public void testSparseBlockCSRIncorrectArrayLengths() {
+ SparseBlockCSR sblock = new
SparseBlockCSR(getFixedSparseBlock());
+
+ int size = (int) sblock.size();
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, size+2, false));
+ assertEquals("Incorrect array lengths.", ex.getMessage());
+
+ int rlen = sblock.numRows();
+ checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock,
"_ptr", new int[rlen]); // should be rlen+1
+ checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock,
"_values", new double[size-1]);
+ checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock,
"_indexes", new int[size-1]);
+ }
+
+ @Test
+ public void testSparseBlockDCSRIncorrectArrayLengths() {
+ SparseBlockDCSR sblock = new
SparseBlockDCSR(getFixedSparseBlock());
+
+ int size = (int) sblock.size();
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, size+2, false));
+ assertEquals("Incorrect array lengths.", ex.getMessage());
+
+ int rows = sblock.numRows();
+ checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock,
"_rowptr", new int[rows]); // should be rows+1
+ checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock,
"_colidx", new int[size-1]);
+ checkValidityFailsWhenArrayLengthIsTemporarilyModified(sblock,
"_values", new double[size-1]);
+ }
+
+ @Test
+ public void testSparseBlockMCSCIncorrectArrayLengths() {
+ SparseBlockMCSC sblock = new
SparseBlockMCSC(getFixedSparseBlock());
+
+ int size = (int) sblock.size();
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, size+2, false));
+ assertTrue(ex.getMessage().startsWith("Incorrect size"));
+ }
+
+ @Test
+ public void testSparseBlockMCSRIncorrectArrayLengths() {
+ SparseBlockMCSR sblock = new
SparseBlockMCSR(getFixedSparseBlock());
+
+ int size = (int) sblock.size();
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, size+2, false));
+ assertTrue(ex.getMessage().startsWith("Incorrect size"));
+ }
+
+ @Test
+ public void testSparseBlockCOOUnsortedRowIndices() {
+ SparseBlockCOO sblock = new
SparseBlockCOO(getFixedSparseBlock());
+ int[] r = new int[]{0, 2, 1, 2, 3, 3}; // unsorted
+ setField(sblock, "_rindexes", r);
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, 6, false));
+ assertEquals("Wrong sorted order of row indices",
ex.getMessage());
+ }
+
+ @Test
+ public void testSparseBlockCSCDecreasingColPointers() {
+ SparseBlockCSC sblock = new
SparseBlockCSC(getFixedSparseBlock());
+ int[] ptr = new int[]{0, 2, 1, 4, 6}; // unsorted
+ setField(sblock, "_ptr", ptr);
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, 6, true));
+ assertTrue(ex.getMessage().startsWith("Column pointers are
decreasing at column"));
+ }
+
+ @Test
+ public void testSparseBlockCSRDecreasingRowPointers() {
+ SparseBlockCSR sblock = new
SparseBlockCSR(getFixedSparseBlock());
+ int[] ptr = new int[]{0, 2, 1, 4, 6}; // unsorted
+ setField(sblock, "_ptr", ptr);
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, 6, true));
+ assertTrue(ex.getMessage().startsWith("Row pointers are
decreasing at row"));
+ }
+
+ @Test
+ public void testSparseBlockDCSRDecreasingRowIndices() {
+ SparseBlockDCSR sblock = new
SparseBlockDCSR(getFixedSparseBlock());
+ int[] rowIdxs = new int[]{0, 2, 1, 3}; // unsorted
+ setField(sblock, "_rowidx", rowIdxs);
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, 6, false));
+ assertTrue(ex.getMessage().startsWith("Row indices are
decreasing at row"));
+ }
+
+ @Test
+ public void testSparseBlockDCSRDecreasingRowPointers() {
+ SparseBlockDCSR sblock = new
SparseBlockDCSR(getFixedSparseBlock());
+ int[] rowPtr = new int[]{0, 1, 2, 6, 4}; // unsorted
+ setField(sblock, "_rowptr", rowPtr);
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, 6, false));
+ assertTrue(ex.getMessage().startsWith("Row pointers are
decreasing at row"));
+ }
+
+ @Test
+ public void testSparseBlockCOOUnsortedColumnIndicesWithinRow() {
+ SparseBlockCOO sblock = new
SparseBlockCOO(getFixedSparseBlock());
+ int[] c = new int[]{0, 1, 3, 4, 4, 3}; // unsorted for last row
+ setField(sblock, "_cindexes", c);
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, 6, false));
+ assertTrue(ex.getMessage().startsWith("Wrong sparse row
ordering"));
+ }
+
+ @Test
+ public void testSparseBlockCSCUnsortedRowIndicesWithinColumn() {
+ SparseBlockCSC sblock = new
SparseBlockCSC(getFixedSparseBlock());
+ int[] idxs = new int[]{0, 1, 2, 3, 3, 2}; // unsorted for last
col
+ setField(sblock, "_indexes", idxs);
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, 6, false));
+ assertTrue(ex.getMessage().startsWith("Wrong sparse column
ordering"));
+ }
+
+ @Test
+ public void testSparseBlockCSRUnsortedColumnIndicesWithinRow() {
+ SparseBlockCSR sblock = new
SparseBlockCSR(getFixedSparseBlock());
+ int[] idxs = new int[]{0, 1, 2, 3, 3, 2}; // unsorted for last
row
+ setField(sblock, "_indexes", idxs);
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, 6, false));
+ assertTrue(ex.getMessage().startsWith("Wrong sparse row
ordering"));
+ }
+
+ @Test
+ public void testSparseBlockDCSRUnsortedColumnIndicesWithinRow() {
+ SparseBlockDCSR sblock = new
SparseBlockDCSR(getFixedSparseBlock());
+ int[] colIdxs = new int[]{0, 1, 2, 3, 3, 2}; // unsorted for
last row
+ setField(sblock, "_colidx", colIdxs);
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, 6, false));
+ assertTrue(ex.getMessage().startsWith("Wrong sparse row
ordering"));
+ }
+
+ @Test
+ public void testSparseBlockMCSCUnsortedRowIndicesWithinColumn() {
+ SparseBlockMCSC sblock = new
SparseBlockMCSC(getFixedSparseBlock());
+ int[] indexes = new int[]{3, 2}; // unsorted
+ setField(sblock.getCols()[3], "indexes", indexes);
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, 6, false));
+ assertTrue(ex.getMessage().startsWith("Wrong sparse column
ordering"));
+ }
+
+ @Test
+ public void testSparseBlockMCSRUnsortedColumnIndicesWithinRow() {
+ SparseBlockMCSR sblock = new
SparseBlockMCSR(getFixedSparseBlock());
+ int[] indexes = new int[]{3, 2}; // unsorted
+ setField(sblock.getRows()[3], "indexes", indexes);
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, 6, false));
+ assertTrue(ex.getMessage().startsWith("Wrong sparse row
ordering"));
+ }
+
+ @Test
+ public void testSparseBlockMCSCInvalidIndices() {
+ SparseBlockMCSC sblock = new
SparseBlockMCSC(getFixedSparseBlock());
+ int[] indexes = sblock.getCols()[3].indexes();
+ indexes[0] = -1;
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, 6, false));
+ assertTrue(ex.getMessage().startsWith("Invalid index"));
+ }
+
+ @Test
+ public void testSparseBlockMCSRInvalidIndices() {
+ SparseBlockMCSR sblock = new
SparseBlockMCSR(getFixedSparseBlock());
+ int[] indexes = sblock.getRows()[3].indexes();
+ indexes[0] = -1;
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, 6, false));
+ assertTrue(ex.getMessage().startsWith("Invalid index"));
+ }
+
+ @Test
+ public void testSparseBlockCOOInvalidValue() {
+ runSparseBlockInvalidValueTest(SparseBlock.Type.COO);
+ }
+
+ @Test
+ public void testSparseBlockCSCInvalidValue() {
+ runSparseBlockInvalidValueTest(SparseBlock.Type.CSC);
+ }
+
+ @Test
+ public void testSparseBlockCSRInvalidValue() {
+ runSparseBlockInvalidValueTest(SparseBlock.Type.CSR);
+ }
+
+ @Test
+ public void testSparseBlockDCSRInvalidValue() {
+ runSparseBlockInvalidValueTest(SparseBlock.Type.DCSR);
+ }
+
+ @Test
+ public void testSparseBlockMCSCInvalidValue() {
+ SparseBlockMCSC sblock = new
SparseBlockMCSC(getFixedSparseBlock());
+ double[] values = sblock.valuesCol(3);
+ values[0] = 0;
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, 6, true));
+ assertTrue(ex.getMessage().startsWith("The values array should
not contain zeros"));
+ }
+
+ @Test
+ public void testSparseBlockMCSRInvalidValue() {
+ SparseBlockMCSR sblock = new
SparseBlockMCSR(getFixedSparseBlock());
+ double[] values = sblock.values(3);
+ values[0] = 0;
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, 6, true));
+ assertTrue(ex.getMessage().startsWith("The values array should
not contain zeros"));
+ }
+
+ @Test
+ public void testSparseBlockCOOInvalidRIndex() {
+ runSparseBlockInvalidIndexTest(SparseBlock.Type.COO,
"_rindexes");
+ }
+
+ @Test
+ public void testSparseBlockCOOInvalidCIndex() {
+ runSparseBlockInvalidIndexTest(SparseBlock.Type.COO,
"_cindexes");
+ }
+
+
+ @Test
+ public void testSparseBlockCSCInvalidIndex() {
+ runSparseBlockInvalidIndexTest(SparseBlock.Type.CSC,
"_indexes");
+ }
+
+ @Test
+ public void testSparseBlockCSRInvalidIndex() {
+ runSparseBlockInvalidIndexTest(SparseBlock.Type.CSR,
"_indexes");
+ }
+
+ @Test
+ public void testSparseBlockDCSRInvalidIndex() {
+ runSparseBlockInvalidIndexTest(SparseBlock.Type.DCSR,
"_colidx");
+ }
+
+ @Test
+ public void testSparseBlockCOOCapacityExceedsAllowedLimit() {
+ SparseBlockCOO sblock = new SparseBlockCOO(3, 50);
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(3, 3, 0, false));
+ assertTrue(ex.getMessage().startsWith("Capacity is larger than
the nnz times a resize factor"));
+ }
+
+ @Test
+ public void testSparseBlockCSCCapacityExceedsAllowedLimit() {
+ SparseBlockCSC sblock = new SparseBlockCSC(3, 3, 50);
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(3, 3, 0, false));
+ assertTrue(ex.getMessage().startsWith("Capacity is larger than
the nnz times a resize factor"));
+ }
+
+ @Test
+ public void testSparseBlockCSRCapacityExceedsAllowedLimit() {
+ SparseBlockCSR sblock = new SparseBlockCSR(3, 50, 0);
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(3, 3, 0, false));
+ assertTrue(ex.getMessage().startsWith("Capacity is larger than
the nnz times a resize factor"));
+ }
+
+ @Test
+ public void testSparseBlockDCSRCapacityExceedsAllowedLimit() {
+ SparseBlockDCSR sblock = new SparseBlockDCSR(3, 50);
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(3, 3, 0, false));
+ assertTrue(ex.getMessage().startsWith("Capacity is larger than
the nnz times a resize factor"));
+ }
+
+ @Test
+ public void testSparseBlockMCSCCapacityExceedsAllowedLimit() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 13);
+
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock srtmp = mbtmp.getSparseBlock();
+ SparseBlockMCSC sblock = new SparseBlockMCSC(srtmp);
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(_rows, _cols, 2, true));
+ assertTrue(ex.getMessage().startsWith("The capacity is larger
than nnz times a resize factor"));
+ }
+
+ @Test
+ public void testSparseBlockMCSRCapacityExceedsAllowedLimit() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 13);
+
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock srtmp = mbtmp.getSparseBlock();
+ SparseBlockMCSR sblock = new SparseBlockMCSR(srtmp);
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(_rows, _cols, 2, true));
+ assertTrue(ex.getMessage().startsWith("The capacity is larger
than nnz times a resize factor"));
+ }
+
+ private void runSparseBlockValidTest(SparseBlock.Type btype) {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 13);
+
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock srtmp = mbtmp.getSparseBlock();
+ SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype,
srtmp, true);
+
+ assertTrue("should pass checkValidity",
sblock.checkValidity(_rows, _cols, sblock.size(), true));
+ }
+
+ private void runSparseBlockInvalidDimensionsTest(SparseBlock sblock) {
+ RuntimeException ex1 = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(-1, 1, 0, false));
+ assertTrue(ex1.getMessage().startsWith("Invalid block
dimensions"));
+
+ RuntimeException ex2 = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(1, -1, 0, false));
+ assertTrue(ex2.getMessage().startsWith("Invalid block
dimensions"));
+ }
+
+ private void runSparseBlockInvalidIndexTest(SparseBlock.Type btype,
String indexName) {
+ SparseBlock srtmp = getFixedSparseBlock();
+ SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype,
srtmp, true);
+
+ int[] indexes = (int[]) getField(sblock, indexName);
+ indexes[0] = -1;
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, 6, true));
+ assertTrue(ex.getMessage().startsWith("Invalid index at pos"));
+ }
+
+ private void runSparseBlockInvalidValueTest(SparseBlock.Type btype) {
+ SparseBlock srtmp = getFixedSparseBlock();
+ SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype,
srtmp, true);
+
+ double[] values = (double[]) getField(sblock, "_values");
+ values[0] = 0;
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, 6, true));
+ assertTrue(ex.getMessage().startsWith("The values array should
not contain zeros"));
+ }
+
+ private void
checkValidityFailsWhenArrayLengthIsTemporarilyModified(SparseBlock sblock,
String name, Object value){
+ Object old = getField(sblock, name);
+ setField(sblock, name, value);
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(4, 4, 6, false));
+ assertEquals("Incorrect array lengths.", ex.getMessage());
+ setField(sblock, name, old);
+ }
+
+ private SparseBlock getFixedSparseBlock(){
+ double[][] A = new double[][] {{1, 0, 0, 0}, {0, 1, 0, 0}, {0,
0, 1, 1}, {0, 0, 1, 1}};
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ return mbtmp.getSparseBlock();
+ }
+
+ private static void setField(Object obj, String name, Object value) {
+ try {
+ Field f = obj.getClass().getDeclaredField(name);
+ f.setAccessible(true);
+ f.set(obj, value);
+ } catch (Exception ex) {
+ throw new RuntimeException("Reflection failed: " +
ex.getMessage());
+ }
+ }
+
+ private static Object getField(Object obj, String name) {
+ try {
+ Field f = obj.getClass().getDeclaredField(name);
+ f.setAccessible(true);
+ return f.get(obj);
+ } catch (Exception ex) {
+ throw new RuntimeException("Reflection failed: " +
ex.getMessage());
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockColTest.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockColTest.java
new file mode 100644
index 0000000000..19a313e025
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockColTest.java
@@ -0,0 +1,267 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.sparse;
+
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.SparseBlockCSC;
+import org.apache.sysds.runtime.data.SparseBlockMCSC;
+import org.apache.sysds.runtime.data.SparseRow;
+import org.apache.sysds.runtime.data.SparseRowScalar;
+import org.apache.sysds.runtime.data.SparseRowVector;
+
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Arrays;
+
+public class SparseBlockColTest extends AutomatedTestBase
+{
+ private final static int _rows = 324;
+ private final static int _cols = 132;
+ private final static double _sparsity = 0.3;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ }
+
+ @Test
+ public void testSparseBlockCSCGetReset() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlockColWrapper b = wrap(new
SparseBlockCSC(mbtmp.getSparseBlock()));
+ runSparseBlockGetResetTest(b, SparseBlock.Type.CSC);
+ }
+
+ @Test
+ public void testSparseBlockMCSCGetReset() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlockColWrapper b = wrap(new
SparseBlockMCSC(mbtmp.getSparseBlock()));
+ runSparseBlockGetResetTest(b, SparseBlock.Type.MCSC);
+ }
+
+ @Test
+ public void testSparseBlockCSCSetSort() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlockColWrapper b = wrap(new
SparseBlockCSC(mbtmp.getSparseBlock()));
+ SparseRow[] cols = (new
SparseBlockMCSC(mbtmp.getSparseBlock())).getCols();
+ runSparseBlockSetSortTest(b, cols);
+ }
+
+ @Test
+ public void testSparseBlockMCSCSetSort() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlockColWrapper b = wrap(new
SparseBlockMCSC(mbtmp.getSparseBlock()));
+ SparseRow[] cols = (new
SparseBlockMCSC(mbtmp.getSparseBlock())).getCols();
+ runSparseBlockSetSortTest(b, cols);
+ }
+
+ @Test
+ public void testSparseBlockCSCSetDelIdxRange() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlockColWrapper b = wrap(new
SparseBlockCSC(mbtmp.getSparseBlock()));
+ SparseRow[] cols = (new
SparseBlockMCSC(mbtmp.getSparseBlock())).getCols();
+ runSparseBlockSetDelIdxRangeTest(b, cols);
+ }
+
+ @Test
+ public void testSparseBlockMCSCSetDelIdxRange() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlockColWrapper b = wrap(new
SparseBlockMCSC(mbtmp.getSparseBlock()));
+ SparseRow[] cols = (new
SparseBlockMCSC(mbtmp.getSparseBlock())).getCols();
+ runSparseBlockSetDelIdxRangeTest(b, cols);
+ }
+
+ private void runSparseBlockGetResetTest(SparseBlockColWrapper sblock,
SparseBlock.Type btype) {
+ int c = _cols/3;
+ SparseRow col = sblock.getCol(c);
+ int size = sblock.sizeCol(c);
+ Assert.assertEquals(col.size(), size);
+
+ sblock.resetCol(c);
+ col = sblock.getCol(c);
+ size = sblock.sizeCol(c);
+ Assert.assertEquals(0, size);
+ Assert.assertTrue(col.isEmpty());
+ if(btype == SparseBlock.Type.CSC) Assert.assertTrue(col
instanceof SparseRowScalar);
+
+ // nothing changes
+ SparseBlockColWrapper sblock2 = sblock.copy();
+ sblock.resetCol(c);
+ SparseRow col2 = sblock.getCol(c);
+ Assert.assertArrayEquals(col.indexes(), col2.indexes());
+ Assert.assertArrayEquals(col.values(), col2.values(), 0);
+ Assert.assertEquals(sblock.getObject(), sblock2.getObject());
+ }
+
+ private void runSparseBlockSetSortTest(SparseBlockColWrapper sblock,
SparseRow[] cols) {
+ int c = _cols/3;
+ SparseRow col = cols[c];
+ double[] values = col.values().clone();
+ int[] indexes = col.indexes().clone();
+ int size = col.size();
+
+ // reverse
+ for (int i = 0; i < size/2; i++) {
+ double t = values[i];
+ values[i] = values[size-1-i];
+ values[size-1-i] = t;
+ int t2 = indexes[i];
+ indexes[i] = indexes[size-1-i];
+ indexes[size-1-i] = t2;
+ }
+ Assert.assertFalse(Arrays.equals(col.values(), values));
+ Assert.assertFalse(Arrays.equals(col.indexes(), indexes));
+
+ SparseRow col2 = new SparseRowVector(values, indexes);
+ sblock.resetCol(c);
+ sblock.setCol(c, col2, true);
+ Assert.assertArrayEquals(col2.indexes(),
sblock.getCol(c).indexes());
+ Assert.assertArrayEquals(col2.values(),
sblock.getCol(c).values(), 0);
+
+ int nnz = (int) ((SparseBlock) sblock.getObject()).size();
+ int rlen = ((SparseBlock) sblock.getObject()).numRows();
+ int clen = cols.length;
+ RuntimeException ex =
Assert.assertThrows(RuntimeException.class,
+ () -> ((SparseBlock)
sblock.getObject()).checkValidity(rlen, clen, nnz, true));
+ Assert.assertTrue(ex.getMessage().startsWith("Wrong sparse
column ordering"));
+
+ sblock.sortCol(c);
+
Assert.assertTrue(((SparseBlock)sblock.getObject()).checkValidity(rlen, clen,
nnz, true));
+ Assert.assertArrayEquals(col.indexes(),
sblock.getCol(c).indexes());
+ Assert.assertArrayEquals(col.values(),
sblock.getCol(c).values(), 0);
+ }
+
+ private void runSparseBlockSetDelIdxRangeTest(SparseBlockColWrapper
sblock, SparseRow[] cols) {
+ int c = _cols/3;
+ int rl = _rows/4;
+ int ru = _rows/2;
+
+ SparseRow[] cols2 = Arrays.copyOf(cols, cols.length);
+ double[] v = getRandomMatrix(1, _rows, -10, 10, 1, 1234)[0];
+ for(int i=0; i<rl; i++) v[i] = cols[c].get(i);
+ cols2[c] = new SparseRowVector(v);
+ SparseBlock sblock2 = new SparseBlockMCSC(cols2, false, _rows);
+
+ sblock.setIndexRangeCol(c, rl, _rows, v, rl, _rows-rl);
+ Assert.assertEquals(sblock2, sblock.getObject());
+
+ sblock.deleteIndexRangeCol(c, rl, ru);
+ for(int i=rl; i<ru; i++) cols2[c].set(i, 0);
+ Assert.assertEquals(sblock2, sblock.getObject());
+
+ sblock.deleteIndexRangeCol(c, rl, _rows+1);
+ for(int i=ru; i<_rows; i++) cols2[c].set(i, 0);
+ Assert.assertEquals(sblock2, sblock.getObject());
+ }
+
+ private interface SparseBlockColWrapper {
+ SparseRow getCol(int c);
+ void setCol(int c, SparseRow col, boolean deep);
+ void setIndexRangeCol(int c, int rl, int ru, double[] v, int
vix, int vlen);
+ void deleteIndexRangeCol(int c, int rl, int ru);
+ int sizeCol(int c);
+ void sortCol(int c);
+ void resetCol(int c);
+ SparseBlockColWrapper copy();
+ Object getObject();
+ }
+
+ private SparseBlockColWrapper wrap(SparseBlockCSC b) {
+ return new SparseBlockColWrapper() {
+ @Override
+ public SparseRow getCol(int c) { return b.getCol(c); }
+
+ @Override
+ public void setCol(int c, SparseRow col, boolean deep) {
+ b.setCol(c, col, deep); }
+
+ @Override
+ public void setIndexRangeCol(int c, int rl, int ru,
double[] v, int vix, int vlen){
+ b.setIndexRangeCol(c, rl, ru, v, vix, vlen);
+ }
+
+ @Override
+ public void deleteIndexRangeCol(int c, int rl, int ru){
+ b.deleteIndexRangeCol(c, rl, ru);
+ }
+
+ @Override
+ public int sizeCol(int c) { return b.sizeCol(c); }
+
+ @Override
+ public void sortCol(int c) { b.sortCol(c); }
+
+ @Override
+ public void resetCol(int c) { b.resetCol(c); }
+
+ @Override
+ public SparseBlockColWrapper copy() { return wrap(new
SparseBlockCSC(b)); }
+
+ @Override
+ public Object getObject() { return b; }
+ };
+ }
+
+ private SparseBlockColWrapper wrap(SparseBlockMCSC b) {
+ return new SparseBlockColWrapper() {
+ @Override
+ public SparseRow getCol(int c) { return b.getCol(c); }
+
+ @Override
+ public void setCol(int c, SparseRow col, boolean deep) {
+ b.setCol(c, col, deep); }
+
+ @Override
+ public void setIndexRangeCol(int c, int rl, int ru,
double[] v, int vix, int vlen){
+ b.setIndexRangeCol(c, rl, ru, v, vix, vlen);
+ }
+
+ @Override
+ public void deleteIndexRangeCol(int c, int rl, int ru){
+ b.deleteIndexRangeCol(c, rl, ru);
+ }
+
+ @Override
+ public int sizeCol(int c) { return b.sizeCol(c); }
+
+ @Override
+ public void sortCol(int c) { b.sortCol(c); }
+
+ @Override
+ public void resetCol(int c) { b.resetCol(c, 0, 0); }
+
+ @Override
+ public SparseBlockColWrapper copy() { return wrap(new
SparseBlockMCSC(b)); }
+
+ @Override
+ public Object getObject() { return b; }
+ };
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCompactTest.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCompactTest.java
new file mode 100644
index 0000000000..436d724b0e
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockCompactTest.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.sparse;
+
+import java.lang.reflect.Field;
+
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.SparseBlockFactory;
+import org.apache.sysds.runtime.data.SparseRow;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+
+public class SparseBlockCompactTest extends AutomatedTestBase
+{
+ private final static int _rows = 324;
+ private final static int _cols = 132;
+ private final static double _sparsity = 0.22;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ }
+
+ @Test
+ public void testSparseBlockCompactCOO() {
+ runSparseBlockCompactZerosTest(SparseBlock.Type.COO);
+ }
+
+ @Test
+ public void testSparseBlockCompactCSC() {
+ runSparseBlockCompactZerosTest(SparseBlock.Type.CSC);
+ }
+
+ @Test
+ public void testSparseBlockCompactCSR() {
+ runSparseBlockCompactZerosTest(SparseBlock.Type.CSR);
+ }
+
+ @Test
+ public void testSparseBlockCompactDCSR() {
+ runSparseBlockCompactZerosTest(SparseBlock.Type.DCSR);
+ }
+
+ @Test
+ public void testSparseBlockCompactMCSC() {
+ runSparseBlockModifiedCompactZerosTest(SparseBlock.Type.MCSC,
"_columns");
+ }
+
+ @Test
+ public void testSparseBlockCompactMCSR() {
+ runSparseBlockModifiedCompactZerosTest(SparseBlock.Type.MCSR,
"_rows");
+ }
+
+ private void runSparseBlockCompactZerosTest(SparseBlock.Type btype) {
+
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 13);
+
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock srtmp = mbtmp.getSparseBlock();
+ SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype,
srtmp, true);
+
+ double[] values = (double[]) getField(sblock, "_values");
+ values[0] = 0.0;
+ values[values.length-1] = 0.0;
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(_rows, _cols, sblock.size(),
true));
+ assertTrue(ex.getMessage().startsWith("The values array should
not contain zeros"));
+ long size = sblock.size();
+
+ sblock.compact();
+
+ assertTrue("should pass checkValidity",
sblock.checkValidity(_rows, _cols, sblock.size(), true));
+ assertEquals(size-2, sblock.size());
+ }
+
+ private void runSparseBlockModifiedCompactZerosTest(SparseBlock.Type
btype, String field) {
+
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 13);
+
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock srtmp = mbtmp.getSparseBlock();
+ SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype,
srtmp, true);
+
+ SparseRow[] sr = (SparseRow[]) getField(sblock, field);
+ double[] values = sr[0].values();
+ values[0] = 0.0;
+ values[values.length-1] = 0.0;
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(_rows, _cols, sblock.size(),
true));
+ assertTrue(ex.getMessage().startsWith("The values array should
not contain zeros"));
+ long size = sblock.size();
+
+ sblock.compact();
+
+ assertTrue("should pass checkValidity",
sblock.checkValidity(_rows, _cols, sblock.size(), true));
+ assertEquals(size-2, sblock.size());
+ }
+
+ private static Object getField(Object obj, String name) {
+ try {
+ Field f = obj.getClass().getDeclaredField(name);
+ f.setAccessible(true);
+ return f.get(obj);
+ } catch (Exception ex) {
+ throw new RuntimeException("Reflection failed: " +
ex.getMessage());
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockContainsTest.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockContainsTest.java
new file mode 100644
index 0000000000..2df94b4336
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockContainsTest.java
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.sparse;
+
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.SparseBlockFactory;
+import org.apache.sysds.runtime.data.SparseBlockMCSC;
+import org.apache.sysds.runtime.data.SparseBlockMCSR;
+import org.apache.sysds.runtime.data.SparseRow;
+import org.apache.sysds.runtime.data.SparseRowVector;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+
+public class SparseBlockContainsTest extends AutomatedTestBase
+{
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ }
+
+ @Test
+ public void testSparseBlockContainsNoMatchCOO() {
+ runSparseBlockContainsNoMatchTest(SparseBlock.Type.COO);
+ }
+
+ @Test
+ public void testSparseBlockContainsNoMatchCSC() {
+ runSparseBlockContainsNoMatchTest(SparseBlock.Type.CSC);
+ }
+
+ @Test
+ public void testSparseBlockContainsNoMatchCSR() {
+ runSparseBlockContainsNoMatchTest(SparseBlock.Type.CSR);
+ }
+
+ @Test
+ public void testSparseBlockContainsNoMatchDCSR() {
+ runSparseBlockContainsNoMatchTest(SparseBlock.Type.DCSR);
+ }
+
+ @Test
+ public void testSparseBlockContainsNoMatchMCSC() {
+ runSparseBlockContainsNoMatchTest(SparseBlock.Type.MCSC);
+ }
+
+ @Test
+ public void testSparseBlockContainsNoMatchMCSR() {
+ runSparseBlockContainsNoMatchTest(SparseBlock.Type.MCSR);
+ }
+
+ @Test
+ public void testSparseBlockContainsNaNCOO() {
+ runSparseBlockContainsNaNTest(SparseBlock.Type.COO);
+ }
+
+ @Test
+ public void testSparseBlockContainsNaNCSC() {
+ runSparseBlockContainsNaNTest(SparseBlock.Type.CSC);
+ }
+
+ @Test
+ public void testSparseBlockContainsNaNCSR() {
+ runSparseBlockContainsNaNTest(SparseBlock.Type.CSR);
+ }
+
+ @Test
+ public void testSparseBlockContainsNaNDCSR() {
+ runSparseBlockContainsNaNTest(SparseBlock.Type.DCSR);
+ }
+
+ @Test
+ public void testSparseBlockContainsNaNMCSC() {
+ runSparseBlockContainsNaNTest(SparseBlock.Type.MCSC);
+ }
+
+ @Test
+ public void testSparseBlockContainsNaNMCSR() {
+ runSparseBlockContainsNaNTest(SparseBlock.Type.MCSR);
+ }
+
+ @Test
+ public void testSparseBlockContainsEarlyAbortCOO() {
+ runSparseBlockContainsEarlyAbortTest(SparseBlock.Type.COO);
+ }
+
+ @Test
+ public void testSparseBlockContainsEarlyAbortCSC() {
+ runSparseBlockContainsEarlyAbortTest(SparseBlock.Type.CSC);
+ }
+
+ @Test
+ public void testSparseBlockContainsEarlyAbortCSR() {
+ runSparseBlockContainsEarlyAbortTest(SparseBlock.Type.CSR);
+ }
+
+ @Test
+ public void testSparseBlockContainsEarlyAbortDCSR() {
+ runSparseBlockContainsEarlyAbortTest(SparseBlock.Type.DCSR);
+ }
+
+ @Test
+ public void testSparseBlockContainsEarlyAbortMCSC() {
+ runSparseBlockContainsEarlyAbortTest(SparseBlock.Type.MCSC);
+ }
+
+ @Test
+ public void testSparseBlockContainsEarlyAbortMCSR() {
+ runSparseBlockContainsEarlyAbortTest(SparseBlock.Type.MCSR);
+ }
+
+ @Test
+ public void testSparseBlockContainsPatternLongerThanRowsCOO() {
+
runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type.COO);
+ }
+
+ @Test
+ public void testSparseBlockContainsPatternLongerThanRowsCSC() {
+
runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type.CSC);
+ }
+
+ @Test
+ public void testSparseBlockContainsPatternLongerThanRowsCSR() {
+
runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type.CSR);
+ }
+
+ @Test
+ public void testSparseBlockContainsPatternLongerThanRowsDCSR() {
+
runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type.DCSR);
+ }
+
+ @Test
+ public void testSparseBlockContainsPatternLongerThanRowsMCSC() {
+
runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type.MCSC);
+ }
+
+ @Test
+ public void testSparseBlockContainsPatternLongerThanRowsMCSR() {
+
runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type.MCSR);
+ }
+
+ @Test
+ public void testSparseBlockContainsPatternContainsZeroCOO() {
+
runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type.COO);
+ }
+
+ @Test
+ public void testSparseBlockContainsPatternContainsZeroCSC() {
+
runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type.CSC);
+ }
+
+ @Test
+ public void testSparseBlockContainsPatternContainsZeroCSR() {
+
runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type.CSR);
+ }
+
+ @Test
+ public void testSparseBlockContainsPatternContainsZeroDCSR() {
+
runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type.DCSR);
+ }
+
+ @Test
+ public void testSparseBlockContainsPatternContainsZeroMCSC() {
+
runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type.MCSC);
+ }
+
+ @Test
+ public void testSparseBlockContainsPatternContainsZeroMCSR() {
+
runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type.MCSR);
+ }
+
+ @Test
+ public void testSparseBlockNonCompactContainsPatternCOO() {
+
runSparseBlockNonCompactContainsPatternTest(SparseBlock.Type.COO);
+ }
+
+ @Test
+ public void testSparseBlockNonCompactContainsPatternCSC() {
+
runSparseBlockNonCompactContainsPatternTest(SparseBlock.Type.CSC);
+ }
+
+ @Test
+ public void testSparseBlockNonCompactContainsPatternCSR() {
+
runSparseBlockNonCompactContainsPatternTest(SparseBlock.Type.CSR);
+ }
+
+ @Test
+ public void testSparseBlockNonCompactContainsPatternDCSR() {
+
runSparseBlockNonCompactContainsPatternTest(SparseBlock.Type.DCSR);
+ }
+
+ @Test
+ public void testSparseBlockNonCompactContainsPatternMCSC() {
+ double[] pattern = new double[]{0., 1., 2.};
+ SparseRowVector c1 = new SparseRowVector(new double[]{0., 1.,
1., 0., 0., 0.}, new int[]{0, 1, 2, 3, 4, 5});
+ SparseRowVector c2 = new SparseRowVector(new double[]{1., 2.,
2., 0., 1., 0.}, new int[]{0, 1, 2, 3, 4, 5});
+ SparseRowVector c3 = new SparseRowVector(new double[]{2., 0.,
0., 0., 2., 0.}, new int[]{0, 1, 2, 3, 4, 5});
+
+ SparseBlock sblock = new SparseBlockMCSC(new SparseRow[] {c1,
c2, c3}, true, 6);
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(6, 3, sblock.size(), true));
+ assertTrue(ex.getMessage().startsWith("The values array should
not contain zeros"));
+
+ List<Integer> result = sblock.contains(pattern, false);
+ assertEquals(List.of(0, 4), result);
+ }
+
+ @Test
+ public void testSparseBlockNonCompactContainsPatternMCSR() {
+
runSparseBlockNonCompactContainsPatternTest(SparseBlock.Type.MCSR);
+ }
+
+ private void runSparseBlockContainsNoMatchTest(SparseBlock.Type btype) {
+ double[] pattern = new double[]{1., 2., 3.};
+ double[][] A = new double[][]{{4., 5., 6.}, {7., 8., 9.}, {0.,
0., 0.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}};
+
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock srtmp = mbtmp.getSparseBlock();
+ SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype,
srtmp, true);
+
+ List<Integer> result = sblock.contains(pattern, false);
+ assertEquals(List.of(), result);
+ }
+
+ private void runSparseBlockContainsNaNTest(SparseBlock.Type btype) {
+ double[] pattern = new double[]{Double.NaN, 2., 3.};
+ double[][] A = new double[][]{{Double.NaN, 2., 3.}, {1., 2.,
3.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}};
+
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock srtmp = mbtmp.getSparseBlock();
+ SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype,
srtmp, true);
+
+ List<Integer> result = sblock.contains(pattern, false);
+ assertEquals(List.of(0), result);
+ }
+
+ private void runSparseBlockContainsEarlyAbortTest(SparseBlock.Type
btype) {
+ double[] pattern = new double[]{1., 2., 3.};
+ double[][] A = new double[][]{{0., 0., 0.}, {1., 2., 3.}, {1.,
2., 3.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}};
+
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock srtmp = mbtmp.getSparseBlock();
+ SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype,
srtmp, true);
+
+ List<Integer> result = sblock.contains(pattern, true);
+ assertEquals(List.of(1), result);
+ }
+
+ private void
runSparseBlockContainsPatternLongerThanRowsTest(SparseBlock.Type btype) {
+ double[] pattern = new double[]{1., 2., 3., 4.};
+ double[][] A = new double[][]{{0., 0., 0.}, {1., 2., 3.}, {1.,
2., 3.}, {0., 0., 0.}, {0., 0., 0.}, {0., 0., 0.}};
+
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock srtmp = mbtmp.getSparseBlock();
+ SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype,
srtmp, true);
+
+ List<Integer> result = sblock.contains(pattern, false);
+ assertEquals(List.of(), result);
+ }
+
+ private void
runSparseBlockContainsPatternContainsZeroTest(SparseBlock.Type btype) {
+ double[] pattern = new double[]{0., 1., 2.};
+ double[][] A = new double[][]{{0., 1., 2.}, {0., 0., 0.}, {0.,
0., 0.}, {0., 0., 0.}, {0., 1., 2.}, {1., 2., 0.}};
+
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock srtmp = mbtmp.getSparseBlock();
+ SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype,
srtmp, true);
+
+ List<Integer> result = sblock.contains(pattern, false);
+ assertEquals(List.of(0, 4), result);
+ }
+
+ private void
runSparseBlockNonCompactContainsPatternTest(SparseBlock.Type btype) {
+ double[] pattern = new double[]{0., 1., 2.};
+ SparseRowVector match = new SparseRowVector(new double[]{0.,
1., 2.}, new int[]{0, 1, 2});
+ SparseRowVector nonMatch = new SparseRowVector(new double[]{1.,
2., 0.}, new int[]{0, 1, 2});
+ SparseRowVector nonMatch2 = new SparseRowVector(new
double[]{0., 0., 0.}, new int[]{0, 1, 2});
+
+ SparseBlock mcsr = new SparseBlockMCSR(new SparseRow[] {match,
nonMatch, nonMatch, nonMatch2, match, nonMatch2}, true);
+ SparseBlock sblock = SparseBlockFactory.copySparseBlock(btype,
mcsr, true);
+
+ RuntimeException ex = assertThrows(RuntimeException.class,
+ () -> sblock.checkValidity(6, 3, sblock.size(), true));
+ assertTrue(ex.getMessage().startsWith("The values array should
not contain zeros"));
+
+ List<Integer> result = sblock.contains(pattern, false);
+ assertEquals(List.of(0, 4), result);
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockEqualsTest.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockEqualsTest.java
new file mode 100644
index 0000000000..43dbae3004
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockEqualsTest.java
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.sparse;
+
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.SparseBlockFactory;
+import org.apache.sysds.runtime.data.SparseRowVector;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestUtils;
+
+import org.junit.Test;
+import org.junit.experimental.runners.Enclosed;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.ArrayList;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
+
+@RunWith(Enclosed.class)
+public class SparseBlockEqualsTest {
+
+ @RunWith(Parameterized.class)
+ public static class SparseBlockEqualsSparseBlockTest extends
AutomatedTestBase {
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ }
+
+ private final SparseBlock.Type type1;
+ private final SparseBlock.Type type2;
+
+ public SparseBlockEqualsSparseBlockTest(SparseBlock.Type type1,
SparseBlock.Type type2) {
+ this.type1 = type1;
+ this.type2 = type2;
+ }
+
+ @Parameterized.Parameters(name = "{0} vs {1}")
+ public static Iterable<Object[]> types() {
+ SparseBlock.Type[] types = SparseBlock.Type.values();
+ ArrayList<Object[]> params = new ArrayList<>();
+
+ for (int i = 0; i < types.length; i++) {
+ for (int j = i; j < types.length; j++) {
+ params.add(new Object[]{types[i],
types[j]});
+ }
+ }
+
+ return params;
+ }
+
+ @Test
+ public void testSparseBlockEquals() {
+ runSparseBlockEqualsTest(type1, type2);
+ }
+
+ @Test
+ public void testSparseBlockNotEqualsColIdx() {
+ runSparseBlockNotEqualsColIdxTest(type1, type2);
+ }
+
+ @Test
+ public void testSparseBlockNotEqualsEmptyRow() {
+ runSparseBlockNotEqualsEmptyRowTest(type1, type2);
+ }
+ }
+
+ @RunWith(Parameterized.class)
+ public static class SparseBlockEqualsDenseValuesTest extends
AutomatedTestBase {
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ }
+
+ private final SparseBlock.Type type;
+
+ public SparseBlockEqualsDenseValuesTest(SparseBlock.Type type) {
+ this.type = type;
+ }
+
+ @Parameterized.Parameters(name = "{0}")
+ public static Iterable<Object[]> types() {
+ ArrayList<Object[]> params = new ArrayList<>();
+ for (SparseBlock.Type t : SparseBlock.Type.values()) {
+ params.add(new Object[]{t});
+ }
+ return params;
+ }
+
+ @Test
+ public void testSparseBlockNotEqualsNonSparseBlock() {
+ runSparseBlockNotEqualsNonSparseBlockTest(type);
+ }
+
+ @Test
+ public void testSparseBlockNotEqualsDenseValuesEmptyRow() {
+ runSparseBlockNotEqualsDenseValuesEmptyRowTest(type);
+ }
+
+ @Test
+ public void testSparseBlockNotEqualsDenseValuesNonZero() {
+ runSparseBlockNotEqualsDenseValuesNonZeroTest(type);
+ }
+
+ @Test
+ public void
testSparseBlockNotEqualsDenseValuesAdditionalNonZero() {
+
runSparseBlockNotEqualsDenseValuesAdditionalNonZeroTest(type);
+ }
+ }
+
+ private static void runSparseBlockEqualsTest(SparseBlock.Type type1,
SparseBlock.Type type2) {
+ double[][] A = new double[][]{{1., 2., 3.}, {0., 0., 0.}, {0.,
4., 0.}, {0., 0., 5.}, {6., 0., 0.}, {0., 0., 7.}};
+
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock srtmp = mbtmp.getSparseBlock();
+ SparseBlock sblock1 = SparseBlockFactory.copySparseBlock(type1,
srtmp, true);
+ SparseBlock sblock2 = SparseBlockFactory.copySparseBlock(type2,
srtmp, true);
+
+ assertEquals(sblock1, sblock2);
+ }
+
+ private static void runSparseBlockNotEqualsColIdxTest(SparseBlock.Type
type1, SparseBlock.Type type2) {
+ double[][] A = new double[][]{{1., 2., 3.}, {0., 0., 0.}, {0.,
4., 0.}, {0., 0., 5.}, {6., 0., 0.}, {0., 0., 7.}};
+ double[][] B = new double[][]{{1., 2., 3.}, {0., 0., 0.}, {0.,
0., 4.}, {0., 0., 5.}, {6., 0., 0.}, {0., 0., 7.}};
+
+ MatrixBlock mbtmp1 = DataConverter.convertToMatrixBlock(A);
+ SparseBlock srtmp1 = mbtmp1.getSparseBlock();
+ SparseBlock sblock1 = SparseBlockFactory.copySparseBlock(type1,
srtmp1, true);
+
+ MatrixBlock mbtmp2 = DataConverter.convertToMatrixBlock(B);
+ SparseBlock srtmp2 = mbtmp2.getSparseBlock();
+ SparseBlock sblock2 = SparseBlockFactory.copySparseBlock(type2,
srtmp2, true);
+
+ assertNotEquals("should not be equal: " + type1 + " " + type2,
sblock1, sblock2);
+ }
+
+ private static void
runSparseBlockNotEqualsEmptyRowTest(SparseBlock.Type type1, SparseBlock.Type
type2) {
+ double[][] A = new double[][]{{1., 2., 3.}, {0., 0., 0.}, {0.,
4., 0.}, {0., 0., 5.}, {6., 0., 0.}, {0., 0., 7.}};
+ double[][] B = new double[][]{{1., 2., 3.}, {0., 4., 0.}, {0.,
0., 0.}, {0., 0., 5.}, {6., 0., 0.}, {0., 0., 7.}};
+
+ MatrixBlock mbtmp1 = DataConverter.convertToMatrixBlock(A);
+ SparseBlock srtmp1 = mbtmp1.getSparseBlock();
+ SparseBlock sblock1 = SparseBlockFactory.copySparseBlock(type1,
srtmp1, true);
+
+ MatrixBlock mbtmp2 = DataConverter.convertToMatrixBlock(B);
+ SparseBlock srtmp2 = mbtmp2.getSparseBlock();
+ SparseBlock sblock2 = SparseBlockFactory.copySparseBlock(type2,
srtmp2, true);
+
+ assertNotEquals("should not be equal: " + type1 + " " + type2,
sblock1, sblock2);
+ }
+
+ private static void
runSparseBlockNotEqualsNonSparseBlockTest(SparseBlock.Type type) {
+ double[][] A = new double[][]{{1., 0., 3.}, {0., 0., 0.}, {0.,
0., 0.}, {0., 0., 0.}};
+
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock srtmp = mbtmp.getSparseBlock();
+ SparseBlock sblock = SparseBlockFactory.copySparseBlock(type,
srtmp, true);
+
+ SparseRowVector srv = new SparseRowVector(A[0], new int[]{0, 1,
2});
+
+ assertNotEquals("should not be equal: " + type, sblock, srv);
+ }
+
+ private static void
runSparseBlockNotEqualsDenseValuesEmptyRowTest(SparseBlock.Type type) {
+ double[][] A = new double[][]{{1., 0., 3.}, {0., 0., 0.}, {0.,
0., 0.}, {4., 0., 6.}};
+
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock srtmp = mbtmp.getSparseBlock();
+ SparseBlock sblock = SparseBlockFactory.copySparseBlock(type,
srtmp, true);
+
+ double[] denseValues = new double[]{1., 0., 3., 0., 0., 0., 1.,
1., 1., 4., 0., 6.};
+
+ assertFalse("should not be equal: " + type,
sblock.equals(denseValues, 3, 1e-10));
+ }
+
+ private static void
runSparseBlockNotEqualsDenseValuesNonZeroTest(SparseBlock.Type type) {
+ double[][] A = new double[][]{{1., 0., 3.}, {0., 0., 0.}, {0.,
0., 0.}, {0., 0., 1.}, {4., 0., 6.}};
+
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock srtmp = mbtmp.getSparseBlock();
+ SparseBlock sblock = SparseBlockFactory.copySparseBlock(type,
srtmp, true);
+
+ double[] denseValues = new double[]{1., 0., 3., 0., 0., 0., 0.,
0., 0., 0., 1., 1., 4., 0., 6.};
+
+ assertFalse("should not be equal: " + type,
sblock.equals(denseValues, 3, 1e-10));
+ }
+
+ private static void
runSparseBlockNotEqualsDenseValuesAdditionalNonZeroTest(SparseBlock.Type type) {
+ double[][] A = new double[][]{{1., 0., 3.}, {0., 0., 0.}, {0.,
0., 0.}, {4., 0., 0.}};
+
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock srtmp = mbtmp.getSparseBlock();
+ SparseBlock sblock = SparseBlockFactory.copySparseBlock(type,
srtmp, true);
+
+ double[] denseValues = new double[]{1., 0., 3., 0., 0., 0., 0.,
0., 0., 4., 0., 6.};
+
+ assertFalse("should not be equal: " + type,
sblock.equals(denseValues, 3, 1e-10));
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockInitializationTest.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockInitializationTest.java
new file mode 100644
index 0000000000..3280038606
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockInitializationTest.java
@@ -0,0 +1,484 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.sparse;
+
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.SparseBlockCOO;
+import org.apache.sysds.runtime.data.SparseBlockCSC;
+import org.apache.sysds.runtime.data.SparseBlockCSR;
+import org.apache.sysds.runtime.data.SparseBlockDCSR;
+import org.apache.sysds.runtime.data.SparseBlockFactory;
+import org.apache.sysds.runtime.data.SparseBlockMCSC;
+import org.apache.sysds.runtime.data.SparseBlockMCSR;
+import org.apache.sysds.runtime.data.SparseRow;
+import org.apache.sysds.runtime.data.SparseRowVector;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotSame;
+
+
+public class SparseBlockInitializationTest extends AutomatedTestBase
+{
+ private final static int _rows = 324;
+ private final static int _cols = 132;
+ private final static double _sparsity = 0.22;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ }
+
+ @Test
+ public void testSparseBlockCreationCOO() {
+ runSparseBlockCreationTest(SparseBlock.Type.COO);
+ }
+
+ @Test
+ public void testSparseBlockCreationCSC() {
+ runSparseBlockCreationTest(SparseBlock.Type.CSC);
+ }
+
+ @Test
+ public void testSparseBlockCreationCSR() {
+ runSparseBlockCreationTest(SparseBlock.Type.CSR);
+ }
+
+ @Test
+ public void testSparseBlockCreationDCSR() {
+ runSparseBlockCreationTest(SparseBlock.Type.DCSR);
+ }
+
+ @Test
+ public void testSparseBlockCreationMCSC() {
+ runSparseBlockCreationTest(SparseBlock.Type.MCSC);
+ }
+
+ @Test
+ public void testSparseBlockCreationMCSR() {
+ runSparseBlockCreationTest(SparseBlock.Type.MCSR);
+ }
+
+ private void runSparseBlockCreationTest(SparseBlock.Type type) {
+ SparseBlock sblock = SparseBlockFactory.createSparseBlock(type,
_cols);
+ assertEquals(sblock.getSparseBlockType(), type);
+ }
+
+ @Test
+ public void testSparseBlockCOOInitCapacity() {
+ int init_capacity = 4;
+ SparseBlockCOO sblock = new SparseBlockCOO(_cols);
+ assertEquals("INIT_CAPACITY should be 4", init_capacity,
sblock.values(1).length);
+ }
+
+ @Test
+ public void testSparseBlockCOORows() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock sblock = new SparseBlockCOO(mbtmp.getSparseBlock());
+
+ int totalNnz = 0;
+ int rows = A.length;
+ SparseRow[] sparseRows = new SparseRow[rows];
+
+ for (int i = 0; i < rows; i++) {
+ SparseRow srv = new SparseRowVector(A[i]);
+ sparseRows[i] = srv;
+ totalNnz += srv.size();
+ }
+
+ SparseBlockCOO sblock2 = new SparseBlockCOO(sparseRows,
totalNnz);
+ assertEquals(sblock, sblock2);
+ }
+
+ @Test
+ public void testSparseBlockCOORowsValuesIndexes() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock sblock = new SparseBlockCOO(mbtmp.getSparseBlock());
+
+ int totalNnz = 0;
+ int rows = A.length;
+ SparseRow[] sparseRows = new SparseRow[rows];
+
+ for (int i = 0; i < rows; i++) {
+ int[] indexes = new int[A[i].length];
+ for (int j = 0; j < A[i].length; j++) indexes[j] = j;
+ SparseRow srv = new SparseRowVector(A[i], indexes);
+ srv.compact();
+ sparseRows[i] = srv;
+ totalNnz += srv.size();
+ }
+
+ SparseBlockCOO sblock2 = new SparseBlockCOO(sparseRows,
totalNnz);
+ assertEquals(sblock, sblock2);
+ }
+
+ @Test
+ public void testSparseBlockCSCInitCapacity() {
+ int rlen = 4;
+ int clen = 5;
+ int capacity = 4;
+ SparseBlockCSC sblock = new SparseBlockCSC(rlen, clen,
capacity);
+
+ assertEquals("num rows should be equal to rlen", rlen,
sblock.numRows());
+ assertEquals("length ptr should be equal to clen+1", clen+1,
sblock.colPointers().length);
+ assertEquals("length values should be equal to capacity",
capacity, sblock.valuesCol(0).length);
+ assertEquals("length indexes should be equal to capacity",
capacity, sblock.indexesCol(0).length);
+ }
+
+ @Test
+ public void testSparseBlockCSCInitPointer() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlockCSC sblock = new
SparseBlockCSC(mbtmp.getSparseBlock());
+
+ int[] colPtr = sblock.colPointers();
+ int[] rowInd = sblock.indexesCol(0);
+ double[] values = sblock.valuesCol(0);
+ int nnz = sblock.sizeCol(0);
+ SparseBlockCSC sblock2 = new SparseBlockCSC(colPtr, rowInd,
values, nnz);
+
+ assertEquals(sblock, sblock2);
+ }
+
+ @Test
+ public void testSparseBlockCSCInitMSCS() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock sblock = new
SparseBlockMCSC(mbtmp.getSparseBlock());
+
+ SparseBlockCSC sblock2 = new SparseBlockCSC(sblock);
+ assertEquals(sblock, sblock2);
+ }
+
+ @Test
+ public void testSparseBlockCSCInitCols() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlockMCSC sblock = new
SparseBlockMCSC(mbtmp.getSparseBlock());
+
+ SparseRow[] cols = sblock.getCols();
+ int totalNnz = (int) sblock.size();
+
+ SparseBlock sblock2 = new SparseBlockCSC(cols, totalNnz);
+ assertEquals(sblock, sblock2);
+
+ }
+
+ @Test
+ public void testSparseBlockCSCInitRowColInd() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlockCSC sblock = new
SparseBlockCSC(mbtmp.getSparseBlock());
+
+ int[] ptr = sblock.colPointers();
+ int[] rowInd = sblock.indexesCol(0);
+ double[] values = sblock.valuesCol(0);
+
+ int clen = ptr.length-1;
+ int[] colInd = new int[rowInd.length];
+ for(int i=0; i<clen; i++) {
+ for(int j=ptr[i]; j<ptr[i+1]; j++) {
+ colInd[j] = i;
+ }
+ }
+
+ SparseBlock sblock2 = new SparseBlockCSC(clen, rowInd, colInd,
values);
+ assertEquals(sblock, sblock2);
+ }
+
+ @Test
+ public void testSparseBlockCSCInitUltraSparse() throws Exception {
+ double ultraSparsity = 0.001;
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
ultraSparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlockCSC sblock = new
SparseBlockCSC(mbtmp.getSparseBlock());
+
+ // stream of ijv triples
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutputStream dos = new DataOutputStream(baos);
+ int nnz = 0;
+ for (int c = 0; c < _cols; c++) {
+ for (int r = 0; r < _rows; r++) {
+ double v = A[r][c];
+ if (v != 0) {
+ dos.writeInt(r);
+ dos.writeInt(c);
+ dos.writeDouble(v);
+ nnz++;
+ }
+ }
+ }
+ dos.close();
+
+ SparseBlockCSC sblock2 = new SparseBlockCSC(_rows, _cols);
+ DataInputStream dis = new DataInputStream(new
ByteArrayInputStream(baos.toByteArray()));
+ sblock2.initUltraSparse(nnz, dis);
+ assertEquals(sblock, sblock2);
+ }
+
+ @Test
+ public void testSparseBlockCSCInitSparse() throws Exception {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlockCSC sblock = new
SparseBlockCSC(mbtmp.getSparseBlock());
+
+ // ijv-stream in CSC order
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutputStream dos = new DataOutputStream(baos);
+ int nnz = 0;
+ for (int c = 0; c < _cols; c++) {
+ int lnnz = 0;
+ for (int r = 0; r < _rows; r++) {
+ if (A[r][c] != 0)
+ lnnz++;
+ }
+ dos.writeInt(lnnz);
+ nnz += lnnz;
+
+ for (int r = 0; r < _rows; r++) {
+ double v = A[r][c];
+ if (v != 0) {
+ dos.writeInt(r);
+ dos.writeDouble(v);
+ }
+ }
+ }
+ dos.close();
+
+ SparseBlockCSC sblock2 = new SparseBlockCSC(_rows, _cols);
+ DataInputStream dis = new DataInputStream(new
ByteArrayInputStream(baos.toByteArray()));
+ sblock2.initSparse(_cols, nnz, dis);
+ assertEquals(sblock, sblock2);
+ }
+
+ @Test
+ public void testSparseBlockCSRInitRows() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlockMCSR sblock = new
SparseBlockMCSR(mbtmp.getSparseBlock());
+
+ SparseRow[] rows = sblock.getRows();
+ int totalNnz = (int) sblock.size();
+
+ SparseBlock sblock2 = new SparseBlockCSR(rows, totalNnz);
+ assertEquals(sblock, sblock2);
+ }
+
+ @Test
+ public void testSparseBlockCSRInitRowColInd() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlockCSR sblock = new
SparseBlockCSR(mbtmp.getSparseBlock());
+
+ int[] ptr = sblock.rowPointers();
+ int[] colInd = sblock.indexes();
+ double[] values = sblock.values();
+
+ int rlen = ptr.length-1;
+ int[] rowInd = new int[colInd.length];
+ for(int i=0; i<rlen; i++) {
+ for(int j=ptr[i]; j<ptr[i+1]; j++) {
+ rowInd[j] = i;
+ }
+ }
+
+ SparseBlock sblock2 = new SparseBlockCSR(rlen, rowInd, colInd,
values);
+ assertEquals(sblock, sblock2);
+ }
+
+ @Test
+ public void testSparseBlockDCSRInitCapacity() {
+ int init_capacity = 4;
+ SparseBlockDCSR sblock = new SparseBlockDCSR(_rows);
+ assertEquals("INIT_CAPACITY should be 4", init_capacity,
sblock.values(1).length);
+ }
+
+ @Test
+ public void testSparseBlockDCSRInitRowColInd() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ if(!mbtmp.isInSparseFormat()) mbtmp.denseToSparse(true);
+ SparseBlockDCSR sblock = new
SparseBlockDCSR(mbtmp.getSparseBlock());
+
+ int[] colIdx = sblock.indexes(0);
+ double[] values = sblock.values(0);
+ int rlen = sblock.numRows();
+ int nnz = (int) sblock.size();
+ int nnzr = 0;
+
+ int end = 0;
+ int[] rowIdx = new int[rlen];
+ int[] rowPtr = new int[rlen+1];
+ for(int i=0, j=0; i<rlen; i++) {
+ if(sblock.size(i) != 0){
+ nnzr++;
+ end += sblock.size(i);
+ rowIdx[j] = i;
+ rowPtr[j+1] = end;
+ j++;
+ }
+ }
+
+ SparseBlock sblock2 = new SparseBlockDCSR(rowIdx, rowPtr,
colIdx, values, rlen, nnz, nnzr);
+ assertEquals(sblock, sblock2);
+ }
+
+ @Test
+ public void testSparseBlockDCSRInitCOO() {
+ testSparseBlockDCSRInitFromSparseBlock(SparseBlock.Type.COO);
+ }
+
+ @Test
+ public void testSparseBlockDCSRInitCSC() {
+ testSparseBlockDCSRInitFromSparseBlock(SparseBlock.Type.CSC);
+ }
+
+ @Test
+ public void testSparseBlockDCSRInitMCSC() {
+ testSparseBlockDCSRInitFromSparseBlock(SparseBlock.Type.MCSC);
+ }
+
+ @Test
+ public void testSparseBlockDCSRInitMCSR() {
+ testSparseBlockDCSRInitFromSparseBlock(SparseBlock.Type.MCSR);
+ }
+
+ public void testSparseBlockDCSRInitFromSparseBlock(SparseBlock.Type
btype) {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock srtmp = mbtmp.getSparseBlock();
+ SparseBlock sblock = new SparseBlockDCSR(srtmp);
+
+ SparseBlock sblock2 = SparseBlockFactory.copySparseBlock(btype,
srtmp, true);
+ SparseBlock sblock3 = new SparseBlockDCSR(sblock2);
+ assertEquals(sblock, sblock3);
+ }
+
+ @Test
+ public void testSparseBlockMCSCInitMCSCOriginalColNull() {
+ double ultraSparsity = 0.001;
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
ultraSparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock sblock = new
SparseBlockMCSC(mbtmp.getSparseBlock());
+
+ SparseBlock sblock2 = new SparseBlockMCSC(sblock);
+ assertEquals(sblock, sblock2);
+ }
+
+ @Test
+ public void testSparseBlockMCSCInitMCSRNoClenInferred() {
+ double ultraSparsity = 0.001;
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
ultraSparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock sblock = new
SparseBlockMCSR(mbtmp.getSparseBlock());
+
+ SparseBlock sblock2 = new SparseBlockMCSC(sblock);
+ assertEquals(sblock, sblock2);
+ }
+
+ @Test
+ public void testSparseBlockMCSCInitMCSRClenInferred() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock sblock = new
SparseBlockMCSR(mbtmp.getSparseBlock());
+
+ SparseBlock sblock2 = new SparseBlockMCSC(sblock, _cols);
+ assertEquals(sblock, sblock2);
+ }
+
+ @Test
+ public void testSparseBlockMCSCInitCSC() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlock sblock = new SparseBlockCSC(mbtmp.getSparseBlock());
+
+ SparseBlock sblock2 = new SparseBlockMCSC(sblock);
+ assertEquals(sblock, sblock2);
+ }
+
+ @Test
+ public void testSparseBlockMCSCInitColsDeep() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlockMCSC sblock = new
SparseBlockMCSC(mbtmp.getSparseBlock());
+
+ SparseRow[] cols = sblock.getCols();
+ int rlen = sblock.numRows();
+
+ SparseBlock sblock2 = new SparseBlockMCSC(cols, true, rlen);
+ assertEquals(sblock, sblock2);
+ }
+
+ @Test
+ public void testSparseBlockMCSCInitColsNonDeep() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlockMCSC sblock = new
SparseBlockMCSC(mbtmp.getSparseBlock());
+
+ SparseRow[] cols = sblock.getCols();
+ int rlen = sblock.numRows();
+
+ SparseBlock sblock2 = new SparseBlockMCSC(cols, false, rlen);
+ assertEquals(sblock, sblock2);
+ }
+
+ @Test
+ public void testSparseBlockMCSCInitClen() {
+ int clen = _cols;
+ SparseBlockMCSC sblock = new SparseBlockMCSC(clen);
+ assertEquals(clen, sblock.numCols());
+ }
+
+ @Test
+ public void testSparseBlockMCSRInitRows() {
+ double[][] A = getRandomMatrix(_rows, _cols, -10, 10,
_sparsity, 1234);
+ MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A);
+ SparseBlockMCSR sblock = new
SparseBlockMCSR(mbtmp.getSparseBlock());
+
+ SparseRow[] rows = sblock.getRows();
+
+ SparseBlockMCSR sblock2 = new SparseBlockMCSR(rows, true);
+ assertEquals(sblock, sblock2);
+ assertNotSame(sblock.getRows(), sblock2.getRows());
+ }
+
+ @Test
+ public void testSparseBlockCSRInitSize() {
+ int rlen = 3;
+ int capacity = 7;
+ int size = 2;
+ SparseBlockCSR sblock = new SparseBlockCSR(rlen, capacity,
size);
+ sblock.append(0, 1, 1.0);
+ sblock.append(0, 3, 3.0);
+ sblock.compact();
+ assertEquals("size should be 2", 2, sblock.size());
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java
index 57cee617fb..f1429c4d65 100644
---
a/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java
+++
b/src/test/java/org/apache/sysds/test/component/sparse/SparseBlockIterator.java
@@ -42,7 +42,10 @@ import org.apache.sysds.test.TestUtils;
public class SparseBlockIterator extends AutomatedTestBase {
private final static int rows = 324;
private final static int cols = 100;
- private final static int rlPartial = 134;
+ private final static int rlVal = 134;
+ private final static int ruVal = 253;
+ private final static int clVal = 34;
+ private final static int cuVal = 53;
private final static double sparsity1 = 0.1;
private final static double sparsity2 = 0.2;
private final static double sparsity3 = 0.3;
@@ -54,187 +57,367 @@ public class SparseBlockIterator extends
AutomatedTestBase {
@Test
public void testSparseBlockMCSR1Full() {
- runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity1,
false);
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity1,
false, false);
}
@Test
public void testSparseBlockMCSR2Full() {
- runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity2,
false);
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity2,
false, false);
}
@Test
public void testSparseBlockMCSR3Full() {
- runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity3,
false);
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity3,
false, false);
+ }
+
+ @Test
+ public void testSparseBlockMCSR1RlPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity1,
true, false);
+ }
+
+ @Test
+ public void testSparseBlockMCSR2RlPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity2,
true, false);
+ }
+
+ @Test
+ public void testSparseBlockMCSR3RlPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity3,
true, false);
+ }
+
+ @Test
+ public void testSparseBlockMCSR1RuPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity1,
false, true);
+ }
+
+ @Test
+ public void testSparseBlockMCSR2RuPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity2,
false, true);
+ }
+
+ @Test
+ public void testSparseBlockMCSR3RuPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity3,
false, true);
}
@Test
public void testSparseBlockMCSR1Partial() {
- runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity1,
true);
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity1,
true, true);
}
@Test
public void testSparseBlockMCSR2Partial() {
- runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity2,
true);
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity2,
true, true);
}
@Test
public void testSparseBlockMCSR3Partial() {
- runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity3,
true);
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSR, sparsity3,
true, true);
}
@Test
public void testSparseBlockCSR1Full() {
- runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity1,
false);
+ runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity1,
false, false);
}
@Test
public void testSparseBlockCSR2Full() {
- runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity2,
false);
+ runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity2,
false, false);
}
@Test
public void testSparseBlockCSR3Full() {
- runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity3,
false);
+ runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity3,
false, false);
+ }
+
+ @Test
+ public void testSparseBlockCSR1RlPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity1,
true, false);
+ }
+
+ @Test
+ public void testSparseBlockCSR2RlPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity2,
true, false);
+ }
+
+ @Test
+ public void testSparseBlockCSR3RlPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity3,
true, false);
+ }
+
+ @Test
+ public void testSparseBlockCSR1RuPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity1,
false, true);
+ }
+
+ @Test
+ public void testSparseBlockCSR2RuPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity2,
false, true);
+ }
+
+ @Test
+ public void testSparseBlockCSR3RuPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity3,
false, true);
}
@Test
public void testSparseBlockCSR1Partial() {
- runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity1,
true);
+ runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity1,
true, true);
}
@Test
public void testSparseBlockCSR2Partial() {
- runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity2,
true);
+ runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity2,
true, true);
}
@Test
public void testSparseBlockCSR3Partial() {
- runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity3,
true);
+ runSparseBlockIteratorTest(SparseBlock.Type.CSR, sparsity3,
true, true);
}
@Test
public void testSparseBlockCOO1Full() {
- runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity1,
false);
+ runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity1,
false, false);
}
@Test
public void testSparseBlockCOO2Full() {
- runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity2,
false);
+ runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity2,
false, false);
}
@Test
public void testSparseBlockCOO3Full() {
- runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity3,
false);
+ runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity3,
false, false);
+ }
+
+ @Test
+ public void testSparseBlockCOO1RlPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity1,
true, false);
+ }
+
+ @Test
+ public void testSparseBlockCOO2RlPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity2,
true, false);
+ }
+
+ @Test
+ public void testSparseBlockCOO3RlPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity3,
true, false);
+ }
+
+ @Test
+ public void testSparseBlockCOO1RuPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity1,
false, true);
+ }
+
+ @Test
+ public void testSparseBlockCOO2RuPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity2,
false, true);
+ }
+
+ @Test
+ public void testSparseBlockCOO3RuPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity3,
false, true);
}
@Test
public void testSparseBlockCOO1Partial() {
- runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity1,
true);
+ runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity1,
true, true);
}
@Test
public void testSparseBlockCOO2Partial() {
- runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity2,
true);
+ runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity2,
true, true);
}
@Test
public void testSparseBlockCOO3Partial() {
- runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity3,
true);
+ runSparseBlockIteratorTest(SparseBlock.Type.COO, sparsity3,
true, true);
}
@Test
public void testSparseBlockDCSR1Full() {
- runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity1,
false);
+ runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity1,
false, false);
}
@Test
public void testSparseBlockDCSR2Full() {
- runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity2,
false);
+ runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity2,
false, false);
}
@Test
public void testSparseBlockDCSR3Full() {
- runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity3,
false);
+ runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity3,
false, false);
+ }
+
+ @Test
+ public void testSparseBlockDCSR1RlPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity1,
true, false);
+ }
+
+ @Test
+ public void testSparseBlockDCSR2RlPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity2,
true, false);
+ }
+
+ @Test
+ public void testSparseBlockDCSR3RlPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity3,
true, false);
+ }
+
+ @Test
+ public void testSparseBlockDCSR1RuPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity1,
false, true);
+ }
+
+ @Test
+ public void testSparseBlockDCSR2RuPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity2,
false, true);
+ }
+
+ @Test
+ public void testSparseBlockDCSR3RuPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity3,
false, true);
}
@Test
public void testSparseBlockDCSR1Partial() {
- runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity1,
true);
+ runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity1,
true, true);
}
@Test
public void testSparseBlockDCSR2Partial() {
- runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity2,
true);
+ runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity2,
true, true);
}
@Test
public void testSparseBlockDCSR3Partial() {
- runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity3,
true);
+ runSparseBlockIteratorTest(SparseBlock.Type.DCSR, sparsity3,
true, true);
}
@Test
public void testSparseBlockMCSC1Full() {
- runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity1,
false);
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity1,
false, false);
}
@Test
public void testSparseBlockMCSC2Full() {
- runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity2,
false);
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity2,
false, false);
}
@Test
public void testSparseBlockMCSC3Full() {
- runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity3,
false);
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity3,
false, false);
+ }
+
+ @Test
+ public void testSparseBlockMCSC1RlPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity1,
true, false);
+ }
+
+ @Test
+ public void testSparseBlockMCSC2RlPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity2,
true, false);
+ }
+
+ @Test
+ public void testSparseBlockMCSC3RlPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity3,
true, false);
+ }
+
+ @Test
+ public void testSparseBlockMCSC1RuPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity1,
false, true);
+ }
+
+ @Test
+ public void testSparseBlockMCSC2RuPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity2,
false, true);
+ }
+
+ @Test
+ public void testSparseBlockMCSC3RuPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity3,
false, true);
}
@Test
public void testSparseBlockMCSC1Partial() {
- runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity1,
true);
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity1,
true, true);
}
@Test
public void testSparseBlockMCSC2Partial() {
- runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity2,
true);
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity2,
true, true);
}
@Test
public void testSparseBlockMCSC3Partial() {
- runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity3,
true);
+ runSparseBlockIteratorTest(SparseBlock.Type.MCSC, sparsity3,
true, true);
}
@Test
public void testSparseBlockCSC1Full() {
- runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity1,
false);
+ runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity1,
false, false);
}
@Test
public void testSparseBlockCSC2Full() {
- runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity2,
false);
+ runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity2,
false, false);
}
@Test
public void testSparseBlockCSC3Full() {
- runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity3,
false);
+ runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity3,
false, false);
+ }
+
+ @Test
+ public void testSparseBlockCSC1RlPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity1,
true, false);
+ }
+
+ @Test
+ public void testSparseBlockCSC2RlPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity2,
true, false);
+ }
+
+ @Test
+ public void testSparseBlockCSC3RlPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity3,
true, false);
+ }
+
+ @Test
+ public void testSparseBlockCSC1RuPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity1,
false, true);
+ }
+
+ @Test
+ public void testSparseBlockCSC2RuPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity2,
false, true);
+ }
+
+ @Test
+ public void testSparseBlockCSC3RuPartial() {
+ runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity3,
false, true);
}
@Test
public void testSparseBlockCSC1Partial() {
- runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity1,
true);
+ runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity1,
true, true);
}
@Test
public void testSparseBlockCSC2Partial() {
- runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity2,
true);
+ runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity2,
true, true);
}
@Test
public void testSparseBlockCSC3Partial() {
- runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity3,
true);
+ runSparseBlockIteratorTest(SparseBlock.Type.CSC, sparsity3,
true, true);
}
- private void runSparseBlockIteratorTest(SparseBlock.Type btype, double
sparsity, boolean partial) {
+ private void runSparseBlockIteratorTest(SparseBlock.Type btype, double
sparsity, boolean rlPartial, boolean ruPartial) {
try {
//data generation
double[][] A = getRandomMatrix(rows, cols, -10, 10,
sparsity, 8765432);
@@ -247,25 +430,34 @@ public class SparseBlockIterator extends
AutomatedTestBase {
//check for correct number of non-zeros
int[] rnnz = new int[rows];
int nnz = 0;
- int rl = partial ? rlPartial : 0;
- for(int i = rl; i < rows; i++) {
- for(int j = 0; j < cols; j++)
+ int rl = rlPartial ? rlVal : 0;
+ int ru = ruPartial ? ruVal : rows;
+ int cl = rlPartial && ruPartial ? clVal : 0;
+ int cu = rlPartial && ruPartial ? cuVal : cols;
+ for(int i = rl; i < ru; i++) {
+ for(int j = cl; j < cu; j++)
rnnz[i] += (A[i][j] != 0) ? 1 : 0;
nnz += rnnz[i];
}
- if(!partial && nnz != sblock.size())
+ if(!rlPartial && !ruPartial && nnz != sblock.size()) //
no restriction
Assert.fail("Wrong number of non-zeros: " +
sblock.size() + ", expected: " + nnz);
//check correct isEmpty return
- for(int i = rl; i < rows; i++)
- if(sblock.isEmpty(i) != (rnnz[i] == 0))
- Assert.fail("Wrong isEmpty(row) result
for row nnz: " + rnnz[i]);
+ if(!(rlPartial && ruPartial)) { // cols not restricted
+ for(int i = rl; i < ru; i++)
+ if(sblock.isEmpty(i) != (rnnz[i] == 0))
+ Assert.fail("Wrong isEmpty(row)
result for row nnz: " + rnnz[i]);
+ }
//check correct values
- Iterator<IJV> iter = !partial ? sblock.getIterator() :
sblock.getIterator(rl, rows);
+ Iterator<IJV> iter = rlPartial && ruPartial ?
sblock.getIterator(rl, ru, cl, cu): rlPartial? sblock.getIterator(rl, rows) :
ruPartial? sblock.getIterator(ru) : sblock.getIterator();
int count = 0;
while(iter.hasNext()) {
IJV cell = iter.next();
+ if(cell.getI() < rl || cell.getI() >= ru)
+ Assert.fail("iterator row outside of
range");
+ if(cell.getJ() < cl || cell.getJ() >= cu)
+ Assert.fail("iterator column outside of
range");
if(cell.getV() != A[cell.getI()][cell.getJ()])
Assert.fail("Wrong value returned by
iterator: " + cell.getV() + ", expected: " +
A[cell.getI()][cell.getJ()]);
@@ -277,11 +469,9 @@ public class SparseBlockIterator extends AutomatedTestBase
{
// check iterator over non-zero rows
List<Integer> manualNonZeroRows = new ArrayList<>();
List<Integer> iteratorNonZeroRows = new ArrayList<>();
- Iterator<Integer> iterRows = !partial
- ? sblock.getNonEmptyRowsIterator(0, rows)
- : sblock.getNonEmptyRowsIterator(rl, rows);
+ Iterator<Integer> iterRows =
sblock.getNonEmptyRowsIterator(rl, ru);
- for(int i = rl; i < rows; i++)
+ for(int i = rl; i < ru; i++)
if(!sblock.isEmpty(i))
manualNonZeroRows.add(i);
while(iterRows.hasNext()) {
@@ -293,6 +483,16 @@ public class SparseBlockIterator extends AutomatedTestBase
{
Assert.fail("Verification of iterator over
non-zero rows failed.");
}
+ // check second iterator over non-zero rows
+ Iterator<Integer> iterRows2 = !rlPartial && !ruPartial?
sblock.getNonEmptyRows().iterator() : sblock.getNonEmptyRows(rl, ru).iterator();
+ List<Integer> iter2NonZeroRows = new ArrayList<>();
+
+ while(iterRows2.hasNext()) {
+ iter2NonZeroRows.add(iterRows2.next());
+ }
+ if(!manualNonZeroRows.equals(iter2NonZeroRows)) {
+ Assert.fail("Verification of second iterator
over non-zero rows failed.");
+ }
}
catch(Exception ex) {
ex.printStackTrace();
diff --git
a/src/test/java/org/apache/sysds/test/component/sparse/SparseRowTest.java
b/src/test/java/org/apache/sysds/test/component/sparse/SparseRowTest.java
new file mode 100644
index 0000000000..307b4335c4
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/component/sparse/SparseRowTest.java
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.sparse;
+
+import org.apache.sysds.runtime.data.SparseRow;
+import org.apache.sysds.runtime.data.SparseRowScalar;
+import org.apache.sysds.runtime.data.SparseRowVector;
+import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+
+
+public class SparseRowTest extends AutomatedTestBase
+{
+ private final static int cols = 121;
+ private final static int minVal = -10;
+ private final static int maxVal = 10;
+ private final static double sparsity = 0.3;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ }
+
+ @Test
+ public void testSparseRowEmptyToString() {
+ SparseRowScalar srs = new SparseRowScalar();
+ assertEquals("", srs.toString());
+ }
+
+ @Test
+ public void testSparseRowScalarInitZeroVal() {
+ SparseRowScalar srs = new SparseRowScalar(5, 0);
+ srs.compact();
+ assertEquals(-1, srs.getIndex());
+ }
+
+ @Test
+ public void testSparseRowScalarSetNewVal() {
+ SparseRowScalar srs = new SparseRowScalar();
+ assertTrue(srs.set(3, 5.0));
+ }
+
+ @Test
+ public void testSparseRowScalarInvalidSet() {
+ SparseRowScalar srs = new SparseRowScalar(1, 1.0);
+ RuntimeException ex = assertThrows(RuntimeException.class, ()
-> srs.set(3, 5.0));
+ assertEquals("Invalid set to sparse row scalar.",
ex.getMessage());
+ }
+
+ @Test
+ public void testSparseRowScalarAppendZero() {
+ SparseRowScalar srs = new SparseRowScalar(1, 1.0);
+ SparseRow srs2 = srs.append(2, 0.0);
+ assertEquals(srs, srs2);
+ assertNotEquals(0, srs2.values()[0]);
+ }
+
+ @Test
+ public void testSparseRowScalarCompactZero() {
+ SparseRowScalar srs = new SparseRowScalar(1, 0.0);
+ srs.compact();
+ assertEquals(-1, srs.getIndex());
+ }
+
+ @Test
+ public void testSparseRowScalarCompactNonZero() {
+ SparseRowScalar srs = new SparseRowScalar(1, 1.0);
+ srs.compact();
+ assertEquals(1, srs.getIndex());
+ }
+
+ @Test
+ public void testSparseRowScalarCopy() {
+ SparseRowScalar srs = new SparseRowScalar(1, 1.0);
+ SparseRowScalar srs2 = (SparseRowScalar) srs.copy(true);
+ assertEquals(srs.getIndex(), srs2.getIndex());
+ assertEquals(srs.getValue(), srs2.getValue(), 0.0);
+ assertNotEquals(srs, srs2);
+ }
+
+ @Test
+ public void testSparseRowVectorSetValues() {
+ double[] v = getRandomMatrix(1, cols, minVal, maxVal, sparsity,
7)[0];
+ SparseRowVector srv = new
SparseRowVector(UtilFunctions.computeNnz(v, 0, v.length), v, v.length);
+
+ srv.compact();
+ int nnz = srv.size();
+ double[] w = getRandomMatrix(1, nnz, minVal, maxVal, 1, 13)[0];
+ srv.setValues(w);
+
+ assertArrayEquals(w, srv.values(), 0.0);
+ assertEquals(srv.indexes().length, srv.values().length);
+ }
+
+ @Test
+ public void testSparseRowVectorSetIndexes() {
+ double[] v = getRandomMatrix(1, cols, minVal, maxVal, 1, 7)[0];
+ int nnz = UtilFunctions.computeNnz(v, 0, v.length);
+ SparseRowVector srv = new SparseRowVector(nnz, v, v.length);
+
+ int[] indexes = new int[nnz];
+ for(int i = 0; i < nnz; i++) indexes[i] = i;
+ srv.setIndexes(indexes);
+
+ int idx = (int)(Math.random() * nnz);
+ assertEquals(idx, srv.getIndex(idx));
+ assertEquals(-1, srv.getIndex(nnz));
+ assertEquals(srv.values().length, srv.indexes().length);
+ }
+
+ @Test
+ public void testSparseRowVectorCopyFromLargerArray() {
+ double[] v = getRandomMatrix(1, cols, minVal, maxVal, sparsity,
7)[0];
+ double[] w = getRandomMatrix(1, 2*cols, minVal, maxVal,
sparsity, 7)[0];
+ SparseRowVector srv = new
SparseRowVector(UtilFunctions.computeNnz(v, 0, v.length), v, v.length);
+ SparseRowVector other = new
SparseRowVector(UtilFunctions.computeNnz(w, 0, w.length), w, w.length);
+ srv.copy(other);
+
+ assertArrayEquals(other.indexes(), srv.indexes());
+ assertArrayEquals(other.values(), srv.values(), 0.0);
+ assertNotEquals(other, srv);
+ }
+
+ @Test
+ public void testSparseRowVectorSetEstimatedNzs() {
+ double[] v = getRandomMatrix(1, cols, minVal, maxVal, sparsity,
7)[0];
+ int nnz = UtilFunctions.computeNnz(v, 0, v.length);
+ SparseRowVector srv = new SparseRowVector(nnz, v, v.length);
+ srv.setEstimatedNzs(nnz+1);
+ assertEquals(nnz+1, srv.getEstimatedNzs());
+ }
+
+ @Test
+ public void testSparseRowVectorSetAtPos() {
+ double[] v = getRandomMatrix(1, cols, minVal, maxVal, sparsity,
7)[0];
+ int nnz = UtilFunctions.computeNnz(v, 0, v.length);
+ SparseRowVector srv = new SparseRowVector(nnz, v, v.length);
+
+ int pos = nnz-1;
+ int col = 2;
+ double val = 2.0;
+ srv.setAtPos(pos, col, val);
+
+ assertEquals(col, srv.indexes()[pos]);
+ assertEquals(val, srv.indexes()[pos],0.0);
+ }
+
+ @Test
+ public void testSparseRowVectorGetIndex() {
+ double[] v = getRandomMatrix(1, cols, minVal, maxVal, sparsity,
7)[0];
+ int nnz = UtilFunctions.computeNnz(v, 0, v.length);
+ SparseRowVector srv = new SparseRowVector(nnz, v, v.length);
+
+ int pos = 0;
+ srv.setAtPos(pos, 5, 2.0);
+ int index = srv.getIndex(5);
+ assertEquals(pos, index);
+
+ int col2 = cols+1;
+ int index2 = srv.getIndex(col2);
+ assertEquals(-1, index2);
+ }
+
+ @Test
+ public void testSparseRowVectorSearchIndexesFirstLTESizeZero() {
+ SparseRowVector srv = new SparseRowVector();
+ int index = srv.searchIndexesFirstLTE(1);
+ assertEquals(-1, index);
+ }
+
+ @Test
+ public void testSparseRowVectorSearchIndexesFirstLTENotFound() {
+ SparseRowVector srv = new SparseRowVector(new double[] {1.0,
3.0}, new int[] {1, 3});
+ int index = srv.searchIndexesFirstLTE(0);
+ assertEquals(-1, index);
+ int index2 = srv.searchIndexesFirstLTE(2);
+ assertEquals(0, index2);
+ int index3 = srv.searchIndexesFirstLTE(5);
+ assertEquals(1, index3);
+ }
+
+ @Test
+ public void testSparseRowVectorSetIndexRangeWithoutRecap() {
+ SparseRowVector srv = new SparseRowVector();
+ int capacity = srv.capacity();
+
+ double[] v = getRandomMatrix(1, capacity, minVal, maxVal,
sparsity, 7)[0];
+ srv.setIndexRange(0, capacity, v, 0, capacity);
+ assertEquals(capacity, srv.capacity());
+ }
+}