This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 83078d7  cuda support for linalg-functions, restructuring of linalg 
interfaces (#7147)
83078d7 is described below

commit 83078d7b21491936dfe552866556b40040dadf5b
Author: moin <asmushet...@yahoo.de>
AuthorDate: Sat Aug 12 21:12:44 2017 +0200

    cuda support for linalg-functions, restructuring of linalg interfaces 
(#7147)
    
    * cuda support for linalg-functions, restructuring of linalg interfaces
    
    * incorporate newest mshadow
    
    * adjustments to linalg operators
---
 CMakeLists.txt                                 |   4 +-
 Makefile                                       |   6 +
 include/mxnet/base.h                           |   7 +
 mshadow                                        |   2 +-
 src/common/cuda_utils.h                        |  42 ++
 src/io/inst_vector.h                           |   1 +
 {include/mxnet => src/operator}/c_lapack_api.h |  53 ++-
 src/operator/contrib/krprod.h                  |   2 +-
 src/operator/linalg.h                          | 118 ++++++
 src/operator/linalg_impl.h                     | 508 +++++++++++++++++++++++++
 src/operator/tensor/la_op.cc                   |   4 +-
 src/operator/tensor/la_op.cu                   |  77 ++++
 src/operator/tensor/la_op.h                    | 171 ++-------
 src/operator/tensor/la_op_inline.h             | 373 ++++++++----------
 tests/python/unittest/test_operator.py         |   9 +-
 15 files changed, 1008 insertions(+), 369 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index ab29b6a..dc9ca5f 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -353,8 +353,10 @@ if(USE_CUDA)
     list(APPEND mxnet_LINKER_LIBS ${CUDA_cuda_LIBRARY})
     FIND_LIBRARY(CUDA_cufft_LIBRARY nvrtc "${CUDA_TOOLKIT_ROOT_DIR}/lib/x64"  
"${CUDA_TOOLKIT_ROOT_DIR}/lib/win32")
     list(APPEND mxnet_LINKER_LIBS "${CUDA_cufft_LIBRARY}/../cufft.lib") # For 
fft operator
+    FIND_LIBRARY(CUDA_cusolver_LIBRARY nvrtc 
"${CUDA_TOOLKIT_ROOT_DIR}/lib/x64"  "${CUDA_TOOLKIT_ROOT_DIR}/lib/win32")
+    list(APPEND mxnet_LINKER_LIBS "${CUDA_cusolver_LIBRARY}/../cusolver.lib") 
# For cusolver 
   else(MSVC)
-    list(APPEND mxnet_LINKER_LIBS nvrtc cuda cufft)
+    list(APPEND mxnet_LINKER_LIBS nvrtc cuda cufft cusolver)
     link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64")
   endif()
   list(APPEND SOURCE ${cuda_objs} ${CUDA})
diff --git a/Makefile b/Makefile
index 560b77a..33151e5 100644
--- a/Makefile
+++ b/Makefile
@@ -23,6 +23,10 @@ ifndef DLPACK_PATH
        DLPACK_PATH = $(ROOTDIR)/dlpack
 endif
 
+ifndef AMALGAMATION_PATH
+       AMALGAMATION_PATH = $(ROOTDIR)/amalgamation
+endif
+
 ifneq ($(USE_OPENMP), 1)
        export NO_OPENMP = 1
 endif
@@ -439,6 +443,7 @@ clean: cyclean $(EXTRA_PACKAGES_CLEAN)
        cd $(DMLC_CORE); $(MAKE) clean; cd -
        cd $(PS_PATH); $(MAKE) clean; cd -
        cd $(NNVM_PATH); $(MAKE) clean; cd -
+       cd $(AMALGAMATION_PATH); $(MAKE) clean; cd -
        $(RM) -r  $(patsubst %, %/*.d, $(EXTRA_OPERATORS)) $(patsubst %, 
%/*/*.d, $(EXTRA_OPERATORS))
        $(RM) -r  $(patsubst %, %/*.o, $(EXTRA_OPERATORS)) $(patsubst %, 
%/*/*.o, $(EXTRA_OPERATORS))
 else
@@ -448,6 +453,7 @@ clean: cyclean testclean $(EXTRA_PACKAGES_CLEAN)
        cd $(DMLC_CORE); $(MAKE) clean; cd -
        cd $(PS_PATH); $(MAKE) clean; cd -
        cd $(NNVM_PATH); $(MAKE) clean; cd -
+       cd $(AMALGAMATION_PATH); $(MAKE) clean; cd -
 endif
 
 clean_all: clean
diff --git a/include/mxnet/base.h b/include/mxnet/base.h
index 514bb0c..6954083 100644
--- a/include/mxnet/base.h
+++ b/include/mxnet/base.h
@@ -56,6 +56,13 @@
 #define MXNET_USE_CUDNN MSHADOW_USE_CUDNN
 #endif
 
+/*!
+ *\brief whether to use cusolver library
+ */
+#ifndef MXNET_USE_CUSOLVER
+#define MXNET_USE_CUSOLVER MSHADOW_USE_CUSOLVER
+#endif
+
 /*! \brief Error message for using gpu when MXNET_USE_CUDA==0 */
 #define MXNET_GPU_NOT_ENABLED_ERROR  "GPU is not enabled"
 
diff --git a/mshadow b/mshadow
index d32b5da..497eb91 160000
--- a/mshadow
+++ b/mshadow
@@ -1 +1 @@
-Subproject commit d32b5dacf2bb5af4121df5fd60eb7775704f9131
+Subproject commit 497eb9180b24592b7332e7e08f2c053ec5346524
diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h
index 2879ab3..8897007 100644
--- a/src/common/cuda_utils.h
+++ b/src/common/cuda_utils.h
@@ -88,6 +88,35 @@ inline const char* CublasGetErrorString(cublasStatus_t 
error) {
 }
 
 /*!
+ * \brief Get string representation of cuSOLVER errors.
+ * \param error The error.
+ * \return String representation.
+ */
+inline const char* CusolverGetErrorString(cusolverStatus_t error) {
+  switch (error) {
+  case CUSOLVER_STATUS_SUCCESS:
+    return "CUSOLVER_STATUS_SUCCESS";
+  case CUSOLVER_STATUS_NOT_INITIALIZED:
+    return "CUSOLVER_STATUS_NOT_INITIALIZED";
+  case CUSOLVER_STATUS_ALLOC_FAILED:
+    return "CUSOLVER_STATUS_ALLOC_FAILED";
+  case CUSOLVER_STATUS_INVALID_VALUE:
+    return "CUSOLVER_STATUS_INVALID_VALUE";
+  case CUSOLVER_STATUS_ARCH_MISMATCH:
+    return "CUSOLVER_STATUS_ARCH_MISMATCH";
+  case CUSOLVER_STATUS_EXECUTION_FAILED:
+    return "CUSOLVER_STATUS_EXECUTION_FAILED";
+  case CUSOLVER_STATUS_INTERNAL_ERROR:
+    return "CUSOLVER_STATUS_INTERNAL_ERROR";
+  case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
+    return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
+  default:
+    break;
+  }
+  return "Unknown cuSOLVER status";
+}
+
+/*!
  * \brief Get string representation of cuRAND errors.
  * \param status The status.
  * \return String representation.
@@ -165,6 +194,19 @@ inline const char* CurandGetErrorString(curandStatus_t 
status) {
   }
 
 /*!
+ * \brief Protected cuSolver call.
+ * \param func Expression to call.
+ *
+ * It checks for cuSolver errors after invocation of the expression.
+ */
+#define CUSOLVER_CALL(func)                                         \
+  {                                                                 \
+    cusolverStatus_t e = (func);                                    \
+    CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS)                            \
+        << "cuSolver: " << common::cuda::CusolverGetErrorString(e); \
+  }
+
+/*!
  * \brief Protected cuRAND call.
  * \param func Expression to call.
  *
diff --git a/src/io/inst_vector.h b/src/io/inst_vector.h
index 4bc2a6c..6dc7bdf 100644
--- a/src/io/inst_vector.h
+++ b/src/io/inst_vector.h
@@ -30,6 +30,7 @@
 #include <mxnet/base.h>
 #include <dmlc/base.h>
 #include <mshadow/tensor.h>
+#include <mshadow/tensor_blob.h>
 #include <vector>
 #include <string>
 
diff --git a/include/mxnet/c_lapack_api.h b/src/operator/c_lapack_api.h
similarity index 74%
rename from include/mxnet/c_lapack_api.h
rename to src/operator/c_lapack_api.h
index 1ae90a9..96a9b3a 100644
--- a/include/mxnet/c_lapack_api.h
+++ b/src/operator/c_lapack_api.h
@@ -19,14 +19,24 @@
 
 /*!
  * \file c_lapack_api.h
- * \brief Unified interface for LAPACK calls from within mxnet.
+ * \brief Unified interface for CPU-based LAPACK calls.
  *  Purpose is to hide the platform specific differences.
  */
-#ifndef MXNET_C_LAPACK_API_H_
-#define MXNET_C_LAPACK_API_H_
+#ifndef MXNET_OPERATOR_C_LAPACK_API_H_
+#define MXNET_OPERATOR_C_LAPACK_API_H_
 
 // Manually maintained list of LAPACK interfaces that can be used
 // within MXNET. Conventions:
+//    - We should only import LAPACK-functions that are useful and
+//      ensure that we support them most efficiently on CPU/GPU. As an
+//      example take "potrs": It can be emulated by two calls to
+//      "trsm" (from BLAS3) so not really needed from functionality point
+//      of view. In addition, trsm on GPU supports batch-mode processing
+//      which is much more efficient for a bunch of smaller matrices while
+//      there is no such batch support for potrs. As a result, we may
+//      not support "potrs" internally and if we want to expose it to the user 
as
+//      a convenience operator at some time, then we may implement it 
internally
+//      as a sequence of trsm.
 //    - Interfaces must be compliant with lapacke.h in terms of signature and
 //      naming conventions so wrapping a function "foo" which has the
 //      signature
@@ -36,14 +46,21 @@
 //      Note that function signatures in lapacke.h will always have as first
 //      argument the storage order (row/col-major). All wrappers have to 
support
 //      that argument. The underlying fortran functions will always assume a
-//      column-major layout. It is the responsibility of the wrapper function
-//      to handle the (usual) case that it is called with data in row-major
-//      format, either by doing appropriate transpositions explicitly or using
-//      transposition options of the underlying fortran function.
-//    - It is ok to assume that matrices are stored in contiguous memory
-//      (which removes the need to do special handling for lda/ldb parameters
-//      and enables us to save additional matrix transpositions around
-//      the fortran calls).
+//      column-major layout.
+//    - In the (usual) case that a wrapper is called specifying row-major 
storage
+//      order of input/output data, there are two ways to handle this:
+//        1) The wrapper may support this without allocating any additional 
memory
+//           for example by exploiting the fact that a matrix is symmetric and 
switching
+//           certain flags (upper/lower triangular) when calling the fortran 
code.
+//        2) The wrapper may cause a runtime error. In that case it should be 
clearly
+//           documented that these functions do only support col-major layout.
+//      Rationale: This is a low level interface that is not expected to be 
called
+//      directly from many upstream functions. Usually all calls should go 
through
+//      the tensor-based interfaces in linalg.h which simplify calls to lapack 
further
+//      and are better suited to handle additional transpositions that may be 
necessary.
+//      Also we want to push allocation of temporary storage higher up in 
order to
+//      allow more efficient re-use of temporal storage. And don't want to 
plaster
+//      these interfaces here with additional requirements of providing 
buffers.
 //    - It is desired to add some basic checking in the C++-wrappers in order
 //      to catch simple mistakes when calling these wrappers.
 //    - Must support compilation without lapack-package but issue runtime 
error in this case.
@@ -54,9 +71,10 @@
 using namespace mshadow;
 
 extern "C" {
+
   // Fortran signatures
   #define MXNET_LAPACK_FSIGNATURE1(func, dtype) \
-    void func##_(char* uplo, int* n, dtype* a, int* lda, int *info);
+    void func##_(char *uplo, int *n, dtype *a, int *lda, int *info);
 
   MXNET_LAPACK_FSIGNATURE1(spotrf, float)
   MXNET_LAPACK_FSIGNATURE1(dpotrf, double)
@@ -73,9 +91,6 @@ extern "C" {
 #define MXNET_LAPACK_ROW_MAJOR 101
 #define MXNET_LAPACK_COL_MAJOR 102
 
-#define CHECK_LAPACK_CONTIGUOUS(a, b) \
-  CHECK_EQ(a, b) << "non contiguous memory for array in lapack call";
-
 #define CHECK_LAPACK_UPLO(a) \
   CHECK(a == 'U' || a == 'L') << "neither L nor U specified as triangle in 
lapack call";
 
@@ -117,9 +132,9 @@ inline void flip<cpu, double>(int m, int n,
 
 #if MXNET_USE_LAPACK
 
+  // These functions can be called with either row- or col-major format.
   #define MXNET_LAPACK_CWRAPPER1(func, dtype) \
-  inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype* 
a, int lda ) { \
-    CHECK_LAPACK_CONTIGUOUS(n, lda); \
+  inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype 
*a, int lda) { \
     CHECK_LAPACK_UPLO(uplo); \
     char o(loup(uplo, (matrix_layout == MXNET_LAPACK_ROW_MAJOR))); \
     int ret(0); \
@@ -172,7 +187,7 @@ inline void flip<cpu, double>(int m, int n,
 
   // Define compilable stubs.
   #define MXNET_LAPACK_CWRAPPER1(func, dtype) \
-  inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype* 
a, int lda ) { \
+  inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype* 
a, int lda) { \
     LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not 
available."; \
     return 1; \
   }
@@ -209,4 +224,4 @@ inline int MXNET_LAPACK_posv<double>(int matrix_layout, 
char uplo, int n,
   return mxnet_lapack_dposv(matrix_layout, uplo, n, nrhs, a, lda, b, ldb);
 }
 
-#endif  // MXNET_C_LAPACK_API_H_
+#endif  // MXNET_OPERATOR_C_LAPACK_API_H_
diff --git a/src/operator/contrib/krprod.h b/src/operator/contrib/krprod.h
index 6ce94c6..a54ece7 100644
--- a/src/operator/contrib/krprod.h
+++ b/src/operator/contrib/krprod.h
@@ -26,7 +26,7 @@
 #define MXNET_OPERATOR_CONTRIB_KRPROD_H_
 #include <vector>
 #include "mshadow/tensor.h"
-#include "mxnet/c_lapack_api.h"
+#include "../c_lapack_api.h"
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/linalg.h b/src/operator/linalg.h
new file mode 100644
index 0000000..9284a58
--- /dev/null
+++ b/src/operator/linalg.h
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file linalg.h
+ * \brief Unified tensor interface for advanced linear algebra functions
+ * (specifically BLAS3/LAPACK) from within mxnet.
+ */
+#ifndef MXNET_OPERATOR_LINALG_H_
+#define MXNET_OPERATOR_LINALG_H_
+
+#include <mshadow/tensor.h>
+#include "./c_lapack_api.h"
+using namespace mshadow;
+
+// The purpose of this header is to expose the interfaces of the advanced
+// linear algebra functions without clutter by the implementations. In contrast
+// to the implementations in linalg_inline.h, no macros are used to generate
+// similar functions that just differ by name/type in order to improve 
readability.
+//
+// Guidelines for extensions:
+// For any type of computation the following should be provided at minimum:
+//   - 1 templated function supporting cpu/gpu float/double in non-batch mode
+//   - 1 templated function supporting cpu/gpu float/double in batch mode
+// Naming conventions:
+//   - linalg_<func>()
+//   - linalg_batch_<func>()
+// Signatures of CPU/GPU versions should be equivalent whenever possible 
including
+// that a stream is supplied to the cpu-versions as (optional) last argument.
+// The batched versions all work on tensors with one more dimension as the
+// non-batched ones and the first/highest dimension iterates over the elements
+// within the batch.
+
+//////////////////////////////// GEMM 
////////////////////////////////////////////
+
+// CPU/GPU-versions of BLAS3 function "gemm". Please refer to the 
BLAS3-documentation
+// for further information about the function and its parameters.
+// Note that this is C = gemm(A,B,C), so C is input and output parameter.
+template<typename xpu, typename DType>
+void linalg_gemm(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& 
B,
+                 const Tensor<xpu, 2, DType>& C, DType alpha, DType beta,
+                 bool tA, bool tB, Stream<xpu> *s = 0);
+
+template<typename xpu, typename DType>
+void linalg_batch_gemm(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, 
DType>& B,
+                       const Tensor<xpu, 3, DType>& C, DType alpha, DType beta,
+                       bool tA, bool tB, Stream<xpu> *s = 0);
+
+//////////////////////////////// TRSM 
////////////////////////////////////////////
+
+// CPU/GPU-versions of BLAS3 function "trsm". Please refer to the 
BLAS3-documentation
+// for further information about the function and its parameters.
+// Note that this is B = trsm(A,B), so B is input and output parameter.
+template<typename xpu, typename DType>
+void linalg_trsm(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& 
B,
+                 DType alpha, bool rightside, bool lower, bool transpose, 
Stream<xpu> *s = 0);
+
+template<typename xpu, typename DType>
+inline void linalg_batch_trsm(const Tensor<xpu, 3, DType>& A, const 
Tensor<xpu, 3, DType>& B,
+                   DType alpha, bool rightside, bool lower, bool transpose, 
Stream<xpu> *s = 0);
+
+//////////////////////////////// TRMM 
////////////////////////////////////////////
+
+// CPU/GPU-versions of BLAS3 function "trmm". Please refer to the 
BLAS3-documentation
+// for further information about the function and its parameters.
+// Note that this is B = trmm(A,B), so B is input and output parameter.
+
+template<typename xpu, typename DType>
+void linalg_trmm(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& 
B,
+                 DType alpha, bool rightside, bool lower, bool transpose, 
Stream<xpu> *s = 0);
+
+template<typename xpu, typename DType>
+void linalg_batch_trmm(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, 
DType>& B,
+                    DType alpha, bool rightside, bool lower, bool transpose, 
Stream<xpu> *s = 0);
+
+//////////////////////////////// POTRF 
////////////////////////////////////////////
+
+// CPU/GPU-versions of LAPACK function "potrf". Please refer to the 
LAPACK-documentation
+// for further information about the function and its parameters.
+// Note that this is A = potrf(A), so A is input and output parameter.
+
+template<typename xpu, typename DType>
+void linalg_potrf(const Tensor<xpu, 2, DType>& A, bool lower, Stream<xpu> *s = 
0);
+
+template<typename xpu, typename DType>
+void linalg_batch_potrf(const Tensor<xpu, 3, DType>& A, bool lower, 
Stream<xpu> *s = 0);
+
+//////////////////////////////// POTRI 
////////////////////////////////////////////
+
+// CPU/GPU-versions of LAPACK function "potri". Please refer to the 
LAPACK-documentation
+// for further information about the function and its parameters.
+// Note that this is A = potri(A), so A is input and output parameter.
+
+template<typename xpu, typename DType>
+void linalg_potri(const Tensor<xpu, 2, DType>& A, bool lower, Stream<xpu> *s = 
0);
+
+template<typename xpu, typename DType>
+void linalg_batch_potri(const Tensor<xpu, 3, DType>& A, bool lower, 
Stream<xpu> *s = 0);
+
+#include "linalg_impl.h"
+
+#endif  // MXNET_OPERATOR_LINALG_H_
diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h
new file mode 100644
index 0000000..affa794
--- /dev/null
+++ b/src/operator/linalg_impl.h
@@ -0,0 +1,508 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file linalg.h
+ * \brief Implementation of unified tensor interface for advanced linear 
algebra functions
+ * (specifically BLAS3/LAPACK) from within mxnet.
+ */
+#ifndef MXNET_OPERATOR_LINALG_IMPL_H_
+#define MXNET_OPERATOR_LINALG_IMPL_H_
+
+#include <algorithm>
+
+// Convenience functions.
+inline void linalg_check_batch_size(int A, int B, int C) {
+  CHECK_EQ(A, B) << "Inconsistent batch size between arguments to linear 
algebra operator";
+  CHECK_EQ(A, C) << "Inconsistent batch size between arguments to linear 
algebra operator";
+  CHECK_GT(A, 0) << "Zero batch size for arguments to linear algebra operator";
+}
+
+//////////////////////////////// GEMM 
////////////////////////////////////////////
+
+// CPU/GPU-versions of BLAS3 function "gemm". Please refer to the 
BLAS3-documentation
+// for further information about the function and its parameters.
+// Note that this is C = gemm(A,B,C), so C is input and output parameter.
+
+template<typename xpu, typename DType>
+inline void check_gemm(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, 
DType>& B,
+                const Tensor<xpu, 2, DType>& C, DType alpha, DType beta, bool 
tA, bool tB) {
+  // Any checking that helps user debug potential problems.
+  CHECK_EQ((tA ? A.size(1) : A.size(0)), C.size(0))
+    << "Non compatible matrix dimensions between inputs A and C for gemm";
+  CHECK_EQ((tB ? B.size(0) : B.size(1)), C.size(1))
+    << "Non compatible matrix dimensions between inputs B and C for gemm";
+  CHECK_EQ((tA ? A.size(0) : A.size(1)), (tB ? B.size(1) : B.size(0)))
+    << "Non compatible matrix dimensions between inputs A and B for gemm";
+}
+
+#define LINALG_CPU_GEMM(fname, DType) \
+template<> inline \
+void linalg_gemm<cpu, DType>(const Tensor<cpu, 2, DType>& A, const Tensor<cpu, 
2, DType>& B, \
+                             const Tensor<cpu, 2, DType>& C, DType alpha, 
DType beta, \
+                             bool tA, bool tB, Stream<cpu> *s) { \
+  check_gemm(A, B, C, alpha, beta, tA, tB); \
+  cblas_##fname(CblasRowMajor, (tA ? CblasTrans : CblasNoTrans), (tB ? 
CblasTrans : CblasNoTrans), \
+                C.size(0), C.size(1), (tA ? A.size(0) : A.size(1)), alpha, \
+                A.dptr_, A.stride_, B.dptr_, B.stride_, beta, C.dptr_, 
C.stride_); \
+}
+LINALG_CPU_GEMM(sgemm, float)
+LINALG_CPU_GEMM(dgemm, double)
+
+#define LINALG_CPU_BATCH_GEMM(DType) \
+template<> inline \
+void linalg_batch_gemm<cpu, DType>(const Tensor<cpu, 3, DType>& A, const 
Tensor<cpu, 3, DType>& B, \
+                                   const Tensor<cpu, 3, DType>& C, DType 
alpha, DType beta, \
+                                   bool tA, bool tB, Stream<cpu> *s) { \
+  linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
+  for (index_t i = 0; i < A.size(0); ++i) { \
+    linalg_gemm(A[i], B[i], C[i], alpha, beta, tA, tB); \
+  } \
+}
+LINALG_CPU_BATCH_GEMM(float)
+LINALG_CPU_BATCH_GEMM(double)
+
+#ifdef __CUDACC__
+
+template<typename DType>
+__global__ void linalgCollectBatchOffsetsGPU(DType *a[], DType* b, int stride, 
int N) {
+    for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x 
* gridDim.x) {
+      a[i] = b + i * stride;
+    }
+}
+
+// cublas col-major processing accounted for by switching first two operands
+
+#define LINALG_GPU_GEMM(fname, DType) \
+template<> inline \
+void linalg_gemm<gpu, DType>(const Tensor<gpu, 2, DType>& A, const Tensor<gpu, 
2, DType>& B, \
+                             const Tensor<gpu, 2, DType>& C, DType alpha, 
DType beta, \
+                             bool tA, bool tB, Stream<gpu> *s) { \
+  using namespace mxnet; \
+  using mshadow::gpu; \
+  CHECK_NOTNULL(s); \
+  check_gemm(A, B, C, alpha, beta, tA, tB); \
+  CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
+                            (tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
+                            (tA ? CUBLAS_OP_T : CUBLAS_OP_N), \
+                            C.size(1), C.size(0), (tB ? B.size(1) : 
B.size(0)), \
+                            &alpha, B.dptr_, B.stride_, A.dptr_, A.stride_, \
+                            &beta, C.dptr_, C.stride_)) \
+}
+LINALG_GPU_GEMM(Sgemm, float)
+LINALG_GPU_GEMM(Dgemm, double)
+
+#define LINALG_GPU_BATCH_GEMM(fname, DType) \
+template<> inline \
+void linalg_batch_gemm<gpu, DType>(const Tensor<gpu, 3, DType>& A, const 
Tensor<gpu, 3, DType>& B, \
+                                   const Tensor<gpu, 3, DType>& C, DType 
alpha, DType beta, \
+                                   bool tA, bool tB, Stream<gpu> *s) { \
+  using namespace mxnet; \
+  using mshadow::gpu; \
+  CHECK_NOTNULL(s); \
+  linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
+  check_gemm(A[0], B[0], C[0], alpha, beta, tA, tB); \
+  Storage::Handle offsetsA, offsetsB, offsetsC; \
+  offsetsA = Storage::Get()->Alloc(sizeof(DType*)*A.size(0), Context::GPU()); \
+  offsetsB = Storage::Get()->Alloc(sizeof(DType*)*B.size(0), Context::GPU()); \
+  offsetsC = Storage::Get()->Alloc(sizeof(DType*)*C.size(0), Context::GPU()); \
+  using namespace mshadow::cuda; \
+  int ngrid = std::min(kMaxGridNum, \
+                       static_cast<int>((A.size(0) + kBaseThreadNum - 1) / 
kBaseThreadNum)); \
+  linalgCollectBatchOffsetsGPU<<<ngrid, kBaseThreadNum, 0, 
mshadow::Stream<gpu>::GetStream(s)>>> \
+    (static_cast<DType **>(offsetsA.dptr), A.dptr_, A.size(1)*A.stride_, 
A.size(0)); \
+  linalgCollectBatchOffsetsGPU<<<ngrid, kBaseThreadNum, 0, 
mshadow::Stream<gpu>::GetStream(s)>>> \
+    (static_cast<DType **>(offsetsB.dptr), B.dptr_, B.size(1)*B.stride_, 
B.size(0)); \
+  linalgCollectBatchOffsetsGPU<<<ngrid, kBaseThreadNum, 0, 
mshadow::Stream<gpu>::GetStream(s)>>> \
+    (static_cast<DType **>(offsetsC.dptr), C.dptr_, C.size(1)*C.stride_, 
C.size(0)); \
+  CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
+                            (tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
+                            (tA ? CUBLAS_OP_T : CUBLAS_OP_N), \
+                            C.size(2), C.size(1), (tB ? B.size(2) : 
B.size(1)), \
+                            &alpha, static_cast<const DType 
**>(offsetsB.dptr), B.stride_, \
+                            static_cast<const DType **>(offsetsA.dptr),  
A.stride_, \
+                            &beta, static_cast<DType **>(offsetsC.dptr), 
C.stride_, A.size(0))) \
+  Storage::Get()->Free(offsetsA); \
+  Storage::Get()->Free(offsetsB); \
+  Storage::Get()->Free(offsetsC); \
+}
+LINALG_GPU_BATCH_GEMM(SgemmBatched, float)
+LINALG_GPU_BATCH_GEMM(DgemmBatched, double)
+
+#endif
+
+//////////////////////////////// TRSM 
////////////////////////////////////////////
+
+// CPU/GPU-versions of BLAS3 function "trsm". Please refer to the 
BLAS3-documentation
+// for further information about the function and its parameters.
+// Note that this is B = trsm(A,B), so B is input and output parameter.
+
+template<typename xpu, typename DType>
+inline void check_trsm(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, 
DType>& B,
+                DType alpha, bool rightside, bool lower, bool transpose) {
+  // Any checking that helps user debug potential problems.
+  CHECK_EQ(A.size(0), A.size(1))
+    << "First input of trsm is not a square matrix.";
+  CHECK(!rightside || (B.size(1) == A.size(0)))
+    << "Non compatible matrix dimensions between inputs A and B for trsm";
+  CHECK(rightside || (B.size(0) == A.size(1)))
+    << "Non compatible matrix dimensions between inputs A and B for trsm";
+}
+
+#define LINALG_CPU_TRSM(fname, DType) \
+template<> inline \
+void linalg_trsm<cpu, DType>(const Tensor<cpu, 2, DType>& A, const Tensor<cpu, 
2, DType>& B, \
+                 DType alpha, bool rightside, bool lower, bool transpose, 
Stream<cpu> *s) { \
+  check_trsm(A, B, alpha, rightside, lower, transpose); \
+  cblas_##fname(CblasRowMajor, (rightside ? CblasRight : CblasLeft), \
+                (lower ? CblasLower : CblasUpper), (transpose ? CblasTrans : 
CblasNoTrans), \
+                CblasNonUnit, B.size(0), B.size(1), alpha, A.dptr_, \
+                A.stride_, B.dptr_, B.stride_); \
+}
+LINALG_CPU_TRSM(strsm, float)
+LINALG_CPU_TRSM(dtrsm, double)
+
+#define LINALG_CPU_BATCH_TRSM(DType) \
+template<> inline \
+void linalg_batch_trsm<cpu, DType>(const Tensor<cpu, 3, DType>& A, const 
Tensor<cpu, 3, DType>& B, \
+                   DType alpha, bool rightside, bool lower, bool transpose, 
Stream<cpu> *s) { \
+  linalg_check_batch_size(A.size(0), B.size(0), B.size(0)); \
+  for (index_t i = 0; i < A.size(0); ++i) { \
+    linalg_trsm(A[i], B[i], alpha, rightside, lower, transpose); \
+  } \
+}
+LINALG_CPU_BATCH_TRSM(float)
+LINALG_CPU_BATCH_TRSM(double)
+
+#ifdef __CUDACC__
+
+// cublas col-major processing accounted for by switching sides and fill mode
+
+#define LINALG_GPU_TRSM(fname, DType) \
+template<> inline \
+void linalg_trsm<gpu, DType>(const Tensor<gpu, 2, DType>& A, const Tensor<gpu, 
2, DType>& B, \
+                 DType alpha, bool rightside, bool lower, bool transpose, 
Stream<gpu> *s) { \
+  using namespace mxnet; \
+  using mshadow::gpu; \
+  CHECK_NOTNULL(s); \
+  check_trsm(A, B, alpha, rightside, lower, transpose); \
+  CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
+                            (rightside ? CUBLAS_SIDE_LEFT : 
CUBLAS_SIDE_RIGHT), \
+                            (lower ? CUBLAS_FILL_MODE_UPPER : 
CUBLAS_FILL_MODE_LOWER), \
+                            (transpose ? CUBLAS_OP_T : CUBLAS_OP_N), \
+                            CUBLAS_DIAG_NON_UNIT, B.size(1), B.size(0), 
&alpha, \
+                            A.dptr_, A.stride_, B.dptr_, B.stride_)); \
+}
+LINALG_GPU_TRSM(Strsm, float)
+LINALG_GPU_TRSM(Dtrsm, double)
+
+#define LINALG_GPU_BATCH_TRSM(fname, DType) \
+template<> inline \
+void linalg_batch_trsm<gpu, DType>(const Tensor<gpu, 3, DType>& A, const 
Tensor<gpu, 3, DType>& B, \
+                   DType alpha, bool rightside, bool lower, bool transpose, 
Stream<gpu> *s) { \
+  using namespace mxnet; \
+  using mshadow::gpu; \
+  CHECK_NOTNULL(s); \
+  linalg_check_batch_size(A.size(0), B.size(0), B.size(0)); \
+  check_trsm(A[0], B[0], alpha, rightside, lower, transpose); \
+  Storage::Handle offsetsA, offsetsB; \
+  offsetsA = Storage::Get()->Alloc(sizeof(DType*)*A.size(0), Context::GPU()); \
+  offsetsB = Storage::Get()->Alloc(sizeof(DType*)*B.size(0), Context::GPU()); \
+  using namespace mshadow::cuda; \
+  int ngrid = std::min(kMaxGridNum, \
+                       static_cast<int>((A.size(0) + kBaseThreadNum - 1) / 
kBaseThreadNum)); \
+  linalgCollectBatchOffsetsGPU<<<ngrid, kBaseThreadNum, 0, 
mshadow::Stream<gpu>::GetStream(s)>>> \
+    (static_cast<DType **>(offsetsA.dptr), A.dptr_, A.size(1)*A.stride_, 
A.size(0)); \
+  linalgCollectBatchOffsetsGPU<<<ngrid, kBaseThreadNum, 0, 
mshadow::Stream<gpu>::GetStream(s)>>> \
+    (static_cast<DType **>(offsetsB.dptr), B.dptr_, B.size(1)*B.stride_, 
A.size(0)); \
+  CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
+                            (rightside ? CUBLAS_SIDE_LEFT : 
CUBLAS_SIDE_RIGHT), \
+                            (lower ? CUBLAS_FILL_MODE_UPPER : 
CUBLAS_FILL_MODE_LOWER), \
+                            (transpose ? CUBLAS_OP_T : CUBLAS_OP_N), \
+                            CUBLAS_DIAG_NON_UNIT, B.size(2), B.size(1), 
&alpha, \
+                            static_cast<const DType **>(offsetsA.dptr), 
A.stride_, \
+                            static_cast<DType **>(offsetsB.dptr), B.stride_, 
A.size(0))); \
+  Storage::Get()->Free(offsetsA); \
+  Storage::Get()->Free(offsetsB); \
+}
+LINALG_GPU_BATCH_TRSM(StrsmBatched, float)
+LINALG_GPU_BATCH_TRSM(DtrsmBatched, double)
+
+#endif
+
+//////////////////////////////// TRMM 
////////////////////////////////////////////
+
+// CPU/GPU-versions of BLAS3 function "trmm". Please refer to the 
BLAS3-documentation
+// for further information about the function and its parameters.
+// Note that this is B = trmm(A,B), so B is input and output parameter.
+
+template<typename xpu, typename DType>
+inline void check_trmm(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, 
DType>& B,
+                DType alpha, bool rightside, bool lower, bool transpose) {
+  // Any checking that helps user debug potential problems.
+  CHECK_EQ(A.size(0), A.size(1))
+    << "First input of trmm is not a square matrix.";
+  CHECK(!rightside || (B.size(1) == A.size(0)))
+    << "Non compatible matrix dimensions between inputs A and B for trmm";
+  CHECK(rightside || (B.size(0) == A.size(1)))
+    << "Non compatible matrix dimensions between inputs A and B for trmm";
+}
+
+#define LINALG_CPU_TRMM(fname, DType) \
+template<> inline \
+void linalg_trmm<cpu, DType>(const Tensor<cpu, 2, DType>& A, const Tensor<cpu, 
2, DType>& B, \
+                 DType alpha, bool rightside, bool lower, bool transpose, 
Stream<cpu> *s) { \
+  check_trmm(A, B, alpha, rightside, lower, transpose); \
+  cblas_##fname(CblasRowMajor, (rightside ? CblasRight : CblasLeft), \
+                (lower ? CblasLower : CblasUpper), (transpose ? CblasTrans : 
CblasNoTrans), \
+                CblasNonUnit, B.size(0), B.size(1), alpha, A.dptr_, \
+                A.stride_, B.dptr_, B.stride_); \
+}
+LINALG_CPU_TRMM(strmm, float)
+LINALG_CPU_TRMM(dtrmm, double)
+
+#define LINALG_XPU_BATCH_TRMM(xpu, DType) \
+template<> inline \
+void linalg_batch_trmm<xpu, DType>(const Tensor<xpu, 3, DType>& A, const 
Tensor<xpu, 3, DType>& B, \
+                    DType alpha, bool rightside, bool lower, bool transpose, 
Stream<xpu> *s) { \
+  linalg_check_batch_size(A.size(0), B.size(0), B.size(0)); \
+  for (index_t i = 0; i < A.size(0); ++i) { \
+    linalg_trmm(A[i], B[i], alpha, rightside, lower, transpose, s); \
+  } \
+}
+LINALG_XPU_BATCH_TRMM(cpu, float)
+LINALG_XPU_BATCH_TRMM(cpu, double)
+
+#ifdef __CUDACC__
+
+// cublas col-major processing accounted for by switching sides and fill mode
+// doing in-place computation by supplying B as second and third matrix
+#define LINALG_GPU_TRMM(fname, DType) \
+template<> inline \
+void linalg_trmm<gpu, DType>(const Tensor<gpu, 2, DType>& A, const Tensor<gpu, 
2, DType>& B, \
+                 DType alpha, bool rightside, bool lower, bool transpose, 
Stream<gpu> *s) { \
+  using namespace mxnet; \
+  using mshadow::gpu; \
+  CHECK_NOTNULL(s); \
+  check_trmm(A, B, alpha, rightside, lower, transpose); \
+  CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
+                            (rightside ? CUBLAS_SIDE_LEFT : 
CUBLAS_SIDE_RIGHT), \
+                            (lower ? CUBLAS_FILL_MODE_UPPER : 
CUBLAS_FILL_MODE_LOWER), \
+                            (transpose ? CUBLAS_OP_T : CUBLAS_OP_N), \
+                            CUBLAS_DIAG_NON_UNIT, B.size(0), B.size(1), 
&alpha, \
+                            A.dptr_, A.stride_, B.dptr_, B.stride_, \
+                            B.dptr_, B.stride_)); \
+}
+LINALG_GPU_TRMM(Strmm, float)
+LINALG_GPU_TRMM(Dtrmm, double)
+
+LINALG_XPU_BATCH_TRMM(gpu, float)
+LINALG_XPU_BATCH_TRMM(gpu, double)
+
+#endif
+
+//////////////////////////////// POTRF 
////////////////////////////////////////////
+
+// CPU/GPU-versions of LAPACK function "potrf". Please refer to the 
LAPACK-documentation
+// for further information about the function and its parameters.
+// Note that this is A = potrf(A), so A is input and output parameter.
+
+template<typename xpu, typename DType>
+inline void check_potrf(const Tensor<xpu, 2, DType>& A, bool lower) {
+  // Any checking that helps user debug potential problems.
+  CHECK_EQ(A.size(0), A.size(1))
+    << "No square matrix as input to potrf.";
+}
+
+#define LINALG_CPU_POTRF(fname, DType) \
+template<> inline \
+void linalg_potrf<cpu, DType>(const Tensor<cpu, 2, DType>& A, bool lower, 
Stream<cpu> *s) { \
+  check_potrf(A, lower); \
+  int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_ROW_MAJOR, (lower ? 'L' : 'U'), 
A.size(0),  \
+          A.dptr_ , A.stride_)); \
+  CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \
+}
+LINALG_CPU_POTRF(spotrf, float)
+LINALG_CPU_POTRF(dpotrf, double)
+
+#define LINALG_CPU_BATCH_POTRF(DType) \
+template<> inline \
+void linalg_batch_potrf<cpu, DType>(const Tensor<cpu, 3, DType>& A, bool 
lower, Stream<cpu> *s) { \
+  for (index_t i = 0; i < A.size(0); ++i) { \
+    linalg_potrf(A[i], lower); \
+  } \
+}
+LINALG_CPU_BATCH_POTRF(float)
+LINALG_CPU_BATCH_POTRF(double)
+
+#if MXNET_USE_CUSOLVER == 1
+
+#define LINALG_GPU_BUFFSIZE_POTRF(fname, DType) \
+inline int linalg_potrf_buffsize(const Tensor<gpu, 2, DType>& A, bool lower, 
Stream<gpu> *s) { \
+  using namespace mxnet; \
+  using mshadow::gpu; \
+  CHECK_NOTNULL(s); \
+  int buffsize(0); \
+  CUSOLVER_CALL(cusolver##fname(Stream<gpu>::GetSolverHandle(s), \
+                                (lower ? CUBLAS_FILL_MODE_UPPER : 
CUBLAS_FILL_MODE_LOWER), \
+                                 A.size(0), A.dptr_, A.stride_, &buffsize)); \
+  return buffsize;  \
+}
+LINALG_GPU_BUFFSIZE_POTRF(DnSpotrf_bufferSize, float)
+LINALG_GPU_BUFFSIZE_POTRF(DnDpotrf_bufferSize, double)
+
+#define LINALG_GPU_POTRF(fname, DType) \
+template<> inline \
+void linalg_potrf<gpu, DType>(const Tensor<gpu, 2, DType>& A, bool lower, 
Stream<gpu> *s) { \
+  using namespace mxnet; \
+  using mshadow::gpu; \
+  CHECK_NOTNULL(s); \
+  check_potrf(A, lower); \
+  int buffsize(linalg_potrf_buffsize(A, lower, s)); \
+  Storage::Handle buffer = Storage::Get()->Alloc(sizeof(DType)*buffsize, 
Context::GPU()); \
+  Storage::Handle info = Storage::Get()->Alloc(sizeof(int), Context::GPU()); \
+  CUSOLVER_CALL(cusolver##fname(Stream<gpu>::GetSolverHandle(s), \
+                (lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER), \
+                A.size(0), A.dptr_, A.stride_, static_cast<DType 
*>(buffer.dptr), buffsize, \
+                static_cast<int *>(info.dptr))); \
+  Storage::Get()->Free(buffer); \
+  Storage::Get()->Free(info); \
+}
+LINALG_GPU_POTRF(DnSpotrf, float)
+LINALG_GPU_POTRF(DnDpotrf, double)
+
+#define LINALG_GPU_BATCH_POTRF(fname, DType) \
+template<> inline \
+void linalg_batch_potrf<gpu, DType>(const Tensor<gpu, 3, DType>& A, bool 
lower, Stream<gpu> *s) { \
+  using namespace mxnet; \
+  using mshadow::gpu; \
+  CHECK_NOTNULL(s); \
+  CHECK_GT(A.size(0), 0); \
+  check_potrf(A[0], lower); \
+  int buffsize(linalg_potrf_buffsize(A[0], lower, s)); \
+  Storage::Handle buffer = Storage::Get()->Alloc(sizeof(DType)*buffsize, 
Context::GPU()); \
+  Storage::Handle info = Storage::Get()->Alloc(sizeof(int), Context::GPU()); \
+  for (mshadow::index_t i = 0; i < A.size(0); ++i) { \
+    CUSOLVER_CALL(cusolver##fname(Stream<gpu>::GetSolverHandle(s), \
+                 (lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER), \
+                 A[i].size(0), A[i].dptr_, A[i].stride_, \
+                 static_cast<DType *>(buffer.dptr), buffsize, static_cast<int 
*>(info.dptr))); \
+  } \
+  Storage::Get()->Free(buffer); \
+  Storage::Get()->Free(info); \
+}
+LINALG_GPU_BATCH_POTRF(DnSpotrf, float)
+LINALG_GPU_BATCH_POTRF(DnDpotrf, double)
+
+#endif
+
+//////////////////////////////// POTRI 
////////////////////////////////////////////
+
+// CPU/GPU-versions of LAPACK function "potri". Please refer to the 
LAPACK-documentation
+// for further information about the function and its parameters.
+// Note that this is A = potri(A), so A is input and output parameter.
+
+template<typename xpu, typename DType>
+inline void check_potri(const Tensor<xpu, 2, DType>& A, bool lower) {
+  // Any checking that helps user debug potential problems.
+  CHECK_EQ(A.size(0), A.size(1)) << "No square matrix as input to potri.";
+}
+
+#define LINALG_CPU_POTRI(fname, DType) \
+template<> inline \
+void linalg_potri<cpu, DType>(const Tensor<cpu, 2, DType>& A, bool lower, 
Stream<cpu> *s) { \
+  check_potri(A, lower); \
+  int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_ROW_MAJOR, (lower ? 'L' : 'U'), 
A.size(0),  \
+          A.dptr_ , A.stride_)); \
+  CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \
+}
+LINALG_CPU_POTRI(spotri, float)
+LINALG_CPU_POTRI(dpotri, double)
+
+#define LINALG_CPU_BATCH_POTRI(DType) \
+template<> inline \
+void linalg_batch_potri<cpu, DType>(const Tensor<cpu, 3, DType>& A, bool 
lower, Stream<cpu> *s) { \
+  for (index_t i = 0; i < A.size(0); ++i) { \
+    linalg_potri(A[i], lower); \
+  } \
+}
+LINALG_CPU_BATCH_POTRI(float)
+LINALG_CPU_BATCH_POTRI(double)
+
+#ifdef __CUDACC__
+
+// Initializes multiple identity matrices on the same vector.
+template<typename DType>
+__global__ void linalgInitIdentityGPU(DType *a, int stride, int lda, int N) {
+    for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x 
* gridDim.x) {
+      // index relative to the matrix.
+      int index(i % stride);
+      a[i] = (index / lda == index % lda ? DType(1.0) : DType(0));
+    }
+}
+
+// There is no direct support for potri in cuda. We emulate the function by 
two calls to trsm.
+#define LINALG_GPU_POTRI(DType) \
+template<> inline \
+void linalg_potri<gpu, DType>(const Tensor<gpu, 2, DType>& A, bool lower, 
Stream<gpu> *s) { \
+  using namespace mxnet; \
+  CHECK_NOTNULL(s); \
+  check_potri(A, lower); \
+  Storage::Handle buffer = Storage::Get()->Alloc(sizeof(DType)*A.MSize(), 
Context::GPU()); \
+  using namespace mshadow::cuda; \
+  int ngrid = std::min(kMaxGridNum, \
+                       static_cast<int>((A.MSize() + kBaseThreadNum - 1) / 
kBaseThreadNum)); \
+  linalgInitIdentityGPU<<<ngrid, kBaseThreadNum, 0, 
mshadow::Stream<gpu>::GetStream(s)>>> \
+    (static_cast<DType *>(buffer.dptr), A.MSize(), A.stride_, A.MSize());  \
+  Tensor<gpu, 2, DType> B((DType *)buffer.dptr, A.shape_, A.stride_, s); \
+  linalg_trsm(A, B, DType(1.0), false, lower, !lower, s); \
+  linalg_trsm(A, B, DType(1.0), false, lower, lower, s); \
+  Copy(A, B, s); \
+  B.dptr_ = 0; \
+  Storage::Get()->Free(buffer); \
+}
+LINALG_GPU_POTRI(float)
+LINALG_GPU_POTRI(double)
+
+#define LINALG_GPU_BATCH_POTRI(DType) \
+template<> inline \
+void linalg_batch_potri<gpu, DType>(const Tensor<gpu, 3, DType>& A, bool 
lower, Stream<gpu> *s) { \
+  using namespace mxnet; \
+  CHECK_NOTNULL(s); \
+  CHECK_GT(A.size(0), 0); \
+  check_potri(A[0], lower); \
+  Storage::Handle buffer = Storage::Get()->Alloc(sizeof(DType)*A.MSize(), 
Context::GPU()); \
+  using namespace mshadow::cuda; \
+  int ngrid = std::min(kMaxGridNum, \
+                       static_cast<int>((A.MSize() + kBaseThreadNum - 1) / 
kBaseThreadNum)); \
+  linalgInitIdentityGPU<<<ngrid, kBaseThreadNum, 0, 
mshadow::Stream<gpu>::GetStream(s)>>> \
+    (static_cast<DType *>(buffer.dptr), A.size(1)*A.stride_, A.stride_, 
A.MSize()); \
+  Tensor<gpu, 3, DType> B((DType *)buffer.dptr, A.shape_, A.stride_, s); \
+  linalg_batch_trsm(A, B, DType(1.0), false, lower, !lower, s); \
+  linalg_batch_trsm(A, B, DType(1.0), false, lower, lower, s); \
+  Copy(A, B, s); \
+  B.dptr_ = 0; \
+  Storage::Get()->Free(buffer); \
+}
+LINALG_GPU_BATCH_POTRI(float)
+LINALG_GPU_BATCH_POTRI(double)
+
+#endif
+
+#endif  // MXNET_OPERATOR_LINALG_IMPL_H_
diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc
index 1b726ce..70d4f9b 100644
--- a/src/operator/tensor/la_op.cc
+++ b/src/operator/tensor/la_op.cc
@@ -401,7 +401,7 @@ Examples::
   { return std::vector<std::string>{"A"}; } )
 .set_attr<nnvm::FInferShape>("FInferShape", LaReduceShape<2>)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
