Repository: systemml
Updated Branches:
  refs/heads/master 416ebc02a -> 6c468f98b


[SYSTEMML-2105] New single-precision native conv2d/conv2d_bias_add ops

This patch improves the performance of compute-intensive cnn scoring
applications by additional support for single-precision native conv2d
and conv2d_bias_add operations. Furthermore, this also includes
single-precision matrix multiplication (sgemm) and templatized
primitives for im2col, computeNNZ, and biasAdd.

Due to the conversion of double inputs and outputs, initial experiments
showed only minor improvements because the conversion costs were not
amortized. However, with a careful exchange via preallocated direct
float buffers (which avoid unnecessary copies on JNI invocations) as
well as parallelized input/output conversions, this patch significantly
improved performance from 445s to 283s for dense conv2d_bias_add
operations in an end-to-end scoring application.

Furthermore, this also includes minor fixes of the existing native code
for the compilation with openblas libraries.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/6c468f98
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/6c468f98
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/6c468f98

Branch: refs/heads/master
Commit: 6c468f98bfb0c00f21abcec9d49ab63590c393b1
Parents: 416ebc0
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Thu Feb 1 21:16:25 2018 -0800
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Thu Feb 1 21:16:41 2018 -0800

----------------------------------------------------------------------
 .../lib/libsystemml_openblas-Linux-x86_64.so    | Bin 27520 -> 36136 bytes
 src/main/cpp/libmatrixdnn.cpp                   | 163 +++++++++++++------
 src/main/cpp/libmatrixdnn.h                     |  15 +-
 src/main/cpp/libmatrixmult.cpp                  |  23 +--
 src/main/cpp/libmatrixmult.h                    |  10 +-
 src/main/cpp/systemml.cpp                       | 102 +++++++-----
 src/main/cpp/systemml.h                         |  13 +-
 .../runtime/matrix/data/LibMatrixNative.java    |  66 +++++++-
 .../sysml/runtime/util/UtilFunctions.java       |   7 +
 .../org/apache/sysml/utils/NativeHelper.java    |   7 +-
 10 files changed, 282 insertions(+), 124 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/6c468f98/src/main/cpp/lib/libsystemml_openblas-Linux-x86_64.so
----------------------------------------------------------------------
diff --git a/src/main/cpp/lib/libsystemml_openblas-Linux-x86_64.so 
b/src/main/cpp/lib/libsystemml_openblas-Linux-x86_64.so
index dfd1ecb..c04e735 100755
Binary files a/src/main/cpp/lib/libsystemml_openblas-Linux-x86_64.so and 
b/src/main/cpp/lib/libsystemml_openblas-Linux-x86_64.so differ

http://git-wip-us.apache.org/repos/asf/systemml/blob/6c468f98/src/main/cpp/libmatrixdnn.cpp
----------------------------------------------------------------------
diff --git a/src/main/cpp/libmatrixdnn.cpp b/src/main/cpp/libmatrixdnn.cpp
index d6a09b7..f90857c 100644
--- a/src/main/cpp/libmatrixdnn.cpp
+++ b/src/main/cpp/libmatrixdnn.cpp
@@ -31,7 +31,7 @@
   #include "omp.h"
 #endif
 
