Baunsgaard commented on a change in pull request #1480:
URL: https://github.com/apache/systemds/pull/1480#discussion_r768109770



##########
File path: src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
##########
@@ -5096,34 +5081,52 @@ public final MatrixBlock 
aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m
                return aggregateBinaryOperations(m1, m2, null, op);
        }
 
-       public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
+       public final MatrixBlock aggregateBinaryOperations(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
                //check input types, dimensions, configuration
-               if( m1.clen != m2.rlen ) {
+               if( m1.clen != m2.rlen )
                        throw new RuntimeException("Dimensions do not match for 
matrix multiplication ("+m1.clen+"!="+m2.rlen+").");
-               }
-               if( !(op.binaryFn instanceof Multiply && op.aggOp.increOp.fn 
instanceof Plus) ) {
+               if( !(op.binaryFn instanceof Multiply && op.aggOp.increOp.fn 
instanceof Plus) )
                        throw new DMLRuntimeException("Unsupported binary 
aggregate operation: ("+op.binaryFn+", "+op.aggOp+").");
+               if(!(m1 == this || m2 == this))
+                       throw new DMLRuntimeException("Invalid 
aggregateBinaryOperatio: one of either input should be this");
+               return matrixMult(m1, m2, ret, op.getNumThreads());
+       }
+
+       public MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, int k){
+
+               final int rl = m1.rlen;
+               final int cl = m2.clen;
+
+               if(m1.isEmptyBlock(false) || m2.isEmptyBlock(false)) {
+                       if(ret == null)
+                               return new MatrixBlock(rl, cl, true);
+                       else {
+                               ret.reset(rl, cl, true);
+                               ret.setNonZeros(0);
+                               ret.cleanupBlock(true, true);
+                               return ret;
+                       }
                }
-               
-               //setup meta data (dimensions, sparsity)
-               int rl = m1.rlen;
-               int cl = m2.clen;
-               SparsityEstimate sp = estimateSparsityOnAggBinary(m1, m2, op);
-               
-               //create output matrix block
-               if( ret==null )
-                       ret = new MatrixBlock(rl, cl, sp.sparse, 
sp.estimatedNonZeros);
+
+               final boolean m1Perm = m1.isSparsePermutationMatrix();
+               final boolean ultraSparse = 
LibMatrixMult.isUltraSparseMatrixMult(m1, m2, m1Perm);
+               final boolean sparse = !m1Perm && !ultraSparse && 
LibMatrixMult.isSparseOutputMatrixMult(m1, m2);
+               final boolean sparseRet = ultraSparse | sparse;
+
+               // create output matrix block
+               if(ret == null)
+                       ret = new MatrixBlock(rl, cl, sparseRet);
                else
-                       ret.reset(rl, cl, sp.sparse, sp.estimatedNonZeros);
-               
-               //compute matrix multiplication (only supported binary 
aggregate operation)
-               if( NativeHelper.isNativeLibraryLoaded() )
-                       LibMatrixNative.matrixMult(m1, m2, ret, 
op.getNumThreads());
-               else if( op.getNumThreads() > 1 )
-                       LibMatrixMult.matrixMult(m1, m2, ret, 
op.getNumThreads());
+                       ret.reset(rl, cl, sparseRet);
+               ret.allocateBlock();
+
+               if(!sparseRet && NativeHelper.isNativeLibraryLoaded())
+                       LibMatrixNative.matrixMult(m1, m2, ret, k);
+               else if(k > 1)
+                       LibMatrixMult.matrixMult(m1, m2, ret, k, m1Perm, 
ultraSparse, sparse);
                else
-                       LibMatrixMult.matrixMult(m1, m2, ret);
-               
+                       LibMatrixMult.matrixMult(m1, m2, ret, m1Perm, 
ultraSparse, sparse);

Review comment:
       all calls to kernel does skips some of the steps previously duplicated.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: dev-unsubscr...@systemds.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to