This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 4ac76c8 Support for axis parameter in linalg.gemm (#10864) 4ac76c8 is described below commit 4ac76c89da6d4d8feef629949dc0f9534b216e3d Author: moin <asmushet...@yahoo.de> AuthorDate: Tue May 29 20:15:11 2018 +0200 Support for axis parameter in linalg.gemm (#10864) --- src/operator/linalg.h | 7 + src/operator/linalg_impl.h | 276 ++++++++++++++++++++++----------- src/operator/tensor/la_op.cc | 39 ++++- src/operator/tensor/la_op.cu | 8 +- src/operator/tensor/la_op.h | 206 +++++++++++++++++------- src/operator/tensor/la_op_inline.h | 66 ++++---- tests/python/unittest/test_operator.py | 33 +++- 7 files changed, 447 insertions(+), 188 deletions(-) diff --git a/src/operator/linalg.h b/src/operator/linalg.h index aee67d7..dc59400 100644 --- a/src/operator/linalg.h +++ b/src/operator/linalg.h @@ -64,6 +64,13 @@ void linalg_batch_gemm(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DTyp const Tensor<xpu, 3, DType>& C, DType alpha, DType beta, bool tA, bool tB, Stream<xpu> *s = 0); +// Version of batch gemmm where rows are indexed at axis 1 and columns at axis 3. +template<typename xpu, typename DType> +void linalg_batch_gemm(const Tensor<xpu, 4, DType>& A, const Tensor<xpu, 4, DType>& B, + const Tensor<xpu, 4, DType>& C, DType alpha, DType beta, + bool tA, bool tB, Stream<xpu> *s = 0); + + template<typename xpu, typename DType> inline void linalg_gemm(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& B, diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index 151db60..08d2add 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -56,6 +56,11 @@ inline void check_gemm(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DTyp << "Non compatible matrix dimensions between inputs A and B for gemm"; } +template<typename xpu, typename DType> +void linalg_gemm_axis(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B, + const Tensor<xpu, 3, DType>& C, DType alpha, DType beta, + bool tA, bool tB, Stream<xpu> *s = 0); + #if (MSHADOW_USE_CBLAS == 1 || MSHADOW_USE_MKL == 1) #define LINALG_CPU_GEMM(fname, DType) \ @@ -80,6 +85,38 @@ void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 3, DType>& A, const Tensor< } \ } +// Batched gemm where the batch coordinate is given by the second axis. +#define LINALG_CPU_GEMM_AXIS(fname, DType) \ +template<> inline \ +void linalg_gemm_axis<cpu, DType>(const Tensor<cpu, 3, DType>& A, const Tensor<cpu, 3, DType>& B, \ + const Tensor<cpu, 3, DType>& C, DType alpha, DType beta, \ + bool tA, bool tB, Stream<cpu> *s) { \ + linalg_check_batch_size(A.size(1), B.size(1), C.size(1)); \ + for (index_t i = 0; i < A.size(1); ++i) { \ + cblas_##fname(CblasRowMajor, (tA ? CblasTrans : CblasNoTrans), \ + (tB ? CblasTrans : CblasNoTrans), \ + C.size(0), C.size(2), (tA ? A.size(0) : A.size(2)), alpha, \ + A.dptr_+i*A.stride_, A.size(1)*A.stride_, \ + B.dptr_+i*B.stride_, B.size(1)*B.stride_, beta, \ + C.dptr_+i*C.stride_, C.size(1)*C.stride_); \ + } \ +} + +LINALG_CPU_GEMM_AXIS(sgemm, float) +LINALG_CPU_GEMM_AXIS(dgemm, double) + +// Version where matrix rows are given by the second axis. +#define LINALG_XPU_BATCH_GEMM_AXIS(xpu, DType) \ +template<> inline \ +void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 4, DType>& A, const Tensor<xpu, 4, DType>& B, \ + const Tensor<xpu, 4, DType>& C, DType alpha, DType beta, \ + bool tA, bool tB, Stream<xpu> *s) { \ + linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \ + for (index_t i = 0; i < A.size(0); ++i) { \ + linalg_gemm_axis(A[i], B[i], C[i], alpha, beta, tA, tB, s); \ + } \ +} + #else #define LINALG_CPU_GEMM(fname, DType) \ @@ -98,6 +135,14 @@ void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 3, DType>& A, const Tensor< LOG(FATAL) << "linalg_batch_gemm not implemented by mxnet for cpu, needs cblas!"; \ } +#define LINALG_XPU_BATCH_GEMM_AXIS(xpu, DType) \ +template<> inline \ +void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 4, DType>& A, const Tensor<xpu, 4, DType>& B, \ + const Tensor<xpu, 4, DType>& C, DType alpha, DType beta, \ + bool tA, bool tB, Stream<xpu> *s) { \ + LOG(FATAL) << "linalg_batch_gemm not implemented by mxnet for cpu, needs cblas!"; \ +} + #endif // MSHADOW_USE_CBLAS == 1 || MSHADOW_USE_MKL == 1 LINALG_CPU_GEMM(sgemm, float) @@ -106,6 +151,9 @@ LINALG_CPU_GEMM(dgemm, double) LINALG_XPU_BATCH_GEMM(cpu, float) LINALG_XPU_BATCH_GEMM(cpu, double) +LINALG_XPU_BATCH_GEMM_AXIS(cpu, float) +LINALG_XPU_BATCH_GEMM_AXIS(cpu, double) + // Specialization of linalg_gemm<cpu, DType> for DType=mshadow::half::half_t. template<> inline void linalg_gemm<cpu, mshadow::half::half_t>(const Tensor<cpu, 2, mshadow::half::half_t>& A, @@ -140,6 +188,28 @@ void linalg_gemm<gpu, DType>(const Tensor<gpu, 2, DType>& A, const Tensor<gpu, 2 LINALG_GPU_GEMM(Sgemm, float) LINALG_GPU_GEMM(Dgemm, double) +// Version where matrix rows are given by first axis. +#define LINALG_GPU_GEMM_AXIS(fname, DType) \ +template<> inline \ +void linalg_gemm_axis<gpu, DType>(const Tensor<gpu, 3, DType>& A, const Tensor<gpu, 3, DType>& B, \ + const Tensor<gpu, 3, DType>& C, DType alpha, DType beta, \ + bool tA, bool tB, Stream<gpu> *s) { \ + using namespace mxnet; \ + using mshadow::gpu; \ + CHECK_NOTNULL(s); \ + linalg_check_batch_size(A.size(1), B.size(1), C.size(1)); \ + CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \ + (tB ? CUBLAS_OP_T : CUBLAS_OP_N), \ + (tA ? CUBLAS_OP_T : CUBLAS_OP_N), \ + C.size(2), C.size(0), (tB ? B.size(2) : B.size(0)), &alpha, \ + B.dptr_, B.size(1)*B.stride_, B.stride_, \ + A.dptr_, A.size(1)*A.stride_, A.stride_, &beta, \ + C.dptr_, C.size(1)*C.stride_, C.stride_, A.size(1))) \ +} +LINALG_GPU_GEMM_AXIS(SgemmStridedBatched, float) +LINALG_GPU_GEMM_AXIS(DgemmStridedBatched, double) + +// Specialization of linalg_gemm<gpu, DType> for DType=mshadow::half::half_t. // Specialization of linalg_gemm<gpu, DType> for DType=mshadow::half::half_t. template<> inline void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half::half_t>& A, @@ -192,6 +262,8 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half: #if CUDA_VERSION < 8000 LINALG_XPU_BATCH_GEMM(gpu, float) LINALG_XPU_BATCH_GEMM(gpu, double) + LINALG_XPU_BATCH_GEMM_AXIS(gpu, float) + LINALG_XPU_BATCH_GEMM_AXIS(gpu, double) #else #define LINALG_GPU_BATCH_GEMM(fname, DType) \ template<> inline \ @@ -217,10 +289,125 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half: LINALG_GPU_BATCH_GEMM(SgemmStridedBatched, float) LINALG_GPU_BATCH_GEMM(DgemmStridedBatched, double) +// Version where matrix rows are given by second axis. +#define LINALG_GPU_BATCH_GEMM_AXIS(fname, DType) \ + template<> inline \ + void linalg_batch_gemm<gpu, DType>(const Tensor<gpu, 4, DType>& A, \ + const Tensor<gpu, 4, DType>& B, \ + const Tensor<gpu, 4, DType>& C, DType alpha, DType beta, \ + bool tA, bool tB, Stream<gpu> *s) { \ + using namespace mxnet; \ + using mshadow::gpu; \ + CHECK_NOTNULL(s); \ + linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \ + linalg_check_batch_size(A.size(2), B.size(2), C.size(2)); \ + for (index_t i = 0; i < A.size(2); ++i) { \ + CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \ + (tB ? CUBLAS_OP_T : CUBLAS_OP_N), \ + (tA ? CUBLAS_OP_T : CUBLAS_OP_N), \ + C.size(3), C.size(1), (tB ? B.size(3) : B.size(1)), &alpha, \ + B.dptr_+i*B.stride_, B.size(2) * B.stride_, B.size(1)*B.size(2)*B.stride_, \ + A.dptr_+i*A.stride_, A.size(2) * A.stride_, A.size(1)*A.size(2)*A.stride_, &beta, \ + C.dptr_+i*C.stride_, C.size(2) * C.stride_, C.size(1)*C.size(2)*C.stride_, A.size(0))) \ + }\ + } + + LINALG_GPU_BATCH_GEMM_AXIS(SgemmStridedBatched, float) + LINALG_GPU_BATCH_GEMM_AXIS(DgemmStridedBatched, double) + #endif // CUDA < 8000 #endif // __CUDACC__ +/*! + * \brief Performs gemm, setting alpha and beta as appropriate for `req`. + * + * \param A the first operand of the gemm + * \param B the second operand of the gemm + * \param C the data to be assigned + * \param tA whether the `A` operand should be transposed first. + * \param tB whether the `B` operand should be transposed first. + * \param s the stream to perform the operation + * \param req the assignment request + */ +template<typename xpu, typename DType> +inline void linalg_gemm(const Tensor<xpu, 2, DType>& A, + const Tensor<xpu, 2, DType>& B, + const Tensor<xpu, 2, DType>& C, + bool tA, bool tB, Stream<xpu> *s, + mxnet::OpReqType req) { + using namespace mxnet; + switch (req) { + case kNullOp: + break; + case kWriteTo: + case kWriteInplace: + linalg_gemm(A, B, C, DType(1.0), DType(0.0), tA, tB, s); + break; + case kAddTo: + linalg_gemm(A, B, C, DType(1.0), DType(1.0), tA, tB, s); + break; + default: + LOG(FATAL) << "not reached"; + } +} + +#if (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0) + +// A template for a cpu linalg_gemm implementation using mshadow::dot() +#define LINALG_CPU_GEMM_NO_CBLAS(DType) \ +template<> inline \ +void linalg_gemm<cpu, DType>(const Tensor<cpu, 2, DType>& A, \ + const Tensor<cpu, 2, DType>& B, \ + const Tensor<cpu, 2, DType>& C, \ + bool tA, bool tB, Stream<cpu> *s, \ + mxnet::OpReqType req) { \ + using namespace mxnet; \ + using mshadow::cpu; \ + switch (req) { \ + case kNullOp: \ + break; \ + case kWriteTo: \ + case kWriteInplace: \ + if (tA) { \ + if (tB) { \ + const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B.T()); \ + } else { \ + const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B); \ + } \ + } else { \ + if (tB) { \ + const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B.T()); \ + } else { \ + const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B); \ + } \ + } \ + break; \ + case kAddTo: \ + if (tA) { \ + if (tB) { \ + const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B.T()); \ + } else { \ + const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B); \ + } \ + } else { \ + if (tB) { \ + const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B.T()); \ + } else { \ + const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B); \ + } \ + } \ + break; \ + default: \ + LOG(FATAL) << "not reached"; \ + } \ +} + +LINALG_CPU_GEMM_NO_CBLAS(float) +LINALG_CPU_GEMM_NO_CBLAS(double) + +#endif // (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0) + //////////////////////////////// TRSM //////////////////////////////////////////// // CPU/GPU-versions of BLAS3 function "trsm". Please refer to the BLAS3-documentation @@ -313,95 +500,6 @@ LINALG_XPU_BATCH_TRSM(gpu, double) #endif // __CUDACC__ -/*! - * \brief Performs gemm, setting alpha and beta as appropriate for `req`. - * - * \param A the first operand of the gemm - * \param B the second operand of the gemm - * \param C the data to be assigned - * \param tA whether the `A` operand should be transposed first. - * \param tB whether the `B` operand should be transposed first. - * \param s the stream to perform the operation - * \param req the assignment request - */ -template<typename xpu, typename DType> -inline void linalg_gemm(const Tensor<xpu, 2, DType>& A, - const Tensor<xpu, 2, DType>& B, - const Tensor<xpu, 2, DType>& C, - bool tA, bool tB, Stream<xpu> *s, - mxnet::OpReqType req) { - using namespace mxnet; - switch (req) { - case kNullOp: - break; - case kWriteTo: - case kWriteInplace: - linalg_gemm(A, B, C, DType(1.0), DType(0.0), tA, tB, s); - break; - case kAddTo: - linalg_gemm(A, B, C, DType(1.0), DType(1.0), tA, tB, s); - break; - default: - LOG(FATAL) << "not reached"; - } -} - -#if (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0) - -// A template for a cpu linalg_gemm implementation using mshadow::dot() -#define LINALG_CPU_GEMM_NO_CBLAS(DType) \ -template<> inline \ -void linalg_gemm<cpu, DType>(const Tensor<cpu, 2, DType>& A, \ - const Tensor<cpu, 2, DType>& B, \ - const Tensor<cpu, 2, DType>& C, \ - bool tA, bool tB, Stream<cpu> *s, \ - mxnet::OpReqType req) { \ - using namespace mxnet; \ - using mshadow::cpu; \ - switch (req) { \ - case kNullOp: \ - break; \ - case kWriteTo: \ - case kWriteInplace: \ - if (tA) { \ - if (tB) { \ - const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B.T()); \ - } else { \ - const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B); \ - } \ - } else { \ - if (tB) { \ - const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B.T()); \ - } else { \ - const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B); \ - } \ - } \ - break; \ - case kAddTo: \ - if (tA) { \ - if (tB) { \ - const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B.T()); \ - } else { \ - const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B); \ - } \ - } else { \ - if (tB) { \ - const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B.T()); \ - } else { \ - const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B); \ - } \ - } \ - break; \ - default: \ - LOG(FATAL) << "not reached"; \ - } \ -} - -LINALG_CPU_GEMM_NO_CBLAS(float) -LINALG_CPU_GEMM_NO_CBLAS(double) - -#endif // (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0) - //////////////////////////////// TRMM //////////////////////////////////////////// // CPU/GPU-versions of BLAS3 function "trmm". Please refer to the BLAS3-documentation diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index 7083efe..b177165 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -46,8 +46,20 @@ If *n=2*, the BLAS3 function *gemm* is performed: Here, *alpha* and *beta* are scalar parameters, and *op()* is either the identity or matrix transposition (depending on *transpose_a*, *transpose_b*). -If *n>2*, *gemm* is performed separately on the trailing two dimensions for all inputs -(batch mode). +If *n>2*, *gemm* is performed separately for a batch of matrices. The column indices of the matrices +are given by the last dimensions of the tensors, the row indices by the axis specified with the *axis* +parameter. By default, the trailing two dimensions will be used for matrix encoding. + +For a non-default axis parameter, the operation performed is equivalent to a series of swapaxes/gemm/swapaxes +calls. For example let *A*, *B*, *C* be 5 dimensional tensors. Then gemm(*A*, *B*, *C*, axis=1) is equivalent to + + A1 = swapaxes(A, dim1=1, dim2=3) + B1 = swapaxes(B, dim1=1, dim2=3) + C = swapaxes(C, dim1=1, dim2=3) + C = gemm(A1, B1, C) + C = swapaxis(C, dim1=1, dim2=3) + +without the overhead of the additional swapaxis operations. .. note:: The operator supports float32 and float64 data types only. @@ -76,7 +88,7 @@ Examples:: .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>) .set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs) { return std::vector<std::pair<int, int>>{{2, 0}}; }) -.set_attr<FCompute>("FCompute<cpu>", LaOpForward<cpu, 2, 2, 3, 1, gemm>) +.set_attr<FCompute>("FCompute<cpu>", LaOpGemmForward<cpu, 2, 2, 3, 1, gemm>) .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_linalg_gemm"}) .add_argument("A", "NDArray-or-Symbol", "Tensor of input matrices") .add_argument("B", "NDArray-or-Symbol", "Tensor of input matrices") @@ -92,7 +104,7 @@ NNVM_REGISTER_OP(_backward_linalg_gemm) .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; }) .set_attr<nnvm::TIsBackward>("TIsBackward", true) -.set_attr<FCompute>("FCompute<cpu>", LaOpBackward<cpu, 2, 2, 4, 3, gemm_backward>); +.set_attr<FCompute>("FCompute<cpu>", LaOpGemmBackward<cpu, 2, 2, 4, 3, gemm_backward>); NNVM_REGISTER_OP(_linalg_gemm2) .add_alias("linalg_gemm2") @@ -107,8 +119,19 @@ If *n=2*, the BLAS3 function *gemm* is performed: Here *alpha* is a scalar parameter and *op()* is either the identity or the matrix transposition (depending on *transpose_a*, *transpose_b*). -If *n>2*, *gemm* is performed separately on the trailing two dimensions for all inputs -(batch mode). +If *n>2*, *gemm* is performed separately for a batch of matrices. The column indices of the matrices +are given by the last dimensions of the tensors, the row indices by the axis specified with the *axis* +parameter. By default, the trailing two dimensions will be used for matrix encoding. + +For a non-default axis parameter, the operation performed is equivalent to a series of swapaxes/gemm/swapaxes +calls. For example let *A*, *B* be 5 dimensional tensors. Then gemm(*A*, *B*, axis=1) is equivalent to + + A1 = swapaxes(A, dim1=1, dim2=3) + B1 = swapaxes(B, dim1=1, dim2=3) + C = gemm2(A1, B1) + C = swapaxis(C, dim1=1, dim2=3) + +without the overhead of the additional swapaxis operations. .. note:: The operator supports float32 and float64 data types only. @@ -133,7 +156,7 @@ Examples:: { return std::vector<std::string>{"A", "B"}; } ) .set_attr<nnvm::FInferShape>("FInferShape", LaMatrixMultMacOpShape) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) -.set_attr<FCompute>("FCompute<cpu>", LaOpForward<cpu, 2, 2, 2, 1, gemm2>) +.set_attr<FCompute>("FCompute<cpu>", LaOpGemmForward<cpu, 2, 2, 2, 1, gemm2>) .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_linalg_gemm2"}) .add_argument("A", "NDArray-or-Symbol", "Tensor of input matrices") .add_argument("B", "NDArray-or-Symbol", "Tensor of input matrices") @@ -148,7 +171,7 @@ NNVM_REGISTER_OP(_backward_linalg_gemm2) .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; }) .set_attr<nnvm::TIsBackward>("TIsBackward", true) -.set_attr<FCompute>("FCompute<cpu>", LaOpBackward<cpu, 2, 2, 3, 2, gemm2_backward>); +.set_attr<FCompute>("FCompute<cpu>", LaOpGemmBackward<cpu, 2, 2, 3, 2, gemm2_backward>); NNVM_REGISTER_OP(_linalg_potrf) .add_alias("linalg_potrf") diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu index efd705f..d736845 100644 --- a/src/operator/tensor/la_op.cu +++ b/src/operator/tensor/la_op.cu @@ -28,16 +28,16 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_linalg_gemm) -.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 3, 1, gemm>); +.set_attr<FCompute>("FCompute<gpu>", LaOpGemmForward<gpu, 2, 2, 3, 1, gemm>); NNVM_REGISTER_OP(_backward_linalg_gemm) -.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 4, 3, gemm_backward>); +.set_attr<FCompute>("FCompute<gpu>", LaOpGemmBackward<gpu, 2, 2, 4, 3, gemm_backward>); NNVM_REGISTER_OP(_linalg_gemm2) -.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 2, 1, gemm2>); +.set_attr<FCompute>("FCompute<gpu>", LaOpGemmForward<gpu, 2, 2, 2, 1, gemm2>); NNVM_REGISTER_OP(_backward_linalg_gemm2) -.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 3, 2, gemm2_backward>); +.set_attr<FCompute>("FCompute<gpu>", LaOpGemmBackward<gpu, 2, 2, 3, 2, gemm2_backward>); NNVM_REGISTER_OP(_linalg_trmm) .set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 2, 1, trmm>); diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index 3d411b2..8e2acd7 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -40,6 +40,7 @@ namespace op { struct LaMatrixMacParam : public dmlc::Parameter<LaMatrixMacParam> { bool transpose_a, transpose_b; double alpha, beta; + int axis; DMLC_DECLARE_PARAMETER(LaMatrixMacParam) { DMLC_DECLARE_FIELD(transpose_a) .set_default(false) @@ -53,6 +54,9 @@ struct LaMatrixMacParam : public dmlc::Parameter<LaMatrixMacParam> { DMLC_DECLARE_FIELD(beta) .set_default(1.0) .describe("Scalar factor multiplied with C."); + DMLC_DECLARE_FIELD(axis) + .set_default(-2) + .describe("Axis corresponding to the matrix rows."); } }; @@ -60,6 +64,7 @@ struct LaMatrixMacParam : public dmlc::Parameter<LaMatrixMacParam> { struct LaMatrixMultParam : public dmlc::Parameter<LaMatrixMultParam> { bool transpose_a, transpose_b; double alpha; + int axis; DMLC_DECLARE_PARAMETER(LaMatrixMultParam) { DMLC_DECLARE_FIELD(transpose_a) .set_default(false) @@ -70,6 +75,9 @@ struct LaMatrixMultParam : public dmlc::Parameter<LaMatrixMultParam> { DMLC_DECLARE_FIELD(alpha) .set_default(1.0) .describe("Scalar factor multiplied with A*B."); + DMLC_DECLARE_FIELD(axis) + .set_default(-2) + .describe("Axis corresponding to the matrix row indices."); } }; @@ -112,30 +120,37 @@ inline bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs, CHECK_GE(in_attrs->size(), 2); CHECK_EQ(out_attrs->size(), 1); bool transpose_a(false), transpose_b(false); + int axis_param(-2); if ( in_attrs->size() == 2 ) { // Matrix-Matrix mult transpose_a = nnvm::get<LaMatrixMultParam>(attrs.parsed).transpose_a; transpose_b = nnvm::get<LaMatrixMultParam>(attrs.parsed).transpose_b; + axis_param = nnvm::get<LaMatrixMultParam>(attrs.parsed).axis; } else { // Matrix-Matrix mac transpose_a = nnvm::get<LaMatrixMacParam>(attrs.parsed).transpose_a; transpose_b = nnvm::get<LaMatrixMacParam>(attrs.parsed).transpose_b; + axis_param = nnvm::get<LaMatrixMacParam>(attrs.parsed).axis; } if ( (*in_attrs)[0].ndim() >= 2 && (*in_attrs)[0].ndim() == (*in_attrs)[1].ndim() ) { // Forward shape inference. - const int ndim((*in_attrs)[0].ndim()); + const int ndim((*in_attrs)[0].ndim()), axis(axis_param < 0 ? ndim + axis_param : axis_param); + CHECK(axis >= 0 && axis < ndim-1) + << "Invalid row axis (" << axis_param << ")"; std::vector<int> oshape(ndim); - for ( int i = 0; i < ndim-2; ++i ) { - // Both inputs must have same shape except for last two dimensions. - CHECK_EQ((*in_attrs)[0][i], (*in_attrs)[1][i]) - << "Shapes of inputs 0, 1 must be the same, except on last two dimensions"; + for ( int i = 0; i < ndim-1; ++i ) { + if (i != axis) { + // Both inputs must have same shape except for row/col dimensions. + CHECK_EQ((*in_attrs)[0][i], (*in_attrs)[1][i]) + << "Shapes of inputs 0, 1 must be the same, except on row/col axis"; + } oshape[i] = (*in_attrs)[0][i]; } - CHECK_EQ((transpose_a ? (*in_attrs)[0][ndim-2] : (*in_attrs)[0][ndim-1]), - (transpose_b ? (*in_attrs)[1][ndim-1] : (*in_attrs)[1][ndim-2])) + CHECK_EQ((transpose_a ? (*in_attrs)[0][axis] : (*in_attrs)[0][ndim-1]), + (transpose_b ? (*in_attrs)[1][ndim-1] : (*in_attrs)[1][axis])) << "Incompatible matrix dimensions for multiplication"; - oshape[ndim-2] = (transpose_a ? (*in_attrs)[0][ndim-1] : (*in_attrs)[0][ndim-2]); - oshape[ndim-1] = (transpose_b ? (*in_attrs)[1][ndim-2] : (*in_attrs)[1][ndim-1]); + oshape[axis] = (transpose_a ? (*in_attrs)[0][ndim-1] : (*in_attrs)[0][axis]); + oshape[ndim-1] = (transpose_b ? (*in_attrs)[1][axis] : (*in_attrs)[1][ndim-1]); TShape tshape(oshape.begin(), oshape.end()); SHAPE_ASSIGN_CHECK(*out_attrs, 0, tshape); if ( in_attrs->size() > 2 ) { @@ -340,6 +355,33 @@ inline bool LaEigFactShape(const nnvm::NodeAttrs& attrs, return false; } +// Flattener for following adaptors. +template<typename xpu, int dim, typename DType> +mshadow::Tensor<xpu, dim, DType> LaOpFlatten(const TBlob& blob, + mshadow::Stream<xpu> *s, int axis = -2) { + if (axis < 0) { + axis = blob.ndim() + axis; + } + if (axis >= blob.ndim()-2) { + // Leave highest axis, collapse rest. + return blob.FlatToKD<xpu, dim, DType>(s); + } + // Collapse ranges [0,axis-1] and [axis+1,ndim-2]. + CHECK_EQ(dim, 4); + TShape shape(dim); + shape[0] = 1; + for (int i = 0; i < axis; ++i) { + shape[0] *= blob.shape_[i]; + } + shape[1] = blob.shape_[axis]; + shape[2] = 1; + for (int i = axis+1; i < blob.ndim()-1; ++i) { + shape[2] *= blob.shape_[i]; + } + shape[3] = blob.shape_[blob.ndim()-1]; + return blob.get_with_shape<xpu, dim, DType>(shape.get<dim>(), s); +} + // Adapters for calling the various operators with appropriate signatures. template<typename xpu, typename DType, int idim, int odim, int inum, int onum, typename laop> @@ -347,7 +389,7 @@ struct LaOpCaller { static void op(const std::vector<TBlob>& inputs, const std::vector<TBlob>& outputs, const nnvm::NodeAttrs& attrs, - const OpContext& ctx) { + const OpContext& ctx, int axis = -2) { CHECK(false) << "no specialized LaOpCaller defined for template parameters"; } }; @@ -356,10 +398,10 @@ struct LaOpCaller<xpu, DType, idim, odim, 1, 1, laop> { static void op(const std::vector<TBlob>& inputs, const std::vector<TBlob>& outputs, const nnvm::NodeAttrs& attrs, - const OpContext& ctx) { + const OpContext& ctx, int axis = -2) { mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); - laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s), - outputs[0].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs); + laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis), + LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), ctx, attrs); } }; template<typename xpu, typename DType, int idim, int odim, typename laop> @@ -367,11 +409,11 @@ struct LaOpCaller<xpu, DType, idim, odim, 1, 2, laop> { static void op(const std::vector<TBlob>& inputs, const std::vector<TBlob>& outputs, const nnvm::NodeAttrs& attrs, - const OpContext& ctx) { + const OpContext& ctx, int axis = -2) { mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); - laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s), - outputs[0].FlatToKD<xpu, odim+1, DType>(s), - outputs[1].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs); + laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis), + LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), + LaOpFlatten<xpu, odim+1, DType>(outputs[1], s, axis), ctx, attrs); } }; template<typename xpu, typename DType, int idim, int odim, typename laop> @@ -379,11 +421,11 @@ struct LaOpCaller<xpu, DType, idim, odim, 2, 1, laop> { static void op(const std::vector<TBlob>& inputs, const std::vector<TBlob>& outputs, const nnvm::NodeAttrs& attrs, - const OpContext& ctx) { + const OpContext& ctx, int axis = -2) { mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); - laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s), - inputs[1].FlatToKD<xpu, idim+1, DType>(s), - outputs[0].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs); + laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis), + LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis), + LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), ctx, attrs); } }; template<typename xpu, typename DType, int idim, int odim, typename laop> @@ -391,12 +433,12 @@ struct LaOpCaller<xpu, DType, idim, odim, 3, 1, laop> { static void op(const std::vector<TBlob>& inputs, const std::vector<TBlob>& outputs, const nnvm::NodeAttrs& attrs, - const OpContext& ctx) { + const OpContext& ctx, int axis = -2) { mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); - laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s), - inputs[1].FlatToKD<xpu, idim+1, DType>(s), - inputs[2].FlatToKD<xpu, idim+1, DType>(s), - outputs[0].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs); + laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis), + LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis), + LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis), + LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), ctx, attrs); } }; template<typename xpu, typename DType, int idim, int odim, typename laop> @@ -404,13 +446,13 @@ struct LaOpCaller<xpu, DType, idim, odim, 3, 2, laop> { static void op(const std::vector<TBlob>& inputs, const std::vector<TBlob>& outputs, const nnvm::NodeAttrs& attrs, - const OpContext& ctx) { + const OpContext& ctx, int axis = -2) { mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); - laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s), - inputs[1].FlatToKD<xpu, idim+1, DType>(s), - inputs[2].FlatToKD<xpu, idim+1, DType>(s), - outputs[0].FlatToKD<xpu, odim+1, DType>(s), - outputs[1].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs); + laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis), + LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis), + LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis), + LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), + LaOpFlatten<xpu, odim+1, DType>(outputs[1], s, axis), ctx, attrs); } }; template<typename xpu, typename DType, int idim, int odim, typename laop> @@ -418,13 +460,13 @@ struct LaOpCaller<xpu, DType, idim, odim, 4, 1, laop> { static void op(const std::vector<TBlob>& inputs, const std::vector<TBlob>& outputs, const nnvm::NodeAttrs& attrs, - const OpContext& ctx) { + const OpContext& ctx, int axis = -2) { mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); - laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s), - inputs[1].FlatToKD<xpu, idim+1, DType>(s), - inputs[2].FlatToKD<xpu, idim+1, DType>(s), - inputs[3].FlatToKD<xpu, idim+1, DType>(s), - outputs[0].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs); + laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis), + LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis), + LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis), + LaOpFlatten<xpu, idim+1, DType>(inputs[3], s, axis), + LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), ctx, attrs); } }; template<typename xpu, typename DType, int idim, int odim, typename laop> @@ -432,14 +474,14 @@ struct LaOpCaller<xpu, DType, idim, odim, 4, 2, laop> { static void op(const std::vector<TBlob>& inputs, const std::vector<TBlob>& outputs, const nnvm::NodeAttrs& attrs, - const OpContext& ctx) { + const OpContext& ctx, int axis = -2) { mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); - laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s), - inputs[1].FlatToKD<xpu, idim+1, DType>(s), - inputs[2].FlatToKD<xpu, idim+1, DType>(s), - inputs[3].FlatToKD<xpu, idim+1, DType>(s), - outputs[0].FlatToKD<xpu, odim+1, DType>(s), - outputs[1].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs); + laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis), + LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis), + LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis), + LaOpFlatten<xpu, idim+1, DType>(inputs[3], s, axis), + LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), + LaOpFlatten<xpu, odim+1, DType>(outputs[1], s, axis), ctx, attrs); } }; template<typename xpu, typename DType, int idim, int odim, typename laop> @@ -447,15 +489,15 @@ struct LaOpCaller<xpu, DType, idim, odim, 4, 3, laop> { static void op(const std::vector<TBlob>& inputs, const std::vector<TBlob>& outputs, const nnvm::NodeAttrs& attrs, - const OpContext& ctx) { + const OpContext& ctx, int axis = -2) { mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); - laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s), - inputs[1].FlatToKD<xpu, idim+1, DType>(s), - inputs[2].FlatToKD<xpu, idim+1, DType>(s), - inputs[3].FlatToKD<xpu, idim+1, DType>(s), - outputs[0].FlatToKD<xpu, odim+1, DType>(s), - outputs[1].FlatToKD<xpu, odim+1, DType>(s), - outputs[2].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs); + laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis), + LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis), + LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis), + LaOpFlatten<xpu, idim+1, DType>(inputs[3], s, axis), + LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), + LaOpFlatten<xpu, odim+1, DType>(outputs[1], s, axis), + LaOpFlatten<xpu, odim+1, DType>(outputs[2], s, axis), ctx, attrs); } }; @@ -504,6 +546,64 @@ void LaOpBackward(const nnvm::NodeAttrs& attrs, }); } +template<typename xpu, int idim, int odim, int inum, int onum, typename laop> +void LaOpGemmForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector<TBlob>& inputs, + const std::vector<OpReqType>& req, + const std::vector<TBlob>& outputs) { + using namespace mshadow; + CHECK_EQ(inputs.size(), inum); + CHECK_EQ(outputs.size(), onum); + const int axis(inputs.size() == 2 ? nnvm::get<LaMatrixMultParam>(attrs.parsed).axis + : nnvm::get<LaMatrixMacParam>(attrs.parsed).axis); + MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + if (axis == -2 || axis == inputs[0].ndim()-2) { + LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, outputs, + attrs, ctx); + } else { + LaOpCaller<xpu, OType, idim+1, odim+1, inum, onum, laop>::op(inputs, outputs, + attrs, ctx, axis); + } + }); +} + +template<typename xpu, int idim, int odim, int inum, int onum, typename laop> +void LaOpGemmBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector<TBlob>& inputs, + const std::vector<OpReqType>& req, + const std::vector<TBlob>& outputs) { + using namespace mshadow; + Stream<xpu> *s = ctx.get_stream<xpu>(); + CHECK_EQ(inputs.size(), inum); + CHECK_EQ(outputs.size(), onum); + const int axis(inputs.size() == 3 ? nnvm::get<LaMatrixMultParam>(attrs.parsed).axis + : nnvm::get<LaMatrixMacParam>(attrs.parsed).axis); + MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + std::vector<TBlob> tspace(outputs); + for ( int i = 0; i < onum; ++i ) { + if ( req[i] == kAddTo ) { + tspace[i].dptr_ = ctx.requested[0] + .get_space_typed<xpu, 1, OType>(Shape1(outputs[i].Size()), s).dptr_; + } + } + if (axis == -2 || axis == inputs[0].ndim()-2) { + LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, outputs, + attrs, ctx); + } else { + LaOpCaller<xpu, OType, idim+1, odim+1, inum, onum, laop>::op(inputs, outputs, + attrs, ctx, axis); + } + for ( int i = 0; i < onum; ++i ) { + if ( req[i] == kAddTo ) { + Tensor<xpu, 1, OType> out = outputs[i].FlatTo1D<xpu, OType>(s); + out += tspace[i].FlatTo1D<xpu, OType>(s); + } + } + }); +} + // Specific wrapper for syevd (cannot use the default ones, because A, U have // different dimensionality than L diff --git a/src/operator/tensor/la_op_inline.h b/src/operator/tensor/la_op_inline.h index a508eb7..b483108 100644 --- a/src/operator/tensor/la_op_inline.h +++ b/src/operator/tensor/la_op_inline.h @@ -60,24 +60,24 @@ struct Scale { // D = gemm(A,B,C) struct gemm { - template<typename xpu, typename DType> - static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B, - const Tensor<xpu, 3, DType>& C, DType alpha, DType beta, + template<typename xpu, int dim, typename DType> + static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, DType>& B, + const Tensor<xpu, dim, DType>& C, DType alpha, DType beta, bool tA, bool tB, Stream<xpu> *s) { linalg_batch_gemm(A, B, C, alpha, beta, tA, tB, s); } - template<typename xpu, typename DType> - static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B, - const Tensor<xpu, 3, DType>& C, const Tensor<xpu, 3, DType>& D, + template<typename xpu, int dim, typename DType> + static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, DType>& B, + const Tensor<xpu, dim, DType>& C, const Tensor<xpu, dim, DType>& D, Stream<xpu> *s, const nnvm::NodeAttrs& attrs) { if ( C.dptr_ != D.dptr_ ) Copy(D, C, s); const LaMatrixMacParam& param = nnvm::get<LaMatrixMacParam>(attrs.parsed); op(A, B, D, DType(param.alpha), DType(param.beta), param.transpose_a, param.transpose_b, s); } - template<typename xpu, typename DType> - static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B, - const Tensor<xpu, 3, DType>& C, const Tensor<xpu, 3, DType>& D, + template<typename xpu, int dim, typename DType> + static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, DType>& B, + const Tensor<xpu, dim, DType>& C, const Tensor<xpu, dim, DType>& D, const OpContext& ctx, const nnvm::NodeAttrs& attrs) { Stream<xpu> *s = ctx.get_stream<xpu>(); op(A, B, C, D, s, attrs); @@ -86,17 +86,17 @@ struct gemm { // C = gemm2(A,B) struct gemm2 { - template<typename xpu, typename DType> - static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B, - const Tensor<xpu, 3, DType>& C, Stream<xpu> *s, + template<typename xpu, int dim, typename DType> + static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, DType>& B, + const Tensor<xpu, dim, DType>& C, Stream<xpu> *s, const nnvm::NodeAttrs& attrs) { const LaMatrixMultParam& param = nnvm::get<LaMatrixMultParam>(attrs.parsed); gemm::op(A, B, C, DType(param.alpha), DType(0), param.transpose_a, param.transpose_b, s); } - template<typename xpu, typename DType> - static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B, - const Tensor<xpu, 3, DType>& C, const OpContext& ctx, + template<typename xpu, int dim, typename DType> + static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, DType>& B, + const Tensor<xpu, dim, DType>& C, const OpContext& ctx, const nnvm::NodeAttrs& attrs) { Stream<xpu> *s = ctx.get_stream<xpu>(); op(A, B, C, s, attrs); @@ -343,11 +343,11 @@ struct syevd { // Backward operators (always using batch processing) struct gemm_backward { - template<typename xpu, typename DType> - static void op(const Tensor<xpu, 3, DType>& dD, const Tensor<xpu, 3, DType>& A, - const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& C, - const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& dB, - const Tensor<xpu, 3, DType>& dC, + template<typename xpu, int dim, typename DType> + static void op(const Tensor<xpu, dim, DType>& dD, const Tensor<xpu, dim, DType>& A, + const Tensor<xpu, dim, DType>& B, const Tensor<xpu, dim, DType>& C, + const Tensor<xpu, dim, DType>& dA, const Tensor<xpu, dim, DType>& dB, + const Tensor<xpu, dim, DType>& dC, Stream<xpu>* s, const nnvm::NodeAttrs& attrs) { const LaMatrixMacParam& param = nnvm::get<LaMatrixMacParam>(attrs.parsed); bool tA(param.transpose_a), tB(param.transpose_b); @@ -359,11 +359,11 @@ struct gemm_backward { using namespace mxnet_op; Kernel<Scale, xpu>::Launch(s, dC.MSize(), DType(param.beta), dC.dptr_); } - template<typename xpu, typename DType> - static void op(const Tensor<xpu, 3, DType>& dD, const Tensor<xpu, 3, DType>& A, - const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& C, - const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& dB, - const Tensor<xpu, 3, DType>& dC, + template<typename xpu, int dim, typename DType> + static void op(const Tensor<xpu, dim, DType>& dD, const Tensor<xpu, dim, DType>& A, + const Tensor<xpu, dim, DType>& B, const Tensor<xpu, dim, DType>& C, + const Tensor<xpu, dim, DType>& dA, const Tensor<xpu, dim, DType>& dB, + const Tensor<xpu, dim, DType>& dC, const OpContext& ctx, const nnvm::NodeAttrs& attrs) { Stream<xpu> *s = ctx.get_stream<xpu>(); op(dD, A, B, C, dA, dB, dC, s, attrs); @@ -371,10 +371,10 @@ struct gemm_backward { }; struct gemm2_backward { - template<typename xpu, typename DType> - static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& A, - const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& dA, - const Tensor<xpu, 3, DType>& dB, + template<typename xpu, int dim, typename DType> + static void op(const Tensor<xpu, dim, DType>& dC, const Tensor<xpu, dim, DType>& A, + const Tensor<xpu, dim, DType>& B, const Tensor<xpu, dim, DType>& dA, + const Tensor<xpu, dim, DType>& dB, Stream<xpu>* s, const nnvm::NodeAttrs& attrs) { const LaMatrixMultParam& param = nnvm::get<LaMatrixMultParam>(attrs.parsed); bool tA(param.transpose_a), tB(param.transpose_b); @@ -383,10 +383,10 @@ struct gemm2_backward { (tB ? gemm::op(dC, A, dB, DType(param.alpha), DType(0), true, tA, s) : gemm::op(A, dC, dB, DType(param.alpha), DType(0), !tA, false, s)); } - template<typename xpu, typename DType> - static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& A, - const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& dA, - const Tensor<xpu, 3, DType>& dB, + template<typename xpu, int dim, typename DType> + static void op(const Tensor<xpu, dim, DType>& dC, const Tensor<xpu, dim, DType>& A, + const Tensor<xpu, dim, DType>& B, const Tensor<xpu, dim, DType>& dA, + const Tensor<xpu, dim, DType>& dB, const OpContext& ctx, const nnvm::NodeAttrs& attrs) { Stream<xpu> *s = ctx.get_stream<xpu>(); op(dC, A, B, dA, dB, s, attrs); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 3f08971..923a453 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4688,7 +4688,24 @@ def test_laop(): check_fw(test_gemm, [a, b, c], [r]) if grad_check == 1: check_grad(test_gemm, [a, b, c]) - + # Check for different axis that describes matrix rows. + a2 = np.copy(np.swapaxes(a, 0, 2)) + b2 = np.copy(np.swapaxes(b, 0, 2)) + c2 = np.copy(np.swapaxes(c, 0, 2)) + r2 = np.copy(np.swapaxes(r, 0, 2)) + test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha=4., beta=7., axis = 0) + check_fw(test_gemm, [a2, b2, c2], [r2]) + if grad_check == 1: + check_grad(test_gemm, [a2, b2, c2]) + a2 = np.copy(np.swapaxes(a, 1, 2)) + b2 = np.copy(np.swapaxes(b, 1, 2)) + c2 = np.copy(np.swapaxes(c, 1, 2)) + r2 = np.copy(np.swapaxes(r, 1, 2)) + test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha=4., beta=7., axis = -3) + check_fw(test_gemm, [a2, b2, c2], [r2]) + if grad_check == 1: + check_grad(test_gemm, [a2, b2, c2]) + # Check gemm2 operator same way as gemm. res_gemm = 4. * np.dot(data_in1, data_in2) test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha=4.) @@ -4720,6 +4737,20 @@ def test_laop(): check_fw(test_gemm, [a, b], [r]) if grad_check == 1: check_grad(test_gemm, [a, b]) + a2 = np.copy(np.swapaxes(a, 0, 2)) + b2 = np.copy(np.swapaxes(b, 0, 2)) + r2 = np.copy(np.swapaxes(r, 0, 2)) + test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha=4., axis = 0) + check_fw(test_gemm, [a2, b2], [r2]) + if grad_check == 1: + check_grad(test_gemm, [a2, b2]) + a2 = np.copy(np.swapaxes(a, 1, 2)) + b2 = np.copy(np.swapaxes(b, 1, 2)) + r2 = np.copy(np.swapaxes(r, 1, 2)) + test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha=4., axis = -3) + check_fw(test_gemm, [a2, b2], [r2]) + if grad_check == 1: + check_grad(test_gemm, [a2, b2]) # Now test all the other operators. -- To stop receiving notification emails like this one, please contact j...@apache.org.