This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 4ac76c8  Support for axis parameter in linalg.gemm (#10864)
4ac76c8 is described below

commit 4ac76c89da6d4d8feef629949dc0f9534b216e3d
Author: moin <asmushet...@yahoo.de>
AuthorDate: Tue May 29 20:15:11 2018 +0200

    Support for axis parameter in linalg.gemm (#10864)
---
 src/operator/linalg.h                  |   7 +
 src/operator/linalg_impl.h             | 276 ++++++++++++++++++++++-----------
 src/operator/tensor/la_op.cc           |  39 ++++-
 src/operator/tensor/la_op.cu           |   8 +-
 src/operator/tensor/la_op.h            | 206 +++++++++++++++++-------
 src/operator/tensor/la_op_inline.h     |  66 ++++----
 tests/python/unittest/test_operator.py |  33 +++-
 7 files changed, 447 insertions(+), 188 deletions(-)

diff --git a/src/operator/linalg.h b/src/operator/linalg.h
index aee67d7..dc59400 100644
--- a/src/operator/linalg.h
+++ b/src/operator/linalg.h
@@ -64,6 +64,13 @@ void linalg_batch_gemm(const Tensor<xpu, 3, DType>& A, const 
Tensor<xpu, 3, DTyp
                        const Tensor<xpu, 3, DType>& C, DType alpha, DType beta,
                        bool tA, bool tB, Stream<xpu> *s = 0);
 
+// Version of batch gemmm where rows are indexed at axis 1 and columns at axis 
3.
+template<typename xpu, typename DType>
+void linalg_batch_gemm(const Tensor<xpu, 4, DType>& A, const Tensor<xpu, 4, 
DType>& B,
+                       const Tensor<xpu, 4, DType>& C, DType alpha, DType beta,
+                       bool tA, bool tB, Stream<xpu> *s = 0);
+
+
 template<typename xpu, typename DType>
 inline void linalg_gemm(const Tensor<xpu, 2, DType>& A,
                         const Tensor<xpu, 2, DType>& B,
diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h
index 151db60..08d2add 100644
--- a/src/operator/linalg_impl.h
+++ b/src/operator/linalg_impl.h
@@ -56,6 +56,11 @@ inline void check_gemm(const Tensor<xpu, 2, DType>& A, const 
Tensor<xpu, 2, DTyp
     << "Non compatible matrix dimensions between inputs A and B for gemm";
 }
 
+template<typename xpu, typename DType>
+void linalg_gemm_axis(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, 
DType>& B,
+                      const Tensor<xpu, 3, DType>& C, DType alpha, DType beta,
+                      bool tA, bool tB, Stream<xpu> *s = 0);
+
 #if (MSHADOW_USE_CBLAS == 1 || MSHADOW_USE_MKL == 1)
 
 #define LINALG_CPU_GEMM(fname, DType) \
@@ -80,6 +85,38 @@ void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 3, 
DType>& A, const Tensor<
   } \
 }
 
+// Batched gemm where the batch coordinate is given by the second axis.
+#define LINALG_CPU_GEMM_AXIS(fname, DType) \
+template<> inline \
+void linalg_gemm_axis<cpu, DType>(const Tensor<cpu, 3, DType>& A, const 
Tensor<cpu, 3, DType>& B, \
+                                  const Tensor<cpu, 3, DType>& C, DType alpha, 
DType beta, \
+                                  bool tA, bool tB, Stream<cpu> *s) { \
+  linalg_check_batch_size(A.size(1), B.size(1), C.size(1)); \
+  for (index_t i = 0; i < A.size(1); ++i) { \
+     cblas_##fname(CblasRowMajor, (tA ? CblasTrans : CblasNoTrans), \
+                   (tB ? CblasTrans : CblasNoTrans), \
+                   C.size(0), C.size(2), (tA ? A.size(0) : A.size(2)), alpha, \
+                   A.dptr_+i*A.stride_, A.size(1)*A.stride_, \
+                   B.dptr_+i*B.stride_, B.size(1)*B.stride_, beta, \
+                   C.dptr_+i*C.stride_, C.size(1)*C.stride_); \
+  } \
+}
+
+LINALG_CPU_GEMM_AXIS(sgemm, float)
+LINALG_CPU_GEMM_AXIS(dgemm, double)
+
+// Version where matrix rows are given by the second axis.
+#define LINALG_XPU_BATCH_GEMM_AXIS(xpu, DType) \
+template<> inline \
+void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 4, DType>& A, const 
Tensor<xpu, 4, DType>& B, \
+                                   const Tensor<xpu, 4, DType>& C, DType 
alpha, DType beta, \
+                                   bool tA, bool tB, Stream<xpu> *s) { \
+  linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
+  for (index_t i = 0; i < A.size(0); ++i) { \
+    linalg_gemm_axis(A[i], B[i], C[i], alpha, beta, tA, tB, s); \
+  } \
+}
+
 #else
 
 #define LINALG_CPU_GEMM(fname, DType) \
@@ -98,6 +135,14 @@ void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 3, 
DType>& A, const Tensor<
   LOG(FATAL) << "linalg_batch_gemm not implemented by mxnet for cpu, needs 
cblas!"; \
 }
 
