Repository: systemml Updated Branches: refs/heads/master ab8cccdff -> cca6356f8
[MINOR] Utility to obtain the exact output sparsity of sparse products Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/cca6356f Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/cca6356f Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/cca6356f Branch: refs/heads/master Commit: cca6356f8de49dcb6aeb1f23cefd53930309fedb Parents: ab8cccd Author: Matthias Boehm <mboe...@gmail.com> Authored: Sun Oct 21 02:30:25 2018 +0200 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Sun Oct 21 02:30:25 2018 +0200 ---------------------------------------------------------------------- .../sysml/hops/estim/EstimationUtils.java | 56 ++++++++++++++++++++ 1 file changed, 56 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/cca6356f/src/main/java/org/apache/sysml/hops/estim/EstimationUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimationUtils.java b/src/main/java/org/apache/sysml/hops/estim/EstimationUtils.java index d4d30bb..29de2f5 100644 --- a/src/main/java/org/apache/sysml/hops/estim/EstimationUtils.java +++ b/src/main/java/org/apache/sysml/hops/estim/EstimationUtils.java @@ -21,6 +21,7 @@ package org.apache.sysml.hops.estim; import java.util.Arrays; +import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.matrix.data.DenseBlock; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.SparseBlock; @@ -106,4 +107,59 @@ public abstract class EstimationUtils } return retNnz; } + + public static long getSparseProductOutputNnz(MatrixBlock m1, MatrixBlock m2) { + if( !m1.isInSparseFormat() || !m2.isInSparseFormat() ) + throw new DMLRuntimeException("Invalid call to sparse output nnz estimation."); + + final int m = m1.getNumRows(); + final int n2 = m2.getNumColumns(); + long retNnz = 0; + + SparseBlock a = m1.getSparseBlock(); + SparseBlock b = m2.getSparseBlock(); + + SparseRowVector tmpS = new SparseRowVector(1024); + double[] tmpD = null; + + for( int i=0; i<m; i++ ) { + if( a.isEmpty(i) ) continue; + int alen = a.size(i); + int apos = a.pos(i); + int[] aix = a.indexes(i); + double[] avals = a.values(i); + + //compute number of aggregated non-zeros for input row + int nnz1 = (int) Math.min(UtilFunctions.computeNnz(b, aix, apos, alen), n2); + boolean ldense = nnz1 > n2 / 128; + + //perform vector-matrix multiply w/ dense or sparse output + if( ldense ) { //init dense tmp row + tmpD = (tmpD == null) ? new double[n2] : tmpD; + Arrays.fill(tmpD, 0); + } + else { + tmpS.setSize(0); + } + for( int k=apos; k<apos+alen; k++ ) { + if( b.isEmpty(aix[k]) ) continue; + int blen = b.size(aix[k]); + int bpos = b.pos(aix[k]); + int[] bix = b.indexes(aix[k]); + double aval = avals[k]; + double[] bvals = b.values(aix[k]); + if( ldense ) { //dense aggregation + for( int j=bpos; j<bpos+blen; j++ ) + tmpD[bix[j]] += aval * bvals[j]; + } + else { //sparse aggregation + for( int j=bpos; j<bpos+blen; j++ ) + tmpS.add(bix[j], aval * bvals[j]); + } + } + retNnz += !ldense ? tmpS.size() : + UtilFunctions.computeNnz(tmpD, 0, n2); + } + return retNnz; + } }