This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 2bc40ed35e9be42e8442ca5227c61ec10220b7d4 Author: baunsgaard <[email protected]> AuthorDate: Sat Mar 13 11:00:01 2021 +0100 [SYSTEMDS-2897] CLA decompressing write --- .../runtime/compress/CompressedMatrixBlock.java | 58 ++++++++++------------ .../sysds/runtime/compress/colgroup/AColGroup.java | 2 + .../runtime/compress/colgroup/ADictionary.java | 2 + .../compress/colgroup/ColGroupUncompressed.java | 5 ++ .../runtime/compress/colgroup/ColGroupValue.java | 6 +++ .../runtime/compress/colgroup/Dictionary.java | 16 ++++++ .../runtime/compress/colgroup/QDictionary.java | 16 ++++++ .../sysds/runtime/compress/lib/CLALibAppend.java | 9 ++-- .../sysds/runtime/compress/lib/CLALibCompAgg.java | 2 +- .../runtime/compress/lib/CLALibLeftMultBy.java | 50 +++---------------- .../sysds/runtime/compress/lib/CLALibReExpand.java | 2 +- .../runtime/compress/lib/CLALibRightMultBy.java | 4 +- .../sysds/runtime/compress/lib/CLALibScalar.java | 6 ++- .../sysds/runtime/compress/lib/CLALibSquash.java | 2 +- .../controlprogram/caching/CacheableData.java | 2 +- .../controlprogram/caching/MatrixObject.java | 7 ++- .../apache/sysds/runtime/util/DataConverter.java | 7 ++- 17 files changed, 107 insertions(+), 89 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java index f2958da..58688b1 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java @@ -205,32 +205,16 @@ public class CompressedMatrixBlock extends MatrixBlock { // preallocation sparse rows to avoid repeated reallocations MatrixBlock ret = new MatrixBlock(rlen, clen, false, -1); + if(nonZeros == -1) + ret.setNonZeros(this.recomputeNonZeros()); + else + ret.setNonZeros(nonZeros); ret.allocateDenseBlock(); - // (nonZeros == -1) ? - // .allocateBlock() : new MatrixBlock(rlen, clen, sparse, - // nonZeros).allocateBlock(); - - // if(ret.isInSparseFormat()) { - // int[] rnnz = new int[rlen]; - // // for(ColGroup grp : _colGroups) - // // grp.countNonZerosPerRow(rnnz, 0, rlen); - // ret.allocateSparseRowsBlock(); - // SparseBlock rows = ret.getSparseBlock(); - // for(int i = 0; i < rlen; i++) - // rows.allocate(i, rnnz[i]); - // } + // todo Add sparse decompress. - // core decompression (append if sparse) for(AColGroup grp : _colGroups) grp.decompressToBlockUnSafe(ret, 0, rlen, 0, grp.getValues()); - // post-processing (for append in decompress) - if(ret.getNonZeros() == -1 || nonZeros == -1) { - ret.recomputeNonZeros(); - } - else { - ret.setNonZeros(nonZeros); - } if(ret.isInSparseFormat()) ret.sortSparseRows(); @@ -256,8 +240,10 @@ public class CompressedMatrixBlock extends MatrixBlock { Timing time = new Timing(true); MatrixBlock ret = new MatrixBlock(rlen, clen, false, -1).allocateBlock(); - - nonZeros = 0; + if(nonZeros == -1) + ret.setNonZeros(this.recomputeNonZeros()); + else + ret.setNonZeros(nonZeros); boolean overlapping = isOverlapping(); try { ExecutorService pool = CommonThreadPool.get(k); @@ -272,20 +258,13 @@ public class CompressedMatrixBlock extends MatrixBlock { List<Future<Long>> rtasks = pool.invokeAll(tasks); pool.shutdown(); for(Future<Long> rt : rtasks) - nonZeros += rt.get(); // error handling + rt.get(); // error handling } catch(InterruptedException | ExecutionException ex) { LOG.error("Parallel decompression failed defaulting to non parallel implementation " + ex.getMessage()); - nonZeros = -1; ex.printStackTrace(); return decompress(); } - if(overlapping) { - ret.recomputeNonZeros(); - } - else { - ret.setNonZeros(nonZeros); - } if(DMLScript.STATISTICS || LOG.isDebugEnabled()) { double t = time.stop(); @@ -299,6 +278,22 @@ public class CompressedMatrixBlock extends MatrixBlock { return CLALibSquash.squash(this, k); } + @Override + public long recomputeNonZeros() { + if(overlappingColGroups) { + nonZeros = clen * rlen; + } + else { + long nnz = 0; + for(AColGroup g : _colGroups) { + nnz += g.getNumberNonZeros(); + } + nonZeros = nnz; + } + return nonZeros; + + } + /** * Obtain an upper bound on the memory used to store the compressed block. * @@ -497,6 +492,7 @@ public class CompressedMatrixBlock extends MatrixBlock { CLALibLeftMultBy.leftMultByMatrixTransposed(this, tmp, out, k); out = LibMatrixReorg.transposeInPlace(out, k); + out.recomputeNonZeros(); return out; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java index 808d43b..33e9d11 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java @@ -646,6 +646,8 @@ public abstract class AColGroup implements Serializable { public abstract boolean containsValue(double pattern); + public abstract long getNumberNonZeros(); + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictionary.java index 8422a3f..10b069e 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictionary.java @@ -250,4 +250,6 @@ public abstract class ADictionary { public abstract ADictionary reExpandColumns(int max); public abstract boolean containsValue(double pattern); + + public abstract long getNumberNonZeros(int[] counts, int nCol); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java index 824ff8c..1722ffd 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java @@ -543,4 +543,9 @@ public class ColGroupUncompressed extends AColGroup { public boolean containsValue(double pattern){ return _data.containsValue(pattern); } + + @Override + public long getNumberNonZeros(){ + return _data.getNonZeros(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java index d0a6ee9..085dd1d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java @@ -1030,4 +1030,10 @@ public abstract class ColGroupValue extends AColGroup implements Cloneable { public boolean containsValue(double pattern){ return _dict.containsValue(pattern); } + + @Override + public long getNumberNonZeros(){ + int[] counts = getCounts(); + return _dict.getNumberNonZeros(counts, _colIndexes.length); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/Dictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/Dictionary.java index 658decb..34fea3d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/Dictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/Dictionary.java @@ -387,4 +387,20 @@ public class Dictionary extends ADictionary { return false; } + + @Override + public long getNumberNonZeros(int[] counts, int nCol){ + long nnz = 0; + final int nRow = _values.length / nCol; + for(int i = 0; i < nRow; i++){ + long rowCount = 0; + final int off = i * nCol; + for(int j = off; j < off + nCol; j++){ + if(_values[j] != 0) + rowCount ++; + } + nnz += rowCount * counts[i]; + } + return nnz; + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/QDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/QDictionary.java index 05b1817..8986d71 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/QDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/QDictionary.java @@ -452,4 +452,20 @@ public class QDictionary extends ADictionary { return false; throw new NotImplementedException("Not contains value on Q Dictionary"); } + + @Override + public long getNumberNonZeros(int[] counts, int nCol){ + long nnz = 0; + final int nRow = _values.length / nCol; + for(int i = 0; i < nRow; i++){ + long rowCount = 0; + final int off = i * nCol; + for(int j = off; j < off + nCol; j++){ + if(_values[j] != 0) + rowCount ++; + } + nnz += rowCount * counts[i]; + } + return nnz; + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibAppend.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibAppend.java index c6bf61d..2ea5397 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibAppend.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibAppend.java @@ -42,9 +42,7 @@ public class CLALibAppend { // return left; final int m = left.getNumRows(); final int n = left.getNumColumns() + right.getNumColumns(); - long nnz = left.getNonZeros() + right.getNonZeros(); - if(left.getNonZeros() < 0 || right.getNonZeros() < 0) - nnz = -1; + // try to compress both sides (if not already compressed). if(!(left instanceof CompressedMatrixBlock) && m > 1000){ @@ -85,8 +83,11 @@ public class CLALibAppend { ret.getColGroups().add(tmp); } + long nnzl = (leftC.getNonZeros() <= -1 ) ? leftC.recomputeNonZeros() : leftC.getNonZeros() ; + long nnzr = (rightC.getNonZeros() <= -1 ) ? rightC.recomputeNonZeros() : rightC.getNonZeros() ; + // meta data maintenance - ret.setNonZeros(nnz); + ret.setNonZeros(nnzl + nnzr); return ret; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java index 2eb6888..f7b288d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java @@ -56,7 +56,7 @@ public class CLALibCompAgg { // private static final Log LOG = LogFactory.getLog(LibCompAgg.class.getName()); // private static final long MIN_PAR_AGG_THRESHOLD = 8 * 1024 * 1024; - private static final long MIN_PAR_AGG_THRESHOLD = 8; + private static final long MIN_PAR_AGG_THRESHOLD = 8 * 1024 ; private static ThreadLocal<MatrixBlock> memPool = new ThreadLocal<MatrixBlock>() { @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java index f6390c6..3ef03f3 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java @@ -59,7 +59,9 @@ public class CLALibLeftMultBy { public static MatrixBlock leftMultByMatrixTransposed(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k){ MatrixBlock transposed = new MatrixBlock(m2.getNumColumns(), m2.getNumRows(), false); LibMatrixReorg.transpose(m2, transposed); - return leftMultByMatrix(m1, transposed, ret, k ); + ret = leftMultByMatrix(m1, transposed, ret, k ); + ret.recomputeNonZeros(); + return ret; // return LibMatrixReorg.transpose(ret, new MatrixBlock(ret.getNumColumns(), ret.getNumRows(), false)); } @@ -75,8 +77,10 @@ public class CLALibLeftMultBy { public static MatrixBlock leftMultByMatrix(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) { prepareReturnMatrix(m1, m2, ret, false); - return leftMultByMatrix(m1 + ret = leftMultByMatrix(m1 .getColGroups(), m2, ret, false, m1.getNumColumns(), m1.isOverlapping(), k, m1.getMaxNumValues()); + ret.recomputeNonZeros(); + return ret; } private static MatrixBlock leftMultByMatrix(List<AColGroup> groups, MatrixBlock that, MatrixBlock ret, @@ -172,48 +176,6 @@ public class CLALibLeftMultBy { } } - // public static MatrixBlock leftMultByVectorTranspose(List<AColGroup> colGroups, MatrixBlock vector, - // MatrixBlock result, boolean doTranspose, int k, Pair<Integer, int[]> v, boolean overlap) { - - // // transpose vector if required - // MatrixBlock rowVector = vector; - // if(doTranspose) { - // rowVector = new MatrixBlock(1, vector.getNumRows(), false); - // LibMatrixReorg.transpose(vector, rowVector); - // } - - // result.reset(); - // result.allocateDenseBlock(); - - // // multi-threaded execution - // try { - // // compute uncompressed column group in parallel - // // ColGroupUncompressed uc = getUncompressedColGroup(); - // // if(uc != null) - // // uc.leftMultByRowVector(rowVector, result, k); - - // // compute remaining compressed column groups in parallel - // ExecutorService pool = CommonThreadPool.get(Math.min(colGroups.size(), k)); - // ArrayList<LeftMatrixVectorMultTask> tasks = new ArrayList<>(); - - // tasks.add(new LeftMatrixVectorMultTask(colGroups, rowVector, result, v)); - - // List<Future<Object>> ret = pool.invokeAll(tasks); - // pool.shutdown(); - // for(Future<Object> tmp : ret) - // tmp.get(); - - // } - // catch(InterruptedException | ExecutionException e) { - // throw new DMLRuntimeException(e); - // } - - // // post-processing - // result.recomputeNonZeros(); - - // return result; - // } - private static MatrixBlock leftMultByCompressedTransposedMatrix(List<AColGroup> colGroups, CompressedMatrixBlock that, MatrixBlock ret, int k, int numColumns, Pair<Integer, int[]> v, boolean overlapping) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReExpand.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReExpand.java index 97cf128..2b4e318 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReExpand.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibReExpand.java @@ -68,8 +68,8 @@ public class CLALibReExpand { ret.allocateColGroupList(newColGroups); ret.setOverlapping(true); - ret.setNonZeros(-1); + ret.recomputeNonZeros(); return ret; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java index 8ce20f2..ec28497 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java @@ -42,7 +42,9 @@ public class CLALibRightMultBy { private static final Log LOG = LogFactory.getLog(CLALibRightMultBy.class.getName()); public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k, boolean allowOverlap){ - return rightMultByMatrix(m1.getColGroups(), m2, ret, k, m1.getMaxNumValues(), allowOverlap); + ret = rightMultByMatrix(m1.getColGroups(), m2, ret, k, m1.getMaxNumValues(), allowOverlap); + ret.recomputeNonZeros(); + return ret; } private static MatrixBlock rightMultByMatrix(List<AColGroup> colGroups, MatrixBlock that, MatrixBlock ret, int k, diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java index 97a7f67..1c0a980 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java @@ -57,7 +57,9 @@ public class CLALibScalar { public static MatrixBlock scalarOperations(ScalarOperator sop, CompressedMatrixBlock m1, MatrixValue result) { // Special case handling of overlapping relational operations if(CLALibRelationalOp.isValidForRelationalOperation(sop, m1)) { - return CLALibRelationalOp.overlappingRelativeRelationalOperation(sop, m1); + MatrixBlock ret = CLALibRelationalOp.overlappingRelativeRelationalOperation(sop, m1); + ret.recomputeNonZeros(); + return ret; } if(isInvalidForCompressedOutput(m1, sop)) { @@ -96,7 +98,7 @@ public class CLALibScalar { ret.setOverlapping(m1.isOverlapping()); } - ret.setNonZeros(-1); + ret.recomputeNonZeros(); return ret; diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSquash.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSquash.java index b4438b2..63177c5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSquash.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSquash.java @@ -59,7 +59,7 @@ public class CLALibSquash { ret.allocateColGroupList(retCg); ret.setOverlapping(false); - ret.setNonZeros(-1); + ret.recomputeNonZeros(); if(ret.isOverlapping()) throw new DMLCompressionException("Squash should output compressed nonOverlapping matrix"); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java index 06de8f7..cf6ab3f 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java @@ -1040,7 +1040,7 @@ public abstract class CacheableData<T extends CacheBlock> extends Data protected void writeMetaData (String filePathAndName, String outputFormat, FileFormatProperties formatProperties) throws IOException - { + { MetaDataFormat iimd = (MetaDataFormat) _metaData; if (iimd == null) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java index d0bce6e..e55509b 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java @@ -35,6 +35,7 @@ import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.lops.Lop; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.controlprogram.ParForProgramBlock.PDataPartitionFormat; import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; @@ -583,13 +584,15 @@ public class MatrixObject extends CacheableData<MatrixBlock> begin = System.currentTimeMillis(); } - MetaDataFormat iimd = (MetaDataFormat) _metaData; - if(this.isFederated() && FileFormat.safeValueOf(ofmt) == FileFormat.FEDERATED){ ReaderWriterFederated.write(fname,this._fedMapping); } else if (_data != null) { + if(_data instanceof CompressedMatrixBlock) + _data = CompressedMatrixBlock.getUncompressed(_data); + + MetaDataFormat iimd = (MetaDataFormat) _metaData; // Get the dimension information from the metadata stored within MatrixObject DataCharacteristics mc = iimd.getDataCharacteristics(); // Write the matrix to HDFS in requested format diff --git a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java index c181b5c..51ad590 100644 --- a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java +++ b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java @@ -31,6 +31,8 @@ import java.util.Map.Entry; import java.util.StringTokenizer; import org.apache.commons.lang.StringUtils; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.BlockRealMatrix; import org.apache.commons.math3.linear.RealMatrix; @@ -79,7 +81,7 @@ import org.apache.sysds.runtime.meta.DataCharacteristics; * */ public class DataConverter { - // private static final Log LOG = LogFactory.getLog(DataConverter.class.getName()); + private static final Log LOG = LogFactory.getLog(DataConverter.class.getName()); private static final String DELIM = " "; ////////////// @@ -100,6 +102,9 @@ public class DataConverter { public static void writeMatrixToHDFS(MatrixBlock mat, String dir, FileFormat fmt, DataCharacteristics dc, int replication, FileFormatProperties formatProperties, boolean diag) throws IOException { MatrixWriter writer = MatrixWriterFactory.createMatrixWriter( fmt, replication, formatProperties ); + if(mat instanceof CompressedMatrixBlock) + mat = CompressedMatrixBlock.getUncompressed(mat); + LOG.error(mat.getNonZeros()); writer.writeMatrixToHDFS(mat, dir, dc.getRows(), dc.getCols(), dc.getBlocksize(), dc.getNonZeros(), diag); }