+#define LINALG_XPU_BATCH_GEMM_AXIS(xpu, DType) \
+template<> inline \
+void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 4, DType>& A, const 
Tensor<xpu, 4, DType>& B, \
+                                   const Tensor<xpu, 4, DType>& C, DType 
alpha, DType beta, \
+                                   bool tA, bool tB, Stream<xpu> *s) { \
+  LOG(FATAL) << "linalg_batch_gemm not implemented by mxnet for cpu, needs 
cblas!"; \
+}
+
 #endif  // MSHADOW_USE_CBLAS == 1 || MSHADOW_USE_MKL == 1
 
 LINALG_CPU_GEMM(sgemm, float)
@@ -106,6 +151,9 @@ LINALG_CPU_GEMM(dgemm, double)
 LINALG_XPU_BATCH_GEMM(cpu, float)
 LINALG_XPU_BATCH_GEMM(cpu, double)
 
+LINALG_XPU_BATCH_GEMM_AXIS(cpu, float)
+LINALG_XPU_BATCH_GEMM_AXIS(cpu, double)
+
 // Specialization of linalg_gemm<cpu, DType> for DType=mshadow::half::half_t.
 template<> inline
 void linalg_gemm<cpu, mshadow::half::half_t>(const Tensor<cpu, 2, 
mshadow::half::half_t>& A,
@@ -140,6 +188,28 @@ void linalg_gemm<gpu, DType>(const Tensor<gpu, 2, DType>& 
A, const Tensor<gpu, 2
 LINALG_GPU_GEMM(Sgemm, float)
 LINALG_GPU_GEMM(Dgemm, double)
 
+// Version where matrix rows are given by first axis.
+#define LINALG_GPU_GEMM_AXIS(fname, DType) \
+template<> inline \
+void linalg_gemm_axis<gpu, DType>(const Tensor<gpu, 3, DType>& A, const 
Tensor<gpu, 3, DType>& B, \
+                                  const Tensor<gpu, 3, DType>& C, DType alpha, 
DType beta, \
+                                  bool tA, bool tB, Stream<gpu> *s) { \
+  using namespace mxnet; \
+  using mshadow::gpu; \
+  CHECK_NOTNULL(s); \
+  linalg_check_batch_size(A.size(1), B.size(1), C.size(1)); \
+  CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
+                            (tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
+                            (tA ? CUBLAS_OP_T : CUBLAS_OP_N), \
+                            C.size(2), C.size(0), (tB ? B.size(2) : 
B.size(0)), &alpha, \
+                            B.dptr_, B.size(1)*B.stride_, B.stride_, \
+                            A.dptr_, A.size(1)*A.stride_, A.stride_, &beta, \
+                            C.dptr_, C.size(1)*C.stride_, C.stride_, 
A.size(1))) \
+}
+LINALG_GPU_GEMM_AXIS(SgemmStridedBatched, float)
+LINALG_GPU_GEMM_AXIS(DgemmStridedBatched, double)
+
+// Specialization of linalg_gemm<gpu, DType> for DType=mshadow::half::half_t.
 // Specialization of linalg_gemm<gpu, DType> for DType=mshadow::half::half_t.
 template<> inline
 void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, 
mshadow::half::half_t>& A,
@@ -192,6 +262,8 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const 
Tensor<gpu, 2, mshadow::half:
 #if CUDA_VERSION < 8000
   LINALG_XPU_BATCH_GEMM(gpu, float)
   LINALG_XPU_BATCH_GEMM(gpu, double)
+  LINALG_XPU_BATCH_GEMM_AXIS(gpu, float)
+  LINALG_XPU_BATCH_GEMM_AXIS(gpu, double)
 #else
 #define LINALG_GPU_BATCH_GEMM(fname, DType) \
   template<> inline \
@@ -217,10 +289,125 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const 
Tensor<gpu, 2, mshadow::half:
   LINALG_GPU_BATCH_GEMM(SgemmStridedBatched, float)
   LINALG_GPU_BATCH_GEMM(DgemmStridedBatched, double)
 
+// Version where matrix rows are given by second axis.
+#define LINALG_GPU_BATCH_GEMM_AXIS(fname, DType) \
+  template<> inline \
+  void linalg_batch_gemm<gpu, DType>(const Tensor<gpu, 4, DType>& A, \
+                                     const Tensor<gpu, 4, DType>& B, \
+                                     const Tensor<gpu, 4, DType>& C, DType 
alpha, DType beta, \
+                                     bool tA, bool tB, Stream<gpu> *s) { \
+    using namespace mxnet; \
+    using mshadow::gpu; \
+    CHECK_NOTNULL(s); \
+    linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
+    linalg_check_batch_size(A.size(2), B.size(2), C.size(2)); \
+    for (index_t i = 0; i < A.size(2); ++i) { \
+      CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
+          (tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
+          (tA ? CUBLAS_OP_T : CUBLAS_OP_N), \
+          C.size(3), C.size(1), (tB ? B.size(3) : B.size(1)), &alpha, \
+          B.dptr_+i*B.stride_, B.size(2) * B.stride_, 
B.size(1)*B.size(2)*B.stride_, \
+          A.dptr_+i*A.stride_, A.size(2) * A.stride_, 
A.size(1)*A.size(2)*A.stride_, &beta, \
+          C.dptr_+i*C.stride_, C.size(2) * C.stride_, 
C.size(1)*C.size(2)*C.stride_, A.size(0))) \
+    }\
+  }
+
+  LINALG_GPU_BATCH_GEMM_AXIS(SgemmStridedBatched, float)
+  LINALG_GPU_BATCH_GEMM_AXIS(DgemmStridedBatched, double)
+
 #endif  // CUDA < 8000
 
 #endif  // __CUDACC__
 
+/*!
+ * \brief Performs gemm, setting alpha and beta as appropriate for `req`.
+ *
+ * \param A the first operand of the gemm
+ * \param B the second operand of the gemm
+ * \param C the data to be assigned
+ * \param tA whether the `A` operand should be transposed first.
+ * \param tB whether the `B` operand should be transposed first.
+ * \param s the stream to perform the operation
+ * \param req the assignment request
+ */
+template<typename xpu, typename DType>
+inline void linalg_gemm(const Tensor<xpu, 2, DType>& A,
+                        const Tensor<xpu, 2, DType>& B,
+                        const Tensor<xpu, 2, DType>& C,
+                        bool tA, bool tB, Stream<xpu> *s,
+                        mxnet::OpReqType req) {
+  using namespace mxnet;
+  switch (req) {
+    case kNullOp:
+      break;
+    case kWriteTo:
+    case kWriteInplace:
+      linalg_gemm(A, B, C, DType(1.0), DType(0.0), tA, tB, s);
+      break;
+    case kAddTo:
+      linalg_gemm(A, B, C, DType(1.0), DType(1.0), tA, tB, s);
+      break;
+    default:
+      LOG(FATAL) << "not reached";
+  }
+}
+
+#if (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)
+
+// A template for a cpu linalg_gemm implementation using mshadow::dot()
+#define LINALG_CPU_GEMM_NO_CBLAS(DType) \
+template<> inline \
+void linalg_gemm<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
+                             const Tensor<cpu, 2, DType>& B, \
+                             const Tensor<cpu, 2, DType>& C, \
+                             bool tA, bool tB, Stream<cpu> *s, \
+                             mxnet::OpReqType req) { \
+  using namespace mxnet; \
+  using mshadow::cpu; \
+  switch (req) { \
+    case kNullOp: \
+      break; \
+    case kWriteTo: \
+    case kWriteInplace: \
+      if (tA) { \
+        if (tB) { \
+          const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B.T()); \
+        } else { \
+          const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B); \
+        } \
+      } else { \
+        if (tB) { \
+          const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B.T()); \
+        } else { \
+          const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B); \
+        } \
+      } \
+      break; \
+    case kAddTo: \
+      if (tA) { \
+        if (tB) { \
+          const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B.T()); \
+        } else { \
+          const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B); \
+        } \
+      } else { \
+        if (tB) { \
+          const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B.T()); \
+        } else { \
+          const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B); \
+        } \
+      } \
+      break; \
+    default: \
+      LOG(FATAL) << "not reached"; \
+  } \
+}
+
+LINALG_CPU_GEMM_NO_CBLAS(float)
+LINALG_CPU_GEMM_NO_CBLAS(double)
+
+#endif  // (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)
+
 //////////////////////////////// TRSM 
