Repository: systemml
Updated Branches:
  refs/heads/master 240297bd5 -> a3ab19768


[SYSTEMML-2080] Fix correctness sparse relu-backward CP operations

This patch fixes an correctness issues of relu backward over sparse
inputs that has been introduced by a recent refactoring and cleanup. In
detail, the binary inplace multiply, was mistakenly replaced by a binary
inplace plus. 


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a3ab1976
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a3ab1976
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a3ab1976

Branch: refs/heads/master
Commit: a3ab197686b96873b0a7686710e083906903aa69
Parents: 240297b
Author: Matthias Boehm <[email protected]>
Authored: Wed Jan 24 22:01:28 2018 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Wed Jan 24 22:14:34 2018 -0800

----------------------------------------------------------------------
 .../runtime/matrix/data/LibMatrixDNNRelu.java   | 39 +++++++++++++-------
 1 file changed, 25 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/a3ab1976/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java
index f1c3ecb..c8f2f41 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.java
@@ -79,7 +79,7 @@ public class LibMatrixDNNRelu
                        }
                        else {
                                scalarOperations(_params.input1, out, n, _rl, 
_ru, GT0);      // (X > 0)
-                               binaryOperationInPlacePlus(_params.input2, out, 
n, _rl, _ru); // (X > 0) * dout
+                               binaryOperationInPlaceMult(_params.input2, out, 
n, _rl, _ru); // (X > 0) * dout
                        }
                        return 0L;
                }
@@ -113,24 +113,35 @@ public class LibMatrixDNNRelu
                }
        }
        
-       private static void binaryOperationInPlacePlus(MatrixBlock src,
+       private static void binaryOperationInPlaceMult(MatrixBlock src,
                DenseBlock c, int destNumCols, int src_rl, int src_ru)
                throws DMLRuntimeException 
        {
-               if( src.isEmptyBlock(false) )
-                       return; //do nothing (add 0);
+               if( src.isEmptyBlock(false) ) {
+                       c.set(src_rl, src_rl, 0, destNumCols, 0);
+                       return;
+               }
                
                if(src.isInSparseFormat()) {
                        for(int i = src_rl; i < src_ru; i++) {
-                               if( src.getSparseBlock().isEmpty(i) ) continue;
-                               int apos = src.getSparseBlock().pos(i);
-                               int alen = src.getSparseBlock().size(i);
-                               int[] aix = src.getSparseBlock().indexes(i);
-                               double[] avals = src.getSparseBlock().values(i);
-                               double[] cvals = c.values(i);
-                               int cix = c.pos(i);
-                               for(int j = apos; j < apos+alen; j++)
-                                       cvals[ cix+aix[j] ] += avals[j];
+                               if( !src.getSparseBlock().isEmpty(i) ) {
+                                       int apos = src.getSparseBlock().pos(i);
+                                       int alen = src.getSparseBlock().size(i);
+                                       int[] aix = 
src.getSparseBlock().indexes(i);
+                                       double[] avals = 
src.getSparseBlock().values(i);
+                                       double[] cvals = c.values(i);
+                                       int cix = c.pos(i);
+                                       int prevDestIndex = 0;
+                                       for(int j = apos; j < apos+alen; j++) {
+                                               c.set(i, i+1, prevDestIndex, 
aix[j], 0);
+                                               prevDestIndex = aix[j]+1;
+                                               cvals[ cix+aix[j] ] *= avals[j];
+                                       }
+                                       c.set(i, i+1, prevDestIndex, 
destNumCols, 0);
+                               }
+                               else {
+                                       c.set(i, i+1, 0, destNumCols, 0);
+                               }
                        }
                }
                else { //DENSE
@@ -139,7 +150,7 @@ public class LibMatrixDNNRelu
                                double[] avals = a.values(i), cvals = 
c.values(i);
                                int aix = a.pos(i), cix = c.pos(i);
                                for(int j=0; j<destNumCols; j++)
-                                       cvals[cix+j] += avals[aix+j];
+                                       cvals[cix+j] *= avals[aix+j];
                        }
                }
        }

Reply via email to