Repository: systemml
Updated Branches:
  refs/heads/master 744df8139 -> 2ca62e34b


[SYSTEMML-2067] Refactoring of im2col operations into functional API

This patch makes a minor refactoring (without changing the behavior) of
im2col operations from a stateful worker design to a stateless
functional API. This is a preparation step for the support of conv2d
operations in codegen, where im2col will be called directly from the
generated operators.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/2ca62e34
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/2ca62e34
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/2ca62e34

Branch: refs/heads/master
Commit: 2ca62e34b3f6996d4d51bd60d3baae06e0c65332
Parents: 744df81
Author: Matthias Boehm <[email protected]>
Authored: Sat Feb 10 20:32:24 2018 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Sat Feb 10 20:32:24 2018 -0800

----------------------------------------------------------------------
 .../runtime/matrix/data/LibMatrixDNNConv2d.java |  30 +--
 .../runtime/matrix/data/LibMatrixDNNIm2Col.java | 250 +++++++------------
 2 files changed, 110 insertions(+), 170 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca62e34/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2d.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2d.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2d.java
index 22a5e1a..436735e 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2d.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2d.java
@@ -25,7 +25,6 @@ import java.util.concurrent.Callable;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.matrix.data.LibMatrixDNNIm2Col.Im2colWorker;
 import 
org.apache.sysml.runtime.matrix.data.LibMatrixDNNRotate180.Rotate180Worker;
 import org.apache.sysml.utils.NativeHelper;
 import org.apache.sysml.utils.Statistics;