////////////////////////////////////////////
 
 // CPU/GPU-versions of BLAS3 function "trsm". Please refer to the 
BLAS3-documentation
@@ -313,95 +500,6 @@ LINALG_XPU_BATCH_TRSM(gpu, double)
 
 #endif  // __CUDACC__
 
-/*!
- * \brief Performs gemm, setting alpha and beta as appropriate for `req`.
- *
- * \param A the first operand of the gemm
- * \param B the second operand of the gemm
- * \param C the data to be assigned
- * \param tA whether the `A` operand should be transposed first.
- * \param tB whether the `B` operand should be transposed first.
- * \param s the stream to perform the operation
- * \param req the assignment request
- */
-template<typename xpu, typename DType>
-inline void linalg_gemm(const Tensor<xpu, 2, DType>& A,
-                        const Tensor<xpu, 2, DType>& B,
-                        const Tensor<xpu, 2, DType>& C,
-                        bool tA, bool tB, Stream<xpu> *s,
-                        mxnet::OpReqType req) {
-  using namespace mxnet;
-  switch (req) {
-    case kNullOp:
-      break;
-    case kWriteTo:
-    case kWriteInplace:
-      linalg_gemm(A, B, C, DType(1.0), DType(0.0), tA, tB, s);
-      break;
-    case kAddTo:
-      linalg_gemm(A, B, C, DType(1.0), DType(1.0), tA, tB, s);
-      break;
-    default:
-      LOG(FATAL) << "not reached";
-  }
-}
-
-#if (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)
-
-// A template for a cpu linalg_gemm implementation using mshadow::dot()
-#define LINALG_CPU_GEMM_NO_CBLAS(DType) \
-template<> inline \
-void linalg_gemm<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
-                             const Tensor<cpu, 2, DType>& B, \
-                             const Tensor<cpu, 2, DType>& C, \
-                             bool tA, bool tB, Stream<cpu> *s, \
-                             mxnet::OpReqType req) { \
-  using namespace mxnet; \
-  using mshadow::cpu; \
-  switch (req) { \
-    case kNullOp: \
-      break; \
-    case kWriteTo: \
-    case kWriteInplace: \
-      if (tA) { \
-        if (tB) { \
-          const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B.T()); \
-        } else { \
-          const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B); \
-        } \
-      } else { \
-        if (tB) { \
-          const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B.T()); \
-        } else { \
-          const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B); \
-        } \
-      } \
-      break; \
-    case kAddTo: \
-      if (tA) { \
-        if (tB) { \
-          const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B.T()); \
-        } else { \
-          const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B); \
-        } \
-      } else { \
-        if (tB) { \
-          const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B.T()); \
-        } else { \
-          const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B); \
-        } \
-      } \
-      break; \
-    default: \
-      LOG(FATAL) << "not reached"; \
-  } \
-}
-
-LINALG_CPU_GEMM_NO_CBLAS(float)
-LINALG_CPU_GEMM_NO_CBLAS(double)
-
-#endif  // (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)
-
 //////////////////////////////// TRMM 
////////////////////////////////////////////
 
 // CPU/GPU-versions of BLAS3 function "trmm". Please refer to the 
BLAS3-documentation
diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc
index 7083efe..b177165 100644
--- a/src/operator/tensor/la_op.cc
+++ b/src/operator/tensor/la_op.cc
@@ -46,8 +46,20 @@ If *n=2*, the BLAS3 function *gemm* is performed:
 Here, *alpha* and *beta* are scalar parameters, and *op()* is either the 
