Baunsgaard commented on a change in pull request #1480:
URL: https://github.com/apache/systemds/pull/1480#discussion_r768109239
##########
File path: src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
##########
@@ -5096,34 +5081,52 @@ public final MatrixBlock
aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m
return aggregateBinaryOperations(m1, m2, null, op);
}
- public MatrixBlock aggregateBinaryOperations(MatrixBlock m1,
MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
+ public final MatrixBlock aggregateBinaryOperations(MatrixBlock m1,
MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
//check input types, dimensions, configuration
- if( m1.clen != m2.rlen ) {
+ if( m1.clen != m2.rlen )
throw new RuntimeException("Dimensions do not match for
matrix multiplication ("+m1.clen+"!="+m2.rlen+").");
- }
- if( !(op.binaryFn instanceof Multiply && op.aggOp.increOp.fn
instanceof Plus) ) {
+ if( !(op.binaryFn instanceof Multiply && op.aggOp.increOp.fn
instanceof Plus) )
throw new DMLRuntimeException("Unsupported binary
aggregate operation: ("+op.binaryFn+", "+op.aggOp+").");
+ if(!(m1 == this || m2 == this))
+ throw new DMLRuntimeException("Invalid
aggregateBinaryOperatio: one of either input should be this");
+ return matrixMult(m1, m2, ret, op.getNumThreads());
+ }
+
+ public MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2,
MatrixBlock ret, int k){
+
+ final int rl = m1.rlen;
+ final int cl = m2.clen;
+
+ if(m1.isEmptyBlock(false) || m2.isEmptyBlock(false)) {
+ if(ret == null)
+ return new MatrixBlock(rl, cl, true);
+ else {
+ ret.reset(rl, cl, true);
+ ret.setNonZeros(0);
+ ret.cleanupBlock(true, true);
+ return ret;
+ }
}
-
- //setup meta data (dimensions, sparsity)
- int rl = m1.rlen;
- int cl = m2.clen;
- SparsityEstimate sp = estimateSparsityOnAggBinary(m1, m2, op);
-
- //create output matrix block
- if( ret==null )
- ret = new MatrixBlock(rl, cl, sp.sparse,
sp.estimatedNonZeros);
+
+ final boolean m1Perm = m1.isSparsePermutationMatrix();
+ final boolean ultraSparse =
LibMatrixMult.isUltraSparseMatrixMult(m1, m2, m1Perm);
+ final boolean sparse = !m1Perm && !ultraSparse &&
LibMatrixMult.isSparseOutputMatrixMult(m1, m2);
+ final boolean sparseRet = ultraSparse | sparse;
Review comment:
copied sparsity check from matrix mult lib
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]