-.set_attr<FCompute>("FCompute<cpu>", LaReduceForward<cpu, 2, sumlogdiag>)
+.set_attr<FCompute>("FCompute<cpu>", LaOpForward<cpu, 2, 0, 1, 1, sumlogdiag>)
 .set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_linalg_sumlogdiag"})
 .add_argument("A", "NDArray-or-Symbol", "Tensor of square matrices");
 
@@ -411,7 +411,7 @@ NNVM_REGISTER_OP(_backward_linalg_sumlogdiag)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
   { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
-.set_attr<FCompute>("FCompute<cpu>", LaReduceBackward<cpu, 2, 
sumlogdiag_backward>);
+.set_attr<FCompute>("FCompute<cpu>", LaOpBackward<cpu, 2, 2, 2, 1, 
sumlogdiag_backward>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu
new file mode 100644
index 0000000..a89d98f
--- /dev/null
+++ b/src/operator/tensor/la_op.cu
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file la_op.cu
+ * \brief GPU-Operators for advanced linear algebra.
+ */
+#include "./la_op.h"
+#include "./la_op_inline.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(linalg_gemm)
+.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 3, 1, gemm>);
+
+NNVM_REGISTER_OP(_backward_linalg_gemm)
+.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 4, 3, 
gemm_backward>);
+
+NNVM_REGISTER_OP(linalg_gemm2)
+.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 2, 1, gemm2>);
+
+NNVM_REGISTER_OP(_backward_linalg_gemm2)
+.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 3, 2, 
gemm2_backward>);
+
+NNVM_REGISTER_OP(linalg_trmm)
+.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 2, 1, trmm>);
+
+NNVM_REGISTER_OP(_backward_linalg_trmm)
+.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 4, 2, 
trmm_backward>);
+
+NNVM_REGISTER_OP(linalg_trsm)
+.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 2, 1, trsm>);
+
+NNVM_REGISTER_OP(_backward_linalg_trsm)
+.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 4, 2, 
trsm_backward>);
+
+NNVM_REGISTER_OP(linalg_sumlogdiag)
+.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 0, 1, 1, sumlogdiag>);
+
+NNVM_REGISTER_OP(_backward_linalg_sumlogdiag)
+.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 2, 1, 
sumlogdiag_backward>);
+
+NNVM_REGISTER_OP(linalg_potri)
+.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 1, 1, potri>);
+
+NNVM_REGISTER_OP(_backward_linalg_potri)
+.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 3, 1, 
potri_backward>);
+
+#if MXNET_USE_CUSOLVER == 1
+
+NNVM_REGISTER_OP(linalg_potrf)
+.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 1, 1, potrf>);
+
+NNVM_REGISTER_OP(_backward_linalg_potrf)
+.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 2, 1, 
potrf_backward>);
+
+#endif
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h
index 9779988..dd5fab9 100644
--- a/src/operator/tensor/la_op.h
+++ b/src/operator/tensor/la_op.h
@@ -91,9 +91,9 @@ struct LaTriangMatrixMultParam : public 
dmlc::Parameter<LaTriangMatrixMultParam>
 };
 
 // Common function for shape inference for matrix mult and matrix mac.
-bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs,
-                            std::vector<TShape>* in_attrs,
-                            std::vector<TShape>* out_attrs) {
+inline bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs,
+                                   std::vector<TShape>* in_attrs,
+                                   std::vector<TShape>* out_attrs) {
   CHECK_GE(in_attrs->size(), 2);
   CHECK_EQ(out_attrs->size(), 1);
   bool transpose_a(false), transpose_b(false);
@@ -132,9 +132,9 @@ bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs,
   return false;
 }
 
-bool LaTriangMatrixMultOpShape(const nnvm::NodeAttrs& attrs,
-                               std::vector<TShape>* in_attrs,
-                               std::vector<TShape>* out_attrs) {
+inline bool LaTriangMatrixMultOpShape(const nnvm::NodeAttrs& attrs,
+                                      std::vector<TShape>* in_attrs,
+                                      std::vector<TShape>* out_attrs) {
   const LaTriangMatrixMultParam& param = 
nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
   CHECK_EQ(in_attrs->size(), 2);
   CHECK_EQ(out_attrs->size(), 1);
@@ -192,9 +192,9 @@ bool LaTriangMatrixMultOpShape(const nnvm::NodeAttrs& attrs,
 }
 
 template<int dim>
-bool LaReduceShape(const nnvm::NodeAttrs& attrs,
-                   std::vector<TShape>* in_attrs,
-                   std::vector<TShape>* out_attrs) {
+inline bool LaReduceShape(const nnvm::NodeAttrs& attrs,
+                          std::vector<TShape>* in_attrs,
+                          std::vector<TShape>* out_attrs) {
   // Shape for reduction of the dim lowest dimensions to a scalar.
   // Can only deduct in forward direction.
   CHECK_EQ(in_attrs->size(), 1);
@@ -203,7 +203,8 @@ bool LaReduceShape(const nnvm::NodeAttrs& attrs,
   if ( ndim < dim ) {
      return false;
   }
-  std::vector<int> oshape(std::max(1, ndim-dim), 1);
+  std::vector<int> oshape(std::max(1, ndim-dim));
+  oshape[0] = 1;
   for ( int i = 0; i < ndim - dim; ++i ) {
     oshape[i] = (*in_attrs)[0][i];
   }
@@ -218,7 +219,6 @@ template<typename xpu, typename DType, int idim, int odim, 
int inum, int onum, t
 struct LaOpCaller {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
-                 const int index,
                  const nnvm::NodeAttrs& attrs,
                        mshadow::Stream<xpu> *s) {
     CHECK(false) << "no specialized LaOpCaller defined for template 
parameters";
@@ -228,86 +228,75 @@ template<typename xpu, typename DType, int idim, int 
odim, typename laop>
 struct LaOpCaller<xpu, DType, idim, odim, 1, 1, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
-                 const int index,
                  const nnvm::NodeAttrs& attrs,
                        mshadow::Stream<xpu> *s) {
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s)[index],
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s)[index], attrs);
+    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
+             outputs[0].FlatToKD<xpu, odim+1, DType>(s), s, attrs);
   }
 };
 template<typename xpu, typename DType, int idim, int odim, typename laop>
 struct LaOpCaller<xpu, DType, idim, odim, 2, 1, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