identity or
 matrix transposition (depending on *transpose_a*, *transpose_b*).
 
-If *n>2*, *gemm* is performed separately on the trailing two dimensions for 
all inputs
-(batch mode).
+If *n>2*, *gemm* is performed separately for a batch of matrices. The column 
indices of the matrices
+are given by the last dimensions of the tensors, the row indices by the axis 
specified with the *axis* 
+parameter. By default, the trailing two dimensions will be used for matrix 
encoding.
+
+For a non-default axis parameter, the operation performed is equivalent to a 
series of swapaxes/gemm/swapaxes
+calls. For example let *A*, *B*, *C* be 5 dimensional tensors. Then gemm(*A*, 
*B*, *C*, axis=1) is equivalent to
+
+    A1 = swapaxes(A, dim1=1, dim2=3)
+    B1 = swapaxes(B, dim1=1, dim2=3)
+    C = swapaxes(C, dim1=1, dim2=3)
+    C = gemm(A1, B1, C)
+    C = swapaxis(C, dim1=1, dim2=3)
+
+without the overhead of the additional swapaxis operations.
 
 .. note:: The operator supports float32 and float64 data types only.
 
@@ -76,7 +88,7 @@ Examples::
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs)
   { return std::vector<std::pair<int, int>>{{2, 0}}; })
-.set_attr<FCompute>("FCompute<cpu>", LaOpForward<cpu, 2, 2, 3, 1, gemm>)
+.set_attr<FCompute>("FCompute<cpu>", LaOpGemmForward<cpu, 2, 2, 3, 1, gemm>)
 .set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_linalg_gemm"})
 .add_argument("A", "NDArray-or-Symbol", "Tensor of input matrices")
 .add_argument("B", "NDArray-or-Symbol", "Tensor of input matrices")
@@ -92,7 +104,7 @@ NNVM_REGISTER_OP(_backward_linalg_gemm)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
   { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
-.set_attr<FCompute>("FCompute<cpu>", LaOpBackward<cpu, 2, 2, 4, 3, 
gemm_backward>);
+.set_attr<FCompute>("FCompute<cpu>", LaOpGemmBackward<cpu, 2, 2, 4, 3, 
gemm_backward>);
 
 NNVM_REGISTER_OP(_linalg_gemm2)
 .add_alias("linalg_gemm2")
@@ -107,8 +119,19 @@ If *n=2*, the BLAS3 function *gemm* is performed:
 Here *alpha* is a scalar parameter and *op()* is either the identity or the 
matrix
 transposition (depending on *transpose_a*, *transpose_b*).
 
-If *n>2*, *gemm* is performed separately on the trailing two dimensions for 
all inputs
-(batch mode).
+If *n>2*, *gemm* is performed separately for a batch of matrices. The column 
indices of the matrices
+are given by the last dimensions of the tensors, the row indices by the axis 
specified with the *axis* 
+parameter. By default, the trailing two dimensions will be used for matrix 
encoding.
+
+For a non-default axis parameter, the operation performed is equivalent to a 
series of swapaxes/gemm/swapaxes
+calls. For example let *A*, *B* be 5 dimensional tensors. Then gemm(*A*, *B*, 
axis=1) is equivalent to
+
+    A1 = swapaxes(A, dim1=1, dim2=3)
+    B1 = swapaxes(B, dim1=1, dim2=3)
+    C = gemm2(A1, B1)
+    C = swapaxis(C, dim1=1, dim2=3)
+
+without the overhead of the additional swapaxis operations.
 
 .. note:: The operator supports float32 and float64 data types only.
 
@@ -133,7 +156,7 @@ Examples::
   { return std::vector<std::string>{"A", "B"}; } )
 .set_attr<nnvm::FInferShape>("FInferShape", LaMatrixMultMacOpShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
-.set_attr<FCompute>("FCompute<cpu>", LaOpForward<cpu, 2, 2, 2, 1, gemm2>)
+.set_attr<FCompute>("FCompute<cpu>", LaOpGemmForward<cpu, 2, 2, 2, 1, gemm2>)
 .set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_linalg_gemm2"})
 .add_argument("A", "NDArray-or-Symbol", "Tensor of input matrices")
 .add_argument("B", "NDArray-or-Symbol", "Tensor of input matrices")
@@ -148,7 +171,7 @@ NNVM_REGISTER_OP(_backward_linalg_gemm2)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
   { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
-.set_attr<FCompute>("FCompute<cpu>", LaOpBackward<cpu, 2, 2, 3, 2, 
gemm2_backward>);
+.set_attr<FCompute>("FCompute<cpu>", LaOpGemmBackward<cpu, 2, 2, 3, 2, 
gemm2_backward>);
 
 NNVM_REGISTER_OP(_linalg_potrf)
 .add_alias("linalg_potrf")
diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu
index efd705f..d736845 100644
--- a/src/operator/tensor/la_op.cu
+++ b/src/operator/tensor/la_op.cu
@@ -28,16 +28,16 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(_linalg_gemm)
-.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 3, 1, gemm>);
+.set_attr<FCompute>("FCompute<gpu>", LaOpGemmForward<gpu, 2, 2, 3, 1, gemm>);
 
 NNVM_REGISTER_OP(_backward_linalg_gemm)
-.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 4, 3, 
gemm_backward>);
+.set_attr<FCompute>("FCompute<gpu>", LaOpGemmBackward<gpu, 2, 2, 4, 3, 
gemm_backward>);
 
 NNVM_REGISTER_OP(_linalg_gemm2)
