This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 53f72edfbd5361199b27ee04d844b5a3c93dcc0e Author: Sebastian Baunsgaard <[email protected]> AuthorDate: Sun Dec 29 22:07:52 2024 +0100 [MINOR] Update decompression for zeros This commit adds a check for decompressing overlapping matrices to remove very small epsilons from zero to round values to zero on overlapping decompression. Previously the compressed state could make a sparse matrix dense because of these rounding errors. Closes #2170 --- .../runtime/compress/lib/CLALibDecompress.java | 211 ++++++++++++++------- .../sysds/runtime/compress/lib/CLALibUtils.java | 43 +++++ 2 files changed, 183 insertions(+), 71 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java index e77db7cad7..cc585ed58e 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java @@ -26,9 +26,11 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; +import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.api.DMLScript; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.AColGroup; @@ -65,6 +67,11 @@ public final class CLALibDecompress { public static void decompressTo(CompressedMatrixBlock cmb, MatrixBlock ret, int rowOffset, int colOffset, int k, boolean countNNz) { + decompressTo(cmb, ret, rowOffset, colOffset, k, countNNz, false); + } + + public static void decompressTo(CompressedMatrixBlock cmb, MatrixBlock ret, int rowOffset, int colOffset, int k, + boolean countNNz, boolean reset) { Timing time = new Timing(true); if(cmb.getNumColumns() + colOffset > ret.getNumColumns() || cmb.getNumRows() + rowOffset > ret.getNumRows()) { LOG.warn( @@ -78,12 +85,12 @@ public final class CLALibDecompress { final boolean outSparse = ret.isInSparseFormat(); if(!cmb.isEmpty()) { - if(outSparse && cmb.isOverlapping()) + if(outSparse && (cmb.isOverlapping() || reset)) throw new DMLCompressionException("Not supported decompression into sparse block from overlapping state"); else if(outSparse) decompressToSparseBlock(cmb, ret, rowOffset, colOffset); else - decompressToDenseBlock(cmb, ret.getDenseBlock(), rowOffset, colOffset); + decompressToDenseBlock(cmb, ret.getDenseBlock(), rowOffset, colOffset, k, reset); } if(DMLScript.STATISTICS) { @@ -94,7 +101,7 @@ public final class CLALibDecompress { } if(countNNz) - ret.recomputeNonZeros(); + ret.recomputeNonZeros(k); } private static void decompressToSparseBlock(CompressedMatrixBlock cmb, MatrixBlock ret, int rowOffset, @@ -115,23 +122,67 @@ public final class CLALibDecompress { ret.checkSparseRows(); } - private static void decompressToDenseBlock(CompressedMatrixBlock cmb, DenseBlock ret, int rowOffset, int colOffset) { - final List<AColGroup> groups = cmb.getColGroups(); + private static void decompressToDenseBlock(CompressedMatrixBlock cmb, DenseBlock ret, int rowOffset, int colOffset, + int k, boolean reset) { + List<AColGroup> groups = cmb.getColGroups(); // final int nCols = cmb.getNumColumns(); final int nRows = cmb.getNumRows(); final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups); - if(shouldFilter) { + if(shouldFilter && !CLALibUtils.alreadyPreFiltered(groups, cmb.getNumColumns())) { final double[] constV = new double[cmb.getNumColumns()]; - final List<AColGroup> filteredGroups = CLALibUtils.filterGroups(groups, constV); - for(AColGroup g : filteredGroups) - g.decompressToDenseBlock(ret, 0, nRows, rowOffset, colOffset); + groups = CLALibUtils.filterGroups(groups, constV); AColGroup cRet = ColGroupConst.create(constV); - cRet.decompressToDenseBlock(ret, 0, nRows, rowOffset, colOffset); + groups.add(cRet); } - else { - for(AColGroup g : groups) - g.decompressToDenseBlock(ret, 0, nRows, rowOffset, colOffset); + + if(k > 1 && nRows > 1000) + decompressToDenseBlockParallel(ret, groups, rowOffset, colOffset, nRows, k, reset); + else + decompressToDenseBlockSingleThread(ret, groups, rowOffset, colOffset, nRows, reset); + } + + private static void decompressToDenseBlockSingleThread(DenseBlock ret, List<AColGroup> groups, int rowOffset, + int colOffset, int nRows, boolean reset) { + decompressToDenseBlockBlock(ret, groups, rowOffset, colOffset, 0, nRows, reset); + } + + private static void decompressToDenseBlockBlock(DenseBlock ret, List<AColGroup> groups, int rowOffset, int colOffset, + int rl, int ru, boolean reset) { + if(reset) { + if(ret.isContiguous()) { + final int nCol = ret.getDim(1); + ret.fillBlock(0, rl * nCol, ru * nCol, 0.0); + } + else + throw new NotImplementedException(); + } + for(AColGroup g : groups) + g.decompressToDenseBlock(ret, rl, ru, rowOffset, colOffset); + } + + private static void decompressToDenseBlockParallel(DenseBlock ret, List<AColGroup> groups, int rowOffset, + int colOffset, int nRows, int k, boolean reset) { + + final int blklen = Math.max(nRows / k, 512); + final ExecutorService pool = CommonThreadPool.get(k); + try { + List<Future<?>> tasks = new ArrayList<>(nRows / blklen); + for(int r = 0; r < nRows; r += blklen) { + final int start = r; + final int end = Math.min(nRows, r + blklen); + tasks.add( + pool.submit(() -> decompressToDenseBlockBlock(ret, groups, rowOffset, colOffset, start, end, reset))); + } + + for(Future<?> t : tasks) + t.get(); + } + catch(Exception e) { + throw new DMLCompressionException("Failed parallel decompress to"); + } + finally { + pool.shutdown(); } } @@ -148,7 +199,7 @@ public final class CLALibDecompress { MatrixBlock ret = getUncompressedColGroupAndRemoveFromListOfColGroups(groups, overlapping, nRows, nCols); if(ret != null && groups.size() == 0) { - ret.setNonZeros(ret.recomputeNonZeros()); + ret.setNonZeros(ret.recomputeNonZeros(k)); return ret; // if uncompressedColGroup is only colGroup. } @@ -182,23 +233,18 @@ public final class CLALibDecompress { constV = null; final double eps = getEps(constV); - if(k == 1) { - if(ret.isInSparseFormat()) { + if(ret.isInSparseFormat()) decompressSparseSingleThread(ret, filteredGroups, nRows, blklen); - } - else { + else decompressDenseSingleThread(ret, filteredGroups, nRows, blklen, constV, eps, nonZeros, overlapping); - } } - else if(ret.isInSparseFormat()) { + else if(ret.isInSparseFormat()) decompressSparseMultiThread(ret, filteredGroups, nRows, blklen, k); - } - else { + else decompressDenseMultiThread(ret, filteredGroups, nRows, blklen, constV, eps, k, overlapping); - } - ret.recomputeNonZeros(); + ret.recomputeNonZeros(k); ret.examSparsity(); return ret; @@ -249,29 +295,40 @@ public final class CLALibDecompress { private static void decompressDenseSingleThread(MatrixBlock ret, List<AColGroup> filteredGroups, int rlen, int blklen, double[] constV, double eps, long nonZeros, boolean overlapping) { + + final DenseBlock db = ret.getDenseBlock(); + final int nCol = ret.getNumColumns(); for(int i = 0; i < rlen; i += blklen) { final int rl = i; final int ru = Math.min(i + blklen, rlen); for(AColGroup grp : filteredGroups) - grp.decompressToDenseBlock(ret.getDenseBlock(), rl, ru); + grp.decompressToDenseBlock(db, rl, ru); if(constV != null && !ret.isInSparseFormat()) - addVector(ret, constV, eps, rl, ru); + addVector(db, nCol, constV, eps, rl, ru); } } - protected static void decompressDenseMultiThread(MatrixBlock ret, List<AColGroup> groups, double[] constV, int k, - boolean overlapping) { - final int nRows = ret.getNumRows(); - final double eps = getEps(constV); - final int blklen = Math.max(nRows / k, 512); - decompressDenseMultiThread(ret, groups, nRows, blklen, constV, eps, k, overlapping); - } + // private static void decompressDenseMultiThread(MatrixBlock ret, List<AColGroup> groups, double[] constV, int k, + // boolean overlapping) { + // final int nRows = ret.getNumRows(); + // final double eps = getEps(constV); + // final int blklen = Math.max(nRows / k, 512); + // decompressDenseMultiThread(ret, groups, nRows, blklen, constV, eps, k, overlapping); + // } protected static void decompressDenseMultiThread(MatrixBlock ret, List<AColGroup> groups, double[] constV, double eps, int k, boolean overlapping) { + + Timing time = new Timing(true); final int nRows = ret.getNumRows(); final int blklen = Math.max(nRows / k, 512); decompressDenseMultiThread(ret, groups, nRows, blklen, constV, eps, k, overlapping); + if(DMLScript.STATISTICS) { + final double t = time.stop(); + DMLCompressionStatistics.addDecompressTime(t, k); + if(LOG.isTraceEnabled()) + LOG.trace("decompressed block w/ k=" + k + " in " + t + "ms."); + } } private static void decompressDenseMultiThread(MatrixBlock ret, List<AColGroup> filteredGroups, int rlen, int blklen, @@ -297,7 +354,7 @@ public final class CLALibDecompress { catch(InterruptedException | ExecutionException ex) { throw new DMLCompressionException("Parallel decompression failed", ex); } - finally{ + finally { pool.shutdown(); } } @@ -310,13 +367,14 @@ public final class CLALibDecompress { for(int i = 0; i < rlen; i += blklen) tasks.add(new DecompressSparseTask(filteredGroups, ret, i, Math.min(i + blklen, rlen))); + LOG.error("tasks:" + tasks); for(Future<Object> rt : pool.invokeAll(tasks)) rt.get(); } catch(InterruptedException | ExecutionException ex) { throw new DMLCompressionException("Parallel decompression failed", ex); } - finally{ + finally { pool.shutdown(); } } @@ -360,22 +418,23 @@ public final class CLALibDecompress { _eps = eps; _rl = rl; _ru = ru; - _blklen = 32768 / ret.getNumColumns(); + _blklen = Math.max(32768 / ret.getNumColumns(), 128); _constV = constV; } @Override public Long call() { try { - + final DenseBlock db = _ret.getDenseBlock(); + final int nCol = _ret.getNumColumns(); long nnz = 0; for(int b = _rl; b < _ru; b += _blklen) { final int e = Math.min(b + _blklen, _ru); for(AColGroup grp : _colGroups) - grp.decompressToDenseBlock(_ret.getDenseBlock(), b, e); + grp.decompressToDenseBlock(db, b, e); if(_constV != null) - addVector(_ret, _constV, _eps, b, e); + addVector(db, nCol, _constV, _eps, b, e); nnz += _ret.recomputeNonZeros(b, e - 1); } @@ -404,23 +463,22 @@ public final class CLALibDecompress { _eps = eps; _rl = rl; _ru = ru; - _blklen = 32768 / ret.getNumColumns(); + _blklen = Math.max(32768 / ret.getNumColumns(), 128); _constV = constV; } @Override public Long call() { try { - + final DenseBlock db = _ret.getDenseBlock(); + final int nCol = _ret.getNumColumns(); long nnz = 0; for(int b = _rl; b < _ru; b += _blklen) { final int e = Math.min(b + _blklen, _ru); - // for(AColGroup grp : _colGroups) - _grp.decompressToDenseBlock(_ret.getDenseBlock(), b, e); + _grp.decompressToDenseBlock(db, b, e); if(_constV != null) - addVector(_ret, _constV, _eps, b, e); - // nnz += _ret.recomputeNonZeros(b, e - 1); + addVector(db, nCol, _constV, _eps, b, e); } return nnz; @@ -446,14 +504,21 @@ public final class CLALibDecompress { } @Override - public Object call() { - final SparseBlock sb = _ret.getSparseBlock(); - for(AColGroup grp : _colGroups) - grp.decompressToSparseBlock(_ret.getSparseBlock(), _rl, _ru); - for(int i = _rl; i < _ru; i++) - if(!sb.isEmpty(i)) - sb.sort(i); - return null; + public Object call() throws Exception{ + try{ + + final SparseBlock sb = _ret.getSparseBlock(); + for(AColGroup grp : _colGroups) + grp.decompressToSparseBlock(_ret.getSparseBlock(), _rl, _ru); + for(int i = _rl; i < _ru; i++) + if(!sb.isEmpty(i)) + sb.sort(i); + return null; + } + catch(Exception e){ + e.printStackTrace(); + throw new DMLRuntimeException(e); + } } } @@ -467,28 +532,32 @@ public final class CLALibDecompress { * @param rl The row to start at * @param ru The row to end at */ - private static void addVector(final MatrixBlock ret, final double[] rowV, final double eps, final int rl, - final int ru) { - final int nCols = ret.getNumColumns(); - final DenseBlock db = ret.getDenseBlock(); + private static final void addVector(final DenseBlock db, final int nCols, final double[] rowV, final double eps, + final int rl, final int ru) { + if(eps == 0) + addVectorEps(db, nCols, rowV, eps, rl, ru); + else + addVectorNoEps(db, nCols, rowV, eps, rl, ru); + } - if(nCols == 1) { - if(eps == 0) - addValue(db.values(0), rowV[0], rl, ru); - else - addValueEps(db.values(0), rowV[0], eps, rl, ru); - } - else if(db.isContiguous()) { - if(eps == 0) - addVectorContiguousNoEps(db.values(0), rowV, nCols, rl, ru); - else - addVectorContiguousEps(db.values(0), rowV, nCols, eps, rl, ru); - } - else if(eps == 0) + private static final void addVectorEps(final DenseBlock db, final int nCols, final double[] rowV, final double eps, + final int rl, final int ru) { + if(nCols == 1) + addValue(db.values(0), rowV[0], rl, ru); + else if(db.isContiguous()) + addVectorContiguousNoEps(db.values(0), rowV, nCols, rl, ru); + else addVectorNoEps(db, rowV, nCols, rl, ru); + } + + private static final void addVectorNoEps(final DenseBlock db, final int nCols, final double[] rowV, final double eps, + final int rl, final int ru) { + if(nCols == 1) + addValueEps(db.values(0), rowV[0], eps, rl, ru); + else if(db.isContiguous()) + addVectorContiguousEps(db.values(0), rowV, nCols, eps, rl, ru); else addVectorEps(db, rowV, nCols, eps, rl, ru); - } private static void addValue(final double[] retV, final double v, final int rl, final int ru) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java index 485599e382..b397825d5a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java @@ -97,6 +97,49 @@ public final class CLALibUtils { return false; } + protected static boolean alreadyPreFiltered(List<AColGroup> groups, int nCol) { + boolean constFound = false; + for(AColGroup g : groups) { + if(g instanceof AMorphingMMColGroup || g instanceof ColGroupEmpty || g.isEmpty() || + (constFound && g instanceof ColGroupConst)) + return false; + else if(g instanceof ColGroupConst){ + if(g.getNumCols() != nCol) + return false; + + constFound = true; + } + } + + return true; + } + + protected static double[] filterGroupsAndSplitPreAggOneConst(List<AColGroup> groups, List<AColGroup> noPreAggGroups, + List<APreAgg> preAggGroups) { + double[] consts = null; + for(AColGroup g : groups) { + if(g instanceof ColGroupConst) + consts = ((ColGroupConst) g).getValues(); + else if(g instanceof APreAgg) + preAggGroups.add((APreAgg) g); + else + noPreAggGroups.add(g); + } + + return consts; + } + + protected static double[] filterGroupsAndSplitPreAggOneConst(List<AColGroup> groups, List<AColGroup> out) { + double[] consts = null; + for(AColGroup g : groups) { + if(g instanceof ColGroupConst) + consts = ((ColGroupConst) g).getValues(); + else + out.add(g); + } + return consts; + } + /** * Helper method to determine if the column groups contains Morphing or Frame of reference groups. *