-                 const int index,
                  const nnvm::NodeAttrs& attrs,
                        mshadow::Stream<xpu> *s) {
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s)[index],
-             inputs[1].FlatToKD<xpu, idim+1, DType>(s)[index],
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s)[index],
-             attrs);
+    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
+             inputs[1].FlatToKD<xpu, idim+1, DType>(s),
+             outputs[0].FlatToKD<xpu, odim+1, DType>(s), s, attrs);
   }
 };
 template<typename xpu, typename DType, int idim, int odim, typename laop>
 struct LaOpCaller<xpu, DType, idim, odim, 3, 1, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
-                 const int index,
                  const nnvm::NodeAttrs& attrs,
                        mshadow::Stream<xpu> *s) {
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s)[index],
-             inputs[1].FlatToKD<xpu, idim+1, DType>(s)[index],
-             inputs[2].FlatToKD<xpu, idim+1, DType>(s)[index],
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s)[index],
-             attrs);
+    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
+             inputs[1].FlatToKD<xpu, idim+1, DType>(s),
+             inputs[2].FlatToKD<xpu, idim+1, DType>(s),
+             outputs[0].FlatToKD<xpu, odim+1, DType>(s), s, attrs);
   }
 };
 template<typename xpu, typename DType, int idim, int odim, typename laop>
 struct LaOpCaller<xpu, DType, idim, odim, 3, 2, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
