[MINOR] Reduced instruction footprint of ctable and outer operations Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/223066ee Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/223066ee Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/223066ee
Branch: refs/heads/master Commit: 223066eebdf86a89dc2feb72ff4bd32ca2ed5155 Parents: 6b4eaa6 Author: Matthias Boehm <[email protected]> Authored: Fri Nov 10 16:42:58 2017 -0800 Committer: Matthias Boehm <[email protected]> Committed: Fri Nov 10 16:42:58 2017 -0800 ---------------------------------------------------------------------- .../sysml/runtime/functionobjects/CTable.java | 9 ++ .../runtime/matrix/data/LibMatrixBincell.java | 97 +++++++-------- .../sysml/runtime/matrix/data/MatrixBlock.java | 123 +++++++------------ .../binary/matrix/UaggOuterChainTest.java | 4 +- 4 files changed, 95 insertions(+), 138 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/223066ee/src/main/java/org/apache/sysml/runtime/functionobjects/CTable.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/CTable.java b/src/main/java/org/apache/sysml/runtime/functionobjects/CTable.java index fdb6b85..af6d8a0 100644 --- a/src/main/java/org/apache/sysml/runtime/functionobjects/CTable.java +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/CTable.java @@ -42,6 +42,15 @@ public class CTable extends ValueFunction return singleObj; } + public void execute(double v1, double v2, double w, boolean ignoreZeros, CTableMap resultMap, MatrixBlock resultBlock) + throws DMLRuntimeException + { + if( resultBlock != null ) + execute(v1, v2, w, ignoreZeros, resultBlock); + else + execute(v1, v2, w, ignoreZeros, resultMap); + } + public void execute(double v1, double v2, double w, boolean ignoreZeros, CTableMap resultMap) throws DMLRuntimeException { http://git-wip-us.apache.org/repos/asf/systemml/blob/223066ee/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java index 2878b3b..27bb340 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixBincell.java @@ -775,7 +775,7 @@ public class LibMatrixBincell * */ private static void performBinOuterOperation(MatrixBlock mbLeft, MatrixBlock mbRight, MatrixBlock mbOut, BinaryOperator bOp) - throws DMLRuntimeException + throws DMLRuntimeException { int rlen = mbLeft.rlen; int clen = mbOut.clen; @@ -784,42 +784,37 @@ public class LibMatrixBincell mbOut.allocateDenseBlock(); double c[] = mbOut.getDenseBlock(); + //pre-materialize various types used in inner loop + boolean scanType1 = (bOp.fn instanceof LessThan || bOp.fn instanceof Equals + || bOp.fn instanceof NotEquals || bOp.fn instanceof GreaterThanEquals); + boolean scanType2 = (bOp.fn instanceof LessThanEquals || bOp.fn instanceof Equals + || bOp.fn instanceof NotEquals || bOp.fn instanceof GreaterThan); + boolean lt = (bOp.fn instanceof LessThan), lte = (bOp.fn instanceof LessThanEquals); + boolean gt = (bOp.fn instanceof GreaterThan), gte = (bOp.fn instanceof GreaterThanEquals); + boolean eqNeq = (bOp.fn instanceof Equals || bOp.fn instanceof NotEquals); + long lnnz = 0; - for(int r=0, off=0; r<rlen; r++, off+=clen) { - double value = mbLeft.quickGetValue(r, 0); + for( int r=0, off=0; r<rlen; r++, off+=clen ) { + double value = mbLeft.quickGetValue(r, 0); int ixPos1 = Arrays.binarySearch(b, value); int ixPos2 = ixPos1; - - if( ixPos1 >= 0 ){ //match, scan to next val - if(bOp.fn instanceof LessThan || bOp.fn instanceof GreaterThanEquals - || bOp.fn instanceof Equals || bOp.fn instanceof NotEquals) - while( ixPos1<b.length && value==b[ixPos1] ) ixPos1++; - if(bOp.fn instanceof GreaterThan || bOp.fn instanceof LessThanEquals - || bOp.fn instanceof Equals || bOp.fn instanceof NotEquals) - while( ixPos2 > 0 && value==b[ixPos2-1]) --ixPos2; - } else { + if( ixPos1 >= 0 ) { //match, scan to next val + if(scanType1) while( ixPos1<b.length && value==b[ixPos1] ) ixPos1++; + if(scanType2) while( ixPos2 > 0 && value==b[ixPos2-1]) --ixPos2; + } + else ixPos2 = ixPos1 = Math.abs(ixPos1) - 1; + int start = lt ? ixPos1 : (lte||eqNeq) ? ixPos2 : 0; + int end = gt ? ixPos2 : (gte||eqNeq) ? ixPos1 : clen; + + if (bOp.fn instanceof NotEquals) { + Arrays.fill(c, off, off+start, 1.0); + Arrays.fill(c, off+end, off+clen, 1.0); + lnnz += (start+(clen-end)); } - - int start = 0, end = clen; - if(bOp.fn instanceof LessThan || bOp.fn instanceof LessThanEquals) - start = (bOp.fn instanceof LessThan) ? ixPos1 : ixPos2; - else if(bOp.fn instanceof GreaterThan || bOp.fn instanceof GreaterThanEquals) - end = (bOp.fn instanceof GreaterThan) ? ixPos2 : ixPos1; - else if(bOp.fn instanceof Equals || bOp.fn instanceof NotEquals) { - start = ixPos2; - end = ixPos1; - } - if(start < end || bOp.fn instanceof NotEquals) { - if (bOp.fn instanceof NotEquals) { - Arrays.fill(c, off, off+start, 1.0); - Arrays.fill(c, off+end, off+clen, 1.0); - lnnz += (start+(clen-end)); - } - else { - Arrays.fill(c, off+start, off+end, 1.0); - lnnz += (end-start); - } + else if( start < end ) { + Arrays.fill(c, off+start, off+end, 1.0); + lnnz += (end-start); } } mbOut.setNonZeros(lnnz); @@ -835,14 +830,10 @@ public class LibMatrixBincell if( atype == BinaryAccessType.MATRIX_COL_VECTOR ) //MATRIX - COL_VECTOR { - for(int r=0; r<rlen; r++) - { - //replicated value + for(int r=0; r<rlen; r++) { double v2 = m2.quickGetValue(r, 0); - - for(int c=0; c<clen; c++) - { - double v1 = m1.quickGetValue(r, c); + for(int c=0; c<clen; c++) { + double v1 = m1.quickGetValue(r, c); double v = op.fn.execute( v1, v2 ); ret.appendValue(r, c, v); } @@ -851,9 +842,8 @@ public class LibMatrixBincell else if( atype == BinaryAccessType.MATRIX_ROW_VECTOR ) //MATRIX - ROW_VECTOR { for(int r=0; r<rlen; r++) - for(int c=0; c<clen; c++) - { - double v1 = m1.quickGetValue(r, c); + for(int c=0; c<clen; c++) { + double v1 = m1.quickGetValue(r, c); double v2 = m2.quickGetValue(0, c); double v = op.fn.execute( v1, v2 ); ret.appendValue(r, c, v); @@ -869,12 +859,11 @@ public class LibMatrixBincell } else { for(int r=0; r<rlen; r++) { - double v1 = m1.quickGetValue(r, 0); - for(int c=0; c<clen2; c++) - { + double v1 = m1.quickGetValue(r, 0); + for(int c=0; c<clen2; c++) { double v2 = m2.quickGetValue(0, c); double v = op.fn.execute( v1, v2 ); - ret.appendValue(r, c, v); + ret.appendValue(r, c, v); } } } @@ -882,25 +871,25 @@ public class LibMatrixBincell else // MATRIX - MATRIX { //dense non-empty vectors - if( m1.clen==1 && !m1.sparse && !m1.isEmptyBlock(false) + if( m1.clen==1 && !m1.sparse && !m1.isEmptyBlock(false) && !m2.sparse && !m2.isEmptyBlock(false) ) { ret.allocateDenseBlock(); double[] a = m1.denseBlock; double[] b = m2.denseBlock; double[] c = ret.denseBlock; + int lnnz = 0; for( int i=0; i<rlen; i++ ) { c[i] = op.fn.execute( a[i], b[i] ); - if( c[i] != 0 ) - ret.nonZeros++; + lnnz += (c[i] != 0) ? 1 : 0; } + ret.nonZeros = lnnz; } //general case else { for(int r=0; r<rlen; r++) - for(int c=0; c<clen; c++) - { + for(int c=0; c<clen; c++) { double v1 = m1.quickGetValue(r, c); double v2 = m2.quickGetValue(r, c); double v = op.fn.execute( v1, v2 ); @@ -923,6 +912,8 @@ public class LibMatrixBincell throw new DMLRuntimeException("Unsupported safe binary scalar operations over different input/output representation: "+m1.sparse+" "+ret.sparse); boolean copyOnes = (op.fn instanceof NotEquals && op.getConstant()==0); + boolean allocExact = (op.fn instanceof Multiply + || op.fn instanceof Multiply2 || op.fn instanceof Power2); if( m1.sparse ) //SPARSE <- SPARSE { @@ -954,10 +945,8 @@ public class LibMatrixBincell } else { //GENERAL CASE //create sparse row without repeated resizing for specific ops - if( op.fn instanceof Multiply || op.fn instanceof Multiply2 - || op.fn instanceof Power2 ) { + if( allocExact ) c.allocate(r, alen); - } for(int j=apos; j<apos+alen; j++) { double val = op.executeScalar(avals[j]); http://git-wip-us.apache.org/repos/asf/systemml/blob/223066ee/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java index 924d6c5..91248d2 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java @@ -5210,25 +5210,16 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab //sparse-unsafe ctable execution //(because input values of 0 are invalid and have to result in errors) - if ( resultBlock == null ) { - for( int i=0; i<rlen; i++ ) - for( int j=0; j<clen; j++ ) - { - double v1 = this.quickGetValue(i, j); - double w = that2.quickGetValue(i, j); - ctable.execute(v1, v2, w, false, resultMap); - } - } - else { - for( int i=0; i<rlen; i++ ) - for( int j=0; j<clen; j++ ) - { - double v1 = this.quickGetValue(i, j); - double w = that2.quickGetValue(i, j); - ctable.execute(v1, v2, w, false, resultBlock); - } + for( int i=0; i<rlen; i++ ) + for( int j=0; j<clen; j++ ) { + double v1 = this.quickGetValue(i, j); + double w = that2.quickGetValue(i, j); + ctable.execute(v1, v2, w, false, resultMap, resultBlock); + } + + //maintain nnz (if necessary) + if( resultBlock!=null ) resultBlock.recomputeNonZeros(); - } } /** @@ -5250,23 +5241,15 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab //sparse-unsafe ctable execution //(because input values of 0 are invalid and have to result in errors) - if ( resultBlock == null ) { - for( int i=0; i<rlen; i++ ) - for( int j=0; j<clen; j++ ) - { - double v1 = this.quickGetValue(i, j); - ctable.execute(v1, v2, w, false, resultMap); - } - } - else { - for( int i=0; i<rlen; i++ ) - for( int j=0; j<clen; j++ ) - { - double v1 = this.quickGetValue(i, j); - ctable.execute(v1, v2, w, false, resultBlock); - } + for( int i=0; i<rlen; i++ ) + for( int j=0; j<clen; j++ ) { + double v1 = this.quickGetValue(i, j); + ctable.execute(v1, v2, w, false, resultMap, resultBlock); + } + + //maintain nnz (if necessary) + if( resultBlock!=null ) resultBlock.recomputeNonZeros(); - } } /** @@ -5286,29 +5269,18 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab //sparse-unsafe ctable execution //(because input values of 0 are invalid and have to result in errors) - if( resultBlock == null) { - for( int i=0; i<rlen; i++ ) - for( int j=0; j<clen; j++ ) - { - double v1 = this.quickGetValue(i, j); - if( left ) - ctable.execute(offset+i+1, v1, w, false, resultMap); - else - ctable.execute(v1, offset+i+1, w, false, resultMap); - } - } - else { - for( int i=0; i<rlen; i++ ) - for( int j=0; j<clen; j++ ) - { - double v1 = this.quickGetValue(i, j); - if( left ) - ctable.execute(offset+i+1, v1, w, false, resultBlock); - else - ctable.execute(v1, offset+i+1, w, false, resultBlock); - } + for( int i=0; i<rlen; i++ ) + for( int j=0; j<clen; j++ ) { + double v1 = this.quickGetValue(i, j); + if( left ) + ctable.execute(offset+i+1, v1, w, false, resultMap, resultBlock); + else + ctable.execute(v1, offset+i+1, w, false, resultMap, resultBlock); + } + + //maintain nnz (if necessary) + if( resultBlock!=null ) resultBlock.recomputeNonZeros(); - } } /** @@ -5344,40 +5316,27 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab SparseBlock a = this.sparseBlock; SparseBlock b = that.sparseBlock; - for( int i=0; i<rlen; i++ ) - { - if( !a.isEmpty(i) ) - { - int alen = a.size(i); - int apos = a.pos(i); - double[] avals = a.values(i); - int bpos = b.pos(i); - double[] bvals = b.values(i); - - if( resultBlock == null ) { - for( int j=0; j<alen; j++ ) - ctable.execute(avals[apos+j], bvals[bpos+j], w, ignoreZeros, resultMap); - } - else { - for( int j=0; j<alen; j++ ) - ctable.execute(avals[apos+j], bvals[bpos+j], w, ignoreZeros, resultBlock); - } - } - } + for( int i=0; i<rlen; i++ ) { + if( a.isEmpty(i) ) continue; + int alen = a.size(i); + int apos = a.pos(i); + double[] avals = a.values(i); + int bpos = b.pos(i); + double[] bvals = b.values(i); + for( int j=0; j<alen; j++ ) + ctable.execute(avals[apos+j], bvals[bpos+j], + w, ignoreZeros, resultMap, resultBlock); + } } else //SPARSE-UNSAFE | GENERIC INPUTS { //sparse-unsafe ctable execution //(because input values of 0 are invalid and have to result in errors) for( int i=0; i<rlen; i++ ) - for( int j=0; j<clen; j++ ) - { + for( int j=0; j<clen; j++ ) { double v1 = this.quickGetValue(i, j); double v2 = that.quickGetValue(i, j); - if( resultBlock == null ) - ctable.execute(v1, v2, w, ignoreZeros, resultMap); - else - ctable.execute(v1, v2, w, ignoreZeros, resultBlock); + ctable.execute(v1, v2, w, ignoreZeros, resultMap, resultBlock); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/223066ee/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java index 64d4c2a..04a00c9 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java @@ -1359,10 +1359,10 @@ public class UaggOuterChainTest extends AutomatedTestBase //check statistics for right operator in cp and spark if( instType == ExecType.CP ) { - Assert.assertTrue("Missing opcode sp_uaggouerchain", Statistics.getCPHeavyHitterOpCodes().contains(UAggOuterChain.OPCODE)); + Assert.assertTrue("Missing opcode sp_uaggouterchain", Statistics.getCPHeavyHitterOpCodes().contains(UAggOuterChain.OPCODE)); } else if( instType == ExecType.SPARK ) { - Assert.assertTrue("Missing opcode sp_uaggouerchain", + Assert.assertTrue("Missing opcode sp_uaggouterchain", Statistics.getCPHeavyHitterOpCodes().contains(Instruction.SP_INST_PREFIX+UAggOuterChain.OPCODE)); } }
