[SYSTEMML-1861] Performance sparse-sparse binary mult operations This patch improves the performance of sparse-sparse binary multiply operations. Instead of using a merge join with outer join semantics, we now use a dedicated case that realizes multiply via inner join semantics and branchless position maintenance.
On a scenario of X * Y, with 1M x 1K, sparsity=0.1 inputs, this patch improved performance from 330ms to 235ms. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/65e2a46d Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/65e2a46d Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/65e2a46d Branch: refs/heads/master Commit: 65e2a46d2bccfebe4ed5a566d02c34a1cb816da5 Parents: 06fa73a Author: Matthias Boehm <[email protected]> Authored: Tue Aug 22 22:13:05 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Wed Aug 23 12:41:47 2017 -0700 ---------------------------------------------------------------------- .../runtime/matrix/data/LibMatrixBincell.java | 79 ++++++++++---------- 1 file changed, 39 insertions(+), 40 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/65e2a46d/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java index e188b4e..9489225 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java @@ -1184,52 +1184,51 @@ public class LibMatrixBincell } } - /** - * like a merge sort - * - * @param op binary operator - * @param values1 array of double values - * @param cols1 ? - * @param pos1 ? - * @param size1 ? - * @param values2 array of double values - * @param cols2 ? - * @param pos2 ? - * @param size2 ? - * @param resultRow ? - * @param result matrix block - * @throws DMLRuntimeException if DMLRuntimeException occurs - */ private static void mergeForSparseBinary(BinaryOperator op, double[] values1, int[] cols1, int pos1, int size1, - double[] values2, int[] cols2, int pos2, int size2, int resultRow, MatrixBlock result) + double[] values2, int[] cols2, int pos2, int size2, int resultRow, MatrixBlock result) throws DMLRuntimeException { - int p1=0, p2=0, column; - while( p1<size1 && p2< size2 ) - { - double value = 0; - if(cols1[pos1+p1]<cols2[pos2+p2]) { - value = op.fn.execute(values1[pos1+p1], 0); - column = cols1[pos1+p1]; - p1++; + int p1 = 0, p2 = 0; + if( op.fn instanceof Multiply ) { //skip empty + //skip empty: merge-join (with inner join semantics) + //similar to sorted list intersection + SparseBlock sblock = result.getSparseBlock(); + sblock.allocate(resultRow, Math.min(size1, size2), result.clen); + while( p1 < size1 && p2 < size2 ) { + int colPos1 = cols1[pos1+p1]; + int colPos2 = cols2[pos2+p2]; + if( colPos1 == colPos2 ) + sblock.append(resultRow, colPos1, + op.fn.execute(values1[pos1+p1], values2[pos2+p2])); + p1 += (colPos1 <= colPos2) ? 1 : 0; + p2 += (colPos1 >= colPos2) ? 1 : 0; } - else if(cols1[pos1+p1]==cols2[pos2+p2]) { - value = op.fn.execute(values1[pos1+p1], values2[pos2+p2]); - column = cols1[pos1+p1]; - p1++; - p2++; - } - else { - value = op.fn.execute(0, values2[pos2+p2]); - column = cols2[pos2+p2]; - p2++; + result.nonZeros += sblock.size(resultRow); + } + else { + //general case: merge-join (with outer join semantics) + while( p1 < size1 && p2 < size2 ) { + if(cols1[pos1+p1]<cols2[pos2+p2]) { + result.appendValue(resultRow, cols1[pos1+p1], + op.fn.execute(values1[pos1+p1], 0)); + p1++; + } + else if(cols1[pos1+p1]==cols2[pos2+p2]) { + result.appendValue(resultRow, cols1[pos1+p1], + op.fn.execute(values1[pos1+p1], values2[pos2+p2])); + p1++; + p2++; + } + else { + result.appendValue(resultRow, cols2[pos2+p2], + op.fn.execute(0, values2[pos2+p2])); + p2++; + } } - result.appendValue(resultRow, column, value); + //add left over + appendLeftForSparseBinary(op, values1, cols1, pos1, size1, p1, resultRow, result); + appendRightForSparseBinary(op, values2, cols2, pos2, size2, p2, resultRow, result); } - - //add left over - appendLeftForSparseBinary(op, values1, cols1, pos1, size1, p1, resultRow, result); - appendRightForSparseBinary(op, values2, cols2, pos2, size2, p2, resultRow, result); } private static void appendLeftForSparseBinary(BinaryOperator op, double[] values1, int[] cols1, int pos1, int size1,