-                 const int index,
                  const nnvm::NodeAttrs& attrs,
                        mshadow::Stream<xpu> *s) {
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s)[index],
-             inputs[1].FlatToKD<xpu, idim+1, DType>(s)[index],
-             inputs[2].FlatToKD<xpu, idim+1, DType>(s)[index],
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s)[index],
-             outputs[1].FlatToKD<xpu, odim+1, DType>(s)[index],
-             attrs);
+    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
+             inputs[1].FlatToKD<xpu, idim+1, DType>(s),
+             inputs[2].FlatToKD<xpu, idim+1, DType>(s),
+             outputs[0].FlatToKD<xpu, odim+1, DType>(s),
+             outputs[1].FlatToKD<xpu, odim+1, DType>(s), s, attrs);
   }
 };
 template<typename xpu, typename DType, int idim, int odim, typename laop>
 struct LaOpCaller<xpu, DType, idim, odim, 4, 2, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
-                 const int index,
                  const nnvm::NodeAttrs& attrs,
                        mshadow::Stream<xpu> *s) {
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s)[index],
-             inputs[1].FlatToKD<xpu, idim+1, DType>(s)[index],
-             inputs[2].FlatToKD<xpu, idim+1, DType>(s)[index],
-             inputs[3].FlatToKD<xpu, idim+1, DType>(s)[index],
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s)[index],
-             outputs[1].FlatToKD<xpu, odim+1, DType>(s)[index],
-             attrs);
+    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
+             inputs[1].FlatToKD<xpu, idim+1, DType>(s),
+             inputs[2].FlatToKD<xpu, idim+1, DType>(s),
+             inputs[3].FlatToKD<xpu, idim+1, DType>(s),
+             outputs[0].FlatToKD<xpu, odim+1, DType>(s),
+             outputs[1].FlatToKD<xpu, odim+1, DType>(s), s, attrs);
   }
 };
 template<typename xpu, typename DType, int idim, int odim, typename laop>
 struct LaOpCaller<xpu, DType, idim, odim, 4, 3, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
