This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 53f72edfbd5361199b27ee04d844b5a3c93dcc0e
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Sun Dec 29 22:07:52 2024 +0100

    [MINOR] Update decompression for zeros
    
    This commit adds a check for decompressing overlapping matrices
    to remove very small epsilons from zero to round values to zero
    on overlapping decompression. Previously the compressed state could
    make a sparse matrix dense because of these rounding errors.
    
    Closes #2170
---
 .../runtime/compress/lib/CLALibDecompress.java     | 211 ++++++++++++++-------
 .../sysds/runtime/compress/lib/CLALibUtils.java    |  43 +++++
 2 files changed, 183 insertions(+), 71 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java
index e77db7cad7..cc585ed58e 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java
@@ -26,9 +26,11 @@ import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
 
+import org.apache.commons.lang3.NotImplementedException;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.compress.DMLCompressionException;
 import org.apache.sysds.runtime.compress.colgroup.AColGroup;
@@ -65,6 +67,11 @@ public final class CLALibDecompress {
 
        public static void decompressTo(CompressedMatrixBlock cmb, MatrixBlock 
ret, int rowOffset, int colOffset, int k,
                boolean countNNz) {
+               decompressTo(cmb, ret, rowOffset, colOffset, k, countNNz, 
false);
+       }
+
+       public static void decompressTo(CompressedMatrixBlock cmb, MatrixBlock 
ret, int rowOffset, int colOffset, int k,
+               boolean countNNz, boolean reset) {
                Timing time = new Timing(true);
                if(cmb.getNumColumns() + colOffset > ret.getNumColumns() || 
cmb.getNumRows() + rowOffset > ret.getNumRows()) {
                        LOG.warn(
@@ -78,12 +85,12 @@ public final class CLALibDecompress {
 
                final boolean outSparse = ret.isInSparseFormat();
                if(!cmb.isEmpty()) {
-                       if(outSparse && cmb.isOverlapping())
+                       if(outSparse && (cmb.isOverlapping() || reset))
                                throw new DMLCompressionException("Not 
supported decompression into sparse block from overlapping state");
                        else if(outSparse)
                                decompressToSparseBlock(cmb, ret, rowOffset, 
colOffset);
                        else
-                               decompressToDenseBlock(cmb, 
ret.getDenseBlock(), rowOffset, colOffset);
+                               decompressToDenseBlock(cmb, 
ret.getDenseBlock(), rowOffset, colOffset, k, reset);
                }
 
                if(DMLScript.STATISTICS) {
@@ -94,7 +101,7 @@ public final class CLALibDecompress {
                }
 
                if(countNNz)
-                       ret.recomputeNonZeros();
+                       ret.recomputeNonZeros(k);
        }
 
        private static void decompressToSparseBlock(CompressedMatrixBlock cmb, 
MatrixBlock ret, int rowOffset,
@@ -115,23 +122,67 @@ public final class CLALibDecompress {
                ret.checkSparseRows();
        }
 
-       private static void decompressToDenseBlock(CompressedMatrixBlock cmb, 
DenseBlock ret, int rowOffset, int colOffset) {
-               final List<AColGroup> groups = cmb.getColGroups();
+       private static void decompressToDenseBlock(CompressedMatrixBlock cmb, 
DenseBlock ret, int rowOffset, int colOffset,
+               int k, boolean reset) {
+               List<AColGroup> groups = cmb.getColGroups();
                // final int nCols = cmb.getNumColumns();
                final int nRows = cmb.getNumRows();
 
                final boolean shouldFilter = 
CLALibUtils.shouldPreFilter(groups);
-               if(shouldFilter) {
+               if(shouldFilter && !CLALibUtils.alreadyPreFiltered(groups, 
cmb.getNumColumns())) {
                        final double[] constV = new double[cmb.getNumColumns()];
-                       final List<AColGroup> filteredGroups = 
CLALibUtils.filterGroups(groups, constV);
-                       for(AColGroup g : filteredGroups)
-                               g.decompressToDenseBlock(ret, 0, nRows, 
rowOffset, colOffset);
+                       groups = CLALibUtils.filterGroups(groups, constV);
                        AColGroup cRet = ColGroupConst.create(constV);
-                       cRet.decompressToDenseBlock(ret, 0, nRows, rowOffset, 
colOffset);
+                       groups.add(cRet);
                }
-               else {
-                       for(AColGroup g : groups)
-                               g.decompressToDenseBlock(ret, 0, nRows, 
rowOffset, colOffset);
+
+               if(k > 1 && nRows > 1000)
+                       decompressToDenseBlockParallel(ret, groups, rowOffset, 
colOffset, nRows, k, reset);
+               else
+                       decompressToDenseBlockSingleThread(ret, groups, 
rowOffset, colOffset, nRows, reset);
+       }
+
+       private static void decompressToDenseBlockSingleThread(DenseBlock ret, 
List<AColGroup> groups, int rowOffset,
+               int colOffset, int nRows, boolean reset) {
+               decompressToDenseBlockBlock(ret, groups, rowOffset, colOffset, 
0, nRows, reset);
+       }
+
+       private static void decompressToDenseBlockBlock(DenseBlock ret, 
List<AColGroup> groups, int rowOffset, int colOffset,
+               int rl, int ru, boolean reset) {
+               if(reset) {
+                       if(ret.isContiguous()) {
+                               final int nCol = ret.getDim(1);
+                               ret.fillBlock(0, rl * nCol, ru * nCol, 0.0);
+                       }
+                       else
+                               throw new NotImplementedException();
+               }
+               for(AColGroup g : groups)
+                       g.decompressToDenseBlock(ret, rl, ru, rowOffset, 
colOffset);
+       }
+
+       private static void decompressToDenseBlockParallel(DenseBlock ret, 
List<AColGroup> groups, int rowOffset,
+               int colOffset, int nRows, int k, boolean reset) {
+
+               final int blklen = Math.max(nRows / k, 512);
+               final ExecutorService pool = CommonThreadPool.get(k);
+               try {
+                       List<Future<?>> tasks = new ArrayList<>(nRows / blklen);
+                       for(int r = 0; r < nRows; r += blklen) {
+                               final int start = r;
+                               final int end = Math.min(nRows, r + blklen);
+                               tasks.add(
+                                       pool.submit(() -> 
decompressToDenseBlockBlock(ret, groups, rowOffset, colOffset, start, end, 
reset)));
+                       }
+
+                       for(Future<?> t : tasks)
+                               t.get();
+               }
+               catch(Exception e) {
+                       throw new DMLCompressionException("Failed parallel 
decompress to");
+               }
+               finally {
+                       pool.shutdown();
                }
        }
 
@@ -148,7 +199,7 @@ public final class CLALibDecompress {
                MatrixBlock ret = 
getUncompressedColGroupAndRemoveFromListOfColGroups(groups, overlapping, nRows, 
nCols);
 
                if(ret != null && groups.size() == 0) {
-                       ret.setNonZeros(ret.recomputeNonZeros());
+                       ret.setNonZeros(ret.recomputeNonZeros(k));
                        return ret; // if uncompressedColGroup is only colGroup.
                }
 
@@ -182,23 +233,18 @@ public final class CLALibDecompress {
                        constV = null;
 
                final double eps = getEps(constV);
-
                if(k == 1) {
-                       if(ret.isInSparseFormat()) {
+                       if(ret.isInSparseFormat()) 
                                decompressSparseSingleThread(ret, 
filteredGroups, nRows, blklen);
-                       }
-                       else {
+                       else 
                                decompressDenseSingleThread(ret, 
filteredGroups, nRows, blklen, constV, eps, nonZeros, overlapping);
-                       }
                }
-               else if(ret.isInSparseFormat()) {
+               else if(ret.isInSparseFormat()) 
                        decompressSparseMultiThread(ret, filteredGroups, nRows, 
blklen, k);
-               }
-               else {
+               else 
                        decompressDenseMultiThread(ret, filteredGroups, nRows, 
blklen, constV, eps, k, overlapping);
-               }
 
-               ret.recomputeNonZeros();
+               ret.recomputeNonZeros(k);
                ret.examSparsity();
 
                return ret;
@@ -249,29 +295,40 @@ public final class CLALibDecompress {
 
        private static void decompressDenseSingleThread(MatrixBlock ret, 
List<AColGroup> filteredGroups, int rlen,
                int blklen, double[] constV, double eps, long nonZeros, boolean 
overlapping) {
+
+               final DenseBlock db = ret.getDenseBlock();
+               final int nCol = ret.getNumColumns();
                for(int i = 0; i < rlen; i += blklen) {
                        final int rl = i;
                        final int ru = Math.min(i + blklen, rlen);
                        for(AColGroup grp : filteredGroups)
-                               grp.decompressToDenseBlock(ret.getDenseBlock(), 
rl, ru);
+                               grp.decompressToDenseBlock(db, rl, ru);
                        if(constV != null && !ret.isInSparseFormat())
-                               addVector(ret, constV, eps, rl, ru);
+                               addVector(db, nCol, constV, eps, rl, ru);
                }
        }
 
-       protected static void decompressDenseMultiThread(MatrixBlock ret, 
List<AColGroup> groups, double[] constV, int k,
-               boolean overlapping) {
-               final int nRows = ret.getNumRows();
-               final double eps = getEps(constV);
-               final int blklen = Math.max(nRows / k, 512);
-               decompressDenseMultiThread(ret, groups, nRows, blklen, constV, 
eps, k, overlapping);
-       }
+       // private static void decompressDenseMultiThread(MatrixBlock ret, 
List<AColGroup> groups, double[] constV, int k,
+       // boolean overlapping) {
+       // final int nRows = ret.getNumRows();
+       // final double eps = getEps(constV);
+       // final int blklen = Math.max(nRows / k, 512);
+       // decompressDenseMultiThread(ret, groups, nRows, blklen, constV, eps, 
k, overlapping);
+       // }
 
        protected static void decompressDenseMultiThread(MatrixBlock ret, 
List<AColGroup> groups, double[] constV,
                double eps, int k, boolean overlapping) {
+
+               Timing time = new Timing(true);
                final int nRows = ret.getNumRows();
                final int blklen = Math.max(nRows / k, 512);
                decompressDenseMultiThread(ret, groups, nRows, blklen, constV, 
eps, k, overlapping);
+               if(DMLScript.STATISTICS) {
+                       final double t = time.stop();
+                       DMLCompressionStatistics.addDecompressTime(t, k);
+                       if(LOG.isTraceEnabled())
+                               LOG.trace("decompressed block w/ k=" + k + " in 
" + t + "ms.");
+               }
        }
 
        private static void decompressDenseMultiThread(MatrixBlock ret, 
List<AColGroup> filteredGroups, int rlen, int blklen,
@@ -297,7 +354,7 @@ public final class CLALibDecompress {
                catch(InterruptedException | ExecutionException ex) {
                        throw new DMLCompressionException("Parallel 
decompression failed", ex);
                }
-               finally{
+               finally {
                        pool.shutdown();
                }
        }
@@ -310,13 +367,14 @@ public final class CLALibDecompress {
                        for(int i = 0; i < rlen; i += blklen)
                                tasks.add(new 
DecompressSparseTask(filteredGroups, ret, i, Math.min(i + blklen, rlen)));
 
+                       LOG.error("tasks:" + tasks);
                        for(Future<Object> rt : pool.invokeAll(tasks))
                                rt.get();
                }
                catch(InterruptedException | ExecutionException ex) {
                        throw new DMLCompressionException("Parallel 
decompression failed", ex);
                }
-               finally{
+               finally {
                        pool.shutdown();
                }
        }
@@ -360,22 +418,23 @@ public final class CLALibDecompress {
                        _eps = eps;
                        _rl = rl;
                        _ru = ru;
-                       _blklen = 32768 / ret.getNumColumns();
+                       _blklen = Math.max(32768 / ret.getNumColumns(), 128);
                        _constV = constV;
                }
 
                @Override
                public Long call() {
                        try {
-
+                               final DenseBlock db = _ret.getDenseBlock();
+                               final int nCol = _ret.getNumColumns();
                                long nnz = 0;
                                for(int b = _rl; b < _ru; b += _blklen) {
                                        final int e = Math.min(b + _blklen, 
_ru);
                                        for(AColGroup grp : _colGroups)
-                                               
grp.decompressToDenseBlock(_ret.getDenseBlock(), b, e);
+                                               grp.decompressToDenseBlock(db, 
b, e);
 
                                        if(_constV != null)
-                                               addVector(_ret, _constV, _eps, 
b, e);
+                                               addVector(db, nCol, _constV, 
_eps, b, e);
                                        nnz += _ret.recomputeNonZeros(b, e - 1);
                                }
 
@@ -404,23 +463,22 @@ public final class CLALibDecompress {
                        _eps = eps;
                        _rl = rl;
                        _ru = ru;
-                       _blklen = 32768 / ret.getNumColumns();
+                       _blklen = Math.max(32768 / ret.getNumColumns(), 128);
                        _constV = constV;
                }
 
                @Override
                public Long call() {
                        try {
-
+                               final DenseBlock db = _ret.getDenseBlock();
+                               final int nCol = _ret.getNumColumns();
                                long nnz = 0;
                                for(int b = _rl; b < _ru; b += _blklen) {
                                        final int e = Math.min(b + _blklen, 
_ru);
-                                       // for(AColGroup grp : _colGroups)
-                                       
_grp.decompressToDenseBlock(_ret.getDenseBlock(), b, e);
+                                       _grp.decompressToDenseBlock(db, b, e);
 
                                        if(_constV != null)
-                                               addVector(_ret, _constV, _eps, 
b, e);
-                                       // nnz += _ret.recomputeNonZeros(b, e - 
1);
+                                               addVector(db, nCol, _constV, 
_eps, b, e);
                                }
 
                                return nnz;
@@ -446,14 +504,21 @@ public final class CLALibDecompress {
                }
 
                @Override
-               public Object call() {
-                       final SparseBlock sb = _ret.getSparseBlock();
-                       for(AColGroup grp : _colGroups)
-                               
grp.decompressToSparseBlock(_ret.getSparseBlock(), _rl, _ru);
-                       for(int i = _rl; i < _ru; i++)
-                               if(!sb.isEmpty(i))
-                                       sb.sort(i);
-                       return null;
+               public Object call() throws Exception{
+                       try{
+
+                               final SparseBlock sb = _ret.getSparseBlock();
+                               for(AColGroup grp : _colGroups)
+                                       
grp.decompressToSparseBlock(_ret.getSparseBlock(), _rl, _ru);
+                               for(int i = _rl; i < _ru; i++)
+                                       if(!sb.isEmpty(i))
+                                               sb.sort(i);
+                               return null;
+                       }
+                       catch(Exception e){
+                               e.printStackTrace();
+                               throw new DMLRuntimeException(e);
+                       }
                }
        }
 
@@ -467,28 +532,32 @@ public final class CLALibDecompress {
         * @param rl   The row to start at
         * @param ru   The row to end at
         */
-       private static void addVector(final MatrixBlock ret, final double[] 
rowV, final double eps, final int rl,
-               final int ru) {
-               final int nCols = ret.getNumColumns();
-               final DenseBlock db = ret.getDenseBlock();
+       private static final void addVector(final DenseBlock db, final int 
nCols, final double[] rowV, final double eps,
+               final int rl, final int ru) {
+               if(eps == 0)
+                       addVectorEps(db, nCols, rowV, eps, rl, ru);
+               else
+                       addVectorNoEps(db, nCols, rowV, eps, rl, ru);
+       }
 
-               if(nCols == 1) {
-                       if(eps == 0)
-                               addValue(db.values(0), rowV[0], rl, ru);
-                       else
-                               addValueEps(db.values(0), rowV[0], eps, rl, ru);
-               }
-               else if(db.isContiguous()) {
-                       if(eps == 0)
-                               addVectorContiguousNoEps(db.values(0), rowV, 
nCols, rl, ru);
-                       else
-                               addVectorContiguousEps(db.values(0), rowV, 
nCols, eps, rl, ru);
-               }
-               else if(eps == 0)
+       private static final void addVectorEps(final DenseBlock db, final int 
nCols, final double[] rowV, final double eps,
+               final int rl, final int ru) {
+               if(nCols == 1)
+                       addValue(db.values(0), rowV[0], rl, ru);
+               else if(db.isContiguous())
+                       addVectorContiguousNoEps(db.values(0), rowV, nCols, rl, 
ru);
+               else
                        addVectorNoEps(db, rowV, nCols, rl, ru);
+       }
+
+       private static final void addVectorNoEps(final DenseBlock db, final int 
nCols, final double[] rowV, final double eps,
+               final int rl, final int ru) {
+               if(nCols == 1)
+                       addValueEps(db.values(0), rowV[0], eps, rl, ru);
+               else if(db.isContiguous())
+                       addVectorContiguousEps(db.values(0), rowV, nCols, eps, 
rl, ru);
                else
                        addVectorEps(db, rowV, nCols, eps, rl, ru);
-
        }
 
        private static void addValue(final double[] retV, final double v, final 
int rl, final int ru) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java
index 485599e382..b397825d5a 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java
@@ -97,6 +97,49 @@ public final class CLALibUtils {
                return false;
        }
 
+       protected static boolean alreadyPreFiltered(List<AColGroup> groups, int 
nCol) {
+               boolean constFound = false;
+               for(AColGroup g : groups) {
+                       if(g instanceof AMorphingMMColGroup || g instanceof 
ColGroupEmpty || g.isEmpty() ||
+                               (constFound && g instanceof ColGroupConst))
+                               return false;
+                       else if(g instanceof ColGroupConst){
+                               if(g.getNumCols() != nCol)
+                                       return false;
+                               
+                               constFound = true;
+                       }
+               }
+
+               return true;
+       }
+
+       protected static double[] 
filterGroupsAndSplitPreAggOneConst(List<AColGroup> groups, List<AColGroup> 
noPreAggGroups,
+               List<APreAgg> preAggGroups) {
+               double[] consts = null;
+               for(AColGroup g : groups) {
+                       if(g instanceof ColGroupConst)
+                         consts = ((ColGroupConst) g).getValues();
+                       else if(g instanceof APreAgg)
+                               preAggGroups.add((APreAgg) g);
+                       else
+                               noPreAggGroups.add(g);
+               }
+
+               return consts;
+       }
+
+       protected static double[] 
filterGroupsAndSplitPreAggOneConst(List<AColGroup> groups, List<AColGroup> out) 
{
+               double[] consts = null;
+               for(AColGroup g : groups) {
+                       if(g instanceof ColGroupConst)
+                               consts = ((ColGroupConst) g).getValues();
+                       else
+                               out.add(g);
+               }
+               return consts;
+       }
+
        /**
         * Helper method to determine if the column groups contains Morphing or 
Frame of reference groups.
         * 

Reply via email to