-.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 2, 1, gemm2>);
+.set_attr<FCompute>("FCompute<gpu>", LaOpGemmForward<gpu, 2, 2, 2, 1, gemm2>);
 
 NNVM_REGISTER_OP(_backward_linalg_gemm2)
-.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 3, 2, 
gemm2_backward>);
+.set_attr<FCompute>("FCompute<gpu>", LaOpGemmBackward<gpu, 2, 2, 3, 2, 
gemm2_backward>);
 
 NNVM_REGISTER_OP(_linalg_trmm)
 .set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 2, 1, trmm>);
diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h
index 3d411b2..8e2acd7 100644
--- a/src/operator/tensor/la_op.h
+++ b/src/operator/tensor/la_op.h
@@ -40,6 +40,7 @@ namespace op {
 struct LaMatrixMacParam : public dmlc::Parameter<LaMatrixMacParam> {
   bool transpose_a, transpose_b;
   double alpha, beta;
+  int axis;
   DMLC_DECLARE_PARAMETER(LaMatrixMacParam) {
     DMLC_DECLARE_FIELD(transpose_a)
       .set_default(false)
@@ -53,6 +54,9 @@ struct LaMatrixMacParam : public 
dmlc::Parameter<LaMatrixMacParam> {
     DMLC_DECLARE_FIELD(beta)
       .set_default(1.0)
       .describe("Scalar factor multiplied with C.");
+    DMLC_DECLARE_FIELD(axis)
+      .set_default(-2)
+      .describe("Axis corresponding to the matrix rows.");
   }
 };
 
@@ -60,6 +64,7 @@ struct LaMatrixMacParam : public 
dmlc::Parameter<LaMatrixMacParam> {
 struct LaMatrixMultParam : public dmlc::Parameter<LaMatrixMultParam> {
   bool transpose_a, transpose_b;
   double alpha;
+  int axis;
   DMLC_DECLARE_PARAMETER(LaMatrixMultParam) {
     DMLC_DECLARE_FIELD(transpose_a)
       .set_default(false)
@@ -70,6 +75,9 @@ struct LaMatrixMultParam : public 
dmlc::Parameter<LaMatrixMultParam> {
     DMLC_DECLARE_FIELD(alpha)
       .set_default(1.0)
       .describe("Scalar factor multiplied with A*B.");
+    DMLC_DECLARE_FIELD(axis)
+      .set_default(-2)
+      .describe("Axis corresponding to the matrix row indices.");
   }
 };
 
@@ -112,30 +120,37 @@ inline bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& 
attrs,
   CHECK_GE(in_attrs->size(), 2);
   CHECK_EQ(out_attrs->size(), 1);
   bool transpose_a(false), transpose_b(false);
+  int axis_param(-2);
   if ( in_attrs->size() == 2 ) {
      // Matrix-Matrix mult
      transpose_a = nnvm::get<LaMatrixMultParam>(attrs.parsed).transpose_a;
      transpose_b = nnvm::get<LaMatrixMultParam>(attrs.parsed).transpose_b;
+     axis_param  = nnvm::get<LaMatrixMultParam>(attrs.parsed).axis;
   } else {
      // Matrix-Matrix mac
      transpose_a = nnvm::get<LaMatrixMacParam>(attrs.parsed).transpose_a;
      transpose_b = nnvm::get<LaMatrixMacParam>(attrs.parsed).transpose_b;
+     axis_param  = nnvm::get<LaMatrixMacParam>(attrs.parsed).axis;
   }
   if ( (*in_attrs)[0].ndim() >= 2 && (*in_attrs)[0].ndim() == 
(*in_attrs)[1].ndim() ) {
     // Forward shape inference.
-    const int ndim((*in_attrs)[0].ndim());
+    const int ndim((*in_attrs)[0].ndim()), axis(axis_param < 0 ? ndim + 
axis_param : axis_param);
+    CHECK(axis >= 0 && axis < ndim-1)
+      << "Invalid row axis (" << axis_param << ")";
     std::vector<int> oshape(ndim);
-    for ( int i = 0; i < ndim-2; ++i ) {
-      // Both inputs must have same shape except for last two dimensions.
-      CHECK_EQ((*in_attrs)[0][i], (*in_attrs)[1][i])
-        << "Shapes of inputs 0, 1 must be the same, except on last two 
dimensions";
+    for ( int i = 0; i < ndim-1; ++i ) {
+      if (i != axis) {
+        // Both inputs must have same shape except for row/col dimensions.
+        CHECK_EQ((*in_attrs)[0][i], (*in_attrs)[1][i])
+          << "Shapes of inputs 0, 1 must be the same, except on row/col axis";
+      }
       oshape[i] = (*in_attrs)[0][i];
     }
-    CHECK_EQ((transpose_a ? (*in_attrs)[0][ndim-2] : (*in_attrs)[0][ndim-1]),
-             (transpose_b ? (*in_attrs)[1][ndim-1] : (*in_attrs)[1][ndim-2]))
+    CHECK_EQ((transpose_a ? (*in_attrs)[0][axis] : (*in_attrs)[0][ndim-1]),
+             (transpose_b ? (*in_attrs)[1][ndim-1] : (*in_attrs)[1][axis]))
              << "Incompatible matrix dimensions for multiplication";
-    oshape[ndim-2] = (transpose_a ? (*in_attrs)[0][ndim-1] : 
(*in_attrs)[0][ndim-2]);
-    oshape[ndim-1] = (transpose_b ? (*in_attrs)[1][ndim-2] : 
(*in_attrs)[1][ndim-1]);
+    oshape[axis] = (transpose_a ? (*in_attrs)[0][ndim-1] : 
(*in_attrs)[0][axis]);
+    oshape[ndim-1] = (transpose_b ? (*in_attrs)[1][axis] : 
(*in_attrs)[1][ndim-1]);
     TShape tshape(oshape.begin(), oshape.end());
     SHAPE_ASSIGN_CHECK(*out_attrs, 0, tshape);
     if ( in_attrs->size() > 2 ) {
@@ -340,6 +355,33 @@ inline bool LaEigFactShape(const nnvm::NodeAttrs& attrs,
   return false;
 }
 
+// Flattener for following adaptors.
+template<typename xpu, int dim, typename DType>
+mshadow::Tensor<xpu, dim, DType> LaOpFlatten(const TBlob& blob,
+                                             mshadow::Stream<xpu> *s, int axis 
= -2) {
+  if (axis < 0) {
+    axis = blob.ndim() + axis;
+  }
+  if (axis >= blob.ndim()-2) {
+    // Leave highest axis, collapse rest.
+    return blob.FlatToKD<xpu, dim, DType>(s);
+  }
+  // Collapse ranges [0,axis-1] and [axis+1,ndim-2].
+  CHECK_EQ(dim, 4);
+  TShape shape(dim);
+  shape[0] = 1;
+  for (int i = 0; i < axis; ++i) {
+    shape[0] *= blob.shape_[i];
+  }
+  shape[1] = blob.shape_[axis];
+  shape[2] = 1;
+  for (int i = axis+1; i < blob.ndim()-1; ++i) {
+    shape[2] *= blob.shape_[i];
+  }
+  shape[3] = blob.shape_[blob.ndim()-1];
+  return blob.get_with_shape<xpu, dim, DType>(shape.get<dim>(), s);
+}
+
 // Adapters for calling the various operators with appropriate signatures.
 
 template<typename xpu, typename DType, int idim, int odim, int inum, int onum, 
typename laop>
@@ -347,7 +389,7 @@ struct LaOpCaller {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
                  const nnvm::NodeAttrs& attrs,
-                 const OpContext& ctx) {
+                 const OpContext& ctx, int axis = -2) {
     CHECK(false) << "no specialized LaOpCaller defined for template 
parameters";
   }
 };
@@ -356,10 +398,10 @@ struct LaOpCaller<xpu, DType, idim, odim, 1, 1, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
                  const nnvm::NodeAttrs& attrs,
-                 const OpContext& ctx) {
+                 const OpContext& ctx, int axis = -2) {
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+    laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), ctx, attrs);
   }
 };
 template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -367,11 +409,11 @@ struct LaOpCaller<xpu, DType, idim, odim, 1, 2, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
                  const nnvm::NodeAttrs& attrs,
-                 const OpContext& ctx) {
+                 const OpContext& ctx, int axis = -2) {
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s),
-             outputs[1].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+    laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[1], s, axis), ctx, attrs);
   }
 };
 template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -379,11 +421,11 @@ struct LaOpCaller<xpu, DType, idim, odim, 2, 1, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
                  const nnvm::NodeAttrs& attrs,