-                 const int index,
                  const nnvm::NodeAttrs& attrs,
                        mshadow::Stream<xpu> *s) {
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s)[index],
-             inputs[1].FlatToKD<xpu, idim+1, DType>(s)[index],
-             inputs[2].FlatToKD<xpu, idim+1, DType>(s)[index],
-             inputs[3].FlatToKD<xpu, idim+1, DType>(s)[index],
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s)[index],
-             outputs[1].FlatToKD<xpu, odim+1, DType>(s)[index],
-             outputs[2].FlatToKD<xpu, odim+1, DType>(s)[index],
-             attrs);
+    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
+             inputs[1].FlatToKD<xpu, idim+1, DType>(s),
+             inputs[2].FlatToKD<xpu, idim+1, DType>(s),
+             inputs[3].FlatToKD<xpu, idim+1, DType>(s),
+             outputs[0].FlatToKD<xpu, odim+1, DType>(s),
+             outputs[1].FlatToKD<xpu, odim+1, DType>(s),
+             outputs[2].FlatToKD<xpu, odim+1, DType>(s), s, attrs);
   }
 };
 
@@ -322,24 +311,8 @@ void LaOpForward(const nnvm::NodeAttrs& attrs,
   Stream<xpu> *s = ctx.get_stream<xpu>();
   CHECK_EQ(inputs.size(), inum);
   CHECK_EQ(outputs.size(), onum);
-  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
-    int N(-1);
-    for ( int i = 0; i < inum; ++i ) {
-      CHECK_EQ(inputs[i].CheckContiguous(), true);
-      const int M(inputs[i].FlatToKD<xpu, idim+1, OType>(s).size(0));
-      CHECK_EQ((N == -1 || N == M), true);
-      N = M;
-    }
-    for ( int i = 0; i < onum; ++i ) {
-      CHECK_EQ(outputs[i].CheckContiguous(), true);
-      CHECK_EQ((req[i] == kWriteTo || req[i] == kWriteInplace), true);
-      const int M(outputs[i].FlatToKD<xpu, odim+1, OType>(s).size(0));
-      CHECK_EQ((N == -1 || N == M), true);
-      N = M;
-    }
-    for ( int i = 0; i < N; ++i ) {
-      LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, 
outputs, i, attrs, s);
-    }
+  MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+    LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, outputs, 
attrs, s);
   });
 }
 
@@ -354,28 +327,15 @@ void LaOpBackward(const nnvm::NodeAttrs& attrs,
   Stream<xpu> *s = ctx.get_stream<xpu>();
   CHECK_EQ(inputs.size(), inum);
   CHECK_EQ(outputs.size(), onum);
-  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
-    int N(-1);
-    for ( int i = 0; i < inum; ++i ) {
-      CHECK_EQ(inputs[i].CheckContiguous(), true);
-      const int M(inputs[i].FlatToKD<xpu, idim+1, OType>(s).size(0));
-      CHECK_EQ((N == -1 || N == M), true);
-      N = M;
-    }
+  MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
     std::vector<TBlob> tspace(outputs);
     for ( int i = 0; i < onum; ++i ) {
-      CHECK_EQ(outputs[i].CheckContiguous(), true);
-      const int M(outputs[i].FlatToKD<xpu, odim+1, OType>(s).size(0));
-      CHECK_EQ((N == -1 || N == M), true);
-      N = M;
       if ( req[i] == kAddTo ) {
         tspace[i].dptr_ = ctx.requested[ResourceRequest::kTempSpace]
                              .get_space_typed<xpu, 1, 
OType>(Shape1(outputs[i].Size()), s).dptr_;
       }
     }
-    for ( int i = 0; i < N; ++i ) {
-      LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, tspace, 
i, attrs, s);
-    }
+    LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, tspace, 
attrs, s);
     for ( int i = 0; i < onum; ++i ) {
       if ( req[i] == kAddTo ) {
         Tensor<xpu, 1, OType> out = outputs[i].FlatTo1D<xpu, OType>(s);
@@ -385,53 +345,6 @@ void LaOpBackward(const nnvm::NodeAttrs& attrs,
   });
 }
 
-template<typename xpu, int idim, typename laop>
-void LaReduceForward(const nnvm::NodeAttrs& attrs,
-                     const OpContext& ctx,
-                     const std::vector<TBlob>& inputs,
-                     const std::vector<OpReqType>& req,
-                     const std::vector<TBlob>& outputs) {
-  using namespace mshadow;
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-  CHECK_EQ(inputs.size(), 1);
-  CHECK_EQ(outputs.size(), 1);
-  CHECK_EQ(inputs[0].CheckContiguous(), true);
-  CHECK_EQ(outputs[0].CheckContiguous(), true);
-  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
-    Tensor<xpu, idim+1, OType> in(inputs[0].FlatToKD<xpu, idim+1, OType>(s));
-    Tensor<xpu, 1, OType> out(outputs[0].FlatTo1D<xpu, OType>(s));
-    const int N(outputs[0].Size());
-    CHECK_EQ(in.size(0), N);
-    for ( int i = 0; i < N; ++i ) {
-      laop::op(in[i], out[i], attrs);
-    }
-  });
-}
-
-template<typename xpu, int idim, typename laop>
-void LaReduceBackward(const nnvm::NodeAttrs& attrs,
-                      const OpContext& ctx,
-                      const std::vector<TBlob>& inputs,
-                      const std::vector<OpReqType>& req,
-                      const std::vector<TBlob>& outputs) {
-  using namespace mshadow;
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-  CHECK_EQ(inputs.size(), 2);
-  CHECK_EQ(outputs.size(), 1);
-  CHECK_EQ(inputs[0].CheckContiguous(), true);
-  CHECK_EQ(inputs[1].CheckContiguous(), true);
-  CHECK_EQ(outputs[0].CheckContiguous(), true);
-  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
-    const int N(inputs[0].Size());
-    Tensor<xpu, 1, OType> in0(inputs[0].FlatTo1D<xpu, OType>(s));
-    Tensor<xpu, idim+1, OType> in1(inputs[1].FlatToKD<xpu, idim+1, OType>(s));
-    Tensor<xpu, idim+1, OType> out(outputs[0].FlatToKD<xpu, idim+1, OType>(s));
-    for ( int i = 0; i < N; ++i ) {
-      laop::op(in0[i], in1[i], out[i], attrs, (req[i] == kAddTo));
-    }
-  });
-}
-
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/tensor/la_op_inline.h 
b/src/operator/tensor/la_op_inline.h
index a032988..34fb441 100644
--- a/src/operator/tensor/la_op_inline.h
+++ b/src/operator/tensor/la_op_inline.h
@@ -24,244 +24,186 @@
 #ifndef MXNET_OPERATOR_TENSOR_LA_OP_INLINE_H_
 #define MXNET_OPERATOR_TENSOR_LA_OP_INLINE_H_
 
-#include <mxnet/c_lapack_api.h>
+#include "../linalg.h"
 
 namespace mxnet {
 namespace op {
 
 using namespace mshadow;
 
-#define LA_OP_NOT_AVAIL " operator can only be called with float/double data 
type."
-
-// Signature for single matrix operations (decomposition/inversion).
-#define FUNC_SIGNATURE_1(fname, arg1) 
{CHECK_EQ(MXNET_LAPACK_##fname(MXNET_LAPACK_ROW_MAJOR, 'L', \
-  arg1.size(0), arg1.dptr_, arg1.size(0)), 0) << "fname failed in lapack";}
-
-// Signature for matrix-matrix multiplications involving one diagonal matrix.
-#define FUNC_SIGNATURE_2(fname, arg1, arg2) \
-  { cblas_##fname(CblasRowMajor, (rightside ? CblasRight : CblasLeft), \
-                  CblasLower, (transpose ? CblasTrans : CblasNoTrans), \
-                  CblasNonUnit, arg2.size(0), arg2.size(1), alpha, arg1.dptr_, 
\
-                  (rightside ? arg2.size(1) : arg2.size(0)), arg2.dptr_, 
arg2.size(1)); }
-
-
 // Helper functions.
-template<typename DType>
-void CopyLowerToUpper(DType *dptr, int N)
-  { for (int i = 1; i < N; ++i ) for ( int j = 0; j < i; ++j ) dptr[j*N+i] = 
dptr[i*N+j]; }
-template<typename DType>
-void ZeroUpper(DType *dptr, int N)
-  { for (int i = 0; i < N; ++i ) for ( int j = i+1; j < N; ++j ) dptr[i*N+j] = 
0; }
+struct CopyLowerToUpper {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* 
data) {
+    // Below computation works even when we are dealing with a batch of 
matrices.
+    const int row((i % matrix_size) / stride), col(i % stride);
+    if ( row > col ) data[i + (col - row) * (stride - 1)] = data[i];
+  }
+};
+struct ZeroUpper {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* 
data) {
+    const int row((i % matrix_size) / stride), col(i % stride);
+    if ( row < col ) data[i] = 0;
+  }
+};
+struct Scale {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType scale, DType* data) {
+    data[i] *= scale;
+  }
+};
 
-// Forward operators
+// Forward computations (always using batched processing)
 
 // D = gemm(A,B,C)
 struct gemm {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& 
B,
-                 const Tensor<xpu, 2, DType>& C, DType alpha, DType beta, bool 
tA, bool tB)
-    { CHECK(false) << "gemm" << LA_OP_NOT_AVAIL; }
+  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& 
B,
+    const Tensor<xpu, 3, DType>& C, DType alpha, DType beta, bool tA, bool tB, 
Stream<xpu> *s) {
+    linalg_batch_gemm(A, B, C, alpha, beta, tA, tB, s);
+  }
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& 
B,
-                 const Tensor<xpu, 2, DType>& C, const Tensor<xpu, 2, DType>& 
D,
-                 const nnvm::NodeAttrs& attrs) {
-    if ( C.dptr_ != D.dptr_ ) Copy(D, C);
+  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& 
B,
+                 const Tensor<xpu, 3, DType>& C, const Tensor<xpu, 3, DType>& 
D,
+                 Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
+    if ( C.dptr_ != D.dptr_ ) Copy(D, C, s);
     const LaMatrixMacParam& param = nnvm::get<LaMatrixMacParam>(attrs.parsed);
-    gemm::op(A, B, D, DType(param.alpha), DType(param.beta), 
param.transpose_a, param.transpose_b);
+    gemm::op(A, B, D, DType(param.alpha), DType(param.beta),
+             param.transpose_a, param.transpose_b, s);
   }
 };