-int computeNNZ(double* arr, int limit) {
+template<class FP> int computeNNZ(FP* arr, int limit) {
   int nnz = 0;
 #ifndef USE_INTEL_MKL
   #pragma omp parallel for reduction(+: nnz)
@@ -41,9 +41,14 @@ int computeNNZ(double* arr, int limit) {
   return nnz;
 }
 
+template<class FP> void biasAdd(FP* bias, FP* output, int K, int PQ) {
+  for(int k = 0, index=0; k < K; k++)
+    for(int pq = 0; pq < PQ; pq++, index++)
+      output[index] += bias[k];
+}
+
 void rotate180(double* inputArray, double* outputArray, int N, int C, int H, 
int W,
-            int K, int R, int S, int stride_h, int stride_w, int pad_h,
-            int pad_w, int P, int Q) {
+    int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int 
P, int Q) {
     int PQ = P*Q;
     int KQ = K*Q;
        for (int k = 0; k < K; k++) {
@@ -56,8 +61,7 @@ void rotate180(double* inputArray, double* outputArray, int 
N, int C, int H, int
 }
 
 void col2im(double* inputArray, double* outputArray, int N, int C, int H, int 
W,
-            int K, int R, int S, int stride_h, int stride_w, int pad_h,
-            int pad_w, int P, int Q) {
+    int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int 
P, int Q) {
        for (int p = 0; p < P; p++) {
                // h = p*stride_h + r - pad_h
                //   = r + hOffset
@@ -88,11 +92,10 @@ void col2im(double* inputArray, double* outputArray, int N, 
int C, int H, int W,
        }
 }
 
-void im2col(double* inputArray, double* outputArray, int N, int C, int H, int 
W,
-            int K, int R, int S, int stride_h, int stride_w, int pad_h,
-            int pad_w, int P, int Q) {
+template<class FP> void im2col(FP* inputArray, FP* outputArray, int N, int C, 
int H, int W,
+    int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int 
P, int Q) {
   int CRS = C * R * S;
-  std::size_t size = Q * sizeof(double);
+  std::size_t size = Q * sizeof(FP);
   if (stride_h == 1 && stride_w == 1 && pad_h == 0 && pad_w == 0) {
     for (int c = 0; c < CRS; ++c) {
       int wOffset = c % S;
@@ -102,8 +105,7 @@ void im2col(double* inputArray, double* outputArray, int N, 
int C, int H, int W,
         int hPadded = h + hOffset;
         int outOffset = (c * P + h) * Q;
         int inputOffset = (cInput * H + hPadded) * W;
-        std::memcpy(outputArray + outOffset, inputArray + inputOffset + 
wOffset,
-                    size);
+        std::memcpy(outputArray + outOffset, inputArray + inputOffset + 
wOffset, size);
         int w = Q - 1;
         int wPadded = w + wOffset;
         if (hPadded < H && wPadded < W)
@@ -149,7 +151,8 @@ bool MKL_DNN_ERROR(dnnError_t code) {
 } 
 #endif
 
-int conv2dBackwardFilterDense(double* inputPtr, double* doutPtr, double* 
retPtr, int N, int C, int H, int W, int K, int R, int S,
+int conv2dBackwardFilterDense(double* inputPtr, double* doutPtr, double* 
retPtr,
+    int N, int C, int H, int W, int K, int R, int S,
     int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, int 
numThreads) {
   int CRS = C*R*S;
 #ifdef USE_INTEL_MKL
@@ -167,7 +170,7 @@ int conv2dBackwardFilterDense(double* inputPtr, double* 
doutPtr, double* retPtr,
   resources[dnnResourceSrc] = inputPtr;
   resources[dnnResourceDiffFilter] = retPtr;
   dnnConvolutionCreateBackwardFilter_F64(&pConvolution, NULL, 
dnnAlgorithmConvolutionDirect, dimension, 
-      srcSize, dstSize, filterSize, convolutionStrides, pads, dnnBorderZeros);
+    srcSize, dstSize, filterSize, convolutionStrides, pads, dnnBorderZeros);
   
   // Step 2: Perform the DNN operation
   if(MKL_DNN_ERROR(dnnExecute_F64(pConvolution, resources))) {
@@ -201,18 +204,16 @@ int conv2dBackwardFilterDense(double* inputPtr, double* 
doutPtr, double* retPtr,
 #pragma omp parallel for num_threads(numOpenMPThreads)
   for (int n = 0; n < N; n++) {
     int threadID = omp_get_thread_num();
-       double* loweredMat = loweredMatArrays + numIm2ColElem*threadID;
+    double* loweredMat = loweredMatArrays + numIm2ColElem*threadID;
 
     // Step 1: Perform im2col
-    im2col(inputPtr + n * CHW, loweredMat, 1, C, H, W, K,
-           R, S, stride_h, stride_w, pad_h, pad_w,
-           P, Q);
-           
+    im2col<double>(inputPtr + n * CHW, loweredMat, 1, C, H, W, K,
+      R, S, stride_h, stride_w, pad_h, pad_w, P, Q);
+    
     // Step 2: Rotate dout
     double* rotatedDoutPtr = rotatedDoutPtrArrays + numRotatedElem*threadID;
     rotate180(doutPtr + n * KPQ, rotatedDoutPtr, 1, C, H, W, K,
-           R, S, stride_h, stride_w, pad_h, pad_w,
-           P, Q);
+           R, S, stride_h, stride_w, pad_h, pad_w, P, Q);
     
     // Multiply to get tmp1 = CRS X K
     double* temp1 = temp + numTempElem*threadID;
@@ -240,7 +241,7 @@ int conv2dBackwardFilterDense(double* inputPtr, double* 
doutPtr, double* retPtr,
   
   delete [] temp;
 #endif
-  return computeNNZ(retPtr, K*CRS);
+  return computeNNZ<double>(retPtr, K*CRS);
 }
 
 int conv2dBackwardDataDense(double* filterPtr, double* doutPtr, double* 
retPtr, int N, int C, int H, int W, int K, int R, int S,
@@ -296,20 +297,18 @@ int conv2dBackwardDataDense(double* filterPtr, double* 
doutPtr, double* retPtr,
 
     // Step 2: t(rotatedDout (PQ X K) %*% filter (K X CRS))
     double* col2imInput = col2imInputArrays + numCol2ImElem*threadID;
-    matmult(rotatedDoutPtr, filterPtr, col2imInput,
-            PQ, K, CRS, 1);
+    dmatmult(rotatedDoutPtr, filterPtr, col2imInput, PQ, K, CRS, 1);
 
     // Step 3: Perform col2im
     double* outputArr = retPtr + n * CHW;
     col2im(col2imInput, outputArr, 1, C, H, W, K,
-           R, S, stride_h, stride_w, pad_h, pad_w,
-           P, Q);
+           R, S, stride_h, stride_w, pad_h, pad_w, P, Q);
   } // end omp parallel for
   
   delete [] rotatedDoutPtrArrays;
   delete [] col2imInputArrays;
 #endif
-  return computeNNZ(retPtr, N*CHW);
+  return computeNNZ<double>(retPtr, N*CHW);
 }
 
 void conv2dSparse(int apos, int alen, int* aix, double* avals, double* 
filterPtr, double* retPtr, int N, int C, int H, int W, 
@@ -323,14 +322,13 @@ void conv2dSparse(int apos, int alen, int* aix, double* 
avals, double* filterPtr
        std::memset(temp, 0, size);
        for(int j=apos; j<apos+alen; j++)
                temp[ aix[j] ] = avals[j];
-       im2col(temp, loweredMat, 1, C, H, W, K,
-       R, S, stride_h, stride_w, pad_h, pad_w,
-       P, Q);  
+       im2col<double>(temp, loweredMat, 1, C, H, W, K,
+               R, S, stride_h, stride_w, pad_h, pad_w, P, Q);
        delete [] temp;
        
        // Step 2: filter (K X CRS) %*% loweredMat (CRS X PQ)
-    matmult(filterPtr, loweredMat, retPtr, K, C * R * S, P * Q, 1);
-    
+       dmatmult(filterPtr, loweredMat, retPtr, K, C * R * S, P * Q, 1);
+
        delete [] loweredMat;
 }
 
@@ -361,7 +359,7 @@ void conv2dBackwardFilterSparseDense(int apos, int alen, 
int* aix, double* avals
        // Multiply to get CRS X K
        double* temp1 = new double[CRS * K];
        // Step 3: loweredMat (CRS X PQ) %*% rotatedDoutPtr (PQ X K) 
-    matmult(loweredMat, rotatedDoutPtr, temp1, C * R * S, P * Q, K, 1);
+    dmatmult(loweredMat, rotatedDoutPtr, temp1, C * R * S, P * Q, K, 1);
     delete [] loweredMat;
      
     // Inplace addition
@@ -372,8 +370,8 @@ void conv2dBackwardFilterSparseDense(int apos, int alen, 
int* aix, double* avals
        delete [] temp1;
 }
 
-
-int conv2dBiasAddDense(double* inputPtr, double* biasPtr, double* filterPtr, 
double* retPtr, int N, int C, int H, int W, int K, int R, int S,
+int dconv2dBiasAddDense(double* inputPtr, double* biasPtr, double* filterPtr, 
double* retPtr, 
+    int N, int C, int H, int W, int K, int R, int S,
     int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, bool 
addBias, int numThreads) {
   int KPQ = K * P * Q;
   
@@ -427,28 +425,97 @@ int conv2dBiasAddDense(double* inputPtr, double* biasPtr, 
double* filterPtr, dou
     double* loweredMat = loweredMatArrays + numIm2ColElem*threadID;
 
     // Step 1: Perform im2col
-    im2col(inputPtr + n * CHW, loweredMat, 1, C, H, W, K,
-           R, S, stride_h, stride_w, pad_h, pad_w,
-           P, Q);
+    im2col<double>(inputPtr + n * CHW, loweredMat, 1, C, H, W, K,
+      R, S, stride_h, stride_w, pad_h, pad_w, P, Q);
 
     // Step 2: filter (K X CRS) %*% loweredMat (CRS X PQ)
-    matmult(filterPtr, loweredMat, retPtr + n * KPQ, K,
+    dmatmult(filterPtr, loweredMat, retPtr + n * KPQ, K,
             C * R * S, P * Q, 1);
-    
+
     // Step 3: Add bias
     double* outputArr = retPtr + n*KPQ;
-    if(addBias) {
-           int index = 0;
-               for(int k = 0; k < K; k++) {
-                       for(int pq = 0; pq < PQ; pq++, index++) {
-                               outputArr[index] += biasPtr[k];
-                       }
-               }
-    }
+    if( addBias )
+       biasAdd<double>(biasPtr, outputArr, K, PQ);
   } // end omp parallel for
   delete [] loweredMatArrays;
   // 
------------------------------------------------------------------------------------
 #endif
   
-  return computeNNZ(retPtr, N*KPQ);
+  return computeNNZ<double>(retPtr, N*KPQ);
+}
+
+int sconv2dBiasAddDense(float* inputPtr, float* biasPtr, float* filterPtr, 
float* retPtr, 
+    int N, int C, int H, int W, int K, int R, int S,
+    int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, bool 
addBias, int numThreads) {
+  int KPQ = K * P * Q;
+  
+#ifdef USE_INTEL_MKL
+  setNumThreadsForBLAS(numThreads);
+  // Step 1: Create a description of a DNN operation
+  dnnPrimitive_t pConvolution;
+  size_t dimension = 4;
+  size_t srcSize[4] = {W, H, C, N};
+  size_t dstSize[4] = {Q, P, K, N};
+  size_t filterSize[4] = {S, R, C, K};
+  size_t convolutionStrides[2] = {stride_w, stride_h};
+  int pads[2] = {-pad_w, -pad_h};
+  void* resources[dnnResourceNumber] = {0};
+  resources[dnnResourceSrc] = inputPtr;
+  resources[dnnResourceFilter] = filterPtr;
+  resources[dnnResourceDst] = retPtr;
+  if(addBias) {
+    dnnConvolutionCreateForwardBias_F32(&pConvolution, NULL, 
dnnAlgorithmConvolutionDirect, dimension, 
+      srcSize, dstSize, filterSize, convolutionStrides, pads, dnnBorderZeros);
+    resources[dnnResourceBias] = biasPtr;
+  }
+  else { 
+    dnnConvolutionCreateForward_F32(&pConvolution, NULL, 
dnnAlgorithmConvolutionDirect, dimension, 
+      srcSize, dstSize, filterSize, convolutionStrides, pads, dnnBorderZeros);
+  }
+  
+  // Step 2: Perform the DNN operation
+  if(MKL_DNN_ERROR(dnnExecute_F32(pConvolution, resources))) {
+    return -1; // nnz == -1 indicates error.
+  }
+  
+  // Step 3: Destroy the description of the operation
+  dnnDelete_F32(pConvolution);
+#else 
+  // First step:  Avoids oversubscription and other openmp/internal blas 
threading issues
+  setNumThreadsForBLAS(1);
+  
+  int CHW = C * H * W;
+  int PQ = P * Q;
+  int numIm2ColElem = C * R * S * P * Q;
+  
+  // Allocate temporary data structures used in parallel for
+  int numOpenMPThreads = MIN(numThreads, N);
+  float* loweredMatArrays = new float[numIm2ColElem*numOpenMPThreads];
+  int nnz = 0;
+  
+#pragma omp parallel for reduction(+: nnz) num_threads(numOpenMPThreads)
+  for (int n = 0; n < N; n++) {
+    int threadID = omp_get_thread_num();
+    float* loweredMat = loweredMatArrays + numIm2ColElem*threadID;
+
+    // Step 1: Perform im2col
+    im2col<float>(inputPtr + n * CHW, loweredMat, 1, C, H, W, K,
+      R, S, stride_h, stride_w, pad_h, pad_w, P, Q);
+
+    // Step 2: filter (K X CRS) %*% loweredMat (CRS X PQ)
+    smatmult(filterPtr, loweredMat, retPtr + n * KPQ, K,
+            C * R * S, P * Q, 1);
+
+    // Step 3: Add bias
+    float* outputArr = retPtr + n*KPQ;
+    if( addBias )
+       biasAdd<float>(biasPtr, outputArr, K, PQ);
+       
+    // Step 4: thread-local nnz maintenance
+    nnz += computeNNZ<float>(retPtr + n*KPQ, KPQ);   
+  }
+  delete [] loweredMatArrays;
+#endif
+  
+  return nnz;
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/6c468f98/src/main/cpp/libmatrixdnn.h
----------------------------------------------------------------------
diff --git a/src/main/cpp/libmatrixdnn.h b/src/main/cpp/libmatrixdnn.h
index f7d746f..3f457b0 100644
--- a/src/main/cpp/libmatrixdnn.h
+++ b/src/main/cpp/libmatrixdnn.h
@@ -21,7 +21,7 @@
 #define _libmatrixdnn_h
 
 #ifdef USE_INTEL_MKL
-       #include <mkl.h>
+  #include <mkl.h>
        #if INTEL_MKL_VERSION < 20170000
                // Will throw an error at development time in non-standard 
settings
                PLEASE DONOT COMPILE SHARED LIBRARIES WITH OLDER MKL VERSIONS
@@ -33,10 +33,15 @@ int conv2dBackwardFilterDense(double* inputPtr, double* 
doutPtr, double* retPtr,
 
 int conv2dBackwardDataDense(double* filterPtr, double* doutPtr, double* 
retPtr, int N, int C, int H, int W, int K, int R, int S,
     int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, int 
numThreads);
-    
-int conv2dBiasAddDense(double* inputPtr, double* biasPtr, double* filterPtr, 
double* retPtr, int N, int C, int H, int W, int K, int R, int S,
-    int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, bool 
addBias, int numThreads);
-    
+
+int dconv2dBiasAddDense(double* inputPtr, double* biasPtr, double* filterPtr, 
double* retPtr,
+   int N, int C, int H, int W, int K, int R, int S,
+   int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, bool 
addBias, int numThreads);
+
+int sconv2dBiasAddDense(float* inputPtr, float* biasPtr, float* filterPtr, 
float* retPtr,
+   int N, int C, int H, int W, int K, int R, int S,
+   int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, bool 
addBias, int numThreads);
+
 void conv2dSparse(int apos, int alen, int* aix, double* avals, double* filter, 
double* ret, int N, int C, int H, int W, 
                        int K, int R, int S, int stride_h, int stride_w, int 
pad_h, int pad_w, int P, int Q, int numThreads);
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/6c468f98/src/main/cpp/libmatrixmult.cpp
----------------------------------------------------------------------
diff --git a/src/main/cpp/libmatrixmult.cpp b/src/main/cpp/libmatrixmult.cpp
index 6844c2a..0fce70d 100644
--- a/src/main/cpp/libmatrixmult.cpp
+++ b/src/main/cpp/libmatrixmult.cpp
@@ -22,6 +22,7 @@
 #include <cstdlib>
 #include "omp.h"
 #include <cmath>
+#include <cblas.h>
 
 int SYSML_CURRENT_NUM_THREADS = -1;
 void setNumThreadsForBLAS(int numThreads) {
@@ -35,20 +36,14 @@ void setNumThreadsForBLAS(int numThreads) {
        }
 }
  
-// Multiplies two matrices m1Ptr and m2Ptr in row-major format of shape
-// (m1rlen, m1clen) and (m1clen, m2clen)
-void matmult(double* m1Ptr, double* m2Ptr, double* retPtr, int m1rlen,
-             int m1clen, int m2clen, int numThreads) {
-  int m = m1rlen;
-  int n = m2clen;
-  int k = m1clen;
-  
+void dmatmult(double* m1Ptr, double* m2Ptr, double* retPtr, int m, int k, int 
n, int numThreads) {
   setNumThreadsForBLAS(numThreads);
-  
-  // if(m2clen == 1)
-  //   cblas_dgemv(CblasRowMajor, CblasNoTrans, m1rlen, m1clen, 1, m1Ptr, 
m1clen, m2Ptr, 1, 0, retPtr, 1);
-  // else 
-       cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1, 
m1Ptr, k, m2Ptr, n, 0, retPtr, n);
+  cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1, m1Ptr, k, 
m2Ptr, n, 0, retPtr, n);
+}
+
+void smatmult(float* m1Ptr, float* m2Ptr, float* retPtr, int m, int k, int n, 
int numThreads) {  
+  setNumThreadsForBLAS(numThreads);
+  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1, m1Ptr, k, 
m2Ptr, n, 0, retPtr, n);
 }
 
 void tsmm(double* m1Ptr, double* retPtr, int m1rlen, int m1clen, bool 
isLeftTranspose, int numThreads) {
@@ -57,7 +52,5 @@ void tsmm(double* m1Ptr, double* retPtr, int m1rlen, int 
m1clen, bool isLeftTran
   int k = isLeftTranspose ? m1rlen : m1clen;
   
   setNumThreadsForBLAS(numThreads);
-  
   cblas_dgemm(CblasRowMajor, isLeftTranspose ? CblasTrans : CblasNoTrans, 
isLeftTranspose ? CblasNoTrans : CblasTrans, m, n, k, 1, m1Ptr, k, m1Ptr, n, 0, 
retPtr, n);
 }
- 

http://git-wip-us.apache.org/repos/asf/systemml/blob/6c468f98/src/main/cpp/libmatrixmult.h
----------------------------------------------------------------------
diff --git a/src/main/cpp/libmatrixmult.h b/src/main/cpp/libmatrixmult.h
index d39242e..ee231d4 100644
--- a/src/main/cpp/libmatrixmult.h
+++ b/src/main/cpp/libmatrixmult.h
@@ -44,7 +44,7 @@
 // for MKL we use mkl_set_num_threads. This avoids performance degradation due 
to overprovisioning.
 #ifdef USE_OPEN_BLAS
 #include <cblas.h>
-// extern "C" void openblas_set_num_threads(int numThreads);
+extern "C" void openblas_set_num_threads(int numThreads);
 #elif defined USE_INTEL_MKL
 #include <mkl.h>
 #include <mkl_service.h>
@@ -54,9 +54,9 @@ void setNumThreadsForBLAS(int numThreads);
 
 // Multiplies two matrices m1Ptr and m2Ptr in row-major format of shape
 // (m1rlen, m1clen) and (m1clen, m2clen)
-void matmult(double* m1Ptr, double* m2Ptr, double* retPtr, int m1rlen,
-             int m1clen, int m2clen, int numThreads);
-             
+void dmatmult(double* m1Ptr, double* m2Ptr, double* retPtr, int m, int k, int 
n, int numThreads);
+void smatmult(float* m1Ptr, float* m2Ptr, float* retPtr, int m, int k, int n, 
int numThreads);
+
 void tsmm(double* m1Ptr, double* retPtr, int m1rlen, int m1clen, bool 
isLeftTranspose,  int numThreads);
-             
+
 #endif

http://git-wip-us.apache.org/repos/asf/systemml/blob/6c468f98/src/main/cpp/systemml.cpp
----------------------------------------------------------------------
diff --git a/src/main/cpp/systemml.cpp b/src/main/cpp/systemml.cpp
index 0f00afd..35a0074 100644
--- a/src/main/cpp/systemml.cpp
+++ b/src/main/cpp/systemml.cpp
@@ -46,7 +46,7 @@
 // 1. We chose GetDoubleArrayElements over GetPrimitiveArrayCritical in a 
multi-threaded scenario. This avoids any potential OOM related to GC halts.
 // 2. For input array, we don't copy back the array using JNI_ABORT.
 
-// JNI Methods to get/release double* 
+// JNI Methods to get/release double*
 #define GET_DOUBLE_ARRAY(env, input, numThreads) \
        ((double*)env->GetPrimitiveArrayCritical(input, NULL))
 // ( maxThreads != -1 && ((int)numThreads) == maxThreads ? 
((double*)env->GetPrimitiveArrayCritical(input, NULL)) :  
env->GetDoubleArrayElements(input,NULL) )
@@ -59,11 +59,11 @@
 // JNI_ABORT
 // Actual: the array object is un-pinned. Earlier writes are not aborted.
 // Copy: the buffer with the copy is freed; any changes to it are lost.
-#define RELEASE_INPUT_DOUBLE_ARRAY(env, input, inputPtr, numThreads) \
+#define RELEASE_INPUT_ARRAY(env, input, inputPtr, numThreads) \
        env->ReleasePrimitiveArrayCritical(input, inputPtr, JNI_ABORT)
 // ( maxThreads != -1 && ((int)numThreads) == maxThreads ? 
env->ReleasePrimitiveArrayCritical(input, inputPtr, JNI_ABORT) : 
env->ReleaseDoubleArrayElements(input, inputPtr, JNI_ABORT) )
 
-#define RELEASE_DOUBLE_ARRAY(env, input, inputPtr, numThreads) \
+#define RELEASE_ARRAY(env, input, inputPtr, numThreads) \
        env->ReleasePrimitiveArrayCritical(input, inputPtr, 0)
 // ( maxThreads != -1 && ((int)numThreads) == maxThreads ? 
env->ReleasePrimitiveArrayCritical(input, inputPtr, 0) :  
env->ReleaseDoubleArrayElements(input, inputPtr, 0) )
   
@@ -84,11 +84,11 @@ JNIEXPORT jboolean JNICALL 
Java_org_apache_sysml_utils_NativeHelper_matrixMultDe
   if(m1Ptr == NULL || m2Ptr == NULL || retPtr == NULL)
        return (jboolean) false;
 
-  matmult(m1Ptr, m2Ptr, retPtr, (int)m1rlen, (int)m1clen, (int)m2clen, 
(int)numThreads);
+  dmatmult(m1Ptr, m2Ptr, retPtr, (int)m1rlen, (int)m1clen, (int)m2clen, 
(int)numThreads);
 
-  RELEASE_INPUT_DOUBLE_ARRAY(env, m1, m1Ptr, numThreads);
-  RELEASE_INPUT_DOUBLE_ARRAY(env, m2, m2Ptr, numThreads);
-  RELEASE_DOUBLE_ARRAY(env, ret, retPtr, numThreads); 
+  RELEASE_INPUT_ARRAY(env, m1, m1Ptr, numThreads);
+  RELEASE_INPUT_ARRAY(env, m2, m2Ptr, numThreads);
+  RELEASE_ARRAY(env, ret, retPtr, numThreads); 
   return (jboolean) true;
 }
 
@@ -101,8 +101,8 @@ JNIEXPORT jboolean JNICALL 
Java_org_apache_sysml_utils_NativeHelper_tsmm
 
   tsmm(m1Ptr, retPtr, (int) m1rlen, (int) m1clen, (bool) isLeftTranspose, 
(int) numThreads);
   
-  RELEASE_INPUT_DOUBLE_ARRAY(env, m1, m1Ptr, numThreads);
-  RELEASE_DOUBLE_ARRAY(env, ret, retPtr, numThreads);
+  RELEASE_INPUT_ARRAY(env, m1, m1Ptr, numThreads);
+  RELEASE_ARRAY(env, ret, retPtr, numThreads);
   return (jboolean) true;
 }
 
@@ -118,10 +118,10 @@ JNIEXPORT jboolean JNICALL 
Java_org_apache_sysml_utils_NativeHelper_conv2dSparse
   conv2dSparse((int)apos, (int)alen, aixPtr, avalsPtr, filterPtr, retPtr, 
(int)N, (int)C, (int)H, (int)W, 
                        (int)K, (int)R, (int)S, (int)stride_h, (int)stride_w, 
(int)pad_h, (int)pad_w, (int)P, (int)Q, (int)numThreads);
   
-  RELEASE_INPUT_DOUBLE_ARRAY(env, avals, avalsPtr, numThreads);
-  RELEASE_INPUT_DOUBLE_ARRAY(env, filter, filterPtr, numThreads);
+  RELEASE_INPUT_ARRAY(env, avals, avalsPtr, numThreads);
+  RELEASE_INPUT_ARRAY(env, filter, filterPtr, numThreads);
   env->ReleasePrimitiveArrayCritical(aix, aixPtr, JNI_ABORT);
-  RELEASE_DOUBLE_ARRAY(env, ret, retPtr, numThreads); 
+  RELEASE_ARRAY(env, ret, retPtr, numThreads); 
   return (jboolean) true;
 }
 
@@ -137,51 +137,73 @@ JNIEXPORT jboolean JNICALL 
Java_org_apache_sysml_utils_NativeHelper_conv2dBackwa
   conv2dBackwardFilterSparseDense((int)apos, (int)alen, aixPtr, avalsPtr, 
doutPtr, retPtr, (int)N, (int)C, (int)H, (int)W, 
                        (int)K, (int)R, (int)S, (int)stride_h, (int)stride_w, 
(int)pad_h, (int)pad_w, (int)P, (int)Q, (int)numThreads);
   
-  RELEASE_INPUT_DOUBLE_ARRAY(env, avals, avalsPtr, numThreads);
-  RELEASE_INPUT_DOUBLE_ARRAY(env, dout, doutPtr, numThreads);
+  RELEASE_INPUT_ARRAY(env, avals, avalsPtr, numThreads);
+  RELEASE_INPUT_ARRAY(env, dout, doutPtr, numThreads);
   env->ReleasePrimitiveArrayCritical(aix, aixPtr, JNI_ABORT);
-  RELEASE_DOUBLE_ARRAY(env, ret, retPtr, numThreads); 
+  RELEASE_ARRAY(env, ret, retPtr, numThreads); 
   return (jboolean) true;
 }
 
 JNIEXPORT jint JNICALL Java_org_apache_sysml_utils_NativeHelper_conv2dDense(
-       JNIEnv* env, jclass, jdoubleArray input, jdoubleArray filter,
+    JNIEnv* env, jclass, jdoubleArray input, jdoubleArray filter,
     jdoubleArray ret, jint N, jint C, jint H, jint W, jint K, jint R, jint S,
-    jint stride_h, jint stride_w, jint pad_h, jint pad_w, jint P, jint Q, jint 
numThreads) {
+    jint stride_h, jint stride_w, jint pad_h, jint pad_w, jint P, jint Q, jint 
numThreads)
+{
   double* inputPtr = GET_DOUBLE_ARRAY(env, input, numThreads);
   double* filterPtr = GET_DOUBLE_ARRAY(env, filter, numThreads);
   double* retPtr = GET_DOUBLE_ARRAY(env, ret, numThreads);
   if(inputPtr == NULL || filterPtr == NULL || retPtr == NULL)
-       return (jint) -1;
+    return (jint) -1;
   
-  int nnz = conv2dBiasAddDense(inputPtr, 0, filterPtr, retPtr, (int) N, (int) 
C, (int) H, (int) W, (int) K, (int) R, (int) S,
+  int nnz = dconv2dBiasAddDense(inputPtr, 0, filterPtr, retPtr,
+    (int) N, (int) C, (int) H, (int) W, (int) K, (int) R, (int) S,
     (int) stride_h, (int) stride_w, (int) pad_h, (int) pad_w, (int) P, (int) 
Q, false, (int) numThreads);
-    
-  RELEASE_INPUT_DOUBLE_ARRAY(env, input, inputPtr, numThreads);
-  RELEASE_INPUT_DOUBLE_ARRAY(env, filter, filterPtr, numThreads);
-  RELEASE_DOUBLE_ARRAY(env, ret, retPtr, numThreads); 
+  
+  RELEASE_INPUT_ARRAY(env, input, inputPtr, numThreads);
+  RELEASE_INPUT_ARRAY(env, filter, filterPtr, numThreads);
+  RELEASE_ARRAY(env, ret, retPtr, numThreads);
   return (jint) nnz;
 }
 
-JNIEXPORT jint JNICALL 
Java_org_apache_sysml_utils_NativeHelper_conv2dBiasAddDense(
-       JNIEnv* env, jclass, jdoubleArray input, jdoubleArray bias, 
jdoubleArray filter,
+JNIEXPORT jint JNICALL 
Java_org_apache_sysml_utils_NativeHelper_dconv2dBiasAddDense(
+    JNIEnv* env, jclass, jdoubleArray input, jdoubleArray bias, jdoubleArray 
filter,
     jdoubleArray ret, jint N, jint C, jint H, jint W, jint K, jint R, jint S,
-    jint stride_h, jint stride_w, jint pad_h, jint pad_w, jint P, jint Q, jint 
numThreads) {
-    
+    jint stride_h, jint stride_w, jint pad_h, jint pad_w, jint P, jint Q, jint 
numThreads)
+{
   double* inputPtr = GET_DOUBLE_ARRAY(env, input, numThreads);
   double* biasPtr = GET_DOUBLE_ARRAY(env, bias, numThreads);
   double* filterPtr = GET_DOUBLE_ARRAY(env, filter, numThreads);
   double* retPtr = GET_DOUBLE_ARRAY(env, ret, numThreads);
   if(inputPtr == NULL || biasPtr == NULL || filterPtr == NULL || retPtr == 
NULL)
-       return (jint) -1;
+    return (jint) -1;
   
-  int nnz = conv2dBiasAddDense(inputPtr, biasPtr, filterPtr, retPtr, (int) N, 
(int) C, (int) H, (int) W, (int) K, (int) R, (int) S,
+  int nnz = dconv2dBiasAddDense(inputPtr, biasPtr, filterPtr, retPtr,
+    (int) N, (int) C, (int) H, (int) W, (int) K, (int) R, (int) S,
     (int) stride_h, (int) stride_w, (int) pad_h, (int) pad_w, (int) P, (int) 
Q, true, (int) numThreads);
-    
-  RELEASE_INPUT_DOUBLE_ARRAY(env, input, inputPtr, numThreads);
-  RELEASE_INPUT_DOUBLE_ARRAY(env, bias, biasPtr, numThreads);
-  RELEASE_INPUT_DOUBLE_ARRAY(env, filter, filterPtr, numThreads);
-  RELEASE_DOUBLE_ARRAY(env, ret, retPtr, numThreads); 
+  
+  RELEASE_INPUT_ARRAY(env, input, inputPtr, numThreads);
+  RELEASE_INPUT_ARRAY(env, bias, biasPtr, numThreads);
+  RELEASE_INPUT_ARRAY(env, filter, filterPtr, numThreads);
+  RELEASE_ARRAY(env, ret, retPtr, numThreads);
+  return (jint) nnz;
+}
+
+JNIEXPORT jint JNICALL 
Java_org_apache_sysml_utils_NativeHelper_sconv2dBiasAddDense(
+    JNIEnv* env, jclass, jobject input, jobject bias, jobject filter,
+    jobject ret, jint N, jint C, jint H, jint W, jint K, jint R, jint S,
+    jint stride_h, jint stride_w, jint pad_h, jint pad_w, jint P, jint Q, jint 
numThreads) 
+{
+  float* inputPtr = (float*) env->GetDirectBufferAddress(input);
+  float* biasPtr =  (float*) env->GetDirectBufferAddress(bias);
+  float* filterPtr = (float*) env->GetDirectBufferAddress(filter);
+  float* retPtr = (float*) env->GetDirectBufferAddress(ret);
+  if(inputPtr == NULL || biasPtr == NULL || filterPtr == NULL || retPtr == 
NULL)
+    return (jint) -1;
+  
+  int nnz = sconv2dBiasAddDense(inputPtr, biasPtr, filterPtr, retPtr, 
+    (int) N, (int) C, (int) H, (int) W, (int) K, (int) R, (int) S,
+    (int) stride_h, (int) stride_w, (int) pad_h, (int) pad_w, (int) P, (int) 
Q, true, (int) numThreads);
+  
   return (jint) nnz;
 }
 
@@ -199,9 +221,9 @@ JNIEXPORT jint JNICALL 
Java_org_apache_sysml_utils_NativeHelper_conv2dBackwardDa
   int nnz = conv2dBackwardDataDense(filterPtr, doutPtr, retPtr, (int) N, (int) 
C, (int) H, (int) W, (int) K, (int) R, (int) S,
     (int) stride_h, (int) stride_w, (int) pad_h, (int) pad_w, (int) P, (int) 
Q, (int) numThreads);
   
-  RELEASE_INPUT_DOUBLE_ARRAY(env, filter, filterPtr, numThreads);
-  RELEASE_INPUT_DOUBLE_ARRAY(env, dout, doutPtr, numThreads);
-  RELEASE_DOUBLE_ARRAY(env, ret, retPtr, numThreads);
+  RELEASE_INPUT_ARRAY(env, filter, filterPtr, numThreads);
+  RELEASE_INPUT_ARRAY(env, dout, doutPtr, numThreads);
+  RELEASE_ARRAY(env, ret, retPtr, numThreads);
   return (jint) nnz;
 }
 
@@ -218,8 +240,8 @@ JNIEXPORT jint JNICALL 
Java_org_apache_sysml_utils_NativeHelper_conv2dBackwardFi
   int nnz = conv2dBackwardFilterDense(inputPtr, doutPtr, retPtr, (int) N, 
(int) C, (int) H, (int) W, (int) K, (int) R, (int) S,
     (int) stride_h, (int) stride_w, (int) pad_h, (int) pad_w, (int) P, (int) 
Q, (int) numThreads);
   
-  RELEASE_INPUT_DOUBLE_ARRAY(env, input, inputPtr, numThreads);
-  RELEASE_INPUT_DOUBLE_ARRAY(env, dout, doutPtr, numThreads);
-  RELEASE_DOUBLE_ARRAY(env, ret, retPtr, numThreads);
+  RELEASE_INPUT_ARRAY(env, input, inputPtr, numThreads);
+  RELEASE_INPUT_ARRAY(env, dout, doutPtr, numThreads);
+  RELEASE_ARRAY(env, ret, retPtr, numThreads);
   return (jint) nnz;
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/6c468f98/src/main/cpp/systemml.h
----------------------------------------------------------------------
diff --git a/src/main/cpp/systemml.h b/src/main/cpp/systemml.h
index f6f5cd2..71155fa 100644
--- a/src/main/cpp/systemml.h
+++ b/src/main/cpp/systemml.h
@@ -52,14 +52,21 @@ JNIEXPORT jint JNICALL 
Java_org_apache_sysml_utils_NativeHelper_conv2dDense
 
 /*
  * Class:     org_apache_sysml_utils_NativeHelper
- * Method:    conv2dBiasAddDense
- * Signature: ([D[D[D[DIIIIIIIIIIIIII)I
+ * Method:    dconv2dBiasAddDense
  */
-JNIEXPORT jint JNICALL 
Java_org_apache_sysml_utils_NativeHelper_conv2dBiasAddDense
+JNIEXPORT jint JNICALL 
Java_org_apache_sysml_utils_NativeHelper_dconv2dBiasAddDense
   (JNIEnv *, jclass, jdoubleArray, jdoubleArray, jdoubleArray, jdoubleArray, 
jint, jint, jint, jint, jint, jint, jint, jint, jint, jint, jint, jint, jint, 
jint);
 
 /*
  * Class:     org_apache_sysml_utils_NativeHelper
+ * Method:    sconv2dBiasAddDense
+ */
+JNIEXPORT jint JNICALL 
Java_org_apache_sysml_utils_NativeHelper_sconv2dBiasAddDense
+  (JNIEnv *, jclass, jobject, jobject, jobject, jobject, jint, jint, jint, 
jint, jint, jint, jint, jint, jint, jint, jint, jint, jint, jint);
+
+
+/*
+ * Class:     org_apache_sysml_utils_NativeHelper
  * Method:    conv2dBackwardFilterDense
  * Signature: ([D[D[DIIIIIIIIIIIIII)I
  */

http://git-wip-us.apache.org/repos/asf/systemml/blob/6c468f98/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java
index 51a79e1..dfb8abd 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java
@@ -18,13 +18,27 @@
  */
 package org.apache.sysml.runtime.matrix.data;
 
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.FloatBuffer;
+import java.util.Arrays;
+import java.util.stream.IntStream;
+
 import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.conf.DMLConfig;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.utils.NativeHelper;
 import org.apache.sysml.utils.Statistics;
 
-public class LibMatrixNative {
+public class LibMatrixNative
+{
+       /** ThreadLocal reuse of direct buffers for inputs/outputs (extended on 
demand).*/
+       private static ThreadLocal<FloatBuffer> inBuff = new 
ThreadLocal<FloatBuffer>();
+       private static ThreadLocal<FloatBuffer> biasBuff = new 
ThreadLocal<FloatBuffer>();
+       private static ThreadLocal<FloatBuffer> filterBuff = new 
ThreadLocal<FloatBuffer>();
+       private static ThreadLocal<FloatBuffer> outBuff = new 
ThreadLocal<FloatBuffer>();
        
        // We could encapsulate heuristics in this function
        // For now, we only consider matrix-vector operation to be memory bound
@@ -121,12 +135,31 @@ public class LibMatrixNative {
                        else {
                                if(params.bias.isInSparseFormat())
                                        params.bias.sparseToDense(); // Bias 
matrix is usually extremely small
+                               boolean singlePrecision = 
ConfigurationManager.getDMLConfig()
+                                       
.getTextValue(DMLConfig.FLOATING_POINT_PRECISION).equals("single");
                                long start = DMLScript.STATISTICS ? 
System.nanoTime() : 0;
-                               int nnz = 
NativeHelper.conv2dBiasAddDense(input.getDenseBlockValues(), 
params.bias.getDenseBlockValues(),
-                                               filter.getDenseBlockValues(), 
outputBlock.getDenseBlockValues(),
-                                               params.N, params.C, params.H, 
params.W, 
-                                               params.K, params.R, params.S, 
params.stride_h, params.stride_w, params.pad_h, params.pad_w, 
+                               int nnz = -1;
+                               if( singlePrecision ) {
+                                       //note: since we anyway have to convert 
from double to float, we use
+                                       //preallocated direct buffers (with 
thread-local reuse and resizing on demand)
+                                       //to ensure there are no additional 
copies created by the transfer over jni
+                                       FloatBuffer finput = 
toFloatBuffer(input.getDenseBlockValues(), inBuff, true);
+                                       FloatBuffer fbias = 
toFloatBuffer(params.bias.getDenseBlockValues(), biasBuff, true);
+                                       FloatBuffer ffilter = 
toFloatBuffer(filter.getDenseBlockValues(), filterBuff, true);
+                                       FloatBuffer foutput = 
toFloatBuffer(outputBlock.getDenseBlockValues(), outBuff, false);
+                                       nnz = 
NativeHelper.sconv2dBiasAddDense(finput, fbias, ffilter, foutput,
+                                               params.N, params.C, params.H, 
params.W, params.K, params.R, params.S,
+                                               params.stride_h, 
params.stride_w, params.pad_h, params.pad_w, 
                                                params.P, params.Q, 
params.numThreads);
+                                       fromFloatBuffer(outBuff.get(), 
outputBlock.getDenseBlockValues());
+                               }
+                               else { //Double
+                                       nnz = 
NativeHelper.dconv2dBiasAddDense(input.getDenseBlockValues(), 
params.bias.getDenseBlockValues(),
+                                               filter.getDenseBlockValues(), 
outputBlock.getDenseBlockValues(),
+                                               params.N, params.C, params.H, 
params.W, params.K, params.R, params.S,
+                                               params.stride_h, 
params.stride_w, params.pad_h, params.pad_w, 
+                                               params.P, params.Q, 
params.numThreads); 
+                               }
                                if(nnz != -1) {
                                        if(DMLScript.STATISTICS) {
                                                Statistics.nativeConv2dTime += 
System.nanoTime() - start;
@@ -191,7 +224,7 @@ public class LibMatrixNative {
        }
        
        /**
-        * This method computes the backpropogation errors for previous layer 
of convolution operation
+        * This method computes the backpropagation errors for previous layer 
of convolution operation
         * 
         * @param filter filter used in conv2d 
         * @param dout errors from next layer
@@ -226,4 +259,25 @@ public class LibMatrixNative {
                // Fall back to Java when failures or sparse
                LibMatrixDNN.conv2dBackwardData(filter, dout, outputBlock, 
params);
        }
+       
+       private static FloatBuffer toFloatBuffer(double[] input, 
ThreadLocal<FloatBuffer> buff, boolean copy) {
+               //maintain thread-local buffer (resized on demand)
+               FloatBuffer ret = buff.get();
+               if( ret == null || ret.capacity() < input.length ) {
+                       ret = ByteBuffer.allocateDirect(4*input.length)
+                               .order(ByteOrder.nativeOrder()).asFloatBuffer();
+                       buff.set(ret);
+               }
+               //copy to direct byte buffer
+               final FloatBuffer ret2 = ret;
+               if( copy ) {
+                       IntStream.range(0, input.length).parallel()
+                               .forEach(i -> ret2.put(i, (float)input[i]));
+               }
+               return ret2;
+       }
+       
+       private static void fromFloatBuffer(FloatBuffer buff, double[] output) {
+               Arrays.parallelSetAll(output, i -> (double)buff.get(i) );
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/6c468f98/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java 
b/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java
index e7a3bdc..993c9b2 100644
--- a/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java
@@ -304,6 +304,13 @@ public class UtilFunctions
                        ((Long)obj).intValue() : ((Integer)obj).intValue();
        }
        
+       public static float[] toFloat(double[] data) {
+               float[] ret = new float[data.length];
+               for( int i=0; i<data.length; i++ )
+                       ret[i] = (float)data[i];
+               return ret;
+       }
+       
        public static long getSeqLength(double from, double to, double incr) {
                return getSeqLength(from, to, incr, true);
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/6c468f98/src/main/java/org/apache/sysml/utils/NativeHelper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/NativeHelper.java 
b/src/main/java/org/apache/sysml/utils/NativeHelper.java
index 6f22970..6ec990d 100644
--- a/src/main/java/org/apache/sysml/utils/NativeHelper.java
+++ b/src/main/java/org/apache/sysml/utils/NativeHelper.java
@@ -27,6 +27,7 @@ import org.apache.commons.logging.LogFactory;
 import java.util.Vector;
 import java.io.InputStream;
 import java.io.OutputStream;
+import java.nio.FloatBuffer;
 import java.io.File;
 
 import org.apache.commons.io.FileUtils;
@@ -336,8 +337,10 @@ public class NativeHelper {
        // Called by ConvolutionCPInstruction if both input and filter are dense
        public static native int conv2dDense(double [] input, double [] filter, 
double [] ret, int N, int C, int H, int W, 
                        int K, int R, int S, int stride_h, int stride_w, int 
pad_h, int pad_w, int P, int Q, int numThreads);
-       public static native int conv2dBiasAddDense(double [] input, double [] 
bias, double [] filter, double [] ret, int N, int C, int H, int W, 
-                       int K, int R, int S, int stride_h, int stride_w, int 
pad_h, int pad_w, int P, int Q, int numThreads);
+       public static native int dconv2dBiasAddDense(double [] input, double [] 
bias, double [] filter, double [] ret, int N,
+               int C, int H, int W, int K, int R, int S, int stride_h, int 
stride_w, int pad_h, int pad_w, int P, int Q, int numThreads);
+       public static native int sconv2dBiasAddDense(FloatBuffer input, 
FloatBuffer bias, FloatBuffer filter, FloatBuffer ret,
+               int N, int C, int H, int W, int K, int R, int S, int stride_h, 
int stride_w, int pad_h, int pad_w, int P, int Q, int numThreads);
        // Called by ConvolutionCPInstruction if both input and filter are dense
        public static native int conv2dBackwardFilterDense(double [] input, 
double [] dout, double [] ret, int N, int C, int H, int W, 
                        int K, int R, int S, int stride_h, int stride_w, int 
pad_h, int pad_w, int P, int Q, int numThreads);

Reply via email to