This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit c42a629d1d9d3896fa22ec9793bb6e21df3a76b3 Author: Sebastian Baunsgaard <[email protected]> AuthorDate: Sun Dec 29 22:06:42 2024 +0100 [MINOR] Fused decompression in CLALibScalar Closes #2169 --- .../sysds/runtime/compress/lib/CLALibScalar.java | 84 +++++++++++++++++++++- 1 file changed, 81 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java index a8b1a8ad22..5588a538aa 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java @@ -31,12 +31,14 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; import org.apache.sysds.runtime.compress.colgroup.ColGroupOLE; import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.functionobjects.Divide; import org.apache.sysds.runtime.functionobjects.Minus; import org.apache.sysds.runtime.functionobjects.Multiply; @@ -57,10 +59,13 @@ public final class CLALibScalar { } public static MatrixBlock scalarOperations(ScalarOperator sop, CompressedMatrixBlock m1, MatrixValue result) { + // Timing time = new Timing(true); if(isInvalidForCompressedOutput(m1, sop)) { LOG.warn("scalar overlapping not supported for op: " + sop.fn.getClass().getSimpleName()); - MatrixBlock m1d = m1.decompress(sop.getNumThreads()); - return m1d.scalarOperations(sop, result); + + return fusedScalarAndDecompress(m1, sop); + // MatrixBlock m1d = m1.decompress(sop.getNumThreads()); + // return m1d.scalarOperations(sop, result); } CompressedMatrixBlock ret = setupRet(m1, result); @@ -89,11 +94,84 @@ public final class CLALibScalar { ret.setOverlapping(m1.isOverlapping()); } - ret.recomputeNonZeros(); + if(sop.fn instanceof Divide) { + ret.setNonZeros(m1.getNonZeros()); + } + else { + ret.recomputeNonZeros(); + } + // System.out.println("CLA Scalar: " + sop + " " + m1.getNumRows() + ", " + m1.getNumColumns() + ", " + + // m1.getColGroups().size() + // + " -- " + "\t\t" + time.stop()); return ret; } + private static MatrixBlock fusedScalarAndDecompress(CompressedMatrixBlock in, ScalarOperator sop) { + int k = sop.getNumThreads(); + ExecutorService pool = CommonThreadPool.get(k); + try { + final int nRow = in.getNumRows(); + final int nCol = in.getNumColumns(); + final MatrixBlock out = new MatrixBlock(nRow, nCol, false); + final List<AColGroup> groups = in.getColGroups(); + out.allocateDenseBlock(); + final DenseBlock db = out.getDenseBlock(); + final int blkz = Math.max((int)(Math.ceil((double)nRow / k)), 256); + final List<Future<Long>> tasks = new ArrayList<>(); + for(int i = 0; i < nRow; i += blkz) { + final int start = i; + final int end = Math.min(i + blkz, nRow); + tasks.add(pool.submit(() -> fusedDecompressAndScalar(groups, nCol, start, end, db, sop))); + } + long nnz = 0; + for(Future<Long> t : tasks) { + nnz += t.get(); + } + out.setNonZeros(nnz); + out.examSparsity(true, k); + return out; + } + catch(Exception e) { + throw new DMLCompressionException("failed fused scalar operation", e); + } + finally { + pool.shutdown(); + } + + // MatrixBlock m1d = m1.decompress(sop.getNumThreads()); + // return m1d.scalarOperations(sop, result); + } + + private static long fusedDecompressAndScalar(final List<AColGroup> groups, int nCol, int start, int end, + DenseBlock db, ScalarOperator sop) { + long nnz = 0; + for(int b = start; b < end; b += 32) { + int bs = b; + int be = Math.min(b + 32, end); + nnz += fusedDecompressAndScalarBlock(groups, nCol, bs, be, db, sop); + } + return nnz; + } + + private static long fusedDecompressAndScalarBlock(final List<AColGroup> groups, int nCol, int bs, int be, + DenseBlock db, ScalarOperator sop) { + long nnz = 0; + for(AColGroup g : groups) { + // main block to optimize is decompression speed since it is most likely an overlapping input + g.decompressToDenseBlock(db, bs, be); + } + for(int r = bs; r < be; r++) { + double[] vals = db.values(r); + int off = db.pos(r); + for(int c = off; c < nCol + off; c++) { + vals[c] = sop.executeScalar(vals[c]); + nnz += vals[c] == 0 ? 0 : 1; + } + } + return nnz; + } + private static CompressedMatrixBlock setupRet(CompressedMatrixBlock m1, MatrixValue result) { CompressedMatrixBlock ret; if(result == null || !(result instanceof CompressedMatrixBlock))