-template<>
-void gemm::op<cpu, float>(const Tensor<cpu, 2, float>& A, const Tensor<cpu, 2, 
float>& B,
-                          const Tensor<cpu, 2, float>& C,
-                          float alpha, float beta, bool tA, bool tB ) {
-  CHECK_EQ((tA ? A.size(1) : A.size(0)), C.size(0))
-    << "Non compatible matrix dimensions between inputs A and C for gemm 
operator";
-  CHECK_EQ((tB ? B.size(0) : B.size(1)), C.size(1))
-    << "Non compatible matrix dimensions between inputs B and C for gemm 
operator";
-  CHECK_EQ((tA ? A.size(0) : A.size(1)), (tB ? B.size(1) : B.size(0)))
-    << "Non compatible matrix dimensions between inputs A and B for gemm 
operator";
-  cblas_sgemm(CblasRowMajor, (tA ? CblasTrans : CblasNoTrans), (tB ? 
CblasTrans : CblasNoTrans),
-              (tA ? A.size(1):A.size(0)), (tB ? B.size(0): B.size(1)),
-              (tA ? A.size(0):A.size(1)), alpha, A.dptr_, A.size(1), B.dptr_, 
B.size(1),
-              beta, C.dptr_, (tB ? B.size(0): B.size(1)));
-}
-template<>
-void gemm::op<cpu, double>(const Tensor<cpu, 2, double>& A, const Tensor<cpu, 
2, double>& B,
-                           const Tensor<cpu, 2, double>& C,
-                           double alpha, double beta, bool tA, bool tB) {
-  CHECK_EQ((tA ? A.size(1) : A.size(0)), C.size(0))
-    << "Non compatible matrix dimensions between inputs A and C for gemm 
operator";
-  CHECK_EQ((tB ? B.size(0) : B.size(1)), C.size(1))
-    << "Non compatible matrix dimensions between inputs B and C for gemm 
operator";
-  CHECK_EQ((tA ? A.size(0) : A.size(1)), (tB ? B.size(1) : B.size(0)))
-    << "Non compatible matrix dimensions between inputs A and B for gemm 
operator";
-  cblas_dgemm(CblasRowMajor, (tA ? CblasTrans : CblasNoTrans), (tB ? 
CblasTrans : CblasNoTrans),
-              (tA ? A.size(1):A.size(0)), (tB ? B.size(0): B.size(1)),
-              (tA ? A.size(0):A.size(1)), alpha, A.dptr_, A.size(1), B.dptr_, 
B.size(1),
-              beta, C.dptr_, (tB ? B.size(0): B.size(1)));
-}
 
 // C = gemm2(A,B)
 struct gemm2 {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& 
B,
-                 const Tensor<xpu, 2, DType>& C, const nnvm::NodeAttrs& attrs) 
{
+  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& 
B,
+                 const Tensor<xpu, 3, DType>& C, Stream<xpu> *s, const 
nnvm::NodeAttrs& attrs) {
     const LaMatrixMultParam& param = 
nnvm::get<LaMatrixMultParam>(attrs.parsed);
-    gemm::op(A, B, C, DType(param.alpha), DType(0), param.transpose_a, 
param.transpose_b);
+    gemm::op(A, B, C, DType(param.alpha), DType(0), param.transpose_a, 
param.transpose_b, s);
   }
 };
 
 // L = potrf(A).
 struct potrf {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& 
L,
-                 const nnvm::NodeAttrs& attrs)
-    { CHECK(false) << "potrf" << LA_OP_NOT_AVAIL; }
+  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& 
L,
+                 Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
+    if ( A.dptr_ != L.dptr_ ) Copy(L, A, s);
+    linalg_batch_potrf(L, true, s);
+    using namespace mxnet_op;
+    Kernel<ZeroUpper, xpu>::Launch(s, L.MSize(), L.size(1)*L.stride_, 
L.stride_, L.dptr_);
+  }
 };
-template<>
-void potrf::op<cpu, float>(const Tensor<cpu, 2, float>& A, const Tensor<cpu, 
2, float>& L,
-                           const nnvm::NodeAttrs& attrs) {
-  if ( A.dptr_ != L.dptr_ ) Copy(L, A);
-  FUNC_SIGNATURE_1(spotrf, L);
-  ZeroUpper(L.dptr_, L.size(0));
-}
-template<>
-void potrf::op<cpu, double>(const Tensor<cpu, 2, double>& A, const Tensor<cpu, 
2, double>& L,
-                            const nnvm::NodeAttrs& attrs) {
-  if ( A.dptr_ != L.dptr_ ) Copy(L, A);
-  FUNC_SIGNATURE_1(dpotrf, L);
-  ZeroUpper(L.dptr_, L.size(0));
-}
 
 // A = potri(L).
 struct potri {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 2, DType>& L, const Tensor<xpu, 2, DType>& 
A,
-                 const nnvm::NodeAttrs& attrs)
-    { CHECK(false) << "potri" << LA_OP_NOT_AVAIL; }
+  static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& 
A,
+                 Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
+    if ( A.dptr_ != L.dptr_ ) Copy(A, L, s);
+    linalg_batch_potri(A, true, s);
+    using namespace mxnet_op;
+    Kernel<CopyLowerToUpper, xpu>::Launch(s, A.MSize(), A.size(1)*A.stride_, 
A.stride_, A.dptr_);
+  }
 };
-template<>
-void potri::op<cpu, float>(const Tensor<cpu, 2, float>& L, const Tensor<cpu, 
2, float>& A,
-                           const nnvm::NodeAttrs& attrs) {
-  if ( A.dptr_ != L.dptr_ ) Copy(A, L);
-  FUNC_SIGNATURE_1(spotri, A);
-  CopyLowerToUpper(A.dptr_, A.size(0));
-}
-template<>
-void potri::op<cpu, double>(const Tensor<cpu, 2, double>& A, const Tensor<cpu, 
2, double>& L,
-                            const nnvm::NodeAttrs& attrs) {
-  if ( A.dptr_ != L.dptr_ ) Copy(A, L);
-  FUNC_SIGNATURE_1(dpotri, A);
-  CopyLowerToUpper(A.dptr_, A.size(0));
-}
 
 // B = trsm(L,A)
 struct trsm {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 2, DType>& L, const Tensor<xpu, 2, DType>& 
B,
-                 DType alpha, bool rightside, bool transpose)
-    { CHECK(false) << "trsm" << LA_OP_NOT_AVAIL; }
+  static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& 
B,
+                 DType alpha, bool rightside, bool transpose, Stream<xpu> *s) {
+    linalg_batch_trsm(L, B, alpha, rightside, true, transpose, s);
+  }
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 2, DType>& L, const Tensor<xpu, 2, DType>& 
A,
-                 const Tensor<xpu, 2, DType>& B, const nnvm::NodeAttrs& attrs) 
{
-    if ( A.dptr_ != B.dptr_ ) Copy(B, A);
+  static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& 
A,
+                 const Tensor<xpu, 3, DType>& B,
+                 Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
+    if ( A.dptr_ != B.dptr_ ) Copy(B, A, s);
     const LaTriangMatrixMultParam& param = 
nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
-    op(L, B, DType(param.alpha), param.rightside, param.transpose);
+    op(L, B, DType(param.alpha), param.rightside, param.transpose, s);
   }
 };
-template<>
-void trsm::op<cpu, float>(const Tensor<cpu, 2, float>& L, const Tensor<cpu, 2, 
float>& B,
-                          float alpha, bool rightside, bool transpose) {
-  FUNC_SIGNATURE_2(strsm, L, B);
-}
-template<>
-void trsm::op<cpu, double>(const Tensor<cpu, 2, double>& L, const Tensor<cpu, 
2, double>& B,
-                           double alpha, bool rightside, bool transpose) {
-  FUNC_SIGNATURE_2(dtrsm, L, B);
-}
 
 // B = trmm(L,A)
 struct trmm {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 2, DType>& L, const Tensor<xpu, 2, DType>& 
B,
-                 DType alpha, bool rightside, bool transpose)
-    { CHECK(false) << "trmm" << LA_OP_NOT_AVAIL; }
+  static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& 
B,
+                 DType alpha, bool rightside, bool transpose, Stream<xpu> *s) {
+    linalg_batch_trmm(L, B, alpha, rightside, true, transpose, s);
+  }
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 2, DType>& L, const Tensor<xpu, 2, DType>& 
A,
-                 const Tensor<xpu, 2, DType>& B, const nnvm::NodeAttrs& attrs) 
{
-    if ( A.dptr_ != B.dptr_ ) Copy(B, A);
+  static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& 
A,
+                 const Tensor<xpu, 3, DType>& B, Stream<xpu> *s, const 
nnvm::NodeAttrs& attrs) {
+    if ( A.dptr_ != B.dptr_ ) Copy(B, A, s);
     const LaTriangMatrixMultParam& param = 
nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
-    op(L, B, DType(param.alpha), param.rightside, param.transpose);
+    op(L, B, DType(param.alpha), param.rightside, param.transpose, s);
   }
 };
-template<>
-void trmm::op<cpu, float>(const Tensor<cpu, 2, float>& L, const Tensor<cpu, 2, 
float>& B,
-                          float alpha, bool rightside, bool transpose) {
-  FUNC_SIGNATURE_2(strmm, L, B);
-}
-template<>
-void trmm::op<cpu, double>(const Tensor<cpu, 2, double>& L, const Tensor<cpu, 
2, double>& B,
-                           double alpha, bool rightside, bool transpose) {
-  FUNC_SIGNATURE_2(dtrmm, L, B);
-}
 
 // Useful operator that is not part of BLAS/LAPACK.