@@ -171,14 +170,13 @@ public class LibMatrixDNNConv2d
                @Override
                public Long call() throws Exception {
                        final int PQ = _params.P*_params.Q, K = _params.K, CRS 
= _params.C*_params.R*_params.S;
-                       MatrixBlock outIm2col = new MatrixBlock(CRS, PQ, 
_params.input1.sparse);
+                       MatrixBlock outIm2col = new MatrixBlock(CRS, PQ, 
_params.input1.sparse).allocateBlock();
+                       
LibMatrixDNNIm2Col.preallocateSparseOutput(_params.input1, outIm2col);
                        MatrixBlock outMM = new MatrixBlock(K, PQ, 
_params.output.sparse);
-                       Im2colWorker im2ColWorker = Im2colWorker.getWorker( 
_params.input1, outIm2col, _params, false);
                        long time1 = 0; long time2 = 0;
                        for(int n = _rl; n < _ru; n++)  {
-                               // im2col(input) => _im2ColOutBlock
                                long t1 = DMLScript.FINEGRAINED_STATISTICS ? 
System.nanoTime() : 0;
-                               im2ColWorker.execute(n);
+                               LibMatrixDNNIm2Col.im2col(_params.input1, 
outIm2col, n, _params, false);
                                long t2 = DMLScript.FINEGRAINED_STATISTICS ? 
System.nanoTime() : 0;
                                
                                // filter %*% _im2ColOutBlock => matMultOutBlock
@@ -266,13 +264,12 @@ public class LibMatrixDNNConv2d
                @Override
                public Long call() throws Exception {
                        final int PQ = _params.P*_params.Q, K = _params.K, CRS 
= _params.C*_params.R*_params.S;
-                       MatrixBlock outIm2col = new MatrixBlock(PQ, CRS, false);
+                       MatrixBlock outIm2col = new MatrixBlock(PQ, CRS, 
_params.input1.sparse).allocateBlock();
+                       
LibMatrixDNNIm2Col.preallocateSparseOutput(_params.input1, outIm2col);
                        MatrixBlock outMM = new MatrixBlock(PQ, K, false);
-                       Im2colWorker im2ColWorker = Im2colWorker.getWorker( 
_params.input1, outIm2col, _params, true);
                        
                        for(int n = _rl; n < _ru; n++)  {
-                               // im2col(input) => _im2ColOutBlock
-                               im2ColWorker.execute(n);
+                               LibMatrixDNNIm2Col.im2col(_params.input1, 
outIm2col, n, _params, true);
                                
                                // t(_im2ColOutBlock) %*% t(filter) => 
t(matMultOutBlock)
                                outMM.reset(outMM.rlen, outMM.clen, false);
@@ -427,7 +424,7 @@ public class LibMatrixDNNConv2d
                                
LibMatrixDNNHelper.singleThreadedMatMult(outRotate, filter, outMM, 
!outRotate.sparse, false, _params);
                                long t2 = DMLScript.FINEGRAINED_STATISTICS ? 
System.nanoTime() : 0;
                                // col2im(temp) => output[n,] 
-                               LibMatrixDNNIm2Col.doCol2imOverSingleImage(n, 
outMM, _params);
+                               LibMatrixDNNIm2Col.col2imOverSingleImage(n, 
outMM, _params);
                                long t3 = DMLScript.FINEGRAINED_STATISTICS ? 
System.nanoTime() : 0;
                                
                                if(DMLScript.FINEGRAINED_STATISTICS) {
@@ -504,12 +501,12 @@ public class LibMatrixDNNConv2d
                public Long call() throws Exception {
                        int PQ = _params.P*_params.Q, K = _params.K, CRS = 
_params.C*_params.R*_params.S;
                        MatrixBlock dout = _params.input2;
-                       MatrixBlock im2ColOutBlock = new MatrixBlock(CRS, PQ, 
false);
+                       MatrixBlock im2ColOutBlock = new MatrixBlock(CRS, PQ, 
_params.input1.sparse).allocateBlock();
+                       
LibMatrixDNNIm2Col.preallocateSparseOutput(_params.input1, im2ColOutBlock);
                        MatrixBlock outRotate = new MatrixBlock(PQ, K, 
dout.sparse);
                        MatrixBlock outMM = new MatrixBlock(CRS, K, false);
                        outRotate.allocateBlock();
                        
-                       Im2colWorker im2ColWorker = Im2colWorker.getWorker( 
_params.input1, im2ColOutBlock, _params, false);
                        Rotate180Worker rotate180Worker = 
Rotate180Worker.getWorker( dout, outRotate, _params, true, false);
                        double [] partRet = new double[CRS*_params.K];
                        long time1 = 0; long time2 = 0;
@@ -519,7 +516,7 @@ public class LibMatrixDNNConv2d
                                
                                // im2col(input) => _im2ColOutBlock
                                long t1 = DMLScript.FINEGRAINED_STATISTICS ? 
System.nanoTime() : 0;
-                               im2ColWorker.execute(n);
+                               LibMatrixDNNIm2Col.im2col(_params.input1, 
im2ColOutBlock, n, _params, false);
                                long t2 = DMLScript.FINEGRAINED_STATISTICS ? 
System.nanoTime() : 0;
                                
                                outMM.reset(CRS, K, false);
@@ -556,11 +553,10 @@ public class LibMatrixDNNConv2d
                public Long call() throws Exception {
                        int PQ = _params.P*_params.Q, K = _params.K, CRS = 
_params.C*_params.R*_params.S;
                        MatrixBlock dout = _params.input2;
-                       MatrixBlock im2ColOutBlock = new MatrixBlock(PQ, CRS, 
false).allocateBlock();
+                       MatrixBlock im2ColOutBlock = new MatrixBlock(PQ, CRS, 
_params.input1.sparse).allocateBlock();
+                       
LibMatrixDNNIm2Col.preallocateSparseOutput(_params.input1, im2ColOutBlock);
                        MatrixBlock outRotate = new MatrixBlock(K, PQ, 
dout.sparse).allocateBlock();
                        MatrixBlock outMM = new MatrixBlock(K, CRS, 
false).allocateBlock();
-                       
-                       Im2colWorker im2ColWorker = Im2colWorker.getWorker( 
_params.input1, im2ColOutBlock, _params, true);
                        Rotate180Worker rotate180Worker = 
Rotate180Worker.getWorker( dout, outRotate, _params, true, true);
                        double [] partRet = new double[CRS*_params.K];
                        long time1 = 0; long time2 = 0;
@@ -570,7 +566,7 @@ public class LibMatrixDNNConv2d
                                
                                // im2col(input) => _im2ColOutBlock
                                long t1 = DMLScript.FINEGRAINED_STATISTICS ? 
System.nanoTime() : 0;
-                               im2ColWorker.execute(n);
+                               LibMatrixDNNIm2Col.im2col(_params.input1, 
im2ColOutBlock, n, _params, true);
                                long t2 = DMLScript.FINEGRAINED_STATISTICS ? 
System.nanoTime() : 0;
                                
                                outMM.reset(K, CRS, false);

http://git-wip-us.apache.org/repos/asf/systemml/blob/2ca62e34/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java
index d1f4b30..5c72c4e 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java
@@ -20,178 +20,113 @@ package org.apache.sysml.runtime.matrix.data;
 
 import java.util.Arrays;
 
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.matrix.data.LibMatrixDNNHelper.CellIndex3;
 
 /**
  * This class contains the different implementation of im2col operation
  */
-public class LibMatrixDNNIm2Col {
-       private static final Log LOG = 
LogFactory.getLog(LibMatrixDNNIm2Col.class.getName());
-       static interface Im2colWorker {
-               public void execute(int n);
-               public static Im2colWorker getWorker(MatrixBlock input, 
MatrixBlock out, ConvolutionParameters params, boolean trans) {
-                       if(!input.isInSparseFormat()) {
-                               boolean stride1Pad0 = params.stride_h == 1 && 
params.stride_w == 1 
-                                               && params.pad_h == 0 && 
params.pad_w == 0;
-                               // Note: Only dense im2col operators require 
the im2ColOutBlock to be allocated in the dense format.
-                               out.reset(out.rlen, out.clen, false);
-                               out.allocateDenseBlock();
-                               if( LOG.isTraceEnabled() ) 
-                                       LOG.trace("Using 
DenseIm2colWorkerAllChannels operator to perform im2col 
(stride1pad0="+stride1Pad0+").");
-                               if( stride1Pad0 && !trans )
-                                       return new 
DenseIm2colWorkerStride1Pad0AllChannels(input.getDenseBlockValues(), 
out.getDenseBlockValues(), params);
-                               else
-                                       return new 
DenseIm2colWorkerAllChannels(input.getDenseBlockValues(), 
out.getDenseBlockValues(), params, trans);
-                       }
-                       else {
-                               if(LOG.isTraceEnabled()) 
-                                       LOG.trace("Using SparseIm2colWorker 
operator to perform im2col.");
-                               out.reset(out.rlen, out.clen, true);
-                               out.allocateSparseRowsBlock();
-                               //preallocate sparse-rows (larger than average 
sparsity to account for skew)
-                               int estnnz = 
(int)Math.ceil(4*input.getSparsity()*out.clen);
-                               for(int r = 0; r < out.rlen; r++)
-                                       out.getSparseBlock().allocate(r, 
Math.max(Math.min(estnnz, out.clen),16));
-                               return new 
SparseSparseIm2colWorkerAllChan(input, out, params, trans);
-                       }
-               }
+public class LibMatrixDNNIm2Col 
+{
+       public static void im2col(MatrixBlock in, MatrixBlock out, int r, 
ConvolutionParameters params, boolean trans) {
+               im2col(in, out, r, params.C, params.R, params.S, params.H, 
params.W, params.P, params.Q,
+                       params.stride_h, params.stride_w, params.pad_h, 
params.pad_w, trans);
        }
        
-       /**
-        * Special case operator for performing dense im2col when stride = [1, 
1] and pad = [0, 0] by using System.arraycopy
-        */
-       private static class DenseIm2colWorkerStride1Pad0AllChannels implements 
Im2colWorker {
-               private final double [] inputArray, outputArray; 
-               private final int CRS, S, R, P, Q, CHW, H, W;
-               public DenseIm2colWorkerStride1Pad0AllChannels(double [] 
inputArray, double [] outputArray, ConvolutionParameters params) {
-                       this.inputArray = inputArray;
-                       this.outputArray = outputArray;
-                       this.CRS = params.C * params.R * params.S;
-                       this.H = params.H; this.W = params.W; this.R = 
params.R; this.S = params.S; this.P = params.P; this.Q = params.Q;
-                       this.CHW = params.C*params.H*params.W;
-               }
+       public static void im2col(MatrixBlock in, MatrixBlock out, int r, int 
C, int R, int S, int H, int W, int P, int Q,
+                       int stride_h, int stride_w, int pad_h, int pad_w, 
boolean trans) {
+               boolean stride1Pad0 = stride_h == 1 
+                       && stride_w == 1 && pad_h == 0 && pad_w == 0;
                
-               @Override
-               public void execute(int n) {
-                       int nOffset = n * CHW;
-                       for (int c = 0; c < CRS; ++c) {
-                               int wOffset = c % S;
-                               int hOffset = (c / S) % R;
-                               int cInput = c / R / S;
-                               for (int h = 0; h < P; ++h) {
-                                       int hPadded = h + hOffset;
-                                       int outOffset = (c * P + h) * Q;
-                                       int inputOffset = nOffset + (cInput * H 
+ hPadded) * W;
-                                       System.arraycopy(inputArray, 
inputOffset + wOffset, outputArray, outOffset, Q);
-                                       int w = Q - 1;
-                                       int wPadded = w + wOffset;
-                                       boolean assign = (hPadded < H && 
wPadded < W);
-                                       outputArray[outOffset + w] = assign ? 
inputArray[inputOffset + wPadded] : 0;
-                               }
+               //dense and sparse operation dispatch
+               if( !in.sparse && stride1Pad0 && !trans )
+                       im2colDenseStride1Pad0(in.getDenseBlockValues(),
+                               out.getDenseBlockValues(), r, C, R, S, H, W, P, 
Q);
+               else if( !in.sparse )
+                       im2colDense(in.getDenseBlockValues(), 
out.getDenseBlockValues(),
+                               r, C, R, S, H, W, P, Q, stride_h, stride_w, 
pad_h, pad_w, trans);
+               else
+                       im2colSparse(in, out, r, C, R, S, H, W, P, Q,
+                               stride_h, stride_w, pad_h, pad_w, trans);
+       }
+       
+       public static void im2colDenseStride1Pad0(double[] in, double[] out, 
int r, int C, int R, int S, int H, int W, int P, int Q) {
+               int nOffset = r * C * H * W;
+               int CRS = C * R * S;
+               for (int c = 0; c < CRS; ++c) {
+                       int wOffset = c % S;
+                       int hOffset = (c / S) % R;
+                       int cInput = c / R / S;
+                       for (int h = 0; h < P; ++h) {
+                               int hPadded = h + hOffset;
+                               int outOffset = (c * P + h) * Q;
+                               int inputOffset = nOffset + (cInput * H + 
hPadded) * W;
+                               System.arraycopy(in, inputOffset + wOffset, 
out, outOffset, Q);
+                               int w = Q - 1;
+                               int wPadded = w + wOffset;
+                               boolean assign = (hPadded < H && wPadded < W);
+                               out[outOffset + w] = assign ? in[inputOffset + 
wPadded] : 0;
                        }
                }
        }
        
-       /**
-        * Performing dense im2col (general case)
-        */
-       private static class DenseIm2colWorkerAllChannels implements 
Im2colWorker {
-               private final double[] inputArray, outputArray; 
-               private final int CRS, S, R, P, Q, CHW, H, W; 
-               private final int stride_h, stride_w, pad_h, pad_w;
-               private final boolean trans;
-               public DenseIm2colWorkerAllChannels(double [] inputArray, 
double [] outputArray, ConvolutionParameters params, boolean trans) {
-                       this.inputArray = inputArray;
-                       this.outputArray = outputArray;
-                       this.CRS = params.C * params.R * params.S;
-                       this.H = params.H; this.W = params.W; this.R = 
params.R; this.S = params.S; this.P = params.P; this.Q = params.Q;
-                       this.CHW = params.C*params.H*params.W;
-                       this.stride_h = params.stride_h; this.stride_w = 
params.stride_w;
-                       this.pad_h = params.pad_h; this.pad_w = params.pad_w;
-                       this.trans = trans;
-               }
-               
-               @Override
-               public void execute(int n) {
-                       //reset for selective copy
-                       Arrays.fill(outputArray, 0);
-                       
-                       int nOffset = n * CHW;
-                       for (int c = 0; c < CRS; ++c) {
-                               int wOffset = c % S;
-                               int hOffset = (c / S) % R;
-                               int cInput = c / R / S;
-                               for (int h = 0; h < P; ++h) {
-                                       int outOffset = trans ? c+(h*Q*CRS) : 
(c*P+h)*Q;
-                                       int hPadded = h * stride_h - pad_h + 
hOffset;
-                                       int inputOffset = nOffset + (cInput * H 
+ hPadded) * W;
-                                       if (hPadded < 0 || hPadded >= H ) 
continue;
-                                       for (int w = 0; w < Q; ++w) {
-                                               int wPadded = w * stride_w - 
pad_w + wOffset;
-                                               if( wPadded >= 0 && wPadded < W 
)
-                                                       outputArray[outOffset + 
(trans?w*CRS:w)] 
-                                                               = 
inputArray[inputOffset + wPadded];
-                                       }
+       public static void im2colDense(double[] in, double[] out, int r, int C, 
int R, int S, int H, int W, int P, int Q,
+                       int stride_h, int stride_w, int pad_h, int pad_w, 
boolean trans) {
+               Arrays.fill(out, 0); //reset for selective copy
+               int CHW = C * H * W;
+               int CRS = C * R * S;
+               int nOffset = r * CHW;
+               for (int c = 0; c < CRS; ++c) {
+                       int wOffset = c % S;
+                       int hOffset = (c / S) % R;
+                       int cInput = c / R / S;
+                       for (int h = 0; h < P; ++h) {
+                               int outOffset = trans ? c+(h*Q*CRS) : (c*P+h)*Q;
+                               int hPadded = h * stride_h - pad_h + hOffset;
+                               int inputOffset = nOffset + (cInput * H + 
hPadded) * W;
+                               if (hPadded < 0 || hPadded >= H ) continue;
+                               for (int w = 0; w < Q; ++w) {
+                                       int wPadded = w * stride_w - pad_w + 
wOffset;
+                                       if( wPadded >= 0 && wPadded < W )
+                                               out[outOffset + 
(trans?w*CRS:w)] 
+                                                       = in[inputOffset + 
wPadded];
                                }
                        }
                }
        }
        
-       /**
-        * Performing sparse im2col for all channels for a given row n of the 
input matrix.
-        */
-       private static class SparseSparseIm2colWorkerAllChan implements 
Im2colWorker {
-               private final MatrixBlock input, output;
-               private final int S, R, P, Q, H, W, RS;
-               private final int stride_h, stride_w, pad_h, pad_w;
-               private final boolean trans;
-               private final boolean simple;
-               public SparseSparseIm2colWorkerAllChan(MatrixBlock input, 
MatrixBlock im2ColOutBlock, ConvolutionParameters params, boolean trans) {
-                       this.input = input;
-                       this.output = im2ColOutBlock;
-                       this.RS = params.R * params.S;
-                       this.H = params.H; this.W = params.W; this.R = 
params.R; this.S = params.S; this.P = params.P; this.Q = params.Q;
-                       this.stride_h = params.stride_h; this.stride_w = 
params.stride_w;
-                       this.pad_h = params.pad_h; this.pad_w = params.pad_w;
-                       this.trans = trans;
-                       this.simple = params.isStride1Pad0() && W == S && Q == 
1;
-                       if(!input.isInSparseFormat()) 
-                               throw new RuntimeException("Incorrect operator 
selection. Expected dense input for SparseIm2colWorkerAllChannels");
-               }
+       public static void im2colSparse(MatrixBlock in, MatrixBlock out, int r, 
int C, int R, int S, int H, int W, int P, int Q,
+                       int stride_h, int stride_w, int pad_h, int pad_w, 
boolean trans) {
+               out.reset();
+               SparseBlock sblock = in.sparseBlock;
+               if( sblock.isEmpty(r) )
+                       return;
+               int apos = sblock.pos(r);
+               int alen = sblock.size(r);
+               int[] aix = sblock.indexes(r);
+               double[] avals = sblock.values(r);
+               boolean simple = (stride_h==1 && stride_w==1
+                       && pad_h==0 && pad_w==0 && W == S && Q == 1);
+               int RS = R * S;
                
-               @Override
-               public void execute(int n) {
-                       output.reset();
-                       SparseBlock sblock = input.sparseBlock;
-                       if( sblock.isEmpty(n) )
-                               return;
-                       int apos = sblock.pos(n);
-                       int alen = sblock.size(n);
-                       int[] aix = sblock.indexes(n);
-                       double[] avals = sblock.values(n);
+               // Iterate over the sparse block
+               CellIndex3 ix = new CellIndex3();
+               for(int j=apos; j<apos+alen; j++) {
+                       // Note: the input is of shape [N, CHW]
+                       int chw = aix[j];
                        
-                       // Iterate over the sparse block
-                       CellIndex3 ix = new CellIndex3();
-                       for(int j=apos; j<apos+alen; j++) {
-                               // Note: the input is of shape [N, CHW]
-                               int chw = aix[j];
-                               
-                               // Get individual zero-based c,h,w indexes from 
zero-based 'chw'
-                               ix = 
LibMatrixDNNHelper.computeTensorIndexes(chw, H, W, ix);
-                               
-                               if( simple )
-                                       
appendInputValueToIm2colOutputSimple(output, ix.ix1, ix.ix2, ix.ix3, 
-                                               avals[j], R, S, RS, P, trans);
-                               else
-                                       appendInputValueToIm2colOutput(output, 
ix.ix1, ix.ix2, ix.ix3, avals[j], 
-                                               R, S, RS, P, Q, stride_h, 
stride_w, pad_h, pad_w, trans);
-                       }
+                       // Get individual zero-based c,h,w indexes from 
zero-based 'chw'
+                       ix = LibMatrixDNNHelper.computeTensorIndexes(chw, H, W, 
ix);
                        
-                       output.sortSparseRows();
+                       if( simple )
+                               appendInputValueToIm2colOutputSimple(out, 
ix.ix1, ix.ix2, ix.ix3, 
+                                       avals[j], R, S, RS, P, trans);
+                       else
+                               appendInputValueToIm2colOutput(out, ix.ix1, 
ix.ix2, ix.ix3, avals[j], 
+                                       R, S, RS, P, Q, stride_h, stride_w, 
pad_h, pad_w, trans);
                }
+               
+               out.sortSparseRows();
        }
        
        /**
@@ -258,7 +193,7 @@ public class LibMatrixDNNIm2Col {
        // Therefore, it is provided as utility function rather than an 
operator (like im2col or rotate180)
        
        //Converts input: PQ X CRS matrix and writes to 1 X CHW
-       static void doCol2imOverSingleImage(int outputN, MatrixBlock input, 
ConvolutionParameters params) throws DMLRuntimeException {
+       public static void col2imOverSingleImage(int outputN, MatrixBlock 
input, ConvolutionParameters params) throws DMLRuntimeException {
                if(input.rlen != params.P*params.Q || input.clen != 
params.C*params.R*params.S) {
                        throw new DMLRuntimeException("Incorrect input 
dimensions");
                }
@@ -272,7 +207,7 @@ public class LibMatrixDNNIm2Col {
                
                if(!input.isInSparseFormat()) {
                        double [] inputArray = input.getDenseBlockValues();
-                       doCol2IMDenseInput(0, outputN, inputArray, outputArray, 
params);
+                       col2IMDenseInput(0, outputN, inputArray, outputArray, 
params);
                }
                else {
                        if(!input.isEmptyBlock()) {
@@ -307,7 +242,7 @@ public class LibMatrixDNNIm2Col {
        
        // Converts input: PQ X CRS matrix and writes to 1 X CHW if inputN == 0
        // Or converts input: NPQ X CRS matrix and writes to N X CHW 
-       private static void doCol2IMDenseInput(int inputN, int outputN, double 
[] inputArray, double [] outputArray, ConvolutionParameters params) throws 
DMLRuntimeException {
+       private static void col2IMDenseInput(int inputN, int outputN, double [] 
inputArray, double [] outputArray, ConvolutionParameters params) throws 
DMLRuntimeException {
                final int outputNOffset = outputN*params.C*params.H*params.W;
                final int HW = params.H*params.W;
                final int inputNPQ = inputN*params.P*params.Q;
@@ -342,4 +277,13 @@ public class LibMatrixDNNIm2Col {
                        }
                }
        }
+       
+       public static void preallocateSparseOutput(MatrixBlock in, MatrixBlock 
out) {
+               if( !in.sparse )
+                       return;
+               //preallocate sparse-rows (larger than average sparsity to 
account for skew)
+               int estnnz = (int)Math.ceil(4*in.getSparsity()*out.clen);
+               for(int r = 0; r < out.rlen; r++)
+                       out.getSparseBlock().allocate(r, 
Math.max(Math.min(estnnz, out.clen),16));
+       }
 }

Reply via email to