This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 6a5d35ac96bd9e20c0aa39a12d4f4b733a855c04 Author: baunsgaard <[email protected]> AuthorDate: Wed Jan 27 14:18:02 2021 +0100 [SYSTEMDS-2811] Compressed slice This PR contains an addition to slice compressed, that allows the slices not to decompress the entire matrix, but only the sliced parts. The implementation handle the following cases: - single value slice, is replaced by a getValue() that is placed into a new matrix output. - Row slices, that only decompress the rows selected - Column slices, that maintain compressed outputs. - Selective row/col slices that leverage first the column slice followed by a decompressing row slice. A further improvement would be to maintain compression if the Row slice is large enough to do so. But this would require further work. Closes #1173 --- .../compress/AbstractCompressedMatrixBlock.java | 1 - .../runtime/compress/CompressedMatrixBlock.java | 81 +++++++++--- .../runtime/compress/colgroup/ADictionary.java | 9 ++ .../sysds/runtime/compress/colgroup/ColGroup.java | 23 +++- .../runtime/compress/colgroup/ColGroupDDC.java | 7 +- .../runtime/compress/colgroup/ColGroupOLE.java | 32 +++-- .../runtime/compress/colgroup/ColGroupRLE.java | 141 ++++++++++++--------- .../compress/colgroup/ColGroupUncompressed.java | 5 + .../runtime/compress/colgroup/ColGroupValue.java | 91 ++++++++++--- .../runtime/compress/colgroup/Dictionary.java | 16 +++ .../runtime/compress/colgroup/QDictionary.java | 15 +++ .../spark/MatrixIndexingSPInstruction.java | 9 +- .../sysds/runtime/matrix/data/MatrixBlock.java | 111 ++++++++-------- .../sysds/runtime/matrix/data/MatrixValue.java | 15 +++ .../matrix/data/OperationsOnMatrixValues.java | 38 ++++-- src/test/java/org/apache/sysds/test/TestUtils.java | 2 +- .../component/compress/CompressedMatrixTest.java | 9 +- .../component/compress/CompressedTestBase.java | 71 ++++++++++- .../sysds/test/component/matrix/SliceTest.java | 63 +++++++++ 19 files changed, 546 insertions(+), 193 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java index a9208f1..ffbd017 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java @@ -479,7 +479,6 @@ public abstract class AbstractCompressedMatrixBlock extends MatrixBlock { protected void printDecompressWarning(String operation) { LOG.warn("Operation '" + operation + "' not supported yet - decompressing for ULA operations."); - } protected void printDecompressWarning(String operation, MatrixBlock m2) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java index 074ec80..6186d1a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java @@ -277,17 +277,20 @@ public class CompressedMatrixBlock extends AbstractCompressedMatrixBlock { public double quickGetValue(int r, int c) { // TODO Optimize Quick Get Value, to located the correct column group without having to search for it - double v = 0.0; - for(ColGroup group : _colGroups) { - if(Arrays.binarySearch(group.getColIndices(), c) >= 0) { - v += group.get(r, c); - if(!isOverlapping()) - break; - } + if(isOverlapping()) { + double v = 0.0; + for(ColGroup group : _colGroups) + if(Arrays.binarySearch(group.getColIndices(), c) >= 0) + v += group.get(r, c); + return v; + } + else { + for(ColGroup group : _colGroups) + if(Arrays.binarySearch(group.getColIndices(), c) >= 0) + return group.get(r, c); + return 0; } - // find row value - return v; } ////////////////////////////////////////// @@ -736,21 +739,59 @@ public class CompressedMatrixBlock extends AbstractCompressedMatrixBlock { @Override public MatrixBlock slice(int rl, int ru, int cl, int cu, boolean deep, CacheBlock ret) { - printDecompressWarning("slice"); - MatrixBlock tmp = decompress(); - return tmp.slice(rl, ru, cl, cu, ret); + validateSliceArgument(rl, ru, cl, cu); + MatrixBlock tmp; + if(rl == ru && cl == cu) { + // get a single index, and return in a matrixBlock + tmp = new MatrixBlock(1, 1, 0); + tmp.appendValue(0, 0, getValue(rl, cl)); + return tmp; + } + else if(cl == 0 && cu == getNumColumns() - 1) { + // Row Slice. Potential optimization if the slice contains enough rows. + // +1 since the implementation arguments for slice is inclusive values for ru and cu. + // and it is not inclusive in decompression, and construction of MatrixBlock. + tmp = new MatrixBlock(ru + 1 - rl, getNumColumns(), false).allocateDenseBlock(); + for(ColGroup g : getColGroups()) + g.decompressToBlock(tmp, rl, ru + 1, 0); + return tmp; + } + else if(rl == 0 && ru == getNumRows() - 1) { + tmp = sliceColumns(cl, cu); + } + else { + // In the case where an internal matrix is sliced out, then first slice out the columns + // to an compressed intermediate. + tmp = sliceColumns(cl, cu); + // Then call slice recursively, to do the row slice. + // Since we do not copy the index structure but simply maintain a pointer to the original + // this is fine. + tmp = tmp.slice(rl, ru, 0, tmp.getNumColumns() - 1, ret); + } + ret = tmp; + return tmp; + } + + private CompressedMatrixBlock sliceColumns(int cl, int cu) { + CompressedMatrixBlock ret = new CompressedMatrixBlock(this.getNumRows(), cu + 1 - cl); + + List<ColGroup> newColGroups = new ArrayList<>(); + for(ColGroup grp : getColGroups()) { + ColGroup slice = grp.sliceColumns(cl, cu + 1); + if(slice != null) + newColGroups.add(slice); + } + ret.allocateColGroupList(newColGroups); + + return ret; } @Override public void slice(ArrayList<IndexedMatrixValue> outlist, IndexRange range, int rowCut, int colCut, int blen, int boundaryRlen, int boundaryClen) { - printDecompressWarning("slice"); - try { - MatrixBlock tmp = decompress(); - tmp.slice(outlist, range, rowCut, colCut, blen, boundaryRlen, boundaryClen); - } - catch(DMLRuntimeException ex) { - throw new RuntimeException(ex); - } + printDecompressWarning( + "slice for distribution to spark. (Could be implemented such that it does not decompress)"); + MatrixBlock tmp = decompress(); + tmp.slice(outlist, range, rowCut, colCut, blen, boundaryRlen, boundaryClen); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictionary.java index 43267e3..1e07289 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictionary.java @@ -210,4 +210,13 @@ public abstract class ADictionary { */ protected abstract void addMaxAndMin(double[] ret, int[] colIndexes); + /** + * Modify the dictionary by removing columns not within the index range. + * + * @param idxStart The column index to start at. + * @param idxEnd The column index to end at (not inclusive) + * @param previousNumberOfColumns The number of columns contained in the dictionary. + * @return A dictionary containing the sliced out columns values only. + */ + public abstract ADictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNumberOfColumns); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroup.java index c408199..20e3804 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroup.java @@ -325,29 +325,26 @@ public abstract class ColGroup implements Serializable { */ public abstract void decompressColumnToBlock(MatrixBlock target, int colpos); - /** * Decompress to block. * * @param target dense output vector * @param colpos column to decompress, error if larger or equal numCols - * @param rl the Row to start decompression from - * @param ru the Row to end decompression at + * @param rl the Row to start decompression from + * @param ru the Row to end decompression at */ public abstract void decompressColumnToBlock(MatrixBlock target, int colpos, int rl, int ru); - /** * Decompress to dense array. * * @param target dense output vector double array. * @param colpos column to decompress, error if larger or equal numCols - * @param rl the Row to start decompression from - * @param ru the Row to end decompression at + * @param rl the Row to start decompression from + * @param ru the Row to end decompression at */ public abstract void decompressColumnToBlock(double[] target, int colpos, int rl, int ru); - /** * Serializes column group to data output. * @@ -560,4 +557,16 @@ public abstract class ColGroup implements Serializable { * @return returns if the colgroup is allocated in a dense fashion. */ public abstract boolean isDense(); + + /** + * Slice out the columns within the range of cl and cu to remove the dictionary values related to these columns. + * If the ColGroup slicing from does not contain any columns within the range null is returned. + * + * @param cl The lower bound of the columns to select + * @param cu the upper bound of the columns to select (not inclusive). + * @return A cloned Column Group, with a copied pointer to the old column groups index structure, but reduced + * dictionary and _columnIndexes correctly aligned with the expected sliced compressed matrix. + */ + public abstract ColGroup sliceColumns(int cl, int cu); + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java index c187fe4..f673c89 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java @@ -156,12 +156,11 @@ public abstract class ColGroupDDC extends ColGroupValue { // get value int index = getIndex(r); - if(index < getNumValues()) { + if(index < getNumValues()) return _dict.getValue(index * _colIndexes.length + ix); - } - else { + else return 0.0; - } + } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java index e8288b9..f530ee1 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java @@ -21,7 +21,6 @@ package org.apache.sysds.runtime.compress.colgroup; import java.util.Arrays; -import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.runtime.DMLCompressionException; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.utils.ABitmap; @@ -113,7 +112,6 @@ public class ColGroupOLE extends ColGroupOffset { int rc = rix * target.getNumColumns(); for(int j = 0; j < numCols; j++) { if(safe) { - double v = c[rc + _colIndexes[j]]; double nv = c[rc + _colIndexes[j]] + values[off + j]; if(v == 0.0 && nv != 0.0) { @@ -903,6 +901,30 @@ public class ColGroupOLE extends ColGroupOffset { } } + @Override + public double get(int r, int c) { + final int blksz = CompressionSettings.BITMAP_BLOCK_SZ; + final int numVals = getNumValues(); + int idColOffset = Arrays.binarySearch(_colIndexes, c); + if(idColOffset < 0) + return 0; + int[] apos = skipScan(numVals, r); + int offset = r % blksz; + for(int k = 0; k < numVals; k++) { + int boff = _ptr[k]; + int blen = len(k); + int bix = apos[k]; + int slen = _data[boff + bix]; + for(int blckIx = 1; blckIx <= slen && blckIx < blen; blckIx++) { + if(_data[boff + bix + blckIx] == offset) + return _dict.getValue(k * _colIndexes.length + idColOffset); + else if(_data[boff + bix + blckIx] > offset) + continue; + } + } + return 0; + } + ///////////////////////////////// // internal helper functions @@ -1010,10 +1032,4 @@ public class ColGroupOLE extends ColGroupOffset { return encodedBlocks; } - @Override - public double get(int r, int c) { - throw new NotImplementedException("Not Implemented get(r,c) after removal of iterators in colgroups"); - // TODO Auto-generated method stub - // return 0; - } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java index ed581e7..5a81fa2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java @@ -23,7 +23,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.utils.ABitmap; import org.apache.sysds.runtime.compress.utils.LinearAlgebraUtils; @@ -108,24 +107,24 @@ public class ColGroupRLE extends ColGroupOffset { for(; bix < blen & start < bimax; bix += 2) { start += _data[boff + bix]; int len = _data[boff + bix + 1]; - for(int i = Math.max(rl, start) - (rl - offT); i < Math.min(start + len, ru) - (rl - offT); i++){ + for(int i = Math.max(rl, start) - (rl - offT); i < Math.min(start + len, ru) - (rl - offT); i++) { - int rc = i * target.getNumColumns(); + int rc = i * target.getNumColumns(); for(int j = 0; j < numCols; j++) { - if(values[off + j] != 0) { - if(safe) { - double v = c[rc + _colIndexes[j]]; - double nv = c[rc + _colIndexes[j]] + values[off + j]; - if(v == 0.0 && nv != 0.0) { - target.setNonZeros(target.getNonZeros() + 1); - } - c[rc + _colIndexes[j]] = nv; - } - else { - c[rc + _colIndexes[j]] += values[off + j]; + if(values[off + j] != 0) { + if(safe) { + double v = c[rc + _colIndexes[j]]; + double nv = c[rc + _colIndexes[j]] + values[off + j]; + if(v == 0.0 && nv != 0.0) { + target.setNonZeros(target.getNonZeros() + 1); } + c[rc + _colIndexes[j]] = nv; + } + else { + c[rc + _colIndexes[j]] += values[off + j]; } } + } } start += len; } @@ -138,50 +137,50 @@ public class ColGroupRLE extends ColGroupOffset { @Override public void decompressToBlock(MatrixBlock target, int[] colixTargets) { // if(getNumValues() > 1) { - final int blksz = CompressionSettings.BITMAP_BLOCK_SZ; - final int numCols = getNumCols(); - final int numVals = getNumValues(); - final double[] values = getValues(); + final int blksz = CompressionSettings.BITMAP_BLOCK_SZ; + final int numCols = getNumCols(); + final int numVals = getNumValues(); + final double[] values = getValues(); - // position and start offset arrays - int[] apos = new int[numVals]; - int[] astart = new int[numVals]; - int[] cix = new int[numCols]; + // position and start offset arrays + int[] apos = new int[numVals]; + int[] astart = new int[numVals]; + int[] cix = new int[numCols]; - // prepare target col indexes - for(int j = 0; j < numCols; j++) - cix[j] = colixTargets[_colIndexes[j]]; + // prepare target col indexes + for(int j = 0; j < numCols; j++) + cix[j] = colixTargets[_colIndexes[j]]; - // cache conscious append via horizontal scans - for(int bi = 0; bi < _numRows; bi += blksz) { - int bimax = Math.min(bi + blksz, _numRows); - for(int k = 0, off = 0; k < numVals; k++, off += numCols) { - int boff = _ptr[k]; - int blen = len(k); - int bix = apos[k]; - if(bix >= blen) - continue; - int start = astart[k]; - for(; bix < blen & start < bimax; bix += 2) { - start += _data[boff + bix]; - int len = _data[boff + bix + 1]; - for(int i = start; i < start + len; i++) - for(int j = 0; j < numCols; j++) - if(values[off + j] != 0) { - double v = target.quickGetValue(i, _colIndexes[j]); - target.setValue(i, _colIndexes[j], values[off + j] + v); - } + // cache conscious append via horizontal scans + for(int bi = 0; bi < _numRows; bi += blksz) { + int bimax = Math.min(bi + blksz, _numRows); + for(int k = 0, off = 0; k < numVals; k++, off += numCols) { + int boff = _ptr[k]; + int blen = len(k); + int bix = apos[k]; + if(bix >= blen) + continue; + int start = astart[k]; + for(; bix < blen & start < bimax; bix += 2) { + start += _data[boff + bix]; + int len = _data[boff + bix + 1]; + for(int i = start; i < start + len; i++) + for(int j = 0; j < numCols; j++) + if(values[off + j] != 0) { + double v = target.quickGetValue(i, _colIndexes[j]); + target.setValue(i, _colIndexes[j], values[off + j] + v); + } - start += len; - } - apos[k] = bix; - astart[k] = start; + start += len; } + apos[k] = bix; + astart[k] = start; } + } // } // else { - // // call generic decompression with decoder - // super.decompressToBlock(target, colixTargets); + // // call generic decompression with decoder + // super.decompressToBlock(target, colixTargets); // } } @@ -250,7 +249,7 @@ public class ColGroupRLE extends ColGroupOffset { for(; bix < blen & start < bimax; bix += 2) { start += _data[boff + bix]; int len = _data[boff + bix + 1]; - if(start + len >= rl){ + if(start + len >= rl) { int offsetStart = Math.max(start, rl); for(int i = offsetStart; i < Math.min(start + len, bimax); i++) c[i - rl] += values[off + colpos]; @@ -290,7 +289,7 @@ public class ColGroupRLE extends ColGroupOffset { for(; bix < blen & start < bimax; bix += 2) { start += _data[boff + bix]; int len = _data[boff + bix + 1]; - if(start + len >= rl){ + if(start + len >= rl) { int offsetStart = Math.max(start, rl); for(int i = offsetStart; i < Math.min(start + len, bimax); i++) c[i - rl] += values[off + colpos]; @@ -303,7 +302,6 @@ public class ColGroupRLE extends ColGroupOffset { } } - @Override public int[] getCounts(int[] counts) { final int numVals = getNumValues(); @@ -466,7 +464,8 @@ public class ColGroupRLE extends ColGroupOffset { Math.max(rl, start + lstart), Math.min(start + lstart + llen, ru), outputColumns, - thatNrColumns,k); + thatNrColumns, + k); if(start + lstart + llen >= ru) break; start += lstart + llen; @@ -883,6 +882,35 @@ public class ColGroupRLE extends ColGroupOffset { } } + @Override + public double get(int r, int c) { + + final int numVals = getNumValues(); + int idColOffset = Arrays.binarySearch(_colIndexes, c); + if(idColOffset < 0) + return 0; + int[] astart = new int[numVals]; + int[] apos = skipScan(numVals, r, astart); + for(int k = 0; k < numVals; k++) { + int boff = _ptr[k]; + int blen = len(k); + int bix = apos[k]; + int start = astart[k]; + for(; bix < blen && start <= r; bix += 2) { + int lstart = _data[boff + bix]; + int llen = _data[boff + bix + 1]; + int from = start + lstart; + int to = start + lstart + llen; + if(r >= from && r < to) + return _dict.getValue(k * _colIndexes.length + idColOffset); + start += lstart + llen; + } + + } + + return 0; + } + ///////////////////////////////// // internal helper functions @@ -1039,9 +1067,4 @@ public class ColGroupRLE extends ColGroupOffset { ret[i] = buf.get(i); return ret; } - - @Override - public double get(int r, int c) { - throw new NotImplementedException("Not Implemented get(r,c) after removal of iterators in colgroups"); - } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java index e432054..b0aa9df 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java @@ -509,4 +509,9 @@ public class ColGroupUncompressed extends ColGroup { // they are dense in the sense of compression. return true; } + + @Override + public ColGroup sliceColumns(int cl, int cu){ + throw new NotImplementedException("Not implemented slice columns"); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java index 0fb9dac..7d53ac2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java @@ -134,8 +134,7 @@ public abstract class ColGroupValue extends ColGroup implements Cloneable { return _dict; } - - public void addMinMax(double[] ret){ + public void addMinMax(double[] ret) { _dict.addMaxAndMin(ret, _colIndexes); } @@ -258,14 +257,14 @@ public abstract class ColGroupValue extends ColGroup implements Cloneable { return ret; } - private int[] getAggregateColumnsSetDense(double[] b, int cl, int cu, int cut){ + private int[] getAggregateColumnsSetDense(double[] b, int cl, int cu, int cut) { Set<Integer> aggregateColumnsSet = new HashSet<>(); final int retCols = (cu - cl); - for(int k = 0; k < _colIndexes.length; k ++) { + for(int k = 0; k < _colIndexes.length; k++) { int rowIdxOffset = _colIndexes[k] * cut; for(int h = cl; h < cu; h++) { double v = b[rowIdxOffset + h]; - if(v != 0.0){ + if(v != 0.0) { aggregateColumnsSet.add(h); } } @@ -284,7 +283,9 @@ public abstract class ColGroupValue extends ColGroup implements Cloneable { int[] aggregateColumns = getAggregateColumnsSetDense(b, cl, cu, cut); double[] ret = new double[numVals * aggregateColumns.length]; - for(int k = 0, off = 0; k < numVals * _colIndexes.length; k += _colIndexes.length, off += aggregateColumns.length) { + for(int k = 0, off = 0; + k < numVals * _colIndexes.length; + k += _colIndexes.length, off += aggregateColumns.length) { for(int h = 0; h < _colIndexes.length; h++) { int idb = _colIndexes[h] * cut; double v = dictVals[k + h]; @@ -298,7 +299,7 @@ public abstract class ColGroupValue extends ColGroup implements Cloneable { return new ImmutablePair<>(aggregateColumns, ret); } - private int[] getAggregateColumnsSetSparse(SparseBlock b){ + private int[] getAggregateColumnsSetSparse(SparseBlock b) { Set<Integer> aggregateColumnsSet = new HashSet<>(); for(int h = 0; h < _colIndexes.length; h++) { @@ -321,7 +322,7 @@ public abstract class ColGroupValue extends ColGroup implements Cloneable { int[] aggregateColumns = getAggregateColumnsSetSparse(b); - double[] ret = new double[numVals *aggregateColumns.length]; + double[] ret = new double[numVals * aggregateColumns.length]; for(int h = 0; h < _colIndexes.length; h++) { int colIdx = _colIndexes[h]; @@ -332,15 +333,17 @@ public abstract class ColGroupValue extends ColGroup implements Cloneable { for(int i = b.pos(colIdx); i < b.size(colIdx) + b.pos(colIdx); i++) { while(aggregateColumns[retIdx] < sIndexes[i]) retIdx++; - if(sIndexes[i] == aggregateColumns[retIdx] ) - for(int j = 0, offOrg = h; j < numVals * aggregateColumns.length; j += aggregateColumns.length, offOrg += _colIndexes.length) { + if(sIndexes[i] == aggregateColumns[retIdx]) + for(int j = 0, offOrg = h; + j < numVals * aggregateColumns.length; + j += aggregateColumns.length, offOrg += _colIndexes.length) { ret[j + retIdx] += dictVals[offOrg] * sValues[i]; } } } } - return new ImmutablePair<> (aggregateColumns, ret); + return new ImmutablePair<>(aggregateColumns, ret); } public Pair<int[], double[]> preaggValues(int numVals, MatrixBlock b, double[] dictVals, int cl, int cu, int cut) { @@ -479,8 +482,8 @@ public abstract class ColGroupValue extends ColGroup implements Cloneable { } public static void setupThreadLocalMemory(int len) { - if(memPool.get() == null || memPool.get().getLeft().length < len){ - Pair<int[], double[]> p = new ImmutablePair<>(new int[len],new double[len] ); + if(memPool.get() == null || memPool.get().getLeft().length < len) { + Pair<int[], double[]> p = new ImmutablePair<>(new int[len], new double[len]); memPool.set(p); } } @@ -497,7 +500,7 @@ public abstract class ColGroupValue extends ColGroup implements Cloneable { return new double[len]; } - if(p.getValue().length < len){ + if(p.getValue().length < len) { setupThreadLocalMemory(len); return p.getValue(); } @@ -515,8 +518,8 @@ public abstract class ColGroupValue extends ColGroup implements Cloneable { // sanity check for missing setup if(p == null) return new int[len + 1]; - - if(p.getKey().length < len){ + + if(p.getKey().length < len) { setupThreadLocalMemory(len); return p.getKey(); } @@ -623,7 +626,7 @@ public abstract class ColGroupValue extends ColGroup implements Cloneable { * * @return a shallow copy of the colGroup. */ - public ColGroup copy() { + public ColGroupValue copy() { try { ColGroupValue clone = (ColGroupValue) this.clone(); return clone; @@ -641,4 +644,58 @@ public abstract class ColGroupValue extends ColGroup implements Cloneable { leftMultByRowVector(a, result, numVals, values); } + + @Override + public ColGroup sliceColumns(int cl, int cu) { + if(cu - cl == 1) + return sliceSingleColumn(cl); + else + return sliceMultiColumns(cl, cu); + } + + private ColGroup sliceSingleColumn(int col) { + ColGroupValue ret = (ColGroupValue) copy(); + + int idx = Arrays.binarySearch(_colIndexes, col); + // Binary search returns negative value if the column is not found. + // Therefore return null, if the column is not inside this colGroup. + if(idx >= 0) { + ret._colIndexes = new int[1]; + ret._colIndexes[0] = _colIndexes[idx] - col; + if(_colIndexes.length == 1) + ret._dict = ret._dict.clone(); + else + ret._dict = ret._dict.sliceOutColumnRange(idx, idx + 1, _colIndexes.length); + + return ret; + } + else + return null; + } + + private ColGroup sliceMultiColumns(int cl, int cu) { + ColGroupValue ret = (ColGroupValue) copy(); + int idStart = 0; + int idEnd = 0; + for(int i = 0; i < _colIndexes.length; i++) { + if(_colIndexes[i] < cl) + idStart++; + if(_colIndexes[i] < cu) + idEnd++; + } + int numberOfOutputColumns = idEnd - idStart; + if(numberOfOutputColumns > 0) { + ret._dict = ret._dict.sliceOutColumnRange(idStart, idEnd, _colIndexes.length); + + ret._colIndexes = new int[numberOfOutputColumns]; + // Incrementing idStart here so make sure that the dictionary is extracted before + for(int i = 0; i < numberOfOutputColumns; i++) { + ret._colIndexes[i] = _colIndexes[idStart++] - cl; + } + return ret; + } + else + return null; + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/Dictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/Dictionary.java index 5957e5d..a74be56 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/Dictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/Dictionary.java @@ -291,4 +291,20 @@ public class Dictionary extends ADictionary { sb.append("]"); return sb; } + + + public ADictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNumberOfColumns){ + int numberTuples = getNumberOfValues(previousNumberOfColumns); + int tupleLengthAfter = idxEnd - idxStart; + double[] newDictValues = new double[tupleLengthAfter * numberTuples]; + int orgOffset = idxStart; + int targetOffset = 0; + for(int v = 0; v < numberTuples; v++){ + for(int c = 0; c< tupleLengthAfter; c++, orgOffset++, targetOffset++){ + newDictValues[targetOffset] = _values[orgOffset]; + } + orgOffset += previousNumberOfColumns - idxEnd + idxStart; + } + return new Dictionary(newDictValues); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/QDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/QDictionary.java index bf3687a..43ab8a6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/QDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/QDictionary.java @@ -393,4 +393,19 @@ public class QDictionary extends ADictionary { double[] doubleValues = getValues(); return new Dictionary(doubleValues); } + + public ADictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNumberOfColumns){ + int numberTuples = getNumberOfValues(previousNumberOfColumns); + int tupleLengthAfter = idxEnd - idxStart; + byte[] newDictValues = new byte[tupleLengthAfter * numberTuples]; + int orgOffset = idxStart; + int targetOffset = 0; + for(int v = 0; v < numberTuples; v++){ + for(int c = 0; c< tupleLengthAfter; c++, orgOffset++, targetOffset++){ + newDictValues[targetOffset] = _values[orgOffset]; + } + orgOffset += previousNumberOfColumns - idxEnd + idxStart; + } + return new QDictionary(newDictValues, _scale); + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java index 63e40c6..5d0711b 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java @@ -508,8 +508,7 @@ public class MatrixIndexingSPInstruction extends IndexingSPInstruction { throws Exception { IndexedMatrixValue in = SparkUtils.toIndexedMatrixBlock(kv); - ArrayList<IndexedMatrixValue> outlist = new ArrayList<>(); - OperationsOnMatrixValues.performSlice(in, _ixrange, _blen, outlist); + ArrayList<IndexedMatrixValue> outlist = OperationsOnMatrixValues.performSlice(in, _ixrange, _blen); return SparkUtils.fromIndexedMatrixBlock(outlist).iterator(); } } @@ -534,8 +533,7 @@ public class MatrixIndexingSPInstruction extends IndexingSPInstruction { throws Exception { IndexedMatrixValue in = new IndexedMatrixValue(kv._1(), kv._2()); - ArrayList<IndexedMatrixValue> outlist = new ArrayList<>(); - OperationsOnMatrixValues.performSlice(in, _ixrange, _blen, outlist); + ArrayList<IndexedMatrixValue> outlist = OperationsOnMatrixValues.performSlice(in, _ixrange, _blen); return SparkUtils.fromIndexedMatrixBlock(outlist.get(0)); } } @@ -571,8 +569,7 @@ public class MatrixIndexingSPInstruction extends IndexingSPInstruction { { IndexedMatrixValue in = SparkUtils.toIndexedMatrixBlock(arg); - ArrayList<IndexedMatrixValue> outlist = new ArrayList<>(); - OperationsOnMatrixValues.performSlice(in, _ixrange, _blen, outlist); + ArrayList<IndexedMatrixValue> outlist = OperationsOnMatrixValues.performSlice(in, _ixrange, _blen); assert(outlist.size() == 1); //1-1 row/column block indexing return SparkUtils.fromIndexedMatrixBlock(outlist.get(0)); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index 42857e0..42e65da 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -3884,10 +3884,20 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab (int)ixrange.colStart, (int)ixrange.colEnd, true, ret); } + /** + * Slice out a row block + * @param rl The row lower to start from + * @param ru The row lower to end at + * @return The sliced out matrix block. + */ public MatrixBlock slice(int rl, int ru) { - return slice(rl, ru, 0, clen-1, true, new MatrixBlock()); + return slice(rl, ru, 0, clen-1, true, null); } + public MatrixBlock slice(int rl, int ru, int cl, int cu){ + return slice(rl, ru, cl, cu, true, null); + } + @Override public MatrixBlock slice(int rl, int ru, int cl, int cu, CacheBlock ret) { return slice(rl, ru, cl, cu, true, ret); @@ -3897,21 +3907,19 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab * Method to perform rightIndex operation for a given lower and upper bounds in row and column dimensions. * Extracted submatrix is returned as "result". Note: This operation is now 0-based. * - * @param rl row lower - * @param ru row upper - * @param cl column lower - * @param cu column upper - * @param deep should perform deep copy - * @param ret output matrix block + * This means that if you call with rl == ru then you get 1 row output. + * + * @param rl row lower if this value is bellow 0 or above the number of rows contained in the matrix an execption is thrown + * @param ru row upper if this value is bellow rl or above the number of rows contained in the matrix an exception is thrown + * @param cl column lower if this value us bellow 0 or above the number of columns contained in the matrix an exception is thrown + * @param cu column upper if this value us bellow cl or above the number of columns contained in the matrix an exception is thrown + * @param deep should perform deep copy, this is relelvant in cases where the matrix is in sparse format, + * or the entire matrix is sliced out + * @param ret output sliced out matrix block * @return matrix block output matrix block */ public MatrixBlock slice(int rl, int ru, int cl, int cu, boolean deep, CacheBlock ret) { - // check the validity of bounds - if ( rl < 0 || rl >= getNumRows() || ru < rl || ru >= getNumRows() - || cl < 0 || cl >= getNumColumns() || cu < cl || cu >= getNumColumns() ) { - throw new DMLRuntimeException("Invalid values for matrix indexing: ["+(rl+1)+":"+(ru+1)+"," + (cl+1)+":"+(cu+1)+"] " - + "must be within matrix dimensions ["+getNumRows()+","+getNumColumns()+"]"); - } + validateSliceArgument(rl, ru, cl, cu); // Output matrix will have the same sparsity as that of the input matrix. // (assuming a uniform distribution of non-zeros in the input) @@ -3943,6 +3951,15 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab return result; } + protected void validateSliceArgument(int rl, int ru, int cl, int cu){ + // check the validity of bounds + if ( rl < 0 || rl >= getNumRows() || ru < rl || ru >= getNumRows() + || cl < 0 || cl >= getNumColumns() || cu < cl || cu >= getNumColumns() ) { + throw new DMLRuntimeException("Invalid values for matrix indexing: ["+(rl+1)+":"+(ru+1)+"," + (cl+1)+":"+(cu+1)+"] " + + "must be within matrix dimensions ["+getNumRows()+","+getNumColumns()+"]"); + } + } + private void sliceSparse(int rl, int ru, int cl, int cu, boolean deep, MatrixBlock dest) { //check for early abort if( isEmptyBlock(false) ) @@ -4025,14 +4042,14 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab int rowCut, int colCut, int blen, int boundaryRlen, int boundaryClen) { MatrixBlock topleft=null, topright=null, bottomleft=null, bottomright=null; - Iterator<IndexedMatrixValue> p=outlist.iterator(); + Iterator<IndexedMatrixValue> p = outlist.iterator(); int blockRowFactor=blen, blockColFactor=blen; if(rowCut>range.rowEnd) blockRowFactor=boundaryRlen; if(colCut>range.colEnd) blockColFactor=boundaryClen; - int minrowcut=(int)Math.min(rowCut,range.rowEnd); + int minrowcut=(int)Math.min(rowCut, range.rowEnd); int mincolcut=(int)Math.min(colCut, range.colEnd); int maxrowcut=(int)Math.max(rowCut, range.rowStart); int maxcolcut=(int)Math.max(colCut, range.colStart); @@ -4040,9 +4057,6 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab if(range.rowStart<rowCut && range.colStart<colCut) { topleft=(MatrixBlock) p.next().getValue(); - //topleft.reset(blockRowFactor, blockColFactor, - // checkSparcityOnSlide(rowCut-(int)range.rowStart, colCut-(int)range.colStart, blockRowFactor, blockColFactor)); - topleft.reset(blockRowFactor, blockColFactor, estimateSparsityOnSlice(minrowcut-(int)range.rowStart, mincolcut-(int)range.colStart, blockRowFactor, blockColFactor)); } @@ -4065,43 +4079,36 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab estimateSparsityOnSlice((int)range.rowEnd-maxrowcut+1, (int)range.colEnd-maxcolcut+1, boundaryRlen, boundaryClen)); } - if(sparse) - { - if(sparseBlock!=null) + if(sparse && sparseBlock!=null){ + int r=(int)range.rowStart; + for(; r<Math.min(Math.min(rowCut, sparseBlock.numRows()), range.rowEnd+1); r++) + sliceHelp(r, range, colCut, topleft, topright, blen-rowCut, blen, blen); + + for(; r<=Math.min(range.rowEnd, sparseBlock.numRows()-1); r++) + sliceHelp(r, range, colCut, bottomleft, bottomright, -rowCut, blen, blen); + } + else if(denseBlock!=null){ + double[] a = getDenseBlockValues(); + int i=((int)range.rowStart)*clen; + int r=(int) range.rowStart; + for(; r<Math.min(rowCut, range.rowEnd+1); r++) { - int r=(int)range.rowStart; - for(; r<Math.min(Math.min(rowCut, sparseBlock.numRows()), range.rowEnd+1); r++) - sliceHelp(r, range, colCut, topleft, topright, blen-rowCut, blen, blen); - - for(; r<=Math.min(range.rowEnd, sparseBlock.numRows()-1); r++) - sliceHelp(r, range, colCut, bottomleft, bottomright, -rowCut, blen, blen); + int c=(int) range.colStart; + for(; c<Math.min(colCut, range.colEnd+1); c++) + topleft.appendValue(r+blen-rowCut, c+blen-colCut, a[i+c]); + for(; c<=range.colEnd; c++) + topright.appendValue(r+blen-rowCut, c-colCut, a[i+c]); + i+=clen; } - } - else { - if(denseBlock!=null) + + for(; r<=range.rowEnd; r++) { - double[] a = getDenseBlockValues(); - int i=((int)range.rowStart)*clen; - int r=(int) range.rowStart; - for(; r<Math.min(rowCut, range.rowEnd+1); r++) - { - int c=(int) range.colStart; - for(; c<Math.min(colCut, range.colEnd+1); c++) - topleft.appendValue(r+blen-rowCut, c+blen-colCut, a[i+c]); - for(; c<=range.colEnd; c++) - topright.appendValue(r+blen-rowCut, c-colCut, a[i+c]); - i+=clen; - } - - for(; r<=range.rowEnd; r++) - { - int c=(int) range.colStart; - for(; c<Math.min(colCut, range.colEnd+1); c++) - bottomleft.appendValue(r-rowCut, c+blen-colCut, a[i+c]); - for(; c<=range.colEnd; c++) - bottomright.appendValue(r-rowCut, c-colCut, a[i+c]); - i+=clen; - } + int c=(int) range.colStart; + for(; c<Math.min(colCut, range.colEnd+1); c++) + bottomleft.appendValue(r-rowCut, c+blen-colCut, a[i+c]); + for(; c<=range.colEnd; c++) + bottomright.appendValue(r-rowCut, c-colCut, a[i+c]); + i+=clen; } } } diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java index 9b213ec..d90e844 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java @@ -132,6 +132,21 @@ public abstract class MatrixValue implements WritableComparable public abstract MatrixValue zeroOutOperations(MatrixValue result, IndexRange range, boolean complementary); + /** + * Slice out up to 4 matrixBlocks that are separated by the row and col Cuts. + * + * This is used in the context of spark execution to distributed sliced out matrix blocks of correct block size. + * + * @param outlist The output matrix blocks that is extracted from the matrix + * @param range An index range containing overlapping information. + * @param rowCut The row to cut and split the matrix. + * @param colCut The column to cut ans split the matrix. + * @param blen The Block size of the output matrices. + * @param boundaryRlen The row length of the edge case matrix block, used for the final blocks + * that does not have enough rows to construct a full block. + * @param boundaryClen The col length of the edge case matrix block, used for the final blocks + * that does not have enough cols to construct a full block. + */ public abstract void slice(ArrayList<IndexedMatrixValue> outlist, IndexRange range, int rowCut, int colCut, int blen, int boundaryRlen, int boundaryClen); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/OperationsOnMatrixValues.java b/src/main/java/org/apache/sysds/runtime/matrix/data/OperationsOnMatrixValues.java index 33f638f..0d4b49d 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/OperationsOnMatrixValues.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/OperationsOnMatrixValues.java @@ -241,6 +241,19 @@ public class OperationsOnMatrixValues return value1.aggregateBinaryOperations(value1, value2, valueOut, op); } + /** + * Slice used in broadcasting matrix blocks for spark, since this slice up a given matrix + * into blocks. + * + * The slice call here returns a single block inside the idx range specified + * + * @param ixrange Index range containing the row lower, row upper, col lower and col upper bounds + * @param blen The block size specified, this should align with the ranges. + * @param iix The block row index + * @param jix The block column index + * @param in The Cache block to slice out + * @return A List containing pairs of MatrixIndices and CacheBlocks either containing MatrixBlock or FrameBlocks + */ @SuppressWarnings("rawtypes") public static List performSlice(IndexRange ixrange, int blen, int iix, int jix, CacheBlock in) { if( in instanceof MatrixBlock ) @@ -253,12 +266,13 @@ public class OperationsOnMatrixValues @SuppressWarnings("rawtypes") public static List performSlice(IndexRange ixrange, int blen, int iix, int jix, MatrixBlock in) { IndexedMatrixValue imv = new IndexedMatrixValue(new MatrixIndexes(iix, jix), in); - ArrayList<IndexedMatrixValue> outlist = new ArrayList<>(); - performSlice(imv, ixrange, blen, outlist); + ArrayList<IndexedMatrixValue> outlist = performSlice(imv, ixrange, blen); return SparkUtils.fromIndexedMatrixBlockToPair(outlist); } - public static void performSlice(IndexedMatrixValue in, IndexRange ixrange, int blen, ArrayList<IndexedMatrixValue> outlist) { + public static ArrayList<IndexedMatrixValue> performSlice(IndexedMatrixValue in, IndexRange ixrange, int blen) { + + ArrayList<IndexedMatrixValue> outlist = new ArrayList<>(); long cellIndexTopRow = UtilFunctions.computeCellIndex(in.getIndexes().getRowIndex(), blen, 0); long cellIndexBottomRow = UtilFunctions.computeCellIndex(in.getIndexes().getRowIndex(), blen, in.getValue().getNumRows()-1); long cellIndexLeftCol = UtilFunctions.computeCellIndex(in.getIndexes().getColumnIndex(), blen, 0); @@ -271,7 +285,7 @@ public class OperationsOnMatrixValues //check if block is outside the indexing range if(cellIndexOverlapTop>cellIndexOverlapBottom || cellIndexOverlapLeft>cellIndexOverlapRight) { - return; + return outlist; } IndexRange tmpRange = new IndexRange( @@ -308,13 +322,14 @@ public class OperationsOnMatrixValues for(long r=resultBlockIndexTop; r<=resultBlockIndexBottom; r++) for(long c=resultBlockIndexLeft; c<=resultBlockIndexRight; c++) { - IndexedMatrixValue out=new IndexedMatrixValue(new MatrixIndexes(), new MatrixBlock()); + IndexedMatrixValue out = new IndexedMatrixValue(new MatrixIndexes(), new MatrixBlock()); out.getIndexes().setIndexes(r, c); outlist.add(out); } //execute actual slice operation in.getValue().slice(outlist, tmpRange, rowCut, colCut, blen, boundaryRlen, boundaryClen); + return outlist; } public static void performShift(IndexedMatrixValue in, IndexRange ixrange, int blen, long rlen, long clen, ArrayList<IndexedMatrixValue> outlist) { @@ -399,8 +414,7 @@ public class OperationsOnMatrixValues @SuppressWarnings("rawtypes") public static ArrayList performSlice(IndexRange ixrange, int blen, int iix, int jix, FrameBlock in) { Pair<Long, FrameBlock> lfp = new Pair<>(new Long(((iix-1)*blen)+1), in); - ArrayList<Pair<Long, FrameBlock>> outlist = new ArrayList<>(); - performSlice(lfp, ixrange, blen, outlist); + ArrayList<Pair<Long, FrameBlock>> outlist = performSlice(lfp, ixrange, blen); return outlist; } @@ -409,12 +423,13 @@ public class OperationsOnMatrixValues /** * This function will get slice of the input frame block overlapping in overall slice(Range), slice has requested for. * - * @param in ? + * @param in A Pair of row index to assign the sliced block and input frame block to slice. * @param ixrange index range * @param blen block length - * @param outlist list of pairs of frame blocks + * @return Returns an ArrayList containing pairs of long ids and FrameBlocks */ - public static void performSlice(Pair<Long,FrameBlock> in, IndexRange ixrange, int blen, ArrayList<Pair<Long,FrameBlock>> outlist) { + public static ArrayList<Pair<Long, FrameBlock>> performSlice(Pair<Long,FrameBlock> in, IndexRange ixrange, int blen) { + ArrayList<Pair<Long, FrameBlock>> outlist = new ArrayList<>(); long index = in.getKey(); FrameBlock block = in.getValue(); @@ -432,7 +447,7 @@ public class OperationsOnMatrixValues //check if block is outside the indexing range if(cellIndexOverlapTop>cellIndexOverlapBottom || cellIndexOverlapLeft>cellIndexOverlapRight) { - return; + return outlist; } // Create IndexRange for the slice to be performed on this block. @@ -460,6 +475,7 @@ public class OperationsOnMatrixValues //execute actual slice operation block.slice(outlist, tmpRange, rowCut); + return outlist; } public static void performShift(Pair<Long,FrameBlock> in, IndexRange ixrange, int blenLeft, long rlen, long clen, ArrayList<Pair<Long,FrameBlock>> outlist) { diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java index 8a844b6..0e4883c 100644 --- a/src/test/java/org/apache/sysds/test/TestUtils.java +++ b/src/test/java/org/apache/sysds/test/TestUtils.java @@ -832,7 +832,7 @@ public class TestUtils * @param y value 2 * @return Percent distance */ - private static double getPercentDistance(double x, double y, boolean ignoreZero){ + public static double getPercentDistance(double x, double y, boolean ignoreZero){ if (Double.isNaN(x) && Double.isNaN(y)) return 1.0; if (Double.isInfinite(x) && Double.isInfinite(y)) diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java index 4072fe3..2ade5b3 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java @@ -61,23 +61,22 @@ public class CompressedMatrixTest extends AbstractCompressedUnaryTests { } @Test - @Ignore public void testGetValue() { try { if(!(cmb instanceof CompressedMatrixBlock)) return; // Input was not compressed then just pass test - for(int i = 0; i < rows; i++) for(int j = 0; j < cols; j++) { double ulaVal = mb.quickGetValue(i, j); double claVal = cmb.getValue(i, j); // calls quickGetValue internally if(compressionSettings.lossy || overlappingType == OverLapping.SQUEEZE) TestUtils.compareCellValue(ulaVal, claVal, lossyTolerance, false); - else if(OverLapping.effectOnOutput(overlappingType)) - TestUtils.compareScalarBitsJUnit(ulaVal, claVal, 32768); + else if(OverLapping.effectOnOutput(overlappingType)){ + double percentDistance = TestUtils.getPercentDistance(ulaVal, claVal, true); + assertTrue(percentDistance > .99); + } else TestUtils.compareScalarBitsJUnit(ulaVal, claVal, 0); // Should be exactly same value - } } catch(Exception e) { diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java index 80013d0..6367ecf 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java @@ -31,6 +31,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.lops.MMTSJ.MMTSJType; import org.apache.sysds.lops.MapMultChain.ChainType; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; import org.apache.sysds.runtime.compress.CompressionSettings; @@ -105,8 +106,8 @@ public abstract class CompressedTestBase extends TestBase { // CLA TESTS! new CompressionSettingsBuilder().setSamplingRatio(0.1).setSeed(compressionSeed) .setValidCompressions(EnumSet.of(CompressionType.DDC)).setInvestigateEstimate(true), - // new CompressionSettingsBuilder().setSamplingRatio(0.1).setSeed(compressionSeed) - // .setValidCompressions(EnumSet.of(CompressionType.OLE)).setInvestigateEstimate(true), + new CompressionSettingsBuilder().setSamplingRatio(0.1).setSeed(compressionSeed) + .setValidCompressions(EnumSet.of(CompressionType.OLE)).setInvestigateEstimate(true), new CompressionSettingsBuilder().setSamplingRatio(0.1).setSeed(compressionSeed) .setValidCompressions(EnumSet.of(CompressionType.RLE)).setInvestigateEstimate(true), new CompressionSettingsBuilder().setSamplingRatio(0.1).setSeed(compressionSeed).setInvestigateEstimate(true), @@ -936,6 +937,72 @@ public abstract class CompressedTestBase extends TestBase { } } + @Test + public void testSliceRows() { + testSlice(rows / 5, Math.min(rows - 1, (rows / 5) * 2), 0, cols - 1); + } + + @Test + public void testSliceFirstColumn() { + testSlice(0, rows - 1, 0, 0); + } + + @Test + public void testSliceLastColumn() { + testSlice(0, rows - 1, cols - 1, cols - 1); + } + + @Test + public void testSliceAllButFirstColumn() { + testSlice(0, rows - 1, Math.min(1,cols-1), cols - 1); + } + + @Test + public void testSliceInternal() { + testSlice(rows / 5, + Math.min(rows - 1, (rows / 5) * 2), + Math.min(cols - 1, cols / 5), + Math.min(cols - 1, cols / 5 + 1)); + } + + @Test + public void testSliceFirstValue() { + testSlice(0, 0, 0, 0); + } + + @Test + public void testSliceEntireMatrix() { + testSlice(0, rows - 1, 0, cols - 1); + } + + @Test(expected = DMLRuntimeException.class) + public void TestSliceInvalid_01() { + testSlice(-1, 0, 0, 0); + } + + @Test(expected = DMLRuntimeException.class) + public void TestSliceInvalid_02() { + testSlice(rows, rows, 0, 0); + } + + @Test(expected = DMLRuntimeException.class) + public void TestSliceInvalid_03() { + testSlice(0, 0, cols, cols); + } + + @Test(expected = DMLRuntimeException.class) + public void TestSliceInvalid_04() { + testSlice(0, 0, -1, 0); + } + + public void testSlice(int rl, int ru, int cl, int cu) { + if(!(cmb instanceof CompressedMatrixBlock)) + return; + MatrixBlock ret2 = cmb.slice(rl, ru, cl, cu); + MatrixBlock ret1 = mb.slice(rl, ru, cl, cu); + compareResultMatrices(ret1, ret2, 1); + } + protected void compareResultMatrices(double[][] d1, double[][] d2, double toleranceMultiplier) { if(compressionSettings.lossy) TestUtils.compareMatricesPercentageDistance(d1, d2, 0.25, 0.83, this.toString()); diff --git a/src/test/java/org/apache/sysds/test/component/matrix/SliceTest.java b/src/test/java/org/apache/sysds/test/component/matrix/SliceTest.java new file mode 100644 index 0000000..f69df13 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/matrix/SliceTest.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.matrix; + +import static org.junit.Assert.assertEquals; + +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; +import org.junit.Test; + +public class SliceTest { + MatrixBlock a = genIncMatrix(10, 10); + + @Test + public void sliceTest_01() { + MatrixBlock b = a.slice(0, 4); + assertEquals(5, b.getNumRows()); + } + + @Test + public void sliceTest_02() { + MatrixBlock b = a.slice(0, 9); + assertEquals(10, b.getNumRows()); + } + + @Test + public void sliceTest_03() { + MatrixBlock b = a.slice(9, 9); + assertEquals(1, b.getNumRows()); + } + + private static MatrixBlock gen(int[][] v) { + return DataConverter.convertToMatrixBlock(v); + } + + private static MatrixBlock genIncMatrix(int rows, int cols) { + int[][] ret = new int[rows][cols]; + int x = 0; + for(int i = 0; i < rows; i++) { + for(int j = 0; j < cols; j++) { + ret[i][j] = x++; + } + } + return gen(ret); + } +}
