Repository: systemml Updated Branches: refs/heads/master 6bf3e7836 -> e0006a272
[SYSTEMML-2050] Performance ifelse w/ constant expression (shallow copy) This patch significantly improves the performance of ifelse ternary operations for the case of scalar expression predicates or constant matrix expressions. Specifically, we now use a conditional shallow copy of the input if it's safe to do so. On a scenario of 20 iterations, inputs of size 10K x 10K (dense), and ifelse(TRUE, X, Y), the total runtime improved from 23.3s to 6ms (~4000x). Similarly, for ifelse(TRUE, TRUE, Y), the total runtime reduced to 5.7s (~4x). Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/280d9c9a Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/280d9c9a Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/280d9c9a Branch: refs/heads/master Commit: 280d9c9a5cdbf7a38145903e0df231cebfe3977d Parents: 6bf3e78 Author: Matthias Boehm <[email protected]> Authored: Thu Jan 18 17:53:21 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Thu Jan 18 17:53:21 2018 -0800 ---------------------------------------------------------------------- .../sysml/runtime/matrix/data/MatrixBlock.java | 50 ++++++++++++++------ 1 file changed, 36 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/280d9c9a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java index 5a96bcb..56a3f50 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java @@ -47,6 +47,7 @@ import org.apache.sysml.runtime.functionobjects.CM; import org.apache.sysml.runtime.functionobjects.CTable; import org.apache.sysml.runtime.functionobjects.DiagIndex; import org.apache.sysml.runtime.functionobjects.Divide; +import org.apache.sysml.runtime.functionobjects.IfElse; import org.apache.sysml.runtime.functionobjects.KahanFunction; import org.apache.sysml.runtime.functionobjects.KahanPlus; import org.apache.sysml.runtime.functionobjects.KahanPlusSq; @@ -2786,8 +2787,6 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab public MatrixBlock ternaryOperations(TernaryOperator op, MatrixBlock m2, MatrixBlock m3, MatrixBlock ret) throws DMLRuntimeException { - //TODO perf for special cases like ifelse - //prepare inputs final boolean s1 = (rlen==1 && clen==1); final boolean s2 = (m2.rlen==1 && m2.clen==1); @@ -2797,6 +2796,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab final double d3 = s3 ? m3.quickGetValue(0, 0) : Double.NaN; final int m = Math.max(Math.max(rlen, m2.rlen), m3.rlen); final int n = Math.max(Math.max(clen, m2.clen), m3.clen); + final long nnz = nonZeros; //error handling if( (!s1 && (rlen != m || clen != n)) @@ -2808,19 +2808,41 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab //prepare result ret.reset(m, n, false); - ret.allocateDenseBlock(); - - //basic ternary operations - for( int i=0; i<m; i++ ) - for( int j=0; j<n; j++ ) { - double in1 = s1 ? d1 : quickGetValue(i, j); - double in2 = s2 ? d2 : m2.quickGetValue(i, j); - double in3 = s3 ? d3 : m3.quickGetValue(i, j); - ret.appendValue(i, j, op.fn.execute(in1, in2, in3)); - } - //ensure correct output representation - ret.examSparsity(); + if( op.fn instanceof IfElse && (s1 || nnz==0 || nnz==(long)m*n) ) + { + //special case for shallow-copy if-else + boolean expr = s1 ? (d1 != 0) : (nnz==(long)m*n); + MatrixBlock tmp = expr ? m2 : m3; + if( tmp.rlen==m && tmp.clen==n ) { + //shallow copy incl meta data + ret.copyShallow(tmp); + } + else { + //fill output with given scalar value + double tmpVal = tmp.quickGetValue(0, 0); + if( tmpVal != 0 ) { + ret.allocateDenseBlock(); + ret.denseBlock.set(tmpVal); + ret.nonZeros = (long)m * n; + } + } + } + else { + ret.allocateDenseBlock(); + + //basic ternary operations + for( int i=0; i<m; i++ ) + for( int j=0; j<n; j++ ) { + double in1 = s1 ? d1 : quickGetValue(i, j); + double in2 = s2 ? d2 : m2.quickGetValue(i, j); + double in3 = s3 ? d3 : m3.quickGetValue(i, j); + ret.appendValue(i, j, op.fn.execute(in1, in2, in3)); + } + + //ensure correct output representation + ret.examSparsity(); + } return ret; }