-                 const OpContext& ctx) {
+                 const OpContext& ctx, int axis = -2) {
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[1].FlatToKD<xpu, idim+1, DType>(s),
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+    laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), ctx, attrs);
   }
 };
 template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -391,12 +433,12 @@ struct LaOpCaller<xpu, DType, idim, odim, 3, 1, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
                  const nnvm::NodeAttrs& attrs,
-                 const OpContext& ctx) {
+                 const OpContext& ctx, int axis = -2) {
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[1].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[2].FlatToKD<xpu, idim+1, DType>(s),
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+    laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), ctx, attrs);
   }
 };
 template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -404,13 +446,13 @@ struct LaOpCaller<xpu, DType, idim, odim, 3, 2, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
                  const nnvm::NodeAttrs& attrs,
-                 const OpContext& ctx) {
+                 const OpContext& ctx, int axis = -2) {
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[1].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[2].FlatToKD<xpu, idim+1, DType>(s),
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s),
-             outputs[1].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+    laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[1], s, axis), ctx, attrs);
   }
 };
 template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -418,13 +460,13 @@ struct LaOpCaller<xpu, DType, idim, odim, 4, 1, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
                  const nnvm::NodeAttrs& attrs,
-                 const OpContext& ctx) {
+                 const OpContext& ctx, int axis = -2) {
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[1].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[2].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[3].FlatToKD<xpu, idim+1, DType>(s),
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+    laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[3], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), ctx, attrs);
   }
 };
 template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -432,14 +474,14 @@ struct LaOpCaller<xpu, DType, idim, odim, 4, 2, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
                  const nnvm::NodeAttrs& attrs,
-                 const OpContext& ctx) {
+                 const OpContext& ctx, int axis = -2) {
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[1].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[2].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[3].FlatToKD<xpu, idim+1, DType>(s),
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s),
-             outputs[1].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+    laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[3], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[1], s, axis), ctx, attrs);
   }
 };
 template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -447,15 +489,15 @@ struct LaOpCaller<xpu, DType, idim, odim, 4, 3, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
                  const nnvm::NodeAttrs& attrs,
-                 const OpContext& ctx) {
+                 const OpContext& ctx, int axis = -2) {
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[1].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[2].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[3].FlatToKD<xpu, idim+1, DType>(s),
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s),
-             outputs[1].FlatToKD<xpu, odim+1, DType>(s),
-             outputs[2].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+    laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[3], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[1], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[2], s, axis), ctx, attrs);
   }
 };
 
@@ -504,6 +546,64 @@ void LaOpBackward(const nnvm::NodeAttrs& attrs,
   });
 }
 
+template<typename xpu, int idim, int odim, int inum, int onum, typename laop>
+void LaOpGemmForward(const nnvm::NodeAttrs& attrs,
+                     const OpContext& ctx,
+                     const std::vector<TBlob>& inputs,
+                     const std::vector<OpReqType>& req,
+                     const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  CHECK_EQ(inputs.size(), inum);
+  CHECK_EQ(outputs.size(), onum);
+  const int axis(inputs.size() == 2 ? 
nnvm::get<LaMatrixMultParam>(attrs.parsed).axis
+                                    : 
nnvm::get<LaMatrixMacParam>(attrs.parsed).axis);
+  MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+    if (axis == -2 || axis == inputs[0].ndim()-2) {
+      LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, outputs,
+                                                               attrs, ctx);
+    } else {
+      LaOpCaller<xpu, OType, idim+1, odim+1, inum, onum, laop>::op(inputs, 
outputs,
+                                                                   attrs, ctx, 
axis);
+    }
+  });
+}
+
+template<typename xpu, int idim, int odim, int inum, int onum, typename laop>
+void LaOpGemmBackward(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<TBlob>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  CHECK_EQ(inputs.size(), inum);
+  CHECK_EQ(outputs.size(), onum);
+  const int axis(inputs.size() == 3 ? 
nnvm::get<LaMatrixMultParam>(attrs.parsed).axis
+                                    : 
nnvm::get<LaMatrixMacParam>(attrs.parsed).axis);
+  MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+    std::vector<TBlob> tspace(outputs);
+    for ( int i = 0; i < onum; ++i ) {
+      if ( req[i] == kAddTo ) {
+        tspace[i].dptr_ = ctx.requested[0]
+                             .get_space_typed<xpu, 1, 
OType>(Shape1(outputs[i].Size()), s).dptr_;
+      }
+    }
+    if (axis == -2 || axis == inputs[0].ndim()-2) {
+      LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, outputs,
+                                                               attrs, ctx);
+    } else {
+      LaOpCaller<xpu, OType, idim+1, odim+1, inum, onum, laop>::op(inputs, 
outputs,
+                                                                   attrs, ctx, 
axis);
+    }
+    for ( int i = 0; i < onum; ++i ) {
+      if ( req[i] == kAddTo ) {
+        Tensor<xpu, 1, OType> out = outputs[i].FlatTo1D<xpu, OType>(s);
+        out += tspace[i].FlatTo1D<xpu, OType>(s);
+      }
+    }
+  });
+}
+
 // Specific wrapper for syevd (cannot use the default ones, because A, U have
 // different dimensionality than L
 
diff --git a/src/operator/tensor/la_op_inline.h 
b/src/operator/tensor/la_op_inline.h
index a508eb7..b483108 100644
--- a/src/operator/tensor/la_op_inline.h
+++ b/src/operator/tensor/la_op_inline.h
@@ -60,24 +60,24 @@ struct Scale {
 
 // D = gemm(A,B,C)
 struct gemm {
-  template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& 
B,
-                 const Tensor<xpu, 3, DType>& C, DType alpha, DType beta,
+  template<typename xpu, int dim, typename DType>
+  static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, 
DType>& B,
+                 const Tensor<xpu, dim, DType>& C, DType alpha, DType beta,
                  bool tA, bool tB, Stream<xpu> *s) {
     linalg_batch_gemm(A, B, C, alpha, beta, tA, tB, s);
   }
-  template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& 
B,
-                 const Tensor<xpu, 3, DType>& C, const Tensor<xpu, 3, DType>& 
D,
+  template<typename xpu, int dim, typename DType>
+  static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, 
DType>& B,
+                 const Tensor<xpu, dim, DType>& C, const Tensor<xpu, dim, 
DType>& D,
                  Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
     if ( C.dptr_ != D.dptr_ ) Copy(D, C, s);
     const LaMatrixMacParam& param = nnvm::get<LaMatrixMacParam>(attrs.parsed);
     op(A, B, D, DType(param.alpha), DType(param.beta), param.transpose_a,
        param.transpose_b, s);
   }
-  template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& 
B,
-                 const Tensor<xpu, 3, DType>& C, const Tensor<xpu, 3, DType>& 
D,
+  template<typename xpu, int dim, typename DType>
+  static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, 
DType>& B,
+                 const Tensor<xpu, dim, DType>& C, const Tensor<xpu, dim, 
DType>& D,
                  const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     op(A, B, C, D, s, attrs);
@@ -86,17 +86,17 @@ struct gemm {
 
 // C = gemm2(A,B)
 struct gemm2 {
-  template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& 
B,
-                 const Tensor<xpu, 3, DType>& C, Stream<xpu> *s,
+  template<typename xpu, int dim, typename DType>
+  static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, 
DType>& B,
+                 const Tensor<xpu, dim, DType>& C, Stream<xpu> *s,
                  const nnvm::NodeAttrs& attrs) {
     const LaMatrixMultParam& param = 
nnvm::get<LaMatrixMultParam>(attrs.parsed);
     gemm::op(A, B, C, DType(param.alpha), DType(0), param.transpose_a,
              param.transpose_b, s);
   }
-  template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& 
B,
-                 const Tensor<xpu, 3, DType>& C, const OpContext& ctx,
+  template<typename xpu, int dim, typename DType>
+  static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, 
DType>& B,
+                 const Tensor<xpu, dim, DType>& C, const OpContext& ctx,
                  const nnvm::NodeAttrs& attrs) {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     op(A, B, C, s, attrs);
@@ -343,11 +343,11 @@ struct syevd {
 // Backward operators (always using batch processing)
 
 struct gemm_backward {
-  template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& dD, const Tensor<xpu, 3, DType>& 
A,
-                 const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& 
C,
-                 const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& 
dB,
-                 const Tensor<xpu, 3, DType>& dC,
+  template<typename xpu, int dim, typename DType>
+  static void op(const Tensor<xpu, dim, DType>& dD, const Tensor<xpu, dim, 
DType>& A,
+                 const Tensor<xpu, dim, DType>& B, const Tensor<xpu, dim, 
DType>& C,
+                 const Tensor<xpu, dim, DType>& dA, const Tensor<xpu, dim, 
DType>& dB,
+                 const Tensor<xpu, dim, DType>& dC,
                  Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
     const LaMatrixMacParam& param = nnvm::get<LaMatrixMacParam>(attrs.parsed);
     bool tA(param.transpose_a), tB(param.transpose_b);
@@ -359,11 +359,11 @@ struct gemm_backward {
     using namespace mxnet_op;
     Kernel<Scale, xpu>::Launch(s, dC.MSize(), DType(param.beta), dC.dptr_);
   }
-  template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& dD, const Tensor<xpu, 3, DType>& 
A,
-                 const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& 
C,
-                 const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& 
dB,
-                 const Tensor<xpu, 3, DType>& dC,
+  template<typename xpu, int dim, typename DType>
+  static void op(const Tensor<xpu, dim, DType>& dD, const Tensor<xpu, dim, 
DType>& A,
+                 const Tensor<xpu, dim, DType>& B, const Tensor<xpu, dim, 
DType>& C,
+                 const Tensor<xpu, dim, DType>& dA, const Tensor<xpu, dim, 
DType>& dB,
+                 const Tensor<xpu, dim, DType>& dC,
                  const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     op(dD, A, B, C, dA, dB, dC, s, attrs);
@@ -371,10 +371,10 @@ struct gemm_backward {
 };
 
 struct gemm2_backward {
-  template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& 
A,
-                 const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& 
dA,
-                 const Tensor<xpu, 3, DType>& dB,
+  template<typename xpu, int dim, typename DType>
+  static void op(const Tensor<xpu, dim, DType>& dC, const Tensor<xpu, dim, 
DType>& A,
+                 const Tensor<xpu, dim, DType>& B, const Tensor<xpu, dim, 
DType>& dA,
+                 const Tensor<xpu, dim, DType>& dB,
                  Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
     const LaMatrixMultParam& param = 
nnvm::get<LaMatrixMultParam>(attrs.parsed);
     bool tA(param.transpose_a), tB(param.transpose_b);
@@ -383,10 +383,10 @@ struct gemm2_backward {
     (tB ? gemm::op(dC, A, dB, DType(param.alpha), DType(0), true, tA, s)
         : gemm::op(A, dC, dB, DType(param.alpha), DType(0), !tA, false, s));
   }
-  template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& 
A,
-                 const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& 
dA,
-                 const Tensor<xpu, 3, DType>& dB,
+  template<typename xpu, int dim, typename DType>
+  static void op(const Tensor<xpu, dim, DType>& dC, const Tensor<xpu, dim, 
DType>& A,
+                 const Tensor<xpu, dim, DType>& B, const Tensor<xpu, dim, 
DType>& dA,
+                 const Tensor<xpu, dim, DType>& dB,
                  const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     op(dC, A, B, dA, dB, s, attrs);
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 3f08971..923a453 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4688,7 +4688,24 @@ def test_laop():
     check_fw(test_gemm, [a, b, c], [r])
     if grad_check == 1:
         check_grad(test_gemm, [a, b, c])
-
+    # Check for different axis that describes matrix rows.
+    a2 = np.copy(np.swapaxes(a, 0, 2))
+    b2 = np.copy(np.swapaxes(b, 0, 2))
+    c2 = np.copy(np.swapaxes(c, 0, 2))
+    r2 = np.copy(np.swapaxes(r, 0, 2))
+    test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha=4., beta=7., 
axis = 0)
+    check_fw(test_gemm, [a2, b2, c2], [r2])
+    if grad_check == 1:
+        check_grad(test_gemm, [a2, b2, c2])
+    a2 = np.copy(np.swapaxes(a, 1, 2))
+    b2 = np.copy(np.swapaxes(b, 1, 2))
+    c2 = np.copy(np.swapaxes(c, 1, 2))
+    r2 = np.copy(np.swapaxes(r, 1, 2))
+    test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha=4., beta=7., 
axis = -3)
+    check_fw(test_gemm, [a2, b2, c2], [r2])
+    if grad_check == 1:
+        check_grad(test_gemm, [a2, b2, c2])
+    
     # Check gemm2 operator same way as gemm.
     res_gemm = 4. * np.dot(data_in1, data_in2)
     test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha=4.)
@@ -4720,6 +4737,20 @@ def test_laop():
     check_fw(test_gemm, [a, b], [r])
     if grad_check == 1:
         check_grad(test_gemm, [a, b])
+    a2 = np.copy(np.swapaxes(a, 0, 2))
+    b2 = np.copy(np.swapaxes(b, 0, 2))
+    r2 = np.copy(np.swapaxes(r, 0, 2))
+    test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha=4., axis = 0)
+    check_fw(test_gemm, [a2, b2], [r2])
+    if grad_check == 1:
+        check_grad(test_gemm, [a2, b2])
+    a2 = np.copy(np.swapaxes(a, 1, 2))
+    b2 = np.copy(np.swapaxes(b, 1, 2))
+    r2 = np.copy(np.swapaxes(r, 1, 2))
+    test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha=4., axis = -3)
+    check_fw(test_gemm, [a2, b2], [r2])
+    if grad_check == 1:
+        check_grad(test_gemm, [a2, b2])
 
     # Now test all the other operators.
 

-- 
To stop receiving notification emails like this one, please contact
j...@apache.org.

Reply via email to