Repository: systemml Updated Branches: refs/heads/master 8ca61ae26 -> 98ee9b7d8
[HOTFIX][SYSTEMML-1959] Fix CSR/COO index search (internal/external pos) The recent change to use CSR more aggressively for the purpose of shallow serialize, revealed a hidden issue of CSR/COO index searches, which are used by various operations that require column-wise partitioning. For our default MCSR, the internal and external indexes are equivalent due to separate row objects. This patch fixes the CSR and COO index searches to return the external indexes (within row), but use the internal indexes (across rows) for other operations that rely on these primitives as well. Furthermore, this includes a number of smaller fixes such as error handling in grouped aggregate operations, and correct use of sparse block position offsets in various operations. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/98ee9b7d Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/98ee9b7d Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/98ee9b7d Branch: refs/heads/master Commit: 98ee9b7d8f606efc4bd9a00e038c49e50806f129 Parents: 8ca61ae Author: Matthias Boehm <[email protected]> Authored: Sat Oct 14 18:12:20 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Oct 14 19:19:32 2017 -0700 ---------------------------------------------------------------------- .../sysml/runtime/matrix/data/LibMatrixAgg.java | 43 +++++----- .../runtime/matrix/data/LibMatrixMult.java | 36 ++++---- .../runtime/matrix/data/LibMatrixReorg.java | 4 +- .../sysml/runtime/matrix/data/MatrixBlock.java | 11 +-- .../runtime/matrix/data/SparseBlockCOO.java | 31 +++++-- .../runtime/matrix/data/SparseBlockCSR.java | 90 ++++++++++++-------- .../sparse/SparseBlockGetFirstIndex.java | 15 ++-- 7 files changed, 132 insertions(+), 98 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/98ee9b7d/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java index 4a99adf..f1ed00d 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java @@ -486,7 +486,7 @@ public class LibMatrixAgg else if( op instanceof AggregateOperator ) { AggregateOperator aggop = (AggregateOperator) op; groupedAggregateKahanPlus(groups, target, weights, result, numGroups, aggop, 0, target.clen); - } + } } public static void groupedAggregate(MatrixBlock groups, MatrixBlock target, MatrixBlock weights, MatrixBlock result, int numGroups, Operator op, int k) @@ -515,8 +515,10 @@ public class LibMatrixAgg int blklen = (int)(Math.ceil((double)target.clen/k)); for( int i=0; i<k & i*blklen<target.clen; i++ ) tasks.add( new GrpAggTask(groups, target, weights, result, numGroups, op, i*blklen, Math.min((i+1)*blklen, target.clen)) ); - pool.invokeAll(tasks); + List<Future<Object>> taskret = pool.invokeAll(tasks); pool.shutdown(); + for(Future<Object> task : taskret) + task.get(); //error handling } catch(Exception ex) { throw new DMLRuntimeException(ex); @@ -809,18 +811,17 @@ public class LibMatrixAgg int pos = target.sparseBlock.pos(0); int len = target.sparseBlock.size(0); int[] aix = target.sparseBlock.indexes(0); - double[] avals = target.sparseBlock.values(0); + double[] avals = target.sparseBlock.values(0); for( int j=pos; j<pos+len; j++ ) //for each nnz { - int g = (int) groups.quickGetValue(aix[j], 0); + int g = (int) groups.quickGetValue(aix[j], 0); if ( g > numGroups ) continue; if ( weights != null ) w = weights.quickGetValue(aix[j],0); - aggop.increOp.fn.execute(buffer[g-1][0], avals[j]*w); + aggop.increOp.fn.execute(buffer[g-1][0], avals[j]*w); } } - } else //DENSE target { @@ -828,7 +829,7 @@ public class LibMatrixAgg double d = target.denseBlock[ i ]; if( d != 0 ) //sparse-safe { - int g = (int) groups.quickGetValue(i, 0); + int g = (int) groups.quickGetValue(i, 0); if ( g > numGroups ) continue; if ( weights != null ) @@ -847,7 +848,7 @@ public class LibMatrixAgg for( int i=0; i < groups.getNumRows(); i++ ) { - int g = (int) groups.quickGetValue(i, 0); + int g = (int) groups.quickGetValue(i, 0); if ( g > numGroups ) continue; @@ -856,15 +857,15 @@ public class LibMatrixAgg int pos = a.pos(i); int len = a.size(i); int[] aix = a.indexes(i); - double[] avals = a.values(i); - int j = (cl==0) ? pos : a.posFIndexGTE(i,cl); - j = (j>=0) ? j : len; + double[] avals = a.values(i); + int j = (cl==0) ? 0 : a.posFIndexGTE(i,cl); + j = (j >= 0) ? pos+j : pos+len; for( ; j<pos+len && aix[j]<cu; j++ ) //for each nnz { if ( weights != null ) w = weights.quickGetValue(aix[j],0); - aggop.increOp.fn.execute(buffer[g-1][aix[j]-cl], avals[j]*w); + aggop.increOp.fn.execute(buffer[g-1][aix[j]-cl], avals[j]*w); } } } @@ -875,7 +876,7 @@ public class LibMatrixAgg for( int i=0, aix=0; i < groups.getNumRows(); i++, aix+=numCols ) { - int g = (int) groups.quickGetValue(i, 0); + int g = (int) groups.quickGetValue(i, 0); if ( g > numGroups ) continue; @@ -918,7 +919,7 @@ public class LibMatrixAgg for( int i=0; i < groups.getNumRows(); i++ ) { - int g = (int) groups.quickGetValue(i, 0); + int g = (int) groups.quickGetValue(i, 0); if ( g > numGroups ) continue; @@ -927,15 +928,15 @@ public class LibMatrixAgg int pos = a.pos(i); int len = a.size(i); int[] aix = a.indexes(i); - double[] avals = a.values(i); - int j = (cl==0) ? pos : a.posFIndexGTE(i,cl); - j = (j>=0) ? j : pos+len; + double[] avals = a.values(i); + int j = (cl==0) ? 0 : a.posFIndexGTE(i,cl); + j = (j >= 0) ? pos+j : pos+len; for( ; j<pos+len && aix[j]<cu; j++ ) //for each nnz { if ( weights != null ) w = weights.quickGetValue(i, 0); - cmFn.execute(cmValues[g-1][aix[j]-cl], avals[j], w); + cmFn.execute(cmValues[g-1][aix[j]-cl], avals[j], w); } //TODO sparse unsafe correction } @@ -947,7 +948,7 @@ public class LibMatrixAgg for( int i=0, aix=0; i < groups.getNumRows(); i++, aix+=target.clen ) { - int g = (int) groups.quickGetValue(i, 0); + int g = (int) groups.quickGetValue(i, 0); if ( g > numGroups ) continue; @@ -966,7 +967,7 @@ public class LibMatrixAgg for( int j=0; j < numCols2; j++ ) { // result is 0-indexed, so is cmValues result.appendValue(i, j, cmValues[i][j+cl].getRequiredResult(cmOp)); - } + } } private static void groupedAggregateVecCount( MatrixBlock groups, MatrixBlock result, int numGroups ) @@ -982,7 +983,7 @@ public class LibMatrixAgg //compute counts for( int i = 0; i < m; i++ ) { - int g = (int) a[i]; + int g = (int) a[i]; if ( g > numGroups ) continue; tmp[g-1]++; http://git-wip-us.apache.org/repos/asf/systemml/blob/98ee9b7d/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java index dbb65fe..fee73c5 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java @@ -1240,10 +1240,10 @@ public class LibMatrixMult int[] aix = a.indexes(i); double[] avals = a.values(i); - int k1 = (rl==0) ? apos : a.posFIndexGTE(i, rl); - k1 = (k1>=0) ? k1 : apos+alen; - int k2 = (ru==cd) ? apos+alen : a.posFIndexGTE(i, ru); - k2 = (k2>=0) ? k2 : apos+alen; + int k1 = (rl==0) ? 0 : a.posFIndexGTE(i, rl); + k1 = (k1>=0) ? apos+k1 : apos+alen; + int k2 = (ru==cd) ? alen : a.posFIndexGTE(i, ru); + k2 = (k2>=0) ? apos+k2 : apos+alen; //rest not aligned to blocks of 4 rows final int bn = (k2-k1) % 4; @@ -1797,9 +1797,9 @@ public class LibMatrixMult int apos = a.pos(r); int alen = a.size(r); int[] aix = a.indexes(r); - double[] avals = a.values(r); - int rlix = (rl==0) ? apos : a.posFIndexGTE(r, rl); - rlix = (rlix>=0) ? rlix : apos+alen; + double[] avals = a.values(r); + int rlix = (rl==0) ? 0 : a.posFIndexGTE(r, rl); + rlix = (rlix>=0) ? apos+rlix : apos+alen; for(int i = rlix; i < apos+alen && aix[i]<ru; i++) { @@ -1819,9 +1819,9 @@ public class LibMatrixMult int apos = a.pos(r); int alen = a.size(r); int[] aix = a.indexes(r); - double[] avals = a.values(r); - int rlix = (rl==0) ? apos : a.posFIndexGTE(r, rl); - rlix = (rlix>=0) ? rlix : apos+alen; + double[] avals = a.values(r); + int rlix = (rl==0) ? 0 : a.posFIndexGTE(r, rl); + rlix = (rlix>=0) ? apos+rlix : apos+alen; for(int i = rlix; i < apos+alen && aix[i]<ru; i++) { @@ -1861,9 +1861,9 @@ public class LibMatrixMult int apos = a.pos(r); int alen = a.size(r); int[] aix = a.indexes(r); - double[] avals = a.values(r); - int rlix = (rl==0) ? apos : a.posFIndexGTE(r, rl); - rlix = (rlix>=0) ? rlix : apos+alen; + double[] avals = a.values(r); + int rlix = (rl==0) ? 0 : a.posFIndexGTE(r, rl); + rlix = (rlix>=0) ? apos+rlix : apos+alen; for(int i = rlix; i < apos+alen && aix[i]<ru; i++) { @@ -1883,9 +1883,9 @@ public class LibMatrixMult int apos = a.pos(r); int alen = a.size(r); int[] aix = a.indexes(r); - double[] avals = a.values(r); - int rlix = (rl==0) ? apos : a.posFIndexGTE(r,rl); - rlix = (rlix>=0) ? rlix : apos+alen; + double[] avals = a.values(r); + int rlix = (rl==0) ? 0 : a.posFIndexGTE(r,rl); + rlix = (rlix>=0) ? apos+rlix : apos+alen; for(int i = rlix; i < apos+alen && aix[i]<ru; i++) { @@ -2618,8 +2618,8 @@ public class LibMatrixMult int wlen = w.size(i); int[] wix = w.indexes(i); double[] wval = w.values(i); - int k = (cl==0) ? wpos : w.posFIndexGTE(i,cl); - k = (k>=0) ? k : wpos+wlen; + int k = (cl==0) ? 0 : w.posFIndexGTE(i,cl); + k = (k>=0) ? wpos+k : wpos+wlen; for( ; k<wpos+wlen && wix[k]<cu; k++ ) { if( basic ) { double uvij = dotProductGeneric(mU,mV, i, wix[k], cd); http://git-wip-us.apache.org/repos/asf/systemml/blob/98ee9b7d/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java index 8d7d4a5..3ae07c5 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java @@ -853,7 +853,7 @@ public class LibMatrixReorg //blocked execution for( int bi=rl; bi<ru; bi+=blocksizeI ) { - Arrays.fill(ix, 0); + Arrays.fill(ix, 0); //find column starting positions int bimin = Math.min(bi+blocksizeI, ru); if( cl > 0 ) { @@ -934,7 +934,7 @@ public class LibMatrixReorg int j = ix[iix]; //last block boundary for( ; j<alen && aix[apos+j]<bjmin; j++ ) c[ aix[apos+j]*n2+i ] = avals[ apos+j ]; - ix[iix] = j; //keep block boundary + ix[iix] = j; //keep block boundary } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/98ee9b7d/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java index dac4315..8ee6f8d 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java @@ -3930,9 +3930,9 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab int alen = sparseBlock.size(i); int[] aix = sparseBlock.indexes(i); double[] avals = sparseBlock.values(i); - int astart = (cl>0)?sparseBlock.posFIndexGTE(i, cl) : apos; + int astart = (cl>0)?sparseBlock.posFIndexGTE(i, cl) : 0; if( astart != -1 ) - for( int j=astart; j<apos+alen && aix[j] <= cu; j++ ) + for( int j=apos+astart; j<apos+alen && aix[j] <= cu; j++ ) dest.appendValue(i-rl, aix[j]-cl, avals[j]); } } @@ -4080,11 +4080,12 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab return; //actual slice operation + int pos = sparseBlock.pos(r); for(int i=start; i<=end; i++) { - if(cols[i]<colCut) - left.appendValue(r+rowOffset, cols[i]+normalBlockColFactor-colCut, values[i]); + if(cols[pos+i]<colCut) + left.appendValue(r+rowOffset, cols[pos+i]+normalBlockColFactor-colCut, values[pos+i]); else - right.appendValue(r+rowOffset, cols[i]-colCut, values[i]); + right.appendValue(r+rowOffset, cols[pos+i]-colCut, values[pos+i]); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/98ee9b7d/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java index 63b8071..295a545 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java @@ -239,8 +239,8 @@ public class SparseBlockCOO extends SparseBlock long nnz = 0; for(int i=rl; i<ru; i++) if( !isEmpty(i) ) { - int start = posFIndexGTE(i, cl); - int end = posFIndexGTE(i, cu); + int start = internPosFIndexGTE(i, cl); + int end = internPosFIndexGTE(i, cu); nnz += (start!=-1) ? (end-start) : 0; } return nnz; @@ -346,7 +346,7 @@ public class SparseBlockCOO extends SparseBlock int lsize = _size+lnnz; if( _values.length < lsize ) resize(lsize); - int index = posFIndexGT(r, cl); + int index = internPosFIndexGT(r, cl); shiftRightByN((index>0)?index:pos(r+1), lnnz); //insert values @@ -361,12 +361,12 @@ public class SparseBlockCOO extends SparseBlock @Override public void deleteIndexRange(int r, int cl, int cu) { - int start = posFIndexGTE(r,cl); + int start = internPosFIndexGTE(r,cl); if( start < 0 ) //nothing to delete - return; + return; int len = size(r); - int end = posFIndexGTE(r, cu); + int end = internPosFIndexGTE(r, cu); if( end < 0 ) //delete all remaining end = start+len; @@ -374,7 +374,7 @@ public class SparseBlockCOO extends SparseBlock System.arraycopy(_rindexes, end, _rindexes, start, _size-end); System.arraycopy(_cindexes, end, _cindexes, start, _size-end); System.arraycopy(_values, end, _values, start, _size-end); - _size -= (end-start); + _size -= (end-start); } @Override @@ -409,7 +409,7 @@ public class SparseBlockCOO extends SparseBlock int len = size(r); //search for existing col index in [pos,pos+len) - int index = Arrays.binarySearch(_cindexes, pos, pos+len, c); + int index = Arrays.binarySearch(_cindexes, pos, pos+len, c); return (index >= 0) ? _values[index] : 0; } @@ -428,6 +428,11 @@ public class SparseBlockCOO extends SparseBlock @Override public int posFIndexLTE(int r, int c) { + int index = internPosFIndexLTE(r, c); + return (index>=0) ? index-pos(r) : index; + } + + private int internPosFIndexLTE(int r, int c) { int pos = pos(r); int len = size(r); @@ -443,6 +448,11 @@ public class SparseBlockCOO extends SparseBlock @Override public int posFIndexGTE(int r, int c) { + int index = internPosFIndexGTE(r, c); + return (index>=0) ? index-pos(r) : index; + } + + private int internPosFIndexGTE(int r, int c) { int pos = pos(r); int len = size(r); @@ -458,6 +468,11 @@ public class SparseBlockCOO extends SparseBlock @Override public int posFIndexGT(int r, int c) { + int index = internPosFIndexGT(r, c); + return (index>=0) ? index-pos(r) : index; + } + + private int internPosFIndexGT(int r, int c) { int pos = pos(r); int len = size(r); http://git-wip-us.apache.org/repos/asf/systemml/blob/98ee9b7d/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCSR.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCSR.java b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCSR.java index 35ffedf..19fbe50 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCSR.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCSR.java @@ -108,7 +108,7 @@ public class SparseBlockCSR extends SparseBlock pos += alen; } _ptr[i+1]=pos; - } + } } } @@ -276,6 +276,34 @@ public class SparseBlockCSR extends SparseBlock return (long) Math.min(size, Long.MAX_VALUE); } + + /** + * Get raw access to underlying array of row pointers + * For use in GPU code + * @return array of row pointers + */ + public int[] rowPointers() { + return _ptr; + } + + /** + * Get raw access to underlying array of column indices + * For use in GPU code + * @return array of column indexes + */ + public int[] indexes() { + return indexes(0); + } + + /** + * Get raw access to underlying array of values + * For use in GPU code + * @return array of values + */ + public double[] values() { + return values(0); + } + /////////////////// //SparseBlock implementation @@ -359,8 +387,8 @@ public class SparseBlockCSR extends SparseBlock long nnz = 0; for(int i=rl; i<ru; i++) if( !isEmpty(i) ) { - int start = posFIndexGTE(i, cl); - int end = posFIndexGTE(i, cu); + int start = internPosFIndexGTE(i, cl); + int end = internPosFIndexGTE(i, cu); nnz += (start!=-1) ? (end-start) : 0; } return nnz; @@ -482,7 +510,7 @@ public class SparseBlockCSR extends SparseBlock int lsize = _size+lnnz; if( _values.length < lsize ) resize(lsize); - int index = posFIndexGT(r, cl); + int index = internPosFIndexGT(r, cl); int index2 = (index>0)?index:pos(r+1); shiftRightByN(index2, lnnz); @@ -645,12 +673,12 @@ public class SparseBlockCSR extends SparseBlock @Override public void deleteIndexRange(int r, int cl, int cu) { - int start = posFIndexGTE(r,cl); + int start = internPosFIndexGTE(r,cl); if( start < 0 ) //nothing to delete - return; + return; int len = size(r); - int end = posFIndexGTE(r, cu); + int end = internPosFIndexGTE(r, cu); if( end < 0 ) //delete all remaining end = start+len; @@ -703,6 +731,11 @@ public class SparseBlockCSR extends SparseBlock @Override public int posFIndexLTE(int r, int c) { + int index = internPosFIndexLTE(r, c); + return (index>=0) ? index-pos(r) : index; + } + + private int internPosFIndexLTE(int r, int c) { int pos = pos(r); int len = size(r); @@ -718,6 +751,11 @@ public class SparseBlockCSR extends SparseBlock @Override public int posFIndexGTE(int r, int c) { + int index = internPosFIndexGTE(r, c); + return (index>=0) ? index-pos(r) : index; + } + + private int internPosFIndexGTE(int r, int c) { int pos = pos(r); int len = size(r); @@ -733,6 +771,11 @@ public class SparseBlockCSR extends SparseBlock @Override public int posFIndexGT(int r, int c) { + int index = internPosFIndexGT(r, c); + return (index>=0) ? index-pos(r) : index; + } + + private int internPosFIndexGT(int r, int c) { int pos = pos(r); int len = size(r); @@ -768,7 +811,7 @@ public class SparseBlockCSR extends SparseBlock sb.append("\t"); } sb.append("\n"); - } + } return sb.toString(); } @@ -783,7 +826,7 @@ public class SparseBlockCSR extends SparseBlock tmpCap *= (tmpCap <= 1024) ? RESIZE_FACTOR1 : RESIZE_FACTOR2; } - + return (int)Math.min(tmpCap, Integer.MAX_VALUE); } @@ -825,7 +868,7 @@ public class SparseBlockCSR extends SparseBlock insert(ix, c, v); } - private void shiftRightAndInsert(int ix, int c, double v) { + private void shiftRightAndInsert(int ix, int c, double v) { //overlapping array copy (shift rhs values right by 1) System.arraycopy(_indexes, ix, _indexes, ix+1, _size-ix); System.arraycopy(_values, ix, _values, ix+1, _size-ix); @@ -883,31 +926,4 @@ public class SparseBlockCSR extends SparseBlock for( int i=rl; i<rlen+1; i++ ) _ptr[i]-=cnt; } - - /** - * Get raw access to underlying array of row pointers - * For use in GPU code - * @return array of row pointers - */ - public int[] rowPointers() { - return _ptr; - } - - /** - * Get raw access to underlying array of column indices - * For use in GPU code - * @return array of column indexes - */ - public int[] indexes() { - return _indexes; - } - - /** - * Get raw access to underlying array of values - * For use in GPU code - * @return array of values - */ - public double[] values() { - return _values; - } } http://git-wip-us.apache.org/repos/asf/systemml/blob/98ee9b7d/src/test/java/org/apache/sysml/test/integration/functions/sparse/SparseBlockGetFirstIndex.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/sparse/SparseBlockGetFirstIndex.java b/src/test/java/org/apache/sysml/test/integration/functions/sparse/SparseBlockGetFirstIndex.java index a105adc..16ffbcd 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/sparse/SparseBlockGetFirstIndex.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/sparse/SparseBlockGetFirstIndex.java @@ -206,7 +206,7 @@ public class SparseBlockGetFirstIndex extends AutomatedTestBase //init sparse block SparseBlock sblock = null; MatrixBlock mbtmp = DataConverter.convertToMatrixBlock(A); - SparseBlock srtmp = mbtmp.getSparseBlock(); + SparseBlock srtmp = mbtmp.getSparseBlock(); switch( btype ) { case MCSR: sblock = new SparseBlockMCSR(srtmp); break; case CSR: sblock = new SparseBlockCSR(srtmp); break; @@ -228,7 +228,7 @@ public class SparseBlockGetFirstIndex extends AutomatedTestBase if( sblock.isEmpty(i) != (rnnz[i]==0) ) Assert.fail("Wrong isEmpty(row) result for row nnz: "+rnnz[i]); - //check correct index values + //check correct index values for( int i=0; i<rows; i++ ) { int ix = getFirstIx(A, i, i, itype); int sixpos = -1; @@ -237,10 +237,11 @@ public class SparseBlockGetFirstIndex extends AutomatedTestBase case GTE: sixpos = sblock.posFIndexGTE(i, i); break; case LTE: sixpos = sblock.posFIndexLTE(i, i); break; } - int six = (sixpos>=0) ? sblock.indexes(i)[sixpos] : -1; + int six = (sixpos>=0) ? + sblock.indexes(i)[sblock.pos(i)+sixpos] : -1; if( six != ix ) { Assert.fail("Wrong index returned by index probe ("+ - itype.toString()+","+i+"): "+six+", expected: "+ix); + itype.toString()+","+i+"): "+six+", expected: "+ix); } } } @@ -255,19 +256,19 @@ public class SparseBlockGetFirstIndex extends AutomatedTestBase for( int j=cix+1; j<cols; j++ ) if( A[rix][j] != 0 ) return j; - return -1; + return -1; } else if( type==IndexType.GTE ) { for( int j=cix; j<cols; j++ ) if( A[rix][j] != 0 ) return j; - return -1; + return -1; } else if( type==IndexType.LTE ) { for( int j=cix; j>=0; j-- ) if( A[rix][j] != 0 ) return j; - return -1; + return -1; } return -1;
