This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 4ac76c8 Support for axis parameter in linalg.gemm (#10864)
4ac76c8 is described below
commit 4ac76c89da6d4d8feef629949dc0f9534b216e3d
Author: moin <[email protected]>
AuthorDate: Tue May 29 20:15:11 2018 +0200
Support for axis parameter in linalg.gemm (#10864)
---
src/operator/linalg.h | 7 +
src/operator/linalg_impl.h | 276 ++++++++++++++++++++++-----------
src/operator/tensor/la_op.cc | 39 ++++-
src/operator/tensor/la_op.cu | 8 +-
src/operator/tensor/la_op.h | 206 +++++++++++++++++-------
src/operator/tensor/la_op_inline.h | 66 ++++----
tests/python/unittest/test_operator.py | 33 +++-
7 files changed, 447 insertions(+), 188 deletions(-)
diff --git a/src/operator/linalg.h b/src/operator/linalg.h
index aee67d7..dc59400 100644
--- a/src/operator/linalg.h
+++ b/src/operator/linalg.h
@@ -64,6 +64,13 @@ void linalg_batch_gemm(const Tensor<xpu, 3, DType>& A, const
Tensor<xpu, 3, DTyp
const Tensor<xpu, 3, DType>& C, DType alpha, DType beta,
bool tA, bool tB, Stream<xpu> *s = 0);
+// Version of batch gemmm where rows are indexed at axis 1 and columns at axis
3.
+template<typename xpu, typename DType>
+void linalg_batch_gemm(const Tensor<xpu, 4, DType>& A, const Tensor<xpu, 4,
DType>& B,
+ const Tensor<xpu, 4, DType>& C, DType alpha, DType beta,
+ bool tA, bool tB, Stream<xpu> *s = 0);
+
+
template<typename xpu, typename DType>
inline void linalg_gemm(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 2, DType>& B,
diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h
index 151db60..08d2add 100644
--- a/src/operator/linalg_impl.h
+++ b/src/operator/linalg_impl.h
@@ -56,6 +56,11 @@ inline void check_gemm(const Tensor<xpu, 2, DType>& A, const
Tensor<xpu, 2, DTyp
<< "Non compatible matrix dimensions between inputs A and B for gemm";
}
+template<typename xpu, typename DType>
+void linalg_gemm_axis(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3,
DType>& B,
+ const Tensor<xpu, 3, DType>& C, DType alpha, DType beta,
+ bool tA, bool tB, Stream<xpu> *s = 0);
+
#if (MSHADOW_USE_CBLAS == 1 || MSHADOW_USE_MKL == 1)
#define LINALG_CPU_GEMM(fname, DType) \
@@ -80,6 +85,38 @@ void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 3,
DType>& A, const Tensor<
} \
}
+// Batched gemm where the batch coordinate is given by the second axis.
+#define LINALG_CPU_GEMM_AXIS(fname, DType) \
+template<> inline \
+void linalg_gemm_axis<cpu, DType>(const Tensor<cpu, 3, DType>& A, const
Tensor<cpu, 3, DType>& B, \
+ const Tensor<cpu, 3, DType>& C, DType alpha,
DType beta, \
+ bool tA, bool tB, Stream<cpu> *s) { \
+ linalg_check_batch_size(A.size(1), B.size(1), C.size(1)); \
+ for (index_t i = 0; i < A.size(1); ++i) { \
+ cblas_##fname(CblasRowMajor, (tA ? CblasTrans : CblasNoTrans), \
+ (tB ? CblasTrans : CblasNoTrans), \
+ C.size(0), C.size(2), (tA ? A.size(0) : A.size(2)), alpha, \
+ A.dptr_+i*A.stride_, A.size(1)*A.stride_, \
+ B.dptr_+i*B.stride_, B.size(1)*B.stride_, beta, \
+ C.dptr_+i*C.stride_, C.size(1)*C.stride_); \
+ } \
+}
+
+LINALG_CPU_GEMM_AXIS(sgemm, float)
+LINALG_CPU_GEMM_AXIS(dgemm, double)
+
+// Version where matrix rows are given by the second axis.
+#define LINALG_XPU_BATCH_GEMM_AXIS(xpu, DType) \
+template<> inline \
+void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 4, DType>& A, const
Tensor<xpu, 4, DType>& B, \
+ const Tensor<xpu, 4, DType>& C, DType
alpha, DType beta, \
+ bool tA, bool tB, Stream<xpu> *s) { \
+ linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
+ for (index_t i = 0; i < A.size(0); ++i) { \
+ linalg_gemm_axis(A[i], B[i], C[i], alpha, beta, tA, tB, s); \
+ } \
+}
+
#else
#define LINALG_CPU_GEMM(fname, DType) \
@@ -98,6 +135,14 @@ void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 3,
DType>& A, const Tensor<
LOG(FATAL) << "linalg_batch_gemm not implemented by mxnet for cpu, needs
cblas!"; \
}
+#define LINALG_XPU_BATCH_GEMM_AXIS(xpu, DType) \
+template<> inline \
+void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 4, DType>& A, const
Tensor<xpu, 4, DType>& B, \
+ const Tensor<xpu, 4, DType>& C, DType
alpha, DType beta, \
+ bool tA, bool tB, Stream<xpu> *s) { \
+ LOG(FATAL) << "linalg_batch_gemm not implemented by mxnet for cpu, needs
cblas!"; \
+}
+
#endif // MSHADOW_USE_CBLAS == 1 || MSHADOW_USE_MKL == 1
LINALG_CPU_GEMM(sgemm, float)
@@ -106,6 +151,9 @@ LINALG_CPU_GEMM(dgemm, double)
LINALG_XPU_BATCH_GEMM(cpu, float)
LINALG_XPU_BATCH_GEMM(cpu, double)
+LINALG_XPU_BATCH_GEMM_AXIS(cpu, float)
+LINALG_XPU_BATCH_GEMM_AXIS(cpu, double)
+
// Specialization of linalg_gemm<cpu, DType> for DType=mshadow::half::half_t.
template<> inline
void linalg_gemm<cpu, mshadow::half::half_t>(const Tensor<cpu, 2,
mshadow::half::half_t>& A,
@@ -140,6 +188,28 @@ void linalg_gemm<gpu, DType>(const Tensor<gpu, 2, DType>&
A, const Tensor<gpu, 2
LINALG_GPU_GEMM(Sgemm, float)
LINALG_GPU_GEMM(Dgemm, double)
+// Version where matrix rows are given by first axis.
+#define LINALG_GPU_GEMM_AXIS(fname, DType) \
+template<> inline \
+void linalg_gemm_axis<gpu, DType>(const Tensor<gpu, 3, DType>& A, const
Tensor<gpu, 3, DType>& B, \
+ const Tensor<gpu, 3, DType>& C, DType alpha,
DType beta, \
+ bool tA, bool tB, Stream<gpu> *s) { \
+ using namespace mxnet; \
+ using mshadow::gpu; \
+ CHECK_NOTNULL(s); \
+ linalg_check_batch_size(A.size(1), B.size(1), C.size(1)); \
+ CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
+ (tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
+ (tA ? CUBLAS_OP_T : CUBLAS_OP_N), \
+ C.size(2), C.size(0), (tB ? B.size(2) :
B.size(0)), &alpha, \
+ B.dptr_, B.size(1)*B.stride_, B.stride_, \
+ A.dptr_, A.size(1)*A.stride_, A.stride_, &beta, \
+ C.dptr_, C.size(1)*C.stride_, C.stride_,
A.size(1))) \
+}
+LINALG_GPU_GEMM_AXIS(SgemmStridedBatched, float)
+LINALG_GPU_GEMM_AXIS(DgemmStridedBatched, double)
+
+// Specialization of linalg_gemm<gpu, DType> for DType=mshadow::half::half_t.
// Specialization of linalg_gemm<gpu, DType> for DType=mshadow::half::half_t.
template<> inline
void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2,
mshadow::half::half_t>& A,
@@ -192,6 +262,8 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const
Tensor<gpu, 2, mshadow::half:
#if CUDA_VERSION < 8000
LINALG_XPU_BATCH_GEMM(gpu, float)
LINALG_XPU_BATCH_GEMM(gpu, double)
+ LINALG_XPU_BATCH_GEMM_AXIS(gpu, float)
+ LINALG_XPU_BATCH_GEMM_AXIS(gpu, double)
#else
#define LINALG_GPU_BATCH_GEMM(fname, DType) \
template<> inline \
@@ -217,10 +289,125 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const
Tensor<gpu, 2, mshadow::half:
LINALG_GPU_BATCH_GEMM(SgemmStridedBatched, float)
LINALG_GPU_BATCH_GEMM(DgemmStridedBatched, double)
+// Version where matrix rows are given by second axis.
+#define LINALG_GPU_BATCH_GEMM_AXIS(fname, DType) \
+ template<> inline \
+ void linalg_batch_gemm<gpu, DType>(const Tensor<gpu, 4, DType>& A, \
+ const Tensor<gpu, 4, DType>& B, \
+ const Tensor<gpu, 4, DType>& C, DType
alpha, DType beta, \
+ bool tA, bool tB, Stream<gpu> *s) { \
+ using namespace mxnet; \
+ using mshadow::gpu; \
+ CHECK_NOTNULL(s); \
+ linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
+ linalg_check_batch_size(A.size(2), B.size(2), C.size(2)); \
+ for (index_t i = 0; i < A.size(2); ++i) { \
+ CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
+ (tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
+ (tA ? CUBLAS_OP_T : CUBLAS_OP_N), \
+ C.size(3), C.size(1), (tB ? B.size(3) : B.size(1)), &alpha, \
+ B.dptr_+i*B.stride_, B.size(2) * B.stride_,
B.size(1)*B.size(2)*B.stride_, \
+ A.dptr_+i*A.stride_, A.size(2) * A.stride_,
A.size(1)*A.size(2)*A.stride_, &beta, \
+ C.dptr_+i*C.stride_, C.size(2) * C.stride_,
C.size(1)*C.size(2)*C.stride_, A.size(0))) \
+ }\
+ }
+
+ LINALG_GPU_BATCH_GEMM_AXIS(SgemmStridedBatched, float)
+ LINALG_GPU_BATCH_GEMM_AXIS(DgemmStridedBatched, double)
+
#endif // CUDA < 8000
#endif // __CUDACC__
+/*!
+ * \brief Performs gemm, setting alpha and beta as appropriate for `req`.
+ *
+ * \param A the first operand of the gemm
+ * \param B the second operand of the gemm
+ * \param C the data to be assigned
+ * \param tA whether the `A` operand should be transposed first.
+ * \param tB whether the `B` operand should be transposed first.
+ * \param s the stream to perform the operation
+ * \param req the assignment request
+ */
+template<typename xpu, typename DType>
+inline void linalg_gemm(const Tensor<xpu, 2, DType>& A,
+ const Tensor<xpu, 2, DType>& B,
+ const Tensor<xpu, 2, DType>& C,
+ bool tA, bool tB, Stream<xpu> *s,
+ mxnet::OpReqType req) {
+ using namespace mxnet;
+ switch (req) {
+ case kNullOp:
+ break;
+ case kWriteTo:
+ case kWriteInplace:
+ linalg_gemm(A, B, C, DType(1.0), DType(0.0), tA, tB, s);
+ break;
+ case kAddTo:
+ linalg_gemm(A, B, C, DType(1.0), DType(1.0), tA, tB, s);
+ break;
+ default:
+ LOG(FATAL) << "not reached";
+ }
+}
+
+#if (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)
+
+// A template for a cpu linalg_gemm implementation using mshadow::dot()
+#define LINALG_CPU_GEMM_NO_CBLAS(DType) \
+template<> inline \
+void linalg_gemm<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
+ const Tensor<cpu, 2, DType>& B, \
+ const Tensor<cpu, 2, DType>& C, \
+ bool tA, bool tB, Stream<cpu> *s, \
+ mxnet::OpReqType req) { \
+ using namespace mxnet; \
+ using mshadow::cpu; \
+ switch (req) { \
+ case kNullOp: \
+ break; \
+ case kWriteTo: \
+ case kWriteInplace: \
+ if (tA) { \
+ if (tB) { \
+ const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B.T()); \
+ } else { \
+ const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B); \
+ } \
+ } else { \
+ if (tB) { \
+ const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B.T()); \
+ } else { \
+ const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B); \
+ } \
+ } \
+ break; \
+ case kAddTo: \
+ if (tA) { \
+ if (tB) { \
+ const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B.T()); \
+ } else { \
+ const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B); \
+ } \
+ } else { \
+ if (tB) { \
+ const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B.T()); \
+ } else { \
+ const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B); \
+ } \
+ } \
+ break; \
+ default: \
+ LOG(FATAL) << "not reached"; \
+ } \
+}
+
+LINALG_CPU_GEMM_NO_CBLAS(float)
+LINALG_CPU_GEMM_NO_CBLAS(double)
+
+#endif // (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)
+
//////////////////////////////// TRSM
////////////////////////////////////////////
// CPU/GPU-versions of BLAS3 function "trsm". Please refer to the
BLAS3-documentation
@@ -313,95 +500,6 @@ LINALG_XPU_BATCH_TRSM(gpu, double)
#endif // __CUDACC__
-/*!
- * \brief Performs gemm, setting alpha and beta as appropriate for `req`.
- *
- * \param A the first operand of the gemm
- * \param B the second operand of the gemm
- * \param C the data to be assigned
- * \param tA whether the `A` operand should be transposed first.
- * \param tB whether the `B` operand should be transposed first.
- * \param s the stream to perform the operation
- * \param req the assignment request
- */
-template<typename xpu, typename DType>
-inline void linalg_gemm(const Tensor<xpu, 2, DType>& A,
- const Tensor<xpu, 2, DType>& B,
- const Tensor<xpu, 2, DType>& C,
- bool tA, bool tB, Stream<xpu> *s,
- mxnet::OpReqType req) {
- using namespace mxnet;
- switch (req) {
- case kNullOp:
- break;
- case kWriteTo:
- case kWriteInplace:
- linalg_gemm(A, B, C, DType(1.0), DType(0.0), tA, tB, s);
- break;
- case kAddTo:
- linalg_gemm(A, B, C, DType(1.0), DType(1.0), tA, tB, s);
- break;
- default:
- LOG(FATAL) << "not reached";
- }
-}
-
-#if (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)
-
-// A template for a cpu linalg_gemm implementation using mshadow::dot()
-#define LINALG_CPU_GEMM_NO_CBLAS(DType) \
-template<> inline \
-void linalg_gemm<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
- const Tensor<cpu, 2, DType>& B, \
- const Tensor<cpu, 2, DType>& C, \
- bool tA, bool tB, Stream<cpu> *s, \
- mxnet::OpReqType req) { \
- using namespace mxnet; \
- using mshadow::cpu; \
- switch (req) { \
- case kNullOp: \
- break; \
- case kWriteTo: \
- case kWriteInplace: \
- if (tA) { \
- if (tB) { \
- const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B.T()); \
- } else { \
- const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B); \
- } \
- } else { \
- if (tB) { \
- const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B.T()); \
- } else { \
- const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B); \
- } \
- } \
- break; \
- case kAddTo: \
- if (tA) { \
- if (tB) { \
- const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B.T()); \
- } else { \
- const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B); \
- } \
- } else { \
- if (tB) { \
- const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B.T()); \
- } else { \
- const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B); \
- } \
- } \
- break; \
- default: \
- LOG(FATAL) << "not reached"; \
- } \
-}
-
-LINALG_CPU_GEMM_NO_CBLAS(float)
-LINALG_CPU_GEMM_NO_CBLAS(double)
-
-#endif // (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)
-
//////////////////////////////// TRMM
////////////////////////////////////////////
// CPU/GPU-versions of BLAS3 function "trmm". Please refer to the
BLAS3-documentation
diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc
index 7083efe..b177165 100644
--- a/src/operator/tensor/la_op.cc
+++ b/src/operator/tensor/la_op.cc
@@ -46,8 +46,20 @@ If *n=2*, the BLAS3 function *gemm* is performed:
Here, *alpha* and *beta* are scalar parameters, and *op()* is either the
identity or
matrix transposition (depending on *transpose_a*, *transpose_b*).
-If *n>2*, *gemm* is performed separately on the trailing two dimensions for
all inputs
-(batch mode).
+If *n>2*, *gemm* is performed separately for a batch of matrices. The column
indices of the matrices
+are given by the last dimensions of the tensors, the row indices by the axis
specified with the *axis*
+parameter. By default, the trailing two dimensions will be used for matrix
encoding.
+
+For a non-default axis parameter, the operation performed is equivalent to a
series of swapaxes/gemm/swapaxes
+calls. For example let *A*, *B*, *C* be 5 dimensional tensors. Then gemm(*A*,
*B*, *C*, axis=1) is equivalent to
+
+ A1 = swapaxes(A, dim1=1, dim2=3)
+ B1 = swapaxes(B, dim1=1, dim2=3)
+ C = swapaxes(C, dim1=1, dim2=3)
+ C = gemm(A1, B1, C)
+ C = swapaxis(C, dim1=1, dim2=3)
+
+without the overhead of the additional swapaxis operations.
.. note:: The operator supports float32 and float64 data types only.
@@ -76,7 +88,7 @@ Examples::
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs)
{ return std::vector<std::pair<int, int>>{{2, 0}}; })
-.set_attr<FCompute>("FCompute<cpu>", LaOpForward<cpu, 2, 2, 3, 1, gemm>)
+.set_attr<FCompute>("FCompute<cpu>", LaOpGemmForward<cpu, 2, 2, 3, 1, gemm>)
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseIn{"_backward_linalg_gemm"})
.add_argument("A", "NDArray-or-Symbol", "Tensor of input matrices")
.add_argument("B", "NDArray-or-Symbol", "Tensor of input matrices")
@@ -92,7 +104,7 @@ NNVM_REGISTER_OP(_backward_linalg_gemm)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
{ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
-.set_attr<FCompute>("FCompute<cpu>", LaOpBackward<cpu, 2, 2, 4, 3,
gemm_backward>);
+.set_attr<FCompute>("FCompute<cpu>", LaOpGemmBackward<cpu, 2, 2, 4, 3,
gemm_backward>);
NNVM_REGISTER_OP(_linalg_gemm2)
.add_alias("linalg_gemm2")
@@ -107,8 +119,19 @@ If *n=2*, the BLAS3 function *gemm* is performed:
Here *alpha* is a scalar parameter and *op()* is either the identity or the
matrix
transposition (depending on *transpose_a*, *transpose_b*).
-If *n>2*, *gemm* is performed separately on the trailing two dimensions for
all inputs
-(batch mode).
+If *n>2*, *gemm* is performed separately for a batch of matrices. The column
indices of the matrices
+are given by the last dimensions of the tensors, the row indices by the axis
specified with the *axis*
+parameter. By default, the trailing two dimensions will be used for matrix
encoding.
+
+For a non-default axis parameter, the operation performed is equivalent to a
series of swapaxes/gemm/swapaxes
+calls. For example let *A*, *B* be 5 dimensional tensors. Then gemm(*A*, *B*,
axis=1) is equivalent to
+
+ A1 = swapaxes(A, dim1=1, dim2=3)
+ B1 = swapaxes(B, dim1=1, dim2=3)
+ C = gemm2(A1, B1)
+ C = swapaxis(C, dim1=1, dim2=3)
+
+without the overhead of the additional swapaxis operations.
.. note:: The operator supports float32 and float64 data types only.
@@ -133,7 +156,7 @@ Examples::
{ return std::vector<std::string>{"A", "B"}; } )
.set_attr<nnvm::FInferShape>("FInferShape", LaMatrixMultMacOpShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
-.set_attr<FCompute>("FCompute<cpu>", LaOpForward<cpu, 2, 2, 2, 1, gemm2>)
+.set_attr<FCompute>("FCompute<cpu>", LaOpGemmForward<cpu, 2, 2, 2, 1, gemm2>)
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseIn{"_backward_linalg_gemm2"})
.add_argument("A", "NDArray-or-Symbol", "Tensor of input matrices")
.add_argument("B", "NDArray-or-Symbol", "Tensor of input matrices")
@@ -148,7 +171,7 @@ NNVM_REGISTER_OP(_backward_linalg_gemm2)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
{ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
-.set_attr<FCompute>("FCompute<cpu>", LaOpBackward<cpu, 2, 2, 3, 2,
gemm2_backward>);
+.set_attr<FCompute>("FCompute<cpu>", LaOpGemmBackward<cpu, 2, 2, 3, 2,
gemm2_backward>);
NNVM_REGISTER_OP(_linalg_potrf)
.add_alias("linalg_potrf")
diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu
index efd705f..d736845 100644
--- a/src/operator/tensor/la_op.cu
+++ b/src/operator/tensor/la_op.cu
@@ -28,16 +28,16 @@ namespace mxnet {
namespace op {
NNVM_REGISTER_OP(_linalg_gemm)
-.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 3, 1, gemm>);
+.set_attr<FCompute>("FCompute<gpu>", LaOpGemmForward<gpu, 2, 2, 3, 1, gemm>);
NNVM_REGISTER_OP(_backward_linalg_gemm)
-.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 4, 3,
gemm_backward>);
+.set_attr<FCompute>("FCompute<gpu>", LaOpGemmBackward<gpu, 2, 2, 4, 3,
gemm_backward>);
NNVM_REGISTER_OP(_linalg_gemm2)
-.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 2, 1, gemm2>);
+.set_attr<FCompute>("FCompute<gpu>", LaOpGemmForward<gpu, 2, 2, 2, 1, gemm2>);
NNVM_REGISTER_OP(_backward_linalg_gemm2)
-.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 3, 2,
gemm2_backward>);
+.set_attr<FCompute>("FCompute<gpu>", LaOpGemmBackward<gpu, 2, 2, 3, 2,
gemm2_backward>);
NNVM_REGISTER_OP(_linalg_trmm)
.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 2, 1, trmm>);
diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h
index 3d411b2..8e2acd7 100644
--- a/src/operator/tensor/la_op.h
+++ b/src/operator/tensor/la_op.h
@@ -40,6 +40,7 @@ namespace op {
struct LaMatrixMacParam : public dmlc::Parameter<LaMatrixMacParam> {
bool transpose_a, transpose_b;
double alpha, beta;
+ int axis;
DMLC_DECLARE_PARAMETER(LaMatrixMacParam) {
DMLC_DECLARE_FIELD(transpose_a)
.set_default(false)
@@ -53,6 +54,9 @@ struct LaMatrixMacParam : public
dmlc::Parameter<LaMatrixMacParam> {
DMLC_DECLARE_FIELD(beta)
.set_default(1.0)
.describe("Scalar factor multiplied with C.");
+ DMLC_DECLARE_FIELD(axis)
+ .set_default(-2)
+ .describe("Axis corresponding to the matrix rows.");
}
};
@@ -60,6 +64,7 @@ struct LaMatrixMacParam : public
dmlc::Parameter<LaMatrixMacParam> {
struct LaMatrixMultParam : public dmlc::Parameter<LaMatrixMultParam> {
bool transpose_a, transpose_b;
double alpha;
+ int axis;
DMLC_DECLARE_PARAMETER(LaMatrixMultParam) {
DMLC_DECLARE_FIELD(transpose_a)
.set_default(false)
@@ -70,6 +75,9 @@ struct LaMatrixMultParam : public
dmlc::Parameter<LaMatrixMultParam> {
DMLC_DECLARE_FIELD(alpha)
.set_default(1.0)
.describe("Scalar factor multiplied with A*B.");
+ DMLC_DECLARE_FIELD(axis)
+ .set_default(-2)
+ .describe("Axis corresponding to the matrix row indices.");
}
};
@@ -112,30 +120,37 @@ inline bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs&
attrs,
CHECK_GE(in_attrs->size(), 2);
CHECK_EQ(out_attrs->size(), 1);
bool transpose_a(false), transpose_b(false);
+ int axis_param(-2);
if ( in_attrs->size() == 2 ) {
// Matrix-Matrix mult
transpose_a = nnvm::get<LaMatrixMultParam>(attrs.parsed).transpose_a;
transpose_b = nnvm::get<LaMatrixMultParam>(attrs.parsed).transpose_b;
+ axis_param = nnvm::get<LaMatrixMultParam>(attrs.parsed).axis;
} else {
// Matrix-Matrix mac
transpose_a = nnvm::get<LaMatrixMacParam>(attrs.parsed).transpose_a;
transpose_b = nnvm::get<LaMatrixMacParam>(attrs.parsed).transpose_b;
+ axis_param = nnvm::get<LaMatrixMacParam>(attrs.parsed).axis;
}
if ( (*in_attrs)[0].ndim() >= 2 && (*in_attrs)[0].ndim() ==
(*in_attrs)[1].ndim() ) {
// Forward shape inference.
- const int ndim((*in_attrs)[0].ndim());
+ const int ndim((*in_attrs)[0].ndim()), axis(axis_param < 0 ? ndim +
axis_param : axis_param);
+ CHECK(axis >= 0 && axis < ndim-1)
+ << "Invalid row axis (" << axis_param << ")";
std::vector<int> oshape(ndim);
- for ( int i = 0; i < ndim-2; ++i ) {
- // Both inputs must have same shape except for last two dimensions.
- CHECK_EQ((*in_attrs)[0][i], (*in_attrs)[1][i])
- << "Shapes of inputs 0, 1 must be the same, except on last two
dimensions";
+ for ( int i = 0; i < ndim-1; ++i ) {
+ if (i != axis) {
+ // Both inputs must have same shape except for row/col dimensions.
+ CHECK_EQ((*in_attrs)[0][i], (*in_attrs)[1][i])
+ << "Shapes of inputs 0, 1 must be the same, except on row/col axis";
+ }
oshape[i] = (*in_attrs)[0][i];
}
- CHECK_EQ((transpose_a ? (*in_attrs)[0][ndim-2] : (*in_attrs)[0][ndim-1]),
- (transpose_b ? (*in_attrs)[1][ndim-1] : (*in_attrs)[1][ndim-2]))
+ CHECK_EQ((transpose_a ? (*in_attrs)[0][axis] : (*in_attrs)[0][ndim-1]),
+ (transpose_b ? (*in_attrs)[1][ndim-1] : (*in_attrs)[1][axis]))
<< "Incompatible matrix dimensions for multiplication";
- oshape[ndim-2] = (transpose_a ? (*in_attrs)[0][ndim-1] :
(*in_attrs)[0][ndim-2]);
- oshape[ndim-1] = (transpose_b ? (*in_attrs)[1][ndim-2] :
(*in_attrs)[1][ndim-1]);
+ oshape[axis] = (transpose_a ? (*in_attrs)[0][ndim-1] :
(*in_attrs)[0][axis]);
+ oshape[ndim-1] = (transpose_b ? (*in_attrs)[1][axis] :
(*in_attrs)[1][ndim-1]);
TShape tshape(oshape.begin(), oshape.end());
SHAPE_ASSIGN_CHECK(*out_attrs, 0, tshape);
if ( in_attrs->size() > 2 ) {
@@ -340,6 +355,33 @@ inline bool LaEigFactShape(const nnvm::NodeAttrs& attrs,
return false;
}
+// Flattener for following adaptors.
+template<typename xpu, int dim, typename DType>
+mshadow::Tensor<xpu, dim, DType> LaOpFlatten(const TBlob& blob,
+ mshadow::Stream<xpu> *s, int axis
= -2) {
+ if (axis < 0) {
+ axis = blob.ndim() + axis;
+ }
+ if (axis >= blob.ndim()-2) {
+ // Leave highest axis, collapse rest.
+ return blob.FlatToKD<xpu, dim, DType>(s);
+ }
+ // Collapse ranges [0,axis-1] and [axis+1,ndim-2].
+ CHECK_EQ(dim, 4);
+ TShape shape(dim);
+ shape[0] = 1;
+ for (int i = 0; i < axis; ++i) {
+ shape[0] *= blob.shape_[i];
+ }
+ shape[1] = blob.shape_[axis];
+ shape[2] = 1;
+ for (int i = axis+1; i < blob.ndim()-1; ++i) {
+ shape[2] *= blob.shape_[i];
+ }
+ shape[3] = blob.shape_[blob.ndim()-1];
+ return blob.get_with_shape<xpu, dim, DType>(shape.get<dim>(), s);
+}
+
// Adapters for calling the various operators with appropriate signatures.
template<typename xpu, typename DType, int idim, int odim, int inum, int onum,
typename laop>
@@ -347,7 +389,7 @@ struct LaOpCaller {
static void op(const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
const nnvm::NodeAttrs& attrs,
- const OpContext& ctx) {
+ const OpContext& ctx, int axis = -2) {
CHECK(false) << "no specialized LaOpCaller defined for template
parameters";
}
};
@@ -356,10 +398,10 @@ struct LaOpCaller<xpu, DType, idim, odim, 1, 1, laop> {
static void op(const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
const nnvm::NodeAttrs& attrs,
- const OpContext& ctx) {
+ const OpContext& ctx, int axis = -2) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
- laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
- outputs[0].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+ laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+ LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), ctx, attrs);
}
};
template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -367,11 +409,11 @@ struct LaOpCaller<xpu, DType, idim, odim, 1, 2, laop> {
static void op(const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
const nnvm::NodeAttrs& attrs,
- const OpContext& ctx) {
+ const OpContext& ctx, int axis = -2) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
- laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
- outputs[0].FlatToKD<xpu, odim+1, DType>(s),
- outputs[1].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+ laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+ LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis),
+ LaOpFlatten<xpu, odim+1, DType>(outputs[1], s, axis), ctx, attrs);
}
};
template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -379,11 +421,11 @@ struct LaOpCaller<xpu, DType, idim, odim, 2, 1, laop> {
static void op(const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
const nnvm::NodeAttrs& attrs,
- const OpContext& ctx) {
+ const OpContext& ctx, int axis = -2) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
- laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
- inputs[1].FlatToKD<xpu, idim+1, DType>(s),
- outputs[0].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+ laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+ LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis),
+ LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), ctx, attrs);
}
};
template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -391,12 +433,12 @@ struct LaOpCaller<xpu, DType, idim, odim, 3, 1, laop> {
static void op(const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
const nnvm::NodeAttrs& attrs,
- const OpContext& ctx) {
+ const OpContext& ctx, int axis = -2) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
- laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
- inputs[1].FlatToKD<xpu, idim+1, DType>(s),
- inputs[2].FlatToKD<xpu, idim+1, DType>(s),
- outputs[0].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+ laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+ LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis),
+ LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis),
+ LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), ctx, attrs);
}
};
template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -404,13 +446,13 @@ struct LaOpCaller<xpu, DType, idim, odim, 3, 2, laop> {
static void op(const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
const nnvm::NodeAttrs& attrs,
- const OpContext& ctx) {
+ const OpContext& ctx, int axis = -2) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
- laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
- inputs[1].FlatToKD<xpu, idim+1, DType>(s),
- inputs[2].FlatToKD<xpu, idim+1, DType>(s),
- outputs[0].FlatToKD<xpu, odim+1, DType>(s),
- outputs[1].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+ laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+ LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis),
+ LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis),
+ LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis),
+ LaOpFlatten<xpu, odim+1, DType>(outputs[1], s, axis), ctx, attrs);
}
};
template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -418,13 +460,13 @@ struct LaOpCaller<xpu, DType, idim, odim, 4, 1, laop> {
static void op(const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
const nnvm::NodeAttrs& attrs,
- const OpContext& ctx) {
+ const OpContext& ctx, int axis = -2) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
- laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
- inputs[1].FlatToKD<xpu, idim+1, DType>(s),
- inputs[2].FlatToKD<xpu, idim+1, DType>(s),
- inputs[3].FlatToKD<xpu, idim+1, DType>(s),
- outputs[0].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+ laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+ LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis),
+ LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis),
+ LaOpFlatten<xpu, idim+1, DType>(inputs[3], s, axis),
+ LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), ctx, attrs);
}
};
template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -432,14 +474,14 @@ struct LaOpCaller<xpu, DType, idim, odim, 4, 2, laop> {
static void op(const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
const nnvm::NodeAttrs& attrs,
- const OpContext& ctx) {
+ const OpContext& ctx, int axis = -2) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
- laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
- inputs[1].FlatToKD<xpu, idim+1, DType>(s),
- inputs[2].FlatToKD<xpu, idim+1, DType>(s),
- inputs[3].FlatToKD<xpu, idim+1, DType>(s),
- outputs[0].FlatToKD<xpu, odim+1, DType>(s),
- outputs[1].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+ laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+ LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis),
+ LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis),
+ LaOpFlatten<xpu, idim+1, DType>(inputs[3], s, axis),
+ LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis),
+ LaOpFlatten<xpu, odim+1, DType>(outputs[1], s, axis), ctx, attrs);
}
};
template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -447,15 +489,15 @@ struct LaOpCaller<xpu, DType, idim, odim, 4, 3, laop> {
static void op(const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
const nnvm::NodeAttrs& attrs,
- const OpContext& ctx) {
+ const OpContext& ctx, int axis = -2) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
- laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
- inputs[1].FlatToKD<xpu, idim+1, DType>(s),
- inputs[2].FlatToKD<xpu, idim+1, DType>(s),
- inputs[3].FlatToKD<xpu, idim+1, DType>(s),
- outputs[0].FlatToKD<xpu, odim+1, DType>(s),
- outputs[1].FlatToKD<xpu, odim+1, DType>(s),
- outputs[2].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+ laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+ LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis),
+ LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis),
+ LaOpFlatten<xpu, idim+1, DType>(inputs[3], s, axis),
+ LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis),
+ LaOpFlatten<xpu, odim+1, DType>(outputs[1], s, axis),
+ LaOpFlatten<xpu, odim+1, DType>(outputs[2], s, axis), ctx, attrs);
}
};
@@ -504,6 +546,64 @@ void LaOpBackward(const nnvm::NodeAttrs& attrs,
});
}
+template<typename xpu, int idim, int odim, int inum, int onum, typename laop>
+void LaOpGemmForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ using namespace mshadow;
+ CHECK_EQ(inputs.size(), inum);
+ CHECK_EQ(outputs.size(), onum);
+ const int axis(inputs.size() == 2 ?
nnvm::get<LaMatrixMultParam>(attrs.parsed).axis
+ :
nnvm::get<LaMatrixMacParam>(attrs.parsed).axis);
+ MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+ if (axis == -2 || axis == inputs[0].ndim()-2) {
+ LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, outputs,
+ attrs, ctx);
+ } else {
+ LaOpCaller<xpu, OType, idim+1, odim+1, inum, onum, laop>::op(inputs,
outputs,
+ attrs, ctx,
axis);
+ }
+ });
+}
+
+template<typename xpu, int idim, int odim, int inum, int onum, typename laop>
+void LaOpGemmBackward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ using namespace mshadow;
+ Stream<xpu> *s = ctx.get_stream<xpu>();
+ CHECK_EQ(inputs.size(), inum);
+ CHECK_EQ(outputs.size(), onum);
+ const int axis(inputs.size() == 3 ?
nnvm::get<LaMatrixMultParam>(attrs.parsed).axis
+ :
nnvm::get<LaMatrixMacParam>(attrs.parsed).axis);
+ MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+ std::vector<TBlob> tspace(outputs);
+ for ( int i = 0; i < onum; ++i ) {
+ if ( req[i] == kAddTo ) {
+ tspace[i].dptr_ = ctx.requested[0]
+ .get_space_typed<xpu, 1,
OType>(Shape1(outputs[i].Size()), s).dptr_;
+ }
+ }
+ if (axis == -2 || axis == inputs[0].ndim()-2) {
+ LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, outputs,
+ attrs, ctx);
+ } else {
+ LaOpCaller<xpu, OType, idim+1, odim+1, inum, onum, laop>::op(inputs,
outputs,
+ attrs, ctx,
axis);
+ }
+ for ( int i = 0; i < onum; ++i ) {
+ if ( req[i] == kAddTo ) {
+ Tensor<xpu, 1, OType> out = outputs[i].FlatTo1D<xpu, OType>(s);
+ out += tspace[i].FlatTo1D<xpu, OType>(s);
+ }
+ }
+ });
+}
+
// Specific wrapper for syevd (cannot use the default ones, because A, U have
// different dimensionality than L
diff --git a/src/operator/tensor/la_op_inline.h
b/src/operator/tensor/la_op_inline.h
index a508eb7..b483108 100644
--- a/src/operator/tensor/la_op_inline.h
+++ b/src/operator/tensor/la_op_inline.h
@@ -60,24 +60,24 @@ struct Scale {
// D = gemm(A,B,C)
struct gemm {
- template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
B,
- const Tensor<xpu, 3, DType>& C, DType alpha, DType beta,
+ template<typename xpu, int dim, typename DType>
+ static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim,
DType>& B,
+ const Tensor<xpu, dim, DType>& C, DType alpha, DType beta,
bool tA, bool tB, Stream<xpu> *s) {
linalg_batch_gemm(A, B, C, alpha, beta, tA, tB, s);
}
- template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
B,
- const Tensor<xpu, 3, DType>& C, const Tensor<xpu, 3, DType>&
D,
+ template<typename xpu, int dim, typename DType>
+ static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim,
DType>& B,
+ const Tensor<xpu, dim, DType>& C, const Tensor<xpu, dim,
DType>& D,
Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
if ( C.dptr_ != D.dptr_ ) Copy(D, C, s);
const LaMatrixMacParam& param = nnvm::get<LaMatrixMacParam>(attrs.parsed);
op(A, B, D, DType(param.alpha), DType(param.beta), param.transpose_a,
param.transpose_b, s);
}
- template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
B,
- const Tensor<xpu, 3, DType>& C, const Tensor<xpu, 3, DType>&
D,
+ template<typename xpu, int dim, typename DType>
+ static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim,
DType>& B,
+ const Tensor<xpu, dim, DType>& C, const Tensor<xpu, dim,
DType>& D,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(A, B, C, D, s, attrs);
@@ -86,17 +86,17 @@ struct gemm {
// C = gemm2(A,B)
struct gemm2 {
- template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
B,
- const Tensor<xpu, 3, DType>& C, Stream<xpu> *s,
+ template<typename xpu, int dim, typename DType>
+ static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim,
DType>& B,
+ const Tensor<xpu, dim, DType>& C, Stream<xpu> *s,
const nnvm::NodeAttrs& attrs) {
const LaMatrixMultParam& param =
nnvm::get<LaMatrixMultParam>(attrs.parsed);
gemm::op(A, B, C, DType(param.alpha), DType(0), param.transpose_a,
param.transpose_b, s);
}
- template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
B,
- const Tensor<xpu, 3, DType>& C, const OpContext& ctx,
+ template<typename xpu, int dim, typename DType>
+ static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim,
DType>& B,
+ const Tensor<xpu, dim, DType>& C, const OpContext& ctx,
const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(A, B, C, s, attrs);
@@ -343,11 +343,11 @@ struct syevd {
// Backward operators (always using batch processing)
struct gemm_backward {
- template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& dD, const Tensor<xpu, 3, DType>&
A,
- const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>&
C,
- const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>&
dB,
- const Tensor<xpu, 3, DType>& dC,
+ template<typename xpu, int dim, typename DType>
+ static void op(const Tensor<xpu, dim, DType>& dD, const Tensor<xpu, dim,
DType>& A,
+ const Tensor<xpu, dim, DType>& B, const Tensor<xpu, dim,
DType>& C,
+ const Tensor<xpu, dim, DType>& dA, const Tensor<xpu, dim,
DType>& dB,
+ const Tensor<xpu, dim, DType>& dC,
Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
const LaMatrixMacParam& param = nnvm::get<LaMatrixMacParam>(attrs.parsed);
bool tA(param.transpose_a), tB(param.transpose_b);
@@ -359,11 +359,11 @@ struct gemm_backward {
using namespace mxnet_op;
Kernel<Scale, xpu>::Launch(s, dC.MSize(), DType(param.beta), dC.dptr_);
}
- template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& dD, const Tensor<xpu, 3, DType>&
A,
- const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>&
C,
- const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>&
dB,
- const Tensor<xpu, 3, DType>& dC,
+ template<typename xpu, int dim, typename DType>
+ static void op(const Tensor<xpu, dim, DType>& dD, const Tensor<xpu, dim,
DType>& A,
+ const Tensor<xpu, dim, DType>& B, const Tensor<xpu, dim,
DType>& C,
+ const Tensor<xpu, dim, DType>& dA, const Tensor<xpu, dim,
DType>& dB,
+ const Tensor<xpu, dim, DType>& dC,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(dD, A, B, C, dA, dB, dC, s, attrs);
@@ -371,10 +371,10 @@ struct gemm_backward {
};
struct gemm2_backward {
- template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>&
A,
- const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>&
dA,
- const Tensor<xpu, 3, DType>& dB,
+ template<typename xpu, int dim, typename DType>
+ static void op(const Tensor<xpu, dim, DType>& dC, const Tensor<xpu, dim,
DType>& A,
+ const Tensor<xpu, dim, DType>& B, const Tensor<xpu, dim,
DType>& dA,
+ const Tensor<xpu, dim, DType>& dB,
Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
const LaMatrixMultParam& param =
nnvm::get<LaMatrixMultParam>(attrs.parsed);
bool tA(param.transpose_a), tB(param.transpose_b);
@@ -383,10 +383,10 @@ struct gemm2_backward {
(tB ? gemm::op(dC, A, dB, DType(param.alpha), DType(0), true, tA, s)
: gemm::op(A, dC, dB, DType(param.alpha), DType(0), !tA, false, s));
}
- template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>&
A,
- const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>&
dA,
- const Tensor<xpu, 3, DType>& dB,
+ template<typename xpu, int dim, typename DType>
+ static void op(const Tensor<xpu, dim, DType>& dC, const Tensor<xpu, dim,
DType>& A,
+ const Tensor<xpu, dim, DType>& B, const Tensor<xpu, dim,
DType>& dA,
+ const Tensor<xpu, dim, DType>& dB,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
op(dC, A, B, dA, dB, s, attrs);
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index 3f08971..923a453 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4688,7 +4688,24 @@ def test_laop():
check_fw(test_gemm, [a, b, c], [r])
if grad_check == 1:
check_grad(test_gemm, [a, b, c])
-
+ # Check for different axis that describes matrix rows.
+ a2 = np.copy(np.swapaxes(a, 0, 2))
+ b2 = np.copy(np.swapaxes(b, 0, 2))
+ c2 = np.copy(np.swapaxes(c, 0, 2))
+ r2 = np.copy(np.swapaxes(r, 0, 2))
+ test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha=4., beta=7.,
axis = 0)
+ check_fw(test_gemm, [a2, b2, c2], [r2])
+ if grad_check == 1:
+ check_grad(test_gemm, [a2, b2, c2])
+ a2 = np.copy(np.swapaxes(a, 1, 2))
+ b2 = np.copy(np.swapaxes(b, 1, 2))
+ c2 = np.copy(np.swapaxes(c, 1, 2))
+ r2 = np.copy(np.swapaxes(r, 1, 2))
+ test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha=4., beta=7.,
axis = -3)
+ check_fw(test_gemm, [a2, b2, c2], [r2])
+ if grad_check == 1:
+ check_grad(test_gemm, [a2, b2, c2])
+
# Check gemm2 operator same way as gemm.
res_gemm = 4. * np.dot(data_in1, data_in2)
test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha=4.)
@@ -4720,6 +4737,20 @@ def test_laop():
check_fw(test_gemm, [a, b], [r])
if grad_check == 1:
check_grad(test_gemm, [a, b])
+ a2 = np.copy(np.swapaxes(a, 0, 2))
+ b2 = np.copy(np.swapaxes(b, 0, 2))
+ r2 = np.copy(np.swapaxes(r, 0, 2))
+ test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha=4., axis = 0)
+ check_fw(test_gemm, [a2, b2], [r2])
+ if grad_check == 1:
+ check_grad(test_gemm, [a2, b2])
+ a2 = np.copy(np.swapaxes(a, 1, 2))
+ b2 = np.copy(np.swapaxes(b, 1, 2))
+ r2 = np.copy(np.swapaxes(r, 1, 2))
+ test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha=4., axis = -3)
+ check_fw(test_gemm, [a2, b2], [r2])
+ if grad_check == 1:
+ check_grad(test_gemm, [a2, b2])
# Now test all the other operators.
--
To stop receiving notification emails like this one, please contact
[email protected].