-struct sumlogdiag {
-  template<typename xpu, typename DType,
-           typename std::enable_if<!std::is_floating_point<DType>::value, 
int>::type = 0>
-  static void op(const Tensor<xpu, 2, DType>& A, DType& L, const 
nnvm::NodeAttrs& attrs)
-    { CHECK(false) << "sumlogdiag operator can only be called with 
float/double data type."; }
-  template<typename xpu, typename DType,
-           typename std::enable_if<std::is_floating_point<DType>::value, 
int>::type = 0>
-  static void op(const Tensor<xpu, 2, DType>& A, DType& B, const 
nnvm::NodeAttrs& attrs) {
-    CHECK_EQ(A.size(0), A.size(1)) << "sumlogdiag operator requires a NxN 
matrix as input.";
-    const int N(A.size(0));
+struct ForwardSumLogDiag {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, int N, int stride, DType* A, DType* 
B) {
     DType sum(0);
-    DType *p(A.dptr_);
-    for ( int i = 0; i < N; ++i, p += N+1 ) {
-      sum += log(*p);
+    const int offset(i * N * stride);
+    for ( int j = 0; j < N; ++j ) {
+      sum += log(A[offset+j*(stride+1)]);
     }
-    B = sum;
+    B[i] = sum;
+  }
+};
+struct sumlogdiag {
+  template<typename xpu, typename DType>
+  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 1, DType>& 
B,
+                 Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
+    CHECK_EQ(A.size(1), A.size(2)) << "sumlogdiag operator requires square 
matrices as input.";
+    using namespace mxnet_op;
+    Kernel<ForwardSumLogDiag, xpu>::Launch(s, A.size(0), A.size(1), A.stride_, 
A.dptr_, B.dptr_);
   }
 };
 
-// Backward operators
+// Backward operators (always using batch processing)
 
 struct gemm_backward {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 2, DType>& dD, const Tensor<xpu, 2, DType>& 
A,
-                 const Tensor<xpu, 2, DType>& B, const Tensor<xpu, 2, DType>& 
C,
-                 const Tensor<xpu, 2, DType>& dA, const Tensor<xpu, 2, DType>& 
dB,
-                 const Tensor<xpu, 2, DType>& dC, const nnvm::NodeAttrs& 
attrs) {
+  static void op(const Tensor<xpu, 3, DType>& dD, const Tensor<xpu, 3, DType>& 
A,
+                 const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& 
C,
+                 const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& 
dB,
+                 const Tensor<xpu, 3, DType>& dC,
+                 Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
     const LaMatrixMacParam& param = nnvm::get<LaMatrixMacParam>(attrs.parsed);
-    (param.transpose_a ? gemm::op(B, dD, dA, DType(param.alpha), DType(0), 
param.transpose_b, true)
-                  : gemm::op(dD, B, dA, DType(param.alpha), DType(0), false, 
!param.transpose_b));
-    (param.transpose_b ? gemm::op(dD, A, dB, DType(param.alpha), DType(0), 
true, param.transpose_a)
-                  : gemm::op(A, dD, dB, DType(param.alpha), DType(0), 
!param.transpose_a, false));
-    const int N(dC.size(0)*dC.size(1));
-    for ( int i = 0; i < N; ++i ) {
-      dC.dptr_[i] = param.beta * dD.dptr_[i];
-    }
+    bool tA(param.transpose_a), tB(param.transpose_b);
+    (tA ? gemm::op(B, dD, dA, DType(param.alpha), DType(0), tB, true, s)
+        : gemm::op(dD, B, dA, DType(param.alpha), DType(0), false, !tB, s));
+    (tB ? gemm::op(dD, A, dB, DType(param.alpha), DType(0), true, tA, s)
+        : gemm::op(A, dD, dB, DType(param.alpha), DType(0), !tA, false, s));
+    Copy(dC, dD, s);
+    using namespace mxnet_op;
+    Kernel<Scale, xpu>::Launch(s, dC.MSize(), DType(param.beta), dC.dptr_);
   }
 };
 
 struct gemm2_backward {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 2, DType>& dC, const Tensor<xpu, 2, DType>& 
A,
-                 const Tensor<xpu, 2, DType>& B, const Tensor<xpu, 2, DType>& 
dA,
-                 const Tensor<xpu, 2, DType>& dB, const nnvm::NodeAttrs& 
attrs) {
+  static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& 
A,
+                 const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& 
dA,
+                 const Tensor<xpu, 3, DType>& dB,
+                 Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
     const LaMatrixMultParam& param = 
nnvm::get<LaMatrixMultParam>(attrs.parsed);
-    (param.transpose_a ? gemm::op(B, dC, dA, DType(param.alpha), DType(0), 
param.transpose_b, true)
-                   : gemm::op(dC, B, dA, DType(param.alpha), DType(0), false, 
!param.transpose_b));
-    (param.transpose_b ? gemm::op(dC, A, dB, DType(param.alpha), DType(0), 
true, param.transpose_a)
-                   : gemm::op(A, dC, dB, DType(param.alpha), DType(0), 
!param.transpose_a, false));
+    bool tA(param.transpose_a), tB(param.transpose_b);
+    (tA ? gemm::op(B, dC, dA, DType(param.alpha), DType(0), tB, true, s)
+        : gemm::op(dC, B, dA, DType(param.alpha), DType(0), false, !tB, s));
+    (tB ? gemm::op(dC, A, dB, DType(param.alpha), DType(0), true, tA, s)
+        : gemm::op(A, dC, dB, DType(param.alpha), DType(0), !tA, false, s));
   }
 };
 
 struct potrf_backward {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 2, DType>& dL, const Tensor<xpu, 2, DType>& 
L,
-                 const Tensor<xpu, 2, DType>& dA, const nnvm::NodeAttrs& 
attrs) {
+  static void op(const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>& 
L,
+                 const Tensor<xpu, 3, DType>& dA,
+                 Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
     // Backward of L = potrf(A).
     // dA = 0.5 * L**T * symm(L**T * dL # E) * L**(-1) where
     //     '#' denotes Hadamard product
@@ -269,81 +211,96 @@ struct potrf_backward {
     //      symm(X) = 0.5 * (X + X**T)
     // Hadamard product and symm can be realized by a single copy from lower 
to upper triangle.
     if ( dL.dptr_ != dA.dptr_ ) {
-      Copy(dA, dL);
+      Copy(dA, dL, s);
     }
-    trmm::op(L, dA, DType(1.0), false, true);
-    CopyLowerToUpper(dA.dptr_, dA.size(0));
-    trsm::op(L, dA, DType(1.0), false, true);
-    trsm::op(L, dA, DType(0.5), true, false);
+    trmm::op(L, dA, DType(1.0), false, true, s);
+    using namespace mxnet_op;
+    Kernel<CopyLowerToUpper, xpu>::Launch
+           (s, dA.MSize(), dA.size(1)*dA.stride_, dA.stride_, dA.dptr_);
+    trsm::op(L, dA, DType(1.0), false, true, s);
+    trsm::op(L, dA, DType(0.5), true, false, s);
   }
 };
 
 struct potri_backward {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 2, DType>& dA, const Tensor<xpu, 2, DType>& 
L,
-                 const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& 
dL,
-                 const nnvm::NodeAttrs& attrs) {
+  static void op(const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& 
L,
+                 const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& 
dL,
+                 Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
     // Backward of A = potri(L).
     // dL = -2 * tril(A * dA * L**(-T)), where tril() extracts lower triangle 
and diagonal.
-    gemm::op(A, dA, dL, DType(1.0), DType(0), false, false);
-    trsm::op(L, dL, DType(-2.0), true, true);
-    ZeroUpper(dL.dptr_, dL.size(0));
+    gemm::op(A, dA, dL, DType(1.0), DType(0), false, false, s);
+    trsm::op(L, dL, DType(-2.0), true, true, s);
+    using namespace mxnet_op;
+    Kernel<ZeroUpper, xpu>::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, 
dL.stride_, dL.dptr_);
   }
 };
 
 struct trsm_backward {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 2, DType>& dB, const Tensor<xpu, 2, DType>& 
L,
-                 const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& 
B,
-                 const Tensor<xpu, 2, DType>& dL, const Tensor<xpu, 2, DType>& 
dA,
-                 const nnvm::NodeAttrs& attrs) {
+  static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& 
L,
+                 const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& 
B,
+                 const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>& 
dA,
+                 Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
     // Backward of B = trsm(L,A).
     const LaTriangMatrixMultParam& param = 
nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
     // Compute dA
-    if ( dA.dptr_ != dB.dptr_ ) Copy(dA, dB);
-    trsm::op(L, dA, DType(param.alpha), param.rightside, !param.transpose);
+    if ( dA.dptr_ != dB.dptr_ ) Copy(dA, dB, s);
+    trsm::op(L, dA, DType(param.alpha), param.rightside, !param.transpose, s);
     // Compute dL
     const bool da_left(param.rightside == param.transpose);
-    (da_left ?
-        gemm::op(dA, B, dL, DType(-1.0/param.alpha), DType(0), 
param.transpose, !param.transpose)
-      : gemm::op(B, dA, dL, DType(-1.0/param.alpha), DType(0), 
!param.transpose, param.transpose));
-    ZeroUpper(dL.dptr_, dL.size(0));
+    DType scale(-1.0/param.alpha);
+    (da_left ? gemm::op(dA, B, dL, scale, DType(0), param.transpose, 
!param.transpose, s)
+             : gemm::op(B, dA, dL, scale, DType(0), !param.transpose, 
param.transpose, s));
+    using namespace mxnet_op;
+    Kernel<ZeroUpper, xpu>::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, 
dL.stride_, dL.dptr_);
   }
 };
 
 struct trmm_backward {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 2, DType>& dB, const Tensor<xpu, 2, DType>& 
L,
-                 const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& 
B,
-                 const Tensor<xpu, 2, DType>& dL, const Tensor<xpu, 2, DType>& 
dA,
-                 const nnvm::NodeAttrs& attrs) {
+  static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& 
L,
+                 const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& 
B,
+                 const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>& 
dA,
+                 Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
     // Backward of B = trmm(L,A).
     const LaTriangMatrixMultParam& param = 
nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
     // Compute dL
     const bool db_left(param.rightside == param.transpose);
-    (db_left ? gemm::op(dB, A, dL, DType(param.alpha), DType(0), 
param.transpose, !param.transpose)
-           : gemm::op(A, dB, dL, DType(param.alpha), DType(0), 
!param.transpose, param.transpose));
-    ZeroUpper(dL.dptr_, dL.size(0));
+    DType scale(param.alpha);
+    (db_left ? gemm::op(dB, A, dL, scale, DType(0), param.transpose, 
!param.transpose, s)
+             : gemm::op(A, dB, dL, scale, DType(0), !param.transpose, 
param.transpose, s));
+    using namespace mxnet_op;
+    Kernel<ZeroUpper, xpu>::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, 
dL.stride_, dL.dptr_);
     // Compute dA
-    if ( dA.dptr_ != dB.dptr_ ) Copy(dA, dB);
-    trmm::op(L, dA, DType(param.alpha), param.rightside, !param.transpose);
+    if ( dA.dptr_ != dB.dptr_ ) Copy(dA, dB, s);
+    trmm::op(L, dA, scale, param.rightside, !param.transpose, s);
   }
 };
 
+struct BackwardSumLogDiag {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, int N, int stride, DType* dB, DType* 
A, DType* dA) {
+    const int offset(i * N * stride);
+    for ( int j = 0; j < N; ++j ) {
+      dA[offset+j*(stride+1)] = dB[i]/A[offset+j*(stride+1)];
+    }
+  }
+};
 struct sumlogdiag_backward {
   template<typename xpu, typename DType>
-  static void op(const DType& dB, const Tensor<xpu, 2, DType>& A, const 
Tensor<xpu, 2, DType>& dA,
-                 const nnvm::NodeAttrs& attrs, bool add) {
+  static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& 
A,
+                 const Tensor<xpu, 3, DType>& dA,
+                 Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
     // Backward of B = sumlogdiag(A).
-    const int N(A.size(0));
-    if ( !add ) {
-      for ( int i = 0; i < N*N; ++i ) {
-        dA.dptr_[i] = 0;
-      }
-    }
-    for ( int i = 0; i < N; ++i ) {
-      dA.dptr_[i*(N+1)] += dB / A.dptr_[i*N+i];
-    }
+    // dB is actually a 1-d tensor but we convert it to a 3-D one before 
calling
+    // this function as the LaOpCaller-adapters can only deal with a uniform
+    // dimension for all tensor inputs. This doesn't matter as we will 
interpret
+    // it correctly internally in this function.
+    using namespace mxnet_op;
+    Kernel<Scale, xpu>::Launch(s, dA.MSize(), DType(0), dA.dptr_);
+    Kernel<BackwardSumLogDiag, xpu>::Launch
+         (s, A.size(0), A.size(1), A.stride_, dB.dptr_, A.dptr_, dA.dptr_);
   }
 };
 
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 718e3df..7d56b46 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -3412,15 +3412,8 @@ def test_deformable_psroipooling():
 
 
 def test_laop():
-    return
-
-    # Currently no support for GPU. Will be added soon
-    # so keep these tests here in this file and activate
-    # gpu-testing when it is ready.
-    dev = default_context()
-    if dev.device_type == 'gpu':
-       return
 
+    # enable numerical checking of gradients
     grad_check = 1
 
     data1 = mx.symbol.Variable('data1')

-- 
To stop receiving notification emails like this one, please contact
['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].

Reply via email to