Repository: systemml
Updated Branches:
  refs/heads/master 192fb5582 -> 45eec2d25


[MINOR] Cleanup tensor index computation in convolution operations

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c9977b73
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c9977b73
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c9977b73

Branch: refs/heads/master
Commit: c9977b736b38e8c2b9ec34645f062337e1273c94
Parents: 192fb55
Author: Matthias Boehm <[email protected]>
Authored: Tue Jan 9 20:40:31 2018 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Tue Jan 9 20:40:31 2018 -0800

----------------------------------------------------------------------
 .../runtime/matrix/data/LibMatrixDNNHelper.java | 131 ++++++++++---------
 .../data/LibMatrixDNNPoolingBackwardHelper.java |  32 ++---
 .../data/LibMatrixDNNRotate180Helper.java       |  13 +-
 3 files changed, 92 insertions(+), 84 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/c9977b73/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
index a5564cd..e99c2ef 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
@@ -38,6 +38,16 @@ import org.apache.sysml.utils.Statistics;
 
 public class LibMatrixDNNHelper {
        
+       protected static class CellIndex3 {
+               public int ix1;
+               public int ix2;
+               public int ix3;
+               @Override
+               public String toString() {
+                       return "("+ix1+", "+ix2+", "+ix3+")";
+               }
+       }
+       
        // *********************************** low-level runtime operator 
selection ***********************************************
        // *********************************** based on runtime properties 
(sparsity, native, etc) ********************************
        // These methods help reduce branch miss predictions and 
instruction-cache misses.
@@ -275,18 +285,26 @@ public class LibMatrixDNNHelper {
        
        // *********************************** utility methods 
******************************************************
        
+       protected static CellIndex3 computeTensorIndexes(int j, int H, int W) {
+               return computeTensorIndexes(j, H, W, new CellIndex3());
+       }
+       
        /**
-        * Computes tensor indexes from column index such that column index  is 
equal to ret[0]*HW + ret[1]*W + ret[2]
+        * Computes tensor indexes from a linearized column index such that
+        * the column index is equal to ix1*NM + ix2*M + ix3
         * 
         * @param j column index
-        * @param ret tensor indexes
-        * @param H second last dimension
-        * @param W last dimension
+        * @param N second last dimension
+        * @param M last dimension
+        * @param ret output object for reuse
+        * @return tensor indexes
         */
-       static void computeTensorIndexes(int j, int [] ret, int H, int W) {
-               ret[0] = j / (H*W);
-               ret[1] = (j - ret[0]*(H*W))/W;
-               ret[2] = j % W;
+       protected static CellIndex3 computeTensorIndexes(int j, int N, int M, 
CellIndex3 ret) {
+               int tmp = j / M;
+               ret.ix1 = tmp / N;
+               ret.ix2 = tmp % N;
+               ret.ix3 = j % M;
+               return ret;
        }
        
        //Split a filter of size [K, CRS] into c filters of [K, RS]
@@ -310,23 +328,21 @@ public class LibMatrixDNNHelper {
                                }
                        }
                        else {
+                               SparseBlock sblock = _params.input2.sparseBlock;
+                               CellIndex3 ix = new CellIndex3();
                                for(int k = 0; k < _params.K; k++) {
-                                       if( 
!_params.input2.sparseBlock.isEmpty(k) ) {
-                                               int [] tensorIndexes = new 
int[3];
-                                               // Find maxIndex
-                                               int apos = 
_params.input2.sparseBlock.pos(k);
-                                               int alen = 
_params.input2.sparseBlock.size(k);
-                                               int[] aix = 
_params.input2.sparseBlock.indexes(k);
-                                               double[] avals = 
_params.input2.sparseBlock.values(k);
-                                               for(int j=apos; j<apos+alen; 
j++) {
-                                                       
computeTensorIndexes(aix[j], tensorIndexes, _params.R, _params.S);
-                                                       if(c != 
tensorIndexes[0])
-                                                               continue;
-                                                       int r = 
tensorIndexes[1];
-                                                       int s = 
tensorIndexes[2];
-                                                       outputArr[k*RS + r*S + 
s] = avals[j];
-                                                       nnz += outputArr[k*RS + 
r*S + s] != 0 ? 1 : 0;
-                                               }
+                                       if( sblock.isEmpty(k) ) continue;
+                                       // Find maxIndex
+                                       int apos = sblock.pos(k);
+                                       int alen = sblock.size(k);
+                                       int[] aix = sblock.indexes(k);
+                                       double[] avals = sblock.values(k);
+                                       for(int j=apos; j<apos+alen; j++) {
+                                               ix = 
computeTensorIndexes(aix[j], _params.R, _params.S, ix);
+                                               if(c != ix.ix1)
+                                                       continue;
+                                               outputArr[k*RS + ix.ix2*S + 
ix.ix3] = avals[j];
+                                               nnz += outputArr[k*RS + 
ix.ix2*S + ix.ix3] != 0 ? 1 : 0;
                                        }
                                }
                        }
@@ -415,8 +431,6 @@ public class LibMatrixDNNHelper {
         * @throws DMLRuntimeException if error occurs
         */
        static int getMaxIndexSparse(int p, int q, int inputOffset, int n, int 
c, MatrixBlock input, ConvolutionParameters params, boolean 
performReluBackward) throws DMLRuntimeException {
-               int [] tensorIndexes = new int[3];
-               
                int start_h = params.start_indexes_h[p];
                int end_h = params.end_indexes_h[p];
                int start_w = params.start_indexes_w[q];
@@ -432,18 +446,19 @@ public class LibMatrixDNNHelper {
                // if start_index_h < 0 || start_index_w < 0 || end_index_h >= 
params.H || end_index_w >= params.W
 
                // input.isEmptyBlock() check is done by the caller
-               if( !input.sparseBlock.isEmpty(n) ) {
+               CellIndex3 ix = new CellIndex3();
+               SparseBlock sblock = input.sparseBlock;
+               if( !sblock.isEmpty(n) ) {
                        // Find maxIndex
-                       int apos = input.sparseBlock.pos(n);
-                       int alen = input.sparseBlock.size(n);
-                       int[] aix = input.sparseBlock.indexes(n);
-                       double[] avals = input.sparseBlock.values(n);
+                       int apos = sblock.pos(n);
+                       int alen = sblock.size(n);
+                       int[] aix = sblock.indexes(n);
+                       double[] avals = sblock.values(n);
                        for(int j=apos; j<apos+alen; j++) {
-                               computeTensorIndexes(aix[j], tensorIndexes, 
params.H, params.W);
-                               if(c != tensorIndexes[0])
-                                       continue;
-                               int h = tensorIndexes[1];
-                               int w = tensorIndexes[2];
+                               ix = computeTensorIndexes(aix[j], params.H, 
params.W, ix);
+                               if(c != ix.ix1) continue;
+                               int h = ix.ix2;
+                               int w = ix.ix3;
                                if(h >= start_h && h < end_h && w >= start_w && 
w < end_w) {
                                        double val = performReluBackward && 
avals[j] < 0 ? 0 : avals[j]; 
                                        if(maxVal < val) {
@@ -514,32 +529,26 @@ public class LibMatrixDNNHelper {
                        if(!input.isEmptyBlock()) {
                                int outOffset = 
outputN*params.C*params.H*params.W;
                                int HW = params.H*params.W;
-                               int [] tensorIndexes = new int[3];
+                               CellIndex3 ix = new CellIndex3();
+                               SparseBlock sblock = input.sparseBlock;
                                for(int i = 0; i < input.getNumRows(); i++) {
-                                       if( !input.sparseBlock.isEmpty(i) ) {
-                                               computeTensorIndexes(i, 
tensorIndexes, params.P, params.Q);
-                                               int p = tensorIndexes[1];
-                                               int q = tensorIndexes[2];
-                                               int tmpP = p*params.stride_h - 
params.pad_h;
-                                               int tmpQ = q*params.stride_w - 
params.pad_w;
-                                               if(tensorIndexes[0] != 0) 
-                                                       throw new 
DMLRuntimeException("Incorrect tensor indexes: " + tensorIndexes[0] + " != 0 <" 
+ p + " " + q + " " + tensorIndexes[0] + params.P + " " + params.Q + ">");
-                                               
-                                               int apos = 
input.sparseBlock.pos(i);
-                                               int alen = 
input.sparseBlock.size(i);
-                                               int[] aix = 
input.sparseBlock.indexes(i);
-                                               double[] avals = 
input.sparseBlock.values(i);
-                                               for(int j = apos; j < 
apos+alen; j++) {
-                                                       
computeTensorIndexes(aix[j], tensorIndexes, params.R, params.S);
-                                                       int c = 
tensorIndexes[0];
-                                                       int r = 
tensorIndexes[1];
-                                                       int s = 
tensorIndexes[2];
-                                                       int h = tmpP + r;
-                                                       int w = tmpQ + s;
-                                                       if(h >= 0 && h < 
params.H && w >= 0 && w < params.W) {
-                                                               int outIndex = 
outOffset + c*HW + h*params.W + w;
-                                                               
outputArray[outIndex] += avals[j];
-                                                       }
+                                       if( sblock.isEmpty(i) ) continue;
+                                       ix = computeTensorIndexes(i, params.P, 
params.Q, ix);
+                                       int tmpP = ix.ix2*params.stride_h - 
params.pad_h;
+                                       int tmpQ = ix.ix3*params.stride_w - 
params.pad_w;
+                                       if(ix.ix1 != 0) 
+                                               throw new 
DMLRuntimeException("Incorrect tensor indexes: "+ ix + ", " + params.P + " " + 
params.Q);
+                                       int apos = sblock.pos(i);
+                                       int alen = sblock.size(i);
+                                       int[] aix = sblock.indexes(i);
+                                       double[] avals = sblock.values(i);
+                                       for(int j = apos; j < apos+alen; j++) {
+                                               ix = 
computeTensorIndexes(aix[j], params.R, params.S, ix);
+                                               int h = tmpP + ix.ix2;
+                                               int w = tmpQ + ix.ix3;
+                                               if(h >= 0 && h < params.H && w 
>= 0 && w < params.W) {
+                                                       int outIndex = 
outOffset + ix.ix1*HW + h*params.W + w;
+                                                       outputArray[outIndex] 
+= avals[j];
                                                }
                                        }
                                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/c9977b73/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
index 6e5e978..3dfb545 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
@@ -21,6 +21,8 @@ package org.apache.sysml.runtime.matrix.data;
 import java.util.Arrays;
 import java.util.concurrent.Callable;
 
+import org.apache.sysml.runtime.matrix.data.LibMatrixDNNHelper.CellIndex3;
+
 /**
  * This class contains the set of operators used for performing pooling 
backward
  */
@@ -99,24 +101,22 @@ public class LibMatrixDNNPoolingBackwardHelper {
                
                @Override
                public Long call() throws Exception {
+                       CellIndex3 ix = new CellIndex3();
                        double[] out = output.getDenseBlockValues();
+                       SparseBlock sblock = dout.sparseBlock;
                        for(int n = _rl; n < _ru; n++)  {
-                               if( !dout.sparseBlock.isEmpty(n) ) {
-                                       int [] tensorIndexes = new int[3];
-                                       int apos = dout.sparseBlock.pos(n);
-                                       int alen = dout.sparseBlock.size(n);
-                                       int[] aix = dout.sparseBlock.indexes(n);
-                                       double[] avals = 
dout.sparseBlock.values(n);
-                                       for(int j = apos; j < apos+alen; j++) {
-                                               
LibMatrixDNNHelper.computeTensorIndexes(aix[j], tensorIndexes, P, Q);
-                                               int c = tensorIndexes[0];
-                                               int p = tensorIndexes[1];
-                                               int q = tensorIndexes[2];
-                                               final int inputOffset = n*CHW + 
c*HW;
-                                               int maxIndex = 
LibMatrixDNNHelper.getMaxIndex(p, q, inputOffset, inputArray, _params, 
performReluBackward);
-                                               if(maxIndex != -1)
-                                                       out[maxIndex] += 
avals[j];
-                                       }
+                               if( sblock.isEmpty(n) ) continue;
+                               int apos = sblock.pos(n);
+                               int alen = sblock.size(n);
+                               int[] aix = sblock.indexes(n);
+                               double[] avals = sblock.values(n);
+                               for(int j = apos; j < apos+alen; j++) {
+                                       ix = 
LibMatrixDNNHelper.computeTensorIndexes(aix[j], P, Q, ix);
+                                       final int inputOffset = n*CHW + 
ix.ix1*HW;
+                                       int maxIndex = 
LibMatrixDNNHelper.getMaxIndex(ix.ix2, ix.ix3,
+                                               inputOffset, inputArray, 
_params, performReluBackward);
+                                       if(maxIndex != -1)
+                                               out[maxIndex] += avals[j];
                                }
                        }
                        //thread-local nnz maintenance

http://git-wip-us.apache.org/repos/asf/systemml/blob/c9977b73/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java
index 1ca6ad1..74e2baa 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java
@@ -18,6 +18,8 @@
  */
 package org.apache.sysml.runtime.matrix.data;
 
+import org.apache.sysml.runtime.matrix.data.LibMatrixDNNHelper.CellIndex3;
+
 /**
  * This class contains the different implementation of rotate180 operation
  */
@@ -90,21 +92,18 @@ public class LibMatrixDNNRotate180Helper {
                        if( sblock==null || sblock.isEmpty(inputN) )
                                return;
                        
+                       CellIndex3 ix = new CellIndex3();
                        int outputOffset = outputN*params.P*params.Q;
-                       int [] tensorIndexes = new int[3];
                        int apos = sblock.pos(inputN);
                        int alen = sblock.size(inputN);
                        int[] aix = sblock.indexes(inputN);
                        double[] avals = sblock.values(inputN);
                        for(int j = apos; j < apos+alen; j++) {
-                               LibMatrixDNNHelper.computeTensorIndexes(aix[j], 
tensorIndexes, params.P, params.Q);
-                               int k = tensorIndexes[0];
-                               int p = tensorIndexes[1];
-                               int q = tensorIndexes[2];
+                               ix = 
LibMatrixDNNHelper.computeTensorIndexes(aix[j], params.P, params.Q, ix);
                                if( trans )
-                                       out.appendValue(k, outputOffset + 
p*params.Q + q, avals[j]);
+                                       out.appendValue(ix.ix1, outputOffset + 
ix.ix2*params.Q + ix.ix3, avals[j]);
                                else
-                                       out.appendValue(outputOffset + 
p*params.Q + q, k, avals[j]);
+                                       out.appendValue(outputOffset + 
ix.ix2*params.Q + ix.ix3, ix.ix1, avals[j]);
                        }
                }
        }

Reply via email to