Baunsgaard commented on code in PR #1955: URL: https://github.com/apache/systemds/pull/1955#discussion_r1420746310
########## src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java: ########## @@ -2506,39 +2506,64 @@ private static void matrixMultTransposeSelfSparse( MatrixBlock m1, MatrixBlock r } private static void matrixMultTransposeSelfUltraSparse( MatrixBlock m1, MatrixBlock ret, boolean leftTranspose, int rl, int ru ) { - if( leftTranspose ) - throw new DMLRuntimeException("Left tsmm with sparse output not supported"); - - // Operation X%*%t(X), sparse input and output - SparseBlock a = m1.sparseBlock; - SparseBlock c = ret.sparseBlock; + SparseBlock a = m1.sparseBlock; + SparseBlock c = ret.sparseBlock; int m = m1.rlen; - - final int blocksize = 256; - for(int bi=rl; bi<ru; bi+=blocksize) { //blocking rows in X - int bimin = Math.min(bi+blocksize, ru); - for(int i=bi; i<bimin; i++) //preallocation - if( !a.isEmpty(i) ) - c.allocate(i, 8*SparseRowVector.initialCapacity); //heuristic - for(int bj=bi; bj<m; bj+=blocksize ) { //blocking cols in t(X) - int bjmin = Math.min(bj+blocksize, m); - for(int i=bi; i<bimin; i++) { //rows in X - if( a.isEmpty(i) ) continue; - int apos = a.pos(i); - int alen = a.size(i); - int[] aix = a.indexes(i); - double[] avals = a.values(i); - for(int j=Math.max(bj,i); j<bjmin; j++) { //cols in t(X) - if( a.isEmpty(j) ) continue; - int bpos = a.pos(j); - int blen = a.size(j); - int[] bix = a.indexes(j); - double[] bvals = a.values(j); - - //compute sparse dot product and append - double v = dotProduct(avals, aix, apos, alen, bvals, bix, bpos, blen); - if( v != 0 ) - c.append(i, j, v); + + if(leftTranspose) { + // Operation t(X)%*%X, sparse input and output + for(int i=0; i<m; i++) + c.allocate(i, 8*SparseRowVector.initialCapacity); + SparseRow[] sr = ((SparseBlockMCSR) c).getRows(); + for( int r=0; r<a.numRows(); r++ ) { + if( a.isEmpty(r) ) continue; + final int alen = a.size(r); + final double[] avals = a.values(r); + final int apos = a.pos(r); + int[] aix = a.indexes(r); + int rlix = (rl==0) ? 0 : a.posFIndexGTE(r, rl); + + if(rlix>=0) + rlix = apos+rlix; Review Comment: what i ment with the if statement: ``` if(rlix >= 0){ int len = apos + alen; for(int i = rlix; i < len && aix[i] < ru; i++) { for (int k = a.posFIndexGTE(r, aix[i]); k < len; k++) { sr[aix[i]].add(c.pos(k) + aix[k], avals[i] * avals[k]); } } } ``` you do not need the else case because that terminate on first call to the for loop. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: dev-unsubscr...@systemds.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org