This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 83078d7 cuda support for linalg-functions, restructuring of linalg interfaces (#7147) 83078d7 is described below commit 83078d7b21491936dfe552866556b40040dadf5b Author: moin <asmushet...@yahoo.de> AuthorDate: Sat Aug 12 21:12:44 2017 +0200 cuda support for linalg-functions, restructuring of linalg interfaces (#7147) * cuda support for linalg-functions, restructuring of linalg interfaces * incorporate newest mshadow * adjustments to linalg operators --- CMakeLists.txt | 4 +- Makefile | 6 + include/mxnet/base.h | 7 + mshadow | 2 +- src/common/cuda_utils.h | 42 ++ src/io/inst_vector.h | 1 + {include/mxnet => src/operator}/c_lapack_api.h | 53 ++- src/operator/contrib/krprod.h | 2 +- src/operator/linalg.h | 118 ++++++ src/operator/linalg_impl.h | 508 +++++++++++++++++++++++++ src/operator/tensor/la_op.cc | 4 +- src/operator/tensor/la_op.cu | 77 ++++ src/operator/tensor/la_op.h | 171 ++------- src/operator/tensor/la_op_inline.h | 373 ++++++++---------- tests/python/unittest/test_operator.py | 9 +- 15 files changed, 1008 insertions(+), 369 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ab29b6a..dc9ca5f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -353,8 +353,10 @@ if(USE_CUDA) list(APPEND mxnet_LINKER_LIBS ${CUDA_cuda_LIBRARY}) FIND_LIBRARY(CUDA_cufft_LIBRARY nvrtc "${CUDA_TOOLKIT_ROOT_DIR}/lib/x64" "${CUDA_TOOLKIT_ROOT_DIR}/lib/win32") list(APPEND mxnet_LINKER_LIBS "${CUDA_cufft_LIBRARY}/../cufft.lib") # For fft operator + FIND_LIBRARY(CUDA_cusolver_LIBRARY nvrtc "${CUDA_TOOLKIT_ROOT_DIR}/lib/x64" "${CUDA_TOOLKIT_ROOT_DIR}/lib/win32") + list(APPEND mxnet_LINKER_LIBS "${CUDA_cusolver_LIBRARY}/../cusolver.lib") # For cusolver else(MSVC) - list(APPEND mxnet_LINKER_LIBS nvrtc cuda cufft) + list(APPEND mxnet_LINKER_LIBS nvrtc cuda cufft cusolver) link_directories("${CUDA_TOOLKIT_ROOT_DIR}/lib64") endif() list(APPEND SOURCE ${cuda_objs} ${CUDA}) diff --git a/Makefile b/Makefile index 560b77a..33151e5 100644 --- a/Makefile +++ b/Makefile @@ -23,6 +23,10 @@ ifndef DLPACK_PATH DLPACK_PATH = $(ROOTDIR)/dlpack endif +ifndef AMALGAMATION_PATH + AMALGAMATION_PATH = $(ROOTDIR)/amalgamation +endif + ifneq ($(USE_OPENMP), 1) export NO_OPENMP = 1 endif @@ -439,6 +443,7 @@ clean: cyclean $(EXTRA_PACKAGES_CLEAN) cd $(DMLC_CORE); $(MAKE) clean; cd - cd $(PS_PATH); $(MAKE) clean; cd - cd $(NNVM_PATH); $(MAKE) clean; cd - + cd $(AMALGAMATION_PATH); $(MAKE) clean; cd - $(RM) -r $(patsubst %, %/*.d, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.d, $(EXTRA_OPERATORS)) $(RM) -r $(patsubst %, %/*.o, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.o, $(EXTRA_OPERATORS)) else @@ -448,6 +453,7 @@ clean: cyclean testclean $(EXTRA_PACKAGES_CLEAN) cd $(DMLC_CORE); $(MAKE) clean; cd - cd $(PS_PATH); $(MAKE) clean; cd - cd $(NNVM_PATH); $(MAKE) clean; cd - + cd $(AMALGAMATION_PATH); $(MAKE) clean; cd - endif clean_all: clean diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 514bb0c..6954083 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -56,6 +56,13 @@ #define MXNET_USE_CUDNN MSHADOW_USE_CUDNN #endif +/*! + *\brief whether to use cusolver library + */ +#ifndef MXNET_USE_CUSOLVER +#define MXNET_USE_CUSOLVER MSHADOW_USE_CUSOLVER +#endif + /*! \brief Error message for using gpu when MXNET_USE_CUDA==0 */ #define MXNET_GPU_NOT_ENABLED_ERROR "GPU is not enabled" diff --git a/mshadow b/mshadow index d32b5da..497eb91 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit d32b5dacf2bb5af4121df5fd60eb7775704f9131 +Subproject commit 497eb9180b24592b7332e7e08f2c053ec5346524 diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h index 2879ab3..8897007 100644 --- a/src/common/cuda_utils.h +++ b/src/common/cuda_utils.h @@ -88,6 +88,35 @@ inline const char* CublasGetErrorString(cublasStatus_t error) { } /*! + * \brief Get string representation of cuSOLVER errors. + * \param error The error. + * \return String representation. + */ +inline const char* CusolverGetErrorString(cusolverStatus_t error) { + switch (error) { + case CUSOLVER_STATUS_SUCCESS: + return "CUSOLVER_STATUS_SUCCESS"; + case CUSOLVER_STATUS_NOT_INITIALIZED: + return "CUSOLVER_STATUS_NOT_INITIALIZED"; + case CUSOLVER_STATUS_ALLOC_FAILED: + return "CUSOLVER_STATUS_ALLOC_FAILED"; + case CUSOLVER_STATUS_INVALID_VALUE: + return "CUSOLVER_STATUS_INVALID_VALUE"; + case CUSOLVER_STATUS_ARCH_MISMATCH: + return "CUSOLVER_STATUS_ARCH_MISMATCH"; + case CUSOLVER_STATUS_EXECUTION_FAILED: + return "CUSOLVER_STATUS_EXECUTION_FAILED"; + case CUSOLVER_STATUS_INTERNAL_ERROR: + return "CUSOLVER_STATUS_INTERNAL_ERROR"; + case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED: + return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED"; + default: + break; + } + return "Unknown cuSOLVER status"; +} + +/*! * \brief Get string representation of cuRAND errors. * \param status The status. * \return String representation. @@ -165,6 +194,19 @@ inline const char* CurandGetErrorString(curandStatus_t status) { } /*! + * \brief Protected cuSolver call. + * \param func Expression to call. + * + * It checks for cuSolver errors after invocation of the expression. + */ +#define CUSOLVER_CALL(func) \ + { \ + cusolverStatus_t e = (func); \ + CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \ + << "cuSolver: " << common::cuda::CusolverGetErrorString(e); \ + } + +/*! * \brief Protected cuRAND call. * \param func Expression to call. * diff --git a/src/io/inst_vector.h b/src/io/inst_vector.h index 4bc2a6c..6dc7bdf 100644 --- a/src/io/inst_vector.h +++ b/src/io/inst_vector.h @@ -30,6 +30,7 @@ #include <mxnet/base.h> #include <dmlc/base.h> #include <mshadow/tensor.h> +#include <mshadow/tensor_blob.h> #include <vector> #include <string> diff --git a/include/mxnet/c_lapack_api.h b/src/operator/c_lapack_api.h similarity index 74% rename from include/mxnet/c_lapack_api.h rename to src/operator/c_lapack_api.h index 1ae90a9..96a9b3a 100644 --- a/include/mxnet/c_lapack_api.h +++ b/src/operator/c_lapack_api.h @@ -19,14 +19,24 @@ /*! * \file c_lapack_api.h - * \brief Unified interface for LAPACK calls from within mxnet. + * \brief Unified interface for CPU-based LAPACK calls. * Purpose is to hide the platform specific differences. */ -#ifndef MXNET_C_LAPACK_API_H_ -#define MXNET_C_LAPACK_API_H_ +#ifndef MXNET_OPERATOR_C_LAPACK_API_H_ +#define MXNET_OPERATOR_C_LAPACK_API_H_ // Manually maintained list of LAPACK interfaces that can be used // within MXNET. Conventions: +// - We should only import LAPACK-functions that are useful and +// ensure that we support them most efficiently on CPU/GPU. As an +// example take "potrs": It can be emulated by two calls to +// "trsm" (from BLAS3) so not really needed from functionality point +// of view. In addition, trsm on GPU supports batch-mode processing +// which is much more efficient for a bunch of smaller matrices while +// there is no such batch support for potrs. As a result, we may +// not support "potrs" internally and if we want to expose it to the user as +// a convenience operator at some time, then we may implement it internally +// as a sequence of trsm. // - Interfaces must be compliant with lapacke.h in terms of signature and // naming conventions so wrapping a function "foo" which has the // signature @@ -36,14 +46,21 @@ // Note that function signatures in lapacke.h will always have as first // argument the storage order (row/col-major). All wrappers have to support // that argument. The underlying fortran functions will always assume a -// column-major layout. It is the responsibility of the wrapper function -// to handle the (usual) case that it is called with data in row-major -// format, either by doing appropriate transpositions explicitly or using -// transposition options of the underlying fortran function. -// - It is ok to assume that matrices are stored in contiguous memory -// (which removes the need to do special handling for lda/ldb parameters -// and enables us to save additional matrix transpositions around -// the fortran calls). +// column-major layout. +// - In the (usual) case that a wrapper is called specifying row-major storage +// order of input/output data, there are two ways to handle this: +// 1) The wrapper may support this without allocating any additional memory +// for example by exploiting the fact that a matrix is symmetric and switching +// certain flags (upper/lower triangular) when calling the fortran code. +// 2) The wrapper may cause a runtime error. In that case it should be clearly +// documented that these functions do only support col-major layout. +// Rationale: This is a low level interface that is not expected to be called +// directly from many upstream functions. Usually all calls should go through +// the tensor-based interfaces in linalg.h which simplify calls to lapack further +// and are better suited to handle additional transpositions that may be necessary. +// Also we want to push allocation of temporary storage higher up in order to +// allow more efficient re-use of temporal storage. And don't want to plaster +// these interfaces here with additional requirements of providing buffers. // - It is desired to add some basic checking in the C++-wrappers in order // to catch simple mistakes when calling these wrappers. // - Must support compilation without lapack-package but issue runtime error in this case. @@ -54,9 +71,10 @@ using namespace mshadow; extern "C" { + // Fortran signatures #define MXNET_LAPACK_FSIGNATURE1(func, dtype) \ - void func##_(char* uplo, int* n, dtype* a, int* lda, int *info); + void func##_(char *uplo, int *n, dtype *a, int *lda, int *info); MXNET_LAPACK_FSIGNATURE1(spotrf, float) MXNET_LAPACK_FSIGNATURE1(dpotrf, double) @@ -73,9 +91,6 @@ extern "C" { #define MXNET_LAPACK_ROW_MAJOR 101 #define MXNET_LAPACK_COL_MAJOR 102 -#define CHECK_LAPACK_CONTIGUOUS(a, b) \ - CHECK_EQ(a, b) << "non contiguous memory for array in lapack call"; - #define CHECK_LAPACK_UPLO(a) \ CHECK(a == 'U' || a == 'L') << "neither L nor U specified as triangle in lapack call"; @@ -117,9 +132,9 @@ inline void flip<cpu, double>(int m, int n, #if MXNET_USE_LAPACK + // These functions can be called with either row- or col-major format. #define MXNET_LAPACK_CWRAPPER1(func, dtype) \ - inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype* a, int lda ) { \ - CHECK_LAPACK_CONTIGUOUS(n, lda); \ + inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, int lda) { \ CHECK_LAPACK_UPLO(uplo); \ char o(loup(uplo, (matrix_layout == MXNET_LAPACK_ROW_MAJOR))); \ int ret(0); \ @@ -172,7 +187,7 @@ inline void flip<cpu, double>(int m, int n, // Define compilable stubs. #define MXNET_LAPACK_CWRAPPER1(func, dtype) \ - inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype* a, int lda ) { \ + inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype* a, int lda) { \ LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ return 1; \ } @@ -209,4 +224,4 @@ inline int MXNET_LAPACK_posv<double>(int matrix_layout, char uplo, int n, return mxnet_lapack_dposv(matrix_layout, uplo, n, nrhs, a, lda, b, ldb); } -#endif // MXNET_C_LAPACK_API_H_ +#endif // MXNET_OPERATOR_C_LAPACK_API_H_ diff --git a/src/operator/contrib/krprod.h b/src/operator/contrib/krprod.h index 6ce94c6..a54ece7 100644 --- a/src/operator/contrib/krprod.h +++ b/src/operator/contrib/krprod.h @@ -26,7 +26,7 @@ #define MXNET_OPERATOR_CONTRIB_KRPROD_H_ #include <vector> #include "mshadow/tensor.h" -#include "mxnet/c_lapack_api.h" +#include "../c_lapack_api.h" namespace mxnet { namespace op { diff --git a/src/operator/linalg.h b/src/operator/linalg.h new file mode 100644 index 0000000..9284a58 --- /dev/null +++ b/src/operator/linalg.h @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file linalg.h + * \brief Unified tensor interface for advanced linear algebra functions + * (specifically BLAS3/LAPACK) from within mxnet. + */ +#ifndef MXNET_OPERATOR_LINALG_H_ +#define MXNET_OPERATOR_LINALG_H_ + +#include <mshadow/tensor.h> +#include "./c_lapack_api.h" +using namespace mshadow; + +// The purpose of this header is to expose the interfaces of the advanced +// linear algebra functions without clutter by the implementations. In contrast +// to the implementations in linalg_inline.h, no macros are used to generate +// similar functions that just differ by name/type in order to improve readability. +// +// Guidelines for extensions: +// For any type of computation the following should be provided at minimum: +// - 1 templated function supporting cpu/gpu float/double in non-batch mode +// - 1 templated function supporting cpu/gpu float/double in batch mode +// Naming conventions: +// - linalg_<func>() +// - linalg_batch_<func>() +// Signatures of CPU/GPU versions should be equivalent whenever possible including +// that a stream is supplied to the cpu-versions as (optional) last argument. +// The batched versions all work on tensors with one more dimension as the +// non-batched ones and the first/highest dimension iterates over the elements +// within the batch. + +//////////////////////////////// GEMM //////////////////////////////////////////// + +// CPU/GPU-versions of BLAS3 function "gemm". Please refer to the BLAS3-documentation +// for further information about the function and its parameters. +// Note that this is C = gemm(A,B,C), so C is input and output parameter. +template<typename xpu, typename DType> +void linalg_gemm(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& B, + const Tensor<xpu, 2, DType>& C, DType alpha, DType beta, + bool tA, bool tB, Stream<xpu> *s = 0); + +template<typename xpu, typename DType> +void linalg_batch_gemm(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B, + const Tensor<xpu, 3, DType>& C, DType alpha, DType beta, + bool tA, bool tB, Stream<xpu> *s = 0); + +//////////////////////////////// TRSM //////////////////////////////////////////// + +// CPU/GPU-versions of BLAS3 function "trsm". Please refer to the BLAS3-documentation +// for further information about the function and its parameters. +// Note that this is B = trsm(A,B), so B is input and output parameter. +template<typename xpu, typename DType> +void linalg_trsm(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& B, + DType alpha, bool rightside, bool lower, bool transpose, Stream<xpu> *s = 0); + +template<typename xpu, typename DType> +inline void linalg_batch_trsm(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B, + DType alpha, bool rightside, bool lower, bool transpose, Stream<xpu> *s = 0); + +//////////////////////////////// TRMM //////////////////////////////////////////// + +// CPU/GPU-versions of BLAS3 function "trmm". Please refer to the BLAS3-documentation +// for further information about the function and its parameters. +// Note that this is B = trmm(A,B), so B is input and output parameter. + +template<typename xpu, typename DType> +void linalg_trmm(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& B, + DType alpha, bool rightside, bool lower, bool transpose, Stream<xpu> *s = 0); + +template<typename xpu, typename DType> +void linalg_batch_trmm(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B, + DType alpha, bool rightside, bool lower, bool transpose, Stream<xpu> *s = 0); + +//////////////////////////////// POTRF //////////////////////////////////////////// + +// CPU/GPU-versions of LAPACK function "potrf". Please refer to the LAPACK-documentation +// for further information about the function and its parameters. +// Note that this is A = potrf(A), so A is input and output parameter. + +template<typename xpu, typename DType> +void linalg_potrf(const Tensor<xpu, 2, DType>& A, bool lower, Stream<xpu> *s = 0); + +template<typename xpu, typename DType> +void linalg_batch_potrf(const Tensor<xpu, 3, DType>& A, bool lower, Stream<xpu> *s = 0); + +//////////////////////////////// POTRI //////////////////////////////////////////// + +// CPU/GPU-versions of LAPACK function "potri". Please refer to the LAPACK-documentation +// for further information about the function and its parameters. +// Note that this is A = potri(A), so A is input and output parameter. + +template<typename xpu, typename DType> +void linalg_potri(const Tensor<xpu, 2, DType>& A, bool lower, Stream<xpu> *s = 0); + +template<typename xpu, typename DType> +void linalg_batch_potri(const Tensor<xpu, 3, DType>& A, bool lower, Stream<xpu> *s = 0); + +#include "linalg_impl.h" + +#endif // MXNET_OPERATOR_LINALG_H_ diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h new file mode 100644 index 0000000..affa794 --- /dev/null +++ b/src/operator/linalg_impl.h @@ -0,0 +1,508 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file linalg.h + * \brief Implementation of unified tensor interface for advanced linear algebra functions + * (specifically BLAS3/LAPACK) from within mxnet. + */ +#ifndef MXNET_OPERATOR_LINALG_IMPL_H_ +#define MXNET_OPERATOR_LINALG_IMPL_H_ + +#include <algorithm> + +// Convenience functions. +inline void linalg_check_batch_size(int A, int B, int C) { + CHECK_EQ(A, B) << "Inconsistent batch size between arguments to linear algebra operator"; + CHECK_EQ(A, C) << "Inconsistent batch size between arguments to linear algebra operator"; + CHECK_GT(A, 0) << "Zero batch size for arguments to linear algebra operator"; +} + +//////////////////////////////// GEMM //////////////////////////////////////////// + +// CPU/GPU-versions of BLAS3 function "gemm". Please refer to the BLAS3-documentation +// for further information about the function and its parameters. +// Note that this is C = gemm(A,B,C), so C is input and output parameter. + +template<typename xpu, typename DType> +inline void check_gemm(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& B, + const Tensor<xpu, 2, DType>& C, DType alpha, DType beta, bool tA, bool tB) { + // Any checking that helps user debug potential problems. + CHECK_EQ((tA ? A.size(1) : A.size(0)), C.size(0)) + << "Non compatible matrix dimensions between inputs A and C for gemm"; + CHECK_EQ((tB ? B.size(0) : B.size(1)), C.size(1)) + << "Non compatible matrix dimensions between inputs B and C for gemm"; + CHECK_EQ((tA ? A.size(0) : A.size(1)), (tB ? B.size(1) : B.size(0))) + << "Non compatible matrix dimensions between inputs A and B for gemm"; +} + +#define LINALG_CPU_GEMM(fname, DType) \ +template<> inline \ +void linalg_gemm<cpu, DType>(const Tensor<cpu, 2, DType>& A, const Tensor<cpu, 2, DType>& B, \ + const Tensor<cpu, 2, DType>& C, DType alpha, DType beta, \ + bool tA, bool tB, Stream<cpu> *s) { \ + check_gemm(A, B, C, alpha, beta, tA, tB); \ + cblas_##fname(CblasRowMajor, (tA ? CblasTrans : CblasNoTrans), (tB ? CblasTrans : CblasNoTrans), \ + C.size(0), C.size(1), (tA ? A.size(0) : A.size(1)), alpha, \ + A.dptr_, A.stride_, B.dptr_, B.stride_, beta, C.dptr_, C.stride_); \ +} +LINALG_CPU_GEMM(sgemm, float) +LINALG_CPU_GEMM(dgemm, double) + +#define LINALG_CPU_BATCH_GEMM(DType) \ +template<> inline \ +void linalg_batch_gemm<cpu, DType>(const Tensor<cpu, 3, DType>& A, const Tensor<cpu, 3, DType>& B, \ + const Tensor<cpu, 3, DType>& C, DType alpha, DType beta, \ + bool tA, bool tB, Stream<cpu> *s) { \ + linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \ + for (index_t i = 0; i < A.size(0); ++i) { \ + linalg_gemm(A[i], B[i], C[i], alpha, beta, tA, tB); \ + } \ +} +LINALG_CPU_BATCH_GEMM(float) +LINALG_CPU_BATCH_GEMM(double) + +#ifdef __CUDACC__ + +template<typename DType> +__global__ void linalgCollectBatchOffsetsGPU(DType *a[], DType* b, int stride, int N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { + a[i] = b + i * stride; + } +} + +// cublas col-major processing accounted for by switching first two operands + +#define LINALG_GPU_GEMM(fname, DType) \ +template<> inline \ +void linalg_gemm<gpu, DType>(const Tensor<gpu, 2, DType>& A, const Tensor<gpu, 2, DType>& B, \ + const Tensor<gpu, 2, DType>& C, DType alpha, DType beta, \ + bool tA, bool tB, Stream<gpu> *s) { \ + using namespace mxnet; \ + using mshadow::gpu; \ + CHECK_NOTNULL(s); \ + check_gemm(A, B, C, alpha, beta, tA, tB); \ + CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \ + (tB ? CUBLAS_OP_T : CUBLAS_OP_N), \ + (tA ? CUBLAS_OP_T : CUBLAS_OP_N), \ + C.size(1), C.size(0), (tB ? B.size(1) : B.size(0)), \ + &alpha, B.dptr_, B.stride_, A.dptr_, A.stride_, \ + &beta, C.dptr_, C.stride_)) \ +} +LINALG_GPU_GEMM(Sgemm, float) +LINALG_GPU_GEMM(Dgemm, double) + +#define LINALG_GPU_BATCH_GEMM(fname, DType) \ +template<> inline \ +void linalg_batch_gemm<gpu, DType>(const Tensor<gpu, 3, DType>& A, const Tensor<gpu, 3, DType>& B, \ + const Tensor<gpu, 3, DType>& C, DType alpha, DType beta, \ + bool tA, bool tB, Stream<gpu> *s) { \ + using namespace mxnet; \ + using mshadow::gpu; \ + CHECK_NOTNULL(s); \ + linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \ + check_gemm(A[0], B[0], C[0], alpha, beta, tA, tB); \ + Storage::Handle offsetsA, offsetsB, offsetsC; \ + offsetsA = Storage::Get()->Alloc(sizeof(DType*)*A.size(0), Context::GPU()); \ + offsetsB = Storage::Get()->Alloc(sizeof(DType*)*B.size(0), Context::GPU()); \ + offsetsC = Storage::Get()->Alloc(sizeof(DType*)*C.size(0), Context::GPU()); \ + using namespace mshadow::cuda; \ + int ngrid = std::min(kMaxGridNum, \ + static_cast<int>((A.size(0) + kBaseThreadNum - 1) / kBaseThreadNum)); \ + linalgCollectBatchOffsetsGPU<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>> \ + (static_cast<DType **>(offsetsA.dptr), A.dptr_, A.size(1)*A.stride_, A.size(0)); \ + linalgCollectBatchOffsetsGPU<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>> \ + (static_cast<DType **>(offsetsB.dptr), B.dptr_, B.size(1)*B.stride_, B.size(0)); \ + linalgCollectBatchOffsetsGPU<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>> \ + (static_cast<DType **>(offsetsC.dptr), C.dptr_, C.size(1)*C.stride_, C.size(0)); \ + CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \ + (tB ? CUBLAS_OP_T : CUBLAS_OP_N), \ + (tA ? CUBLAS_OP_T : CUBLAS_OP_N), \ + C.size(2), C.size(1), (tB ? B.size(2) : B.size(1)), \ + &alpha, static_cast<const DType **>(offsetsB.dptr), B.stride_, \ + static_cast<const DType **>(offsetsA.dptr), A.stride_, \ + &beta, static_cast<DType **>(offsetsC.dptr), C.stride_, A.size(0))) \ + Storage::Get()->Free(offsetsA); \ + Storage::Get()->Free(offsetsB); \ + Storage::Get()->Free(offsetsC); \ +} +LINALG_GPU_BATCH_GEMM(SgemmBatched, float) +LINALG_GPU_BATCH_GEMM(DgemmBatched, double) + +#endif + +//////////////////////////////// TRSM //////////////////////////////////////////// + +// CPU/GPU-versions of BLAS3 function "trsm". Please refer to the BLAS3-documentation +// for further information about the function and its parameters. +// Note that this is B = trsm(A,B), so B is input and output parameter. + +template<typename xpu, typename DType> +inline void check_trsm(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& B, + DType alpha, bool rightside, bool lower, bool transpose) { + // Any checking that helps user debug potential problems. + CHECK_EQ(A.size(0), A.size(1)) + << "First input of trsm is not a square matrix."; + CHECK(!rightside || (B.size(1) == A.size(0))) + << "Non compatible matrix dimensions between inputs A and B for trsm"; + CHECK(rightside || (B.size(0) == A.size(1))) + << "Non compatible matrix dimensions between inputs A and B for trsm"; +} + +#define LINALG_CPU_TRSM(fname, DType) \ +template<> inline \ +void linalg_trsm<cpu, DType>(const Tensor<cpu, 2, DType>& A, const Tensor<cpu, 2, DType>& B, \ + DType alpha, bool rightside, bool lower, bool transpose, Stream<cpu> *s) { \ + check_trsm(A, B, alpha, rightside, lower, transpose); \ + cblas_##fname(CblasRowMajor, (rightside ? CblasRight : CblasLeft), \ + (lower ? CblasLower : CblasUpper), (transpose ? CblasTrans : CblasNoTrans), \ + CblasNonUnit, B.size(0), B.size(1), alpha, A.dptr_, \ + A.stride_, B.dptr_, B.stride_); \ +} +LINALG_CPU_TRSM(strsm, float) +LINALG_CPU_TRSM(dtrsm, double) + +#define LINALG_CPU_BATCH_TRSM(DType) \ +template<> inline \ +void linalg_batch_trsm<cpu, DType>(const Tensor<cpu, 3, DType>& A, const Tensor<cpu, 3, DType>& B, \ + DType alpha, bool rightside, bool lower, bool transpose, Stream<cpu> *s) { \ + linalg_check_batch_size(A.size(0), B.size(0), B.size(0)); \ + for (index_t i = 0; i < A.size(0); ++i) { \ + linalg_trsm(A[i], B[i], alpha, rightside, lower, transpose); \ + } \ +} +LINALG_CPU_BATCH_TRSM(float) +LINALG_CPU_BATCH_TRSM(double) + +#ifdef __CUDACC__ + +// cublas col-major processing accounted for by switching sides and fill mode + +#define LINALG_GPU_TRSM(fname, DType) \ +template<> inline \ +void linalg_trsm<gpu, DType>(const Tensor<gpu, 2, DType>& A, const Tensor<gpu, 2, DType>& B, \ + DType alpha, bool rightside, bool lower, bool transpose, Stream<gpu> *s) { \ + using namespace mxnet; \ + using mshadow::gpu; \ + CHECK_NOTNULL(s); \ + check_trsm(A, B, alpha, rightside, lower, transpose); \ + CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \ + (rightside ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT), \ + (lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER), \ + (transpose ? CUBLAS_OP_T : CUBLAS_OP_N), \ + CUBLAS_DIAG_NON_UNIT, B.size(1), B.size(0), &alpha, \ + A.dptr_, A.stride_, B.dptr_, B.stride_)); \ +} +LINALG_GPU_TRSM(Strsm, float) +LINALG_GPU_TRSM(Dtrsm, double) + +#define LINALG_GPU_BATCH_TRSM(fname, DType) \ +template<> inline \ +void linalg_batch_trsm<gpu, DType>(const Tensor<gpu, 3, DType>& A, const Tensor<gpu, 3, DType>& B, \ + DType alpha, bool rightside, bool lower, bool transpose, Stream<gpu> *s) { \ + using namespace mxnet; \ + using mshadow::gpu; \ + CHECK_NOTNULL(s); \ + linalg_check_batch_size(A.size(0), B.size(0), B.size(0)); \ + check_trsm(A[0], B[0], alpha, rightside, lower, transpose); \ + Storage::Handle offsetsA, offsetsB; \ + offsetsA = Storage::Get()->Alloc(sizeof(DType*)*A.size(0), Context::GPU()); \ + offsetsB = Storage::Get()->Alloc(sizeof(DType*)*B.size(0), Context::GPU()); \ + using namespace mshadow::cuda; \ + int ngrid = std::min(kMaxGridNum, \ + static_cast<int>((A.size(0) + kBaseThreadNum - 1) / kBaseThreadNum)); \ + linalgCollectBatchOffsetsGPU<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>> \ + (static_cast<DType **>(offsetsA.dptr), A.dptr_, A.size(1)*A.stride_, A.size(0)); \ + linalgCollectBatchOffsetsGPU<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>> \ + (static_cast<DType **>(offsetsB.dptr), B.dptr_, B.size(1)*B.stride_, A.size(0)); \ + CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \ + (rightside ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT), \ + (lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER), \ + (transpose ? CUBLAS_OP_T : CUBLAS_OP_N), \ + CUBLAS_DIAG_NON_UNIT, B.size(2), B.size(1), &alpha, \ + static_cast<const DType **>(offsetsA.dptr), A.stride_, \ + static_cast<DType **>(offsetsB.dptr), B.stride_, A.size(0))); \ + Storage::Get()->Free(offsetsA); \ + Storage::Get()->Free(offsetsB); \ +} +LINALG_GPU_BATCH_TRSM(StrsmBatched, float) +LINALG_GPU_BATCH_TRSM(DtrsmBatched, double) + +#endif + +//////////////////////////////// TRMM //////////////////////////////////////////// + +// CPU/GPU-versions of BLAS3 function "trmm". Please refer to the BLAS3-documentation +// for further information about the function and its parameters. +// Note that this is B = trmm(A,B), so B is input and output parameter. + +template<typename xpu, typename DType> +inline void check_trmm(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& B, + DType alpha, bool rightside, bool lower, bool transpose) { + // Any checking that helps user debug potential problems. + CHECK_EQ(A.size(0), A.size(1)) + << "First input of trmm is not a square matrix."; + CHECK(!rightside || (B.size(1) == A.size(0))) + << "Non compatible matrix dimensions between inputs A and B for trmm"; + CHECK(rightside || (B.size(0) == A.size(1))) + << "Non compatible matrix dimensions between inputs A and B for trmm"; +} + +#define LINALG_CPU_TRMM(fname, DType) \ +template<> inline \ +void linalg_trmm<cpu, DType>(const Tensor<cpu, 2, DType>& A, const Tensor<cpu, 2, DType>& B, \ + DType alpha, bool rightside, bool lower, bool transpose, Stream<cpu> *s) { \ + check_trmm(A, B, alpha, rightside, lower, transpose); \ + cblas_##fname(CblasRowMajor, (rightside ? CblasRight : CblasLeft), \ + (lower ? CblasLower : CblasUpper), (transpose ? CblasTrans : CblasNoTrans), \ + CblasNonUnit, B.size(0), B.size(1), alpha, A.dptr_, \ + A.stride_, B.dptr_, B.stride_); \ +} +LINALG_CPU_TRMM(strmm, float) +LINALG_CPU_TRMM(dtrmm, double) + +#define LINALG_XPU_BATCH_TRMM(xpu, DType) \ +template<> inline \ +void linalg_batch_trmm<xpu, DType>(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B, \ + DType alpha, bool rightside, bool lower, bool transpose, Stream<xpu> *s) { \ + linalg_check_batch_size(A.size(0), B.size(0), B.size(0)); \ + for (index_t i = 0; i < A.size(0); ++i) { \ + linalg_trmm(A[i], B[i], alpha, rightside, lower, transpose, s); \ + } \ +} +LINALG_XPU_BATCH_TRMM(cpu, float) +LINALG_XPU_BATCH_TRMM(cpu, double) + +#ifdef __CUDACC__ + +// cublas col-major processing accounted for by switching sides and fill mode +// doing in-place computation by supplying B as second and third matrix +#define LINALG_GPU_TRMM(fname, DType) \ +template<> inline \ +void linalg_trmm<gpu, DType>(const Tensor<gpu, 2, DType>& A, const Tensor<gpu, 2, DType>& B, \ + DType alpha, bool rightside, bool lower, bool transpose, Stream<gpu> *s) { \ + using namespace mxnet; \ + using mshadow::gpu; \ + CHECK_NOTNULL(s); \ + check_trmm(A, B, alpha, rightside, lower, transpose); \ + CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \ + (rightside ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT), \ + (lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER), \ + (transpose ? CUBLAS_OP_T : CUBLAS_OP_N), \ + CUBLAS_DIAG_NON_UNIT, B.size(0), B.size(1), &alpha, \ + A.dptr_, A.stride_, B.dptr_, B.stride_, \ + B.dptr_, B.stride_)); \ +} +LINALG_GPU_TRMM(Strmm, float) +LINALG_GPU_TRMM(Dtrmm, double) + +LINALG_XPU_BATCH_TRMM(gpu, float) +LINALG_XPU_BATCH_TRMM(gpu, double) + +#endif + +//////////////////////////////// POTRF //////////////////////////////////////////// + +// CPU/GPU-versions of LAPACK function "potrf". Please refer to the LAPACK-documentation +// for further information about the function and its parameters. +// Note that this is A = potrf(A), so A is input and output parameter. + +template<typename xpu, typename DType> +inline void check_potrf(const Tensor<xpu, 2, DType>& A, bool lower) { + // Any checking that helps user debug potential problems. + CHECK_EQ(A.size(0), A.size(1)) + << "No square matrix as input to potrf."; +} + +#define LINALG_CPU_POTRF(fname, DType) \ +template<> inline \ +void linalg_potrf<cpu, DType>(const Tensor<cpu, 2, DType>& A, bool lower, Stream<cpu> *s) { \ + check_potrf(A, lower); \ + int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_ROW_MAJOR, (lower ? 'L' : 'U'), A.size(0), \ + A.dptr_ , A.stride_)); \ + CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \ +} +LINALG_CPU_POTRF(spotrf, float) +LINALG_CPU_POTRF(dpotrf, double) + +#define LINALG_CPU_BATCH_POTRF(DType) \ +template<> inline \ +void linalg_batch_potrf<cpu, DType>(const Tensor<cpu, 3, DType>& A, bool lower, Stream<cpu> *s) { \ + for (index_t i = 0; i < A.size(0); ++i) { \ + linalg_potrf(A[i], lower); \ + } \ +} +LINALG_CPU_BATCH_POTRF(float) +LINALG_CPU_BATCH_POTRF(double) + +#if MXNET_USE_CUSOLVER == 1 + +#define LINALG_GPU_BUFFSIZE_POTRF(fname, DType) \ +inline int linalg_potrf_buffsize(const Tensor<gpu, 2, DType>& A, bool lower, Stream<gpu> *s) { \ + using namespace mxnet; \ + using mshadow::gpu; \ + CHECK_NOTNULL(s); \ + int buffsize(0); \ + CUSOLVER_CALL(cusolver##fname(Stream<gpu>::GetSolverHandle(s), \ + (lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER), \ + A.size(0), A.dptr_, A.stride_, &buffsize)); \ + return buffsize; \ +} +LINALG_GPU_BUFFSIZE_POTRF(DnSpotrf_bufferSize, float) +LINALG_GPU_BUFFSIZE_POTRF(DnDpotrf_bufferSize, double) + +#define LINALG_GPU_POTRF(fname, DType) \ +template<> inline \ +void linalg_potrf<gpu, DType>(const Tensor<gpu, 2, DType>& A, bool lower, Stream<gpu> *s) { \ + using namespace mxnet; \ + using mshadow::gpu; \ + CHECK_NOTNULL(s); \ + check_potrf(A, lower); \ + int buffsize(linalg_potrf_buffsize(A, lower, s)); \ + Storage::Handle buffer = Storage::Get()->Alloc(sizeof(DType)*buffsize, Context::GPU()); \ + Storage::Handle info = Storage::Get()->Alloc(sizeof(int), Context::GPU()); \ + CUSOLVER_CALL(cusolver##fname(Stream<gpu>::GetSolverHandle(s), \ + (lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER), \ + A.size(0), A.dptr_, A.stride_, static_cast<DType *>(buffer.dptr), buffsize, \ + static_cast<int *>(info.dptr))); \ + Storage::Get()->Free(buffer); \ + Storage::Get()->Free(info); \ +} +LINALG_GPU_POTRF(DnSpotrf, float) +LINALG_GPU_POTRF(DnDpotrf, double) + +#define LINALG_GPU_BATCH_POTRF(fname, DType) \ +template<> inline \ +void linalg_batch_potrf<gpu, DType>(const Tensor<gpu, 3, DType>& A, bool lower, Stream<gpu> *s) { \ + using namespace mxnet; \ + using mshadow::gpu; \ + CHECK_NOTNULL(s); \ + CHECK_GT(A.size(0), 0); \ + check_potrf(A[0], lower); \ + int buffsize(linalg_potrf_buffsize(A[0], lower, s)); \ + Storage::Handle buffer = Storage::Get()->Alloc(sizeof(DType)*buffsize, Context::GPU()); \ + Storage::Handle info = Storage::Get()->Alloc(sizeof(int), Context::GPU()); \ + for (mshadow::index_t i = 0; i < A.size(0); ++i) { \ + CUSOLVER_CALL(cusolver##fname(Stream<gpu>::GetSolverHandle(s), \ + (lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER), \ + A[i].size(0), A[i].dptr_, A[i].stride_, \ + static_cast<DType *>(buffer.dptr), buffsize, static_cast<int *>(info.dptr))); \ + } \ + Storage::Get()->Free(buffer); \ + Storage::Get()->Free(info); \ +} +LINALG_GPU_BATCH_POTRF(DnSpotrf, float) +LINALG_GPU_BATCH_POTRF(DnDpotrf, double) + +#endif + +//////////////////////////////// POTRI //////////////////////////////////////////// + +// CPU/GPU-versions of LAPACK function "potri". Please refer to the LAPACK-documentation +// for further information about the function and its parameters. +// Note that this is A = potri(A), so A is input and output parameter. + +template<typename xpu, typename DType> +inline void check_potri(const Tensor<xpu, 2, DType>& A, bool lower) { + // Any checking that helps user debug potential problems. + CHECK_EQ(A.size(0), A.size(1)) << "No square matrix as input to potri."; +} + +#define LINALG_CPU_POTRI(fname, DType) \ +template<> inline \ +void linalg_potri<cpu, DType>(const Tensor<cpu, 2, DType>& A, bool lower, Stream<cpu> *s) { \ + check_potri(A, lower); \ + int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_ROW_MAJOR, (lower ? 'L' : 'U'), A.size(0), \ + A.dptr_ , A.stride_)); \ + CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \ +} +LINALG_CPU_POTRI(spotri, float) +LINALG_CPU_POTRI(dpotri, double) + +#define LINALG_CPU_BATCH_POTRI(DType) \ +template<> inline \ +void linalg_batch_potri<cpu, DType>(const Tensor<cpu, 3, DType>& A, bool lower, Stream<cpu> *s) { \ + for (index_t i = 0; i < A.size(0); ++i) { \ + linalg_potri(A[i], lower); \ + } \ +} +LINALG_CPU_BATCH_POTRI(float) +LINALG_CPU_BATCH_POTRI(double) + +#ifdef __CUDACC__ + +// Initializes multiple identity matrices on the same vector. +template<typename DType> +__global__ void linalgInitIdentityGPU(DType *a, int stride, int lda, int N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { + // index relative to the matrix. + int index(i % stride); + a[i] = (index / lda == index % lda ? DType(1.0) : DType(0)); + } +} + +// There is no direct support for potri in cuda. We emulate the function by two calls to trsm. +#define LINALG_GPU_POTRI(DType) \ +template<> inline \ +void linalg_potri<gpu, DType>(const Tensor<gpu, 2, DType>& A, bool lower, Stream<gpu> *s) { \ + using namespace mxnet; \ + CHECK_NOTNULL(s); \ + check_potri(A, lower); \ + Storage::Handle buffer = Storage::Get()->Alloc(sizeof(DType)*A.MSize(), Context::GPU()); \ + using namespace mshadow::cuda; \ + int ngrid = std::min(kMaxGridNum, \ + static_cast<int>((A.MSize() + kBaseThreadNum - 1) / kBaseThreadNum)); \ + linalgInitIdentityGPU<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>> \ + (static_cast<DType *>(buffer.dptr), A.MSize(), A.stride_, A.MSize()); \ + Tensor<gpu, 2, DType> B((DType *)buffer.dptr, A.shape_, A.stride_, s); \ + linalg_trsm(A, B, DType(1.0), false, lower, !lower, s); \ + linalg_trsm(A, B, DType(1.0), false, lower, lower, s); \ + Copy(A, B, s); \ + B.dptr_ = 0; \ + Storage::Get()->Free(buffer); \ +} +LINALG_GPU_POTRI(float) +LINALG_GPU_POTRI(double) + +#define LINALG_GPU_BATCH_POTRI(DType) \ +template<> inline \ +void linalg_batch_potri<gpu, DType>(const Tensor<gpu, 3, DType>& A, bool lower, Stream<gpu> *s) { \ + using namespace mxnet; \ + CHECK_NOTNULL(s); \ + CHECK_GT(A.size(0), 0); \ + check_potri(A[0], lower); \ + Storage::Handle buffer = Storage::Get()->Alloc(sizeof(DType)*A.MSize(), Context::GPU()); \ + using namespace mshadow::cuda; \ + int ngrid = std::min(kMaxGridNum, \ + static_cast<int>((A.MSize() + kBaseThreadNum - 1) / kBaseThreadNum)); \ + linalgInitIdentityGPU<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>> \ + (static_cast<DType *>(buffer.dptr), A.size(1)*A.stride_, A.stride_, A.MSize()); \ + Tensor<gpu, 3, DType> B((DType *)buffer.dptr, A.shape_, A.stride_, s); \ + linalg_batch_trsm(A, B, DType(1.0), false, lower, !lower, s); \ + linalg_batch_trsm(A, B, DType(1.0), false, lower, lower, s); \ + Copy(A, B, s); \ + B.dptr_ = 0; \ + Storage::Get()->Free(buffer); \ +} +LINALG_GPU_BATCH_POTRI(float) +LINALG_GPU_BATCH_POTRI(double) + +#endif + +#endif // MXNET_OPERATOR_LINALG_IMPL_H_ diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc index 1b726ce..70d4f9b 100644 --- a/src/operator/tensor/la_op.cc +++ b/src/operator/tensor/la_op.cc @@ -401,7 +401,7 @@ Examples:: { return std::vector<std::string>{"A"}; } ) .set_attr<nnvm::FInferShape>("FInferShape", LaReduceShape<2>) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) -.set_attr<FCompute>("FCompute<cpu>", LaReduceForward<cpu, 2, sumlogdiag>) +.set_attr<FCompute>("FCompute<cpu>", LaOpForward<cpu, 2, 0, 1, 1, sumlogdiag>) .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_linalg_sumlogdiag"}) .add_argument("A", "NDArray-or-Symbol", "Tensor of square matrices"); @@ -411,7 +411,7 @@ NNVM_REGISTER_OP(_backward_linalg_sumlogdiag) .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; }) .set_attr<nnvm::TIsBackward>("TIsBackward", true) -.set_attr<FCompute>("FCompute<cpu>", LaReduceBackward<cpu, 2, sumlogdiag_backward>); +.set_attr<FCompute>("FCompute<cpu>", LaOpBackward<cpu, 2, 2, 2, 1, sumlogdiag_backward>); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu new file mode 100644 index 0000000..a89d98f --- /dev/null +++ b/src/operator/tensor/la_op.cu @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file la_op.cu + * \brief GPU-Operators for advanced linear algebra. + */ +#include "./la_op.h" +#include "./la_op_inline.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(linalg_gemm) +.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 3, 1, gemm>); + +NNVM_REGISTER_OP(_backward_linalg_gemm) +.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 4, 3, gemm_backward>); + +NNVM_REGISTER_OP(linalg_gemm2) +.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 2, 1, gemm2>); + +NNVM_REGISTER_OP(_backward_linalg_gemm2) +.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 3, 2, gemm2_backward>); + +NNVM_REGISTER_OP(linalg_trmm) +.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 2, 1, trmm>); + +NNVM_REGISTER_OP(_backward_linalg_trmm) +.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 4, 2, trmm_backward>); + +NNVM_REGISTER_OP(linalg_trsm) +.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 2, 1, trsm>); + +NNVM_REGISTER_OP(_backward_linalg_trsm) +.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 4, 2, trsm_backward>); + +NNVM_REGISTER_OP(linalg_sumlogdiag) +.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 0, 1, 1, sumlogdiag>); + +NNVM_REGISTER_OP(_backward_linalg_sumlogdiag) +.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 2, 1, sumlogdiag_backward>); + +NNVM_REGISTER_OP(linalg_potri) +.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 1, 1, potri>); + +NNVM_REGISTER_OP(_backward_linalg_potri) +.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 3, 1, potri_backward>); + +#if MXNET_USE_CUSOLVER == 1 + +NNVM_REGISTER_OP(linalg_potrf) +.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 1, 1, potrf>); + +NNVM_REGISTER_OP(_backward_linalg_potrf) +.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 2, 1, potrf_backward>); + +#endif + +} // namespace op +} // namespace mxnet diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index 9779988..dd5fab9 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -91,9 +91,9 @@ struct LaTriangMatrixMultParam : public dmlc::Parameter<LaTriangMatrixMultParam> }; // Common function for shape inference for matrix mult and matrix mac. -bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs, - std::vector<TShape>* in_attrs, - std::vector<TShape>* out_attrs) { +inline bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs, + std::vector<TShape>* in_attrs, + std::vector<TShape>* out_attrs) { CHECK_GE(in_attrs->size(), 2); CHECK_EQ(out_attrs->size(), 1); bool transpose_a(false), transpose_b(false); @@ -132,9 +132,9 @@ bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs, return false; } -bool LaTriangMatrixMultOpShape(const nnvm::NodeAttrs& attrs, - std::vector<TShape>* in_attrs, - std::vector<TShape>* out_attrs) { +inline bool LaTriangMatrixMultOpShape(const nnvm::NodeAttrs& attrs, + std::vector<TShape>* in_attrs, + std::vector<TShape>* out_attrs) { const LaTriangMatrixMultParam& param = nnvm::get<LaTriangMatrixMultParam>(attrs.parsed); CHECK_EQ(in_attrs->size(), 2); CHECK_EQ(out_attrs->size(), 1); @@ -192,9 +192,9 @@ bool LaTriangMatrixMultOpShape(const nnvm::NodeAttrs& attrs, } template<int dim> -bool LaReduceShape(const nnvm::NodeAttrs& attrs, - std::vector<TShape>* in_attrs, - std::vector<TShape>* out_attrs) { +inline bool LaReduceShape(const nnvm::NodeAttrs& attrs, + std::vector<TShape>* in_attrs, + std::vector<TShape>* out_attrs) { // Shape for reduction of the dim lowest dimensions to a scalar. // Can only deduct in forward direction. CHECK_EQ(in_attrs->size(), 1); @@ -203,7 +203,8 @@ bool LaReduceShape(const nnvm::NodeAttrs& attrs, if ( ndim < dim ) { return false; } - std::vector<int> oshape(std::max(1, ndim-dim), 1); + std::vector<int> oshape(std::max(1, ndim-dim)); + oshape[0] = 1; for ( int i = 0; i < ndim - dim; ++i ) { oshape[i] = (*in_attrs)[0][i]; } @@ -218,7 +219,6 @@ template<typename xpu, typename DType, int idim, int odim, int inum, int onum, t struct LaOpCaller { static void op(const std::vector<TBlob>& inputs, const std::vector<TBlob>& outputs, - const int index, const nnvm::NodeAttrs& attrs, mshadow::Stream<xpu> *s) { CHECK(false) << "no specialized LaOpCaller defined for template parameters"; @@ -228,86 +228,75 @@ template<typename xpu, typename DType, int idim, int odim, typename laop> struct LaOpCaller<xpu, DType, idim, odim, 1, 1, laop> { static void op(const std::vector<TBlob>& inputs, const std::vector<TBlob>& outputs, - const int index, const nnvm::NodeAttrs& attrs, mshadow::Stream<xpu> *s) { - laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s)[index], - outputs[0].FlatToKD<xpu, odim+1, DType>(s)[index], attrs); + laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s), + outputs[0].FlatToKD<xpu, odim+1, DType>(s), s, attrs); } }; template<typename xpu, typename DType, int idim, int odim, typename laop> struct LaOpCaller<xpu, DType, idim, odim, 2, 1, laop> { static void op(const std::vector<TBlob>& inputs, const std::vector<TBlob>& outputs, - const int index, const nnvm::NodeAttrs& attrs, mshadow::Stream<xpu> *s) { - laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s)[index], - inputs[1].FlatToKD<xpu, idim+1, DType>(s)[index], - outputs[0].FlatToKD<xpu, odim+1, DType>(s)[index], - attrs); + laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s), + inputs[1].FlatToKD<xpu, idim+1, DType>(s), + outputs[0].FlatToKD<xpu, odim+1, DType>(s), s, attrs); } }; template<typename xpu, typename DType, int idim, int odim, typename laop> struct LaOpCaller<xpu, DType, idim, odim, 3, 1, laop> { static void op(const std::vector<TBlob>& inputs, const std::vector<TBlob>& outputs, - const int index, const nnvm::NodeAttrs& attrs, mshadow::Stream<xpu> *s) { - laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s)[index], - inputs[1].FlatToKD<xpu, idim+1, DType>(s)[index], - inputs[2].FlatToKD<xpu, idim+1, DType>(s)[index], - outputs[0].FlatToKD<xpu, odim+1, DType>(s)[index], - attrs); + laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s), + inputs[1].FlatToKD<xpu, idim+1, DType>(s), + inputs[2].FlatToKD<xpu, idim+1, DType>(s), + outputs[0].FlatToKD<xpu, odim+1, DType>(s), s, attrs); } }; template<typename xpu, typename DType, int idim, int odim, typename laop> struct LaOpCaller<xpu, DType, idim, odim, 3, 2, laop> { static void op(const std::vector<TBlob>& inputs, const std::vector<TBlob>& outputs, - const int index, const nnvm::NodeAttrs& attrs, mshadow::Stream<xpu> *s) { - laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s)[index], - inputs[1].FlatToKD<xpu, idim+1, DType>(s)[index], - inputs[2].FlatToKD<xpu, idim+1, DType>(s)[index], - outputs[0].FlatToKD<xpu, odim+1, DType>(s)[index], - outputs[1].FlatToKD<xpu, odim+1, DType>(s)[index], - attrs); + laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s), + inputs[1].FlatToKD<xpu, idim+1, DType>(s), + inputs[2].FlatToKD<xpu, idim+1, DType>(s), + outputs[0].FlatToKD<xpu, odim+1, DType>(s), + outputs[1].FlatToKD<xpu, odim+1, DType>(s), s, attrs); } }; template<typename xpu, typename DType, int idim, int odim, typename laop> struct LaOpCaller<xpu, DType, idim, odim, 4, 2, laop> { static void op(const std::vector<TBlob>& inputs, const std::vector<TBlob>& outputs, - const int index, const nnvm::NodeAttrs& attrs, mshadow::Stream<xpu> *s) { - laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s)[index], - inputs[1].FlatToKD<xpu, idim+1, DType>(s)[index], - inputs[2].FlatToKD<xpu, idim+1, DType>(s)[index], - inputs[3].FlatToKD<xpu, idim+1, DType>(s)[index], - outputs[0].FlatToKD<xpu, odim+1, DType>(s)[index], - outputs[1].FlatToKD<xpu, odim+1, DType>(s)[index], - attrs); + laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s), + inputs[1].FlatToKD<xpu, idim+1, DType>(s), + inputs[2].FlatToKD<xpu, idim+1, DType>(s), + inputs[3].FlatToKD<xpu, idim+1, DType>(s), + outputs[0].FlatToKD<xpu, odim+1, DType>(s), + outputs[1].FlatToKD<xpu, odim+1, DType>(s), s, attrs); } }; template<typename xpu, typename DType, int idim, int odim, typename laop> struct LaOpCaller<xpu, DType, idim, odim, 4, 3, laop> { static void op(const std::vector<TBlob>& inputs, const std::vector<TBlob>& outputs, - const int index, const nnvm::NodeAttrs& attrs, mshadow::Stream<xpu> *s) { - laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s)[index], - inputs[1].FlatToKD<xpu, idim+1, DType>(s)[index], - inputs[2].FlatToKD<xpu, idim+1, DType>(s)[index], - inputs[3].FlatToKD<xpu, idim+1, DType>(s)[index], - outputs[0].FlatToKD<xpu, odim+1, DType>(s)[index], - outputs[1].FlatToKD<xpu, odim+1, DType>(s)[index], - outputs[2].FlatToKD<xpu, odim+1, DType>(s)[index], - attrs); + laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s), + inputs[1].FlatToKD<xpu, idim+1, DType>(s), + inputs[2].FlatToKD<xpu, idim+1, DType>(s), + inputs[3].FlatToKD<xpu, idim+1, DType>(s), + outputs[0].FlatToKD<xpu, odim+1, DType>(s), + outputs[1].FlatToKD<xpu, odim+1, DType>(s), + outputs[2].FlatToKD<xpu, odim+1, DType>(s), s, attrs); } }; @@ -322,24 +311,8 @@ void LaOpForward(const nnvm::NodeAttrs& attrs, Stream<xpu> *s = ctx.get_stream<xpu>(); CHECK_EQ(inputs.size(), inum); CHECK_EQ(outputs.size(), onum); - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { - int N(-1); - for ( int i = 0; i < inum; ++i ) { - CHECK_EQ(inputs[i].CheckContiguous(), true); - const int M(inputs[i].FlatToKD<xpu, idim+1, OType>(s).size(0)); - CHECK_EQ((N == -1 || N == M), true); - N = M; - } - for ( int i = 0; i < onum; ++i ) { - CHECK_EQ(outputs[i].CheckContiguous(), true); - CHECK_EQ((req[i] == kWriteTo || req[i] == kWriteInplace), true); - const int M(outputs[i].FlatToKD<xpu, odim+1, OType>(s).size(0)); - CHECK_EQ((N == -1 || N == M), true); - N = M; - } - for ( int i = 0; i < N; ++i ) { - LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, outputs, i, attrs, s); - } + MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, outputs, attrs, s); }); } @@ -354,28 +327,15 @@ void LaOpBackward(const nnvm::NodeAttrs& attrs, Stream<xpu> *s = ctx.get_stream<xpu>(); CHECK_EQ(inputs.size(), inum); CHECK_EQ(outputs.size(), onum); - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { - int N(-1); - for ( int i = 0; i < inum; ++i ) { - CHECK_EQ(inputs[i].CheckContiguous(), true); - const int M(inputs[i].FlatToKD<xpu, idim+1, OType>(s).size(0)); - CHECK_EQ((N == -1 || N == M), true); - N = M; - } + MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, { std::vector<TBlob> tspace(outputs); for ( int i = 0; i < onum; ++i ) { - CHECK_EQ(outputs[i].CheckContiguous(), true); - const int M(outputs[i].FlatToKD<xpu, odim+1, OType>(s).size(0)); - CHECK_EQ((N == -1 || N == M), true); - N = M; if ( req[i] == kAddTo ) { tspace[i].dptr_ = ctx.requested[ResourceRequest::kTempSpace] .get_space_typed<xpu, 1, OType>(Shape1(outputs[i].Size()), s).dptr_; } } - for ( int i = 0; i < N; ++i ) { - LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, tspace, i, attrs, s); - } + LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, tspace, attrs, s); for ( int i = 0; i < onum; ++i ) { if ( req[i] == kAddTo ) { Tensor<xpu, 1, OType> out = outputs[i].FlatTo1D<xpu, OType>(s); @@ -385,53 +345,6 @@ void LaOpBackward(const nnvm::NodeAttrs& attrs, }); } -template<typename xpu, int idim, typename laop> -void LaReduceForward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector<TBlob>& inputs, - const std::vector<OpReqType>& req, - const std::vector<TBlob>& outputs) { - using namespace mshadow; - Stream<xpu> *s = ctx.get_stream<xpu>(); - CHECK_EQ(inputs.size(), 1); - CHECK_EQ(outputs.size(), 1); - CHECK_EQ(inputs[0].CheckContiguous(), true); - CHECK_EQ(outputs[0].CheckContiguous(), true); - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { - Tensor<xpu, idim+1, OType> in(inputs[0].FlatToKD<xpu, idim+1, OType>(s)); - Tensor<xpu, 1, OType> out(outputs[0].FlatTo1D<xpu, OType>(s)); - const int N(outputs[0].Size()); - CHECK_EQ(in.size(0), N); - for ( int i = 0; i < N; ++i ) { - laop::op(in[i], out[i], attrs); - } - }); -} - -template<typename xpu, int idim, typename laop> -void LaReduceBackward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector<TBlob>& inputs, - const std::vector<OpReqType>& req, - const std::vector<TBlob>& outputs) { - using namespace mshadow; - Stream<xpu> *s = ctx.get_stream<xpu>(); - CHECK_EQ(inputs.size(), 2); - CHECK_EQ(outputs.size(), 1); - CHECK_EQ(inputs[0].CheckContiguous(), true); - CHECK_EQ(inputs[1].CheckContiguous(), true); - CHECK_EQ(outputs[0].CheckContiguous(), true); - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { - const int N(inputs[0].Size()); - Tensor<xpu, 1, OType> in0(inputs[0].FlatTo1D<xpu, OType>(s)); - Tensor<xpu, idim+1, OType> in1(inputs[1].FlatToKD<xpu, idim+1, OType>(s)); - Tensor<xpu, idim+1, OType> out(outputs[0].FlatToKD<xpu, idim+1, OType>(s)); - for ( int i = 0; i < N; ++i ) { - laop::op(in0[i], in1[i], out[i], attrs, (req[i] == kAddTo)); - } - }); -} - } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/la_op_inline.h b/src/operator/tensor/la_op_inline.h index a032988..34fb441 100644 --- a/src/operator/tensor/la_op_inline.h +++ b/src/operator/tensor/la_op_inline.h @@ -24,244 +24,186 @@ #ifndef MXNET_OPERATOR_TENSOR_LA_OP_INLINE_H_ #define MXNET_OPERATOR_TENSOR_LA_OP_INLINE_H_ -#include <mxnet/c_lapack_api.h> +#include "../linalg.h" namespace mxnet { namespace op { using namespace mshadow; -#define LA_OP_NOT_AVAIL " operator can only be called with float/double data type." - -// Signature for single matrix operations (decomposition/inversion). -#define FUNC_SIGNATURE_1(fname, arg1) {CHECK_EQ(MXNET_LAPACK_##fname(MXNET_LAPACK_ROW_MAJOR, 'L', \ - arg1.size(0), arg1.dptr_, arg1.size(0)), 0) << "fname failed in lapack";} - -// Signature for matrix-matrix multiplications involving one diagonal matrix. -#define FUNC_SIGNATURE_2(fname, arg1, arg2) \ - { cblas_##fname(CblasRowMajor, (rightside ? CblasRight : CblasLeft), \ - CblasLower, (transpose ? CblasTrans : CblasNoTrans), \ - CblasNonUnit, arg2.size(0), arg2.size(1), alpha, arg1.dptr_, \ - (rightside ? arg2.size(1) : arg2.size(0)), arg2.dptr_, arg2.size(1)); } - - // Helper functions. -template<typename DType> -void CopyLowerToUpper(DType *dptr, int N) - { for (int i = 1; i < N; ++i ) for ( int j = 0; j < i; ++j ) dptr[j*N+i] = dptr[i*N+j]; } -template<typename DType> -void ZeroUpper(DType *dptr, int N) - { for (int i = 0; i < N; ++i ) for ( int j = i+1; j < N; ++j ) dptr[i*N+j] = 0; } +struct CopyLowerToUpper { + template<typename DType> + MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data) { + // Below computation works even when we are dealing with a batch of matrices. + const int row((i % matrix_size) / stride), col(i % stride); + if ( row > col ) data[i + (col - row) * (stride - 1)] = data[i]; + } +}; +struct ZeroUpper { + template<typename DType> + MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data) { + const int row((i % matrix_size) / stride), col(i % stride); + if ( row < col ) data[i] = 0; + } +}; +struct Scale { + template<typename DType> + MSHADOW_XINLINE static void Map(int i, DType scale, DType* data) { + data[i] *= scale; + } +}; -// Forward operators +// Forward computations (always using batched processing) // D = gemm(A,B,C) struct gemm { template<typename xpu, typename DType> - static void op(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& B, - const Tensor<xpu, 2, DType>& C, DType alpha, DType beta, bool tA, bool tB) - { CHECK(false) << "gemm" << LA_OP_NOT_AVAIL; } + static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B, + const Tensor<xpu, 3, DType>& C, DType alpha, DType beta, bool tA, bool tB, Stream<xpu> *s) { + linalg_batch_gemm(A, B, C, alpha, beta, tA, tB, s); + } template<typename xpu, typename DType> - static void op(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& B, - const Tensor<xpu, 2, DType>& C, const Tensor<xpu, 2, DType>& D, - const nnvm::NodeAttrs& attrs) { - if ( C.dptr_ != D.dptr_ ) Copy(D, C); + static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B, + const Tensor<xpu, 3, DType>& C, const Tensor<xpu, 3, DType>& D, + Stream<xpu> *s, const nnvm::NodeAttrs& attrs) { + if ( C.dptr_ != D.dptr_ ) Copy(D, C, s); const LaMatrixMacParam& param = nnvm::get<LaMatrixMacParam>(attrs.parsed); - gemm::op(A, B, D, DType(param.alpha), DType(param.beta), param.transpose_a, param.transpose_b); + gemm::op(A, B, D, DType(param.alpha), DType(param.beta), + param.transpose_a, param.transpose_b, s); } }; -template<> -void gemm::op<cpu, float>(const Tensor<cpu, 2, float>& A, const Tensor<cpu, 2, float>& B, - const Tensor<cpu, 2, float>& C, - float alpha, float beta, bool tA, bool tB ) { - CHECK_EQ((tA ? A.size(1) : A.size(0)), C.size(0)) - << "Non compatible matrix dimensions between inputs A and C for gemm operator"; - CHECK_EQ((tB ? B.size(0) : B.size(1)), C.size(1)) - << "Non compatible matrix dimensions between inputs B and C for gemm operator"; - CHECK_EQ((tA ? A.size(0) : A.size(1)), (tB ? B.size(1) : B.size(0))) - << "Non compatible matrix dimensions between inputs A and B for gemm operator"; - cblas_sgemm(CblasRowMajor, (tA ? CblasTrans : CblasNoTrans), (tB ? CblasTrans : CblasNoTrans), - (tA ? A.size(1):A.size(0)), (tB ? B.size(0): B.size(1)), - (tA ? A.size(0):A.size(1)), alpha, A.dptr_, A.size(1), B.dptr_, B.size(1), - beta, C.dptr_, (tB ? B.size(0): B.size(1))); -} -template<> -void gemm::op<cpu, double>(const Tensor<cpu, 2, double>& A, const Tensor<cpu, 2, double>& B, - const Tensor<cpu, 2, double>& C, - double alpha, double beta, bool tA, bool tB) { - CHECK_EQ((tA ? A.size(1) : A.size(0)), C.size(0)) - << "Non compatible matrix dimensions between inputs A and C for gemm operator"; - CHECK_EQ((tB ? B.size(0) : B.size(1)), C.size(1)) - << "Non compatible matrix dimensions between inputs B and C for gemm operator"; - CHECK_EQ((tA ? A.size(0) : A.size(1)), (tB ? B.size(1) : B.size(0))) - << "Non compatible matrix dimensions between inputs A and B for gemm operator"; - cblas_dgemm(CblasRowMajor, (tA ? CblasTrans : CblasNoTrans), (tB ? CblasTrans : CblasNoTrans), - (tA ? A.size(1):A.size(0)), (tB ? B.size(0): B.size(1)), - (tA ? A.size(0):A.size(1)), alpha, A.dptr_, A.size(1), B.dptr_, B.size(1), - beta, C.dptr_, (tB ? B.size(0): B.size(1))); -} // C = gemm2(A,B) struct gemm2 { template<typename xpu, typename DType> - static void op(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& B, - const Tensor<xpu, 2, DType>& C, const nnvm::NodeAttrs& attrs) { + static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B, + const Tensor<xpu, 3, DType>& C, Stream<xpu> *s, const nnvm::NodeAttrs& attrs) { const LaMatrixMultParam& param = nnvm::get<LaMatrixMultParam>(attrs.parsed); - gemm::op(A, B, C, DType(param.alpha), DType(0), param.transpose_a, param.transpose_b); + gemm::op(A, B, C, DType(param.alpha), DType(0), param.transpose_a, param.transpose_b, s); } }; // L = potrf(A). struct potrf { template<typename xpu, typename DType> - static void op(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& L, - const nnvm::NodeAttrs& attrs) - { CHECK(false) << "potrf" << LA_OP_NOT_AVAIL; } + static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& L, + Stream<xpu> *s, const nnvm::NodeAttrs& attrs) { + if ( A.dptr_ != L.dptr_ ) Copy(L, A, s); + linalg_batch_potrf(L, true, s); + using namespace mxnet_op; + Kernel<ZeroUpper, xpu>::Launch(s, L.MSize(), L.size(1)*L.stride_, L.stride_, L.dptr_); + } }; -template<> -void potrf::op<cpu, float>(const Tensor<cpu, 2, float>& A, const Tensor<cpu, 2, float>& L, - const nnvm::NodeAttrs& attrs) { - if ( A.dptr_ != L.dptr_ ) Copy(L, A); - FUNC_SIGNATURE_1(spotrf, L); - ZeroUpper(L.dptr_, L.size(0)); -} -template<> -void potrf::op<cpu, double>(const Tensor<cpu, 2, double>& A, const Tensor<cpu, 2, double>& L, - const nnvm::NodeAttrs& attrs) { - if ( A.dptr_ != L.dptr_ ) Copy(L, A); - FUNC_SIGNATURE_1(dpotrf, L); - ZeroUpper(L.dptr_, L.size(0)); -} // A = potri(L). struct potri { template<typename xpu, typename DType> - static void op(const Tensor<xpu, 2, DType>& L, const Tensor<xpu, 2, DType>& A, - const nnvm::NodeAttrs& attrs) - { CHECK(false) << "potri" << LA_OP_NOT_AVAIL; } + static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A, + Stream<xpu> *s, const nnvm::NodeAttrs& attrs) { + if ( A.dptr_ != L.dptr_ ) Copy(A, L, s); + linalg_batch_potri(A, true, s); + using namespace mxnet_op; + Kernel<CopyLowerToUpper, xpu>::Launch(s, A.MSize(), A.size(1)*A.stride_, A.stride_, A.dptr_); + } }; -template<> -void potri::op<cpu, float>(const Tensor<cpu, 2, float>& L, const Tensor<cpu, 2, float>& A, - const nnvm::NodeAttrs& attrs) { - if ( A.dptr_ != L.dptr_ ) Copy(A, L); - FUNC_SIGNATURE_1(spotri, A); - CopyLowerToUpper(A.dptr_, A.size(0)); -} -template<> -void potri::op<cpu, double>(const Tensor<cpu, 2, double>& A, const Tensor<cpu, 2, double>& L, - const nnvm::NodeAttrs& attrs) { - if ( A.dptr_ != L.dptr_ ) Copy(A, L); - FUNC_SIGNATURE_1(dpotri, A); - CopyLowerToUpper(A.dptr_, A.size(0)); -} // B = trsm(L,A) struct trsm { template<typename xpu, typename DType> - static void op(const Tensor<xpu, 2, DType>& L, const Tensor<xpu, 2, DType>& B, - DType alpha, bool rightside, bool transpose) - { CHECK(false) << "trsm" << LA_OP_NOT_AVAIL; } + static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& B, + DType alpha, bool rightside, bool transpose, Stream<xpu> *s) { + linalg_batch_trsm(L, B, alpha, rightside, true, transpose, s); + } template<typename xpu, typename DType> - static void op(const Tensor<xpu, 2, DType>& L, const Tensor<xpu, 2, DType>& A, - const Tensor<xpu, 2, DType>& B, const nnvm::NodeAttrs& attrs) { - if ( A.dptr_ != B.dptr_ ) Copy(B, A); + static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A, + const Tensor<xpu, 3, DType>& B, + Stream<xpu> *s, const nnvm::NodeAttrs& attrs) { + if ( A.dptr_ != B.dptr_ ) Copy(B, A, s); const LaTriangMatrixMultParam& param = nnvm::get<LaTriangMatrixMultParam>(attrs.parsed); - op(L, B, DType(param.alpha), param.rightside, param.transpose); + op(L, B, DType(param.alpha), param.rightside, param.transpose, s); } }; -template<> -void trsm::op<cpu, float>(const Tensor<cpu, 2, float>& L, const Tensor<cpu, 2, float>& B, - float alpha, bool rightside, bool transpose) { - FUNC_SIGNATURE_2(strsm, L, B); -} -template<> -void trsm::op<cpu, double>(const Tensor<cpu, 2, double>& L, const Tensor<cpu, 2, double>& B, - double alpha, bool rightside, bool transpose) { - FUNC_SIGNATURE_2(dtrsm, L, B); -} // B = trmm(L,A) struct trmm { template<typename xpu, typename DType> - static void op(const Tensor<xpu, 2, DType>& L, const Tensor<xpu, 2, DType>& B, - DType alpha, bool rightside, bool transpose) - { CHECK(false) << "trmm" << LA_OP_NOT_AVAIL; } + static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& B, + DType alpha, bool rightside, bool transpose, Stream<xpu> *s) { + linalg_batch_trmm(L, B, alpha, rightside, true, transpose, s); + } template<typename xpu, typename DType> - static void op(const Tensor<xpu, 2, DType>& L, const Tensor<xpu, 2, DType>& A, - const Tensor<xpu, 2, DType>& B, const nnvm::NodeAttrs& attrs) { - if ( A.dptr_ != B.dptr_ ) Copy(B, A); + static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A, + const Tensor<xpu, 3, DType>& B, Stream<xpu> *s, const nnvm::NodeAttrs& attrs) { + if ( A.dptr_ != B.dptr_ ) Copy(B, A, s); const LaTriangMatrixMultParam& param = nnvm::get<LaTriangMatrixMultParam>(attrs.parsed); - op(L, B, DType(param.alpha), param.rightside, param.transpose); + op(L, B, DType(param.alpha), param.rightside, param.transpose, s); } }; -template<> -void trmm::op<cpu, float>(const Tensor<cpu, 2, float>& L, const Tensor<cpu, 2, float>& B, - float alpha, bool rightside, bool transpose) { - FUNC_SIGNATURE_2(strmm, L, B); -} -template<> -void trmm::op<cpu, double>(const Tensor<cpu, 2, double>& L, const Tensor<cpu, 2, double>& B, - double alpha, bool rightside, bool transpose) { - FUNC_SIGNATURE_2(dtrmm, L, B); -} // Useful operator that is not part of BLAS/LAPACK. -struct sumlogdiag { - template<typename xpu, typename DType, - typename std::enable_if<!std::is_floating_point<DType>::value, int>::type = 0> - static void op(const Tensor<xpu, 2, DType>& A, DType& L, const nnvm::NodeAttrs& attrs) - { CHECK(false) << "sumlogdiag operator can only be called with float/double data type."; } - template<typename xpu, typename DType, - typename std::enable_if<std::is_floating_point<DType>::value, int>::type = 0> - static void op(const Tensor<xpu, 2, DType>& A, DType& B, const nnvm::NodeAttrs& attrs) { - CHECK_EQ(A.size(0), A.size(1)) << "sumlogdiag operator requires a NxN matrix as input."; - const int N(A.size(0)); +struct ForwardSumLogDiag { + template<typename DType> + MSHADOW_XINLINE static void Map(int i, int N, int stride, DType* A, DType* B) { DType sum(0); - DType *p(A.dptr_); - for ( int i = 0; i < N; ++i, p += N+1 ) { - sum += log(*p); + const int offset(i * N * stride); + for ( int j = 0; j < N; ++j ) { + sum += log(A[offset+j*(stride+1)]); } - B = sum; + B[i] = sum; + } +}; +struct sumlogdiag { + template<typename xpu, typename DType> + static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 1, DType>& B, + Stream<xpu> *s, const nnvm::NodeAttrs& attrs) { + CHECK_EQ(A.size(1), A.size(2)) << "sumlogdiag operator requires square matrices as input."; + using namespace mxnet_op; + Kernel<ForwardSumLogDiag, xpu>::Launch(s, A.size(0), A.size(1), A.stride_, A.dptr_, B.dptr_); } }; -// Backward operators +// Backward operators (always using batch processing) struct gemm_backward { template<typename xpu, typename DType> - static void op(const Tensor<xpu, 2, DType>& dD, const Tensor<xpu, 2, DType>& A, - const Tensor<xpu, 2, DType>& B, const Tensor<xpu, 2, DType>& C, - const Tensor<xpu, 2, DType>& dA, const Tensor<xpu, 2, DType>& dB, - const Tensor<xpu, 2, DType>& dC, const nnvm::NodeAttrs& attrs) { + static void op(const Tensor<xpu, 3, DType>& dD, const Tensor<xpu, 3, DType>& A, + const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& C, + const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& dB, + const Tensor<xpu, 3, DType>& dC, + Stream<xpu>* s, const nnvm::NodeAttrs& attrs) { const LaMatrixMacParam& param = nnvm::get<LaMatrixMacParam>(attrs.parsed); - (param.transpose_a ? gemm::op(B, dD, dA, DType(param.alpha), DType(0), param.transpose_b, true) - : gemm::op(dD, B, dA, DType(param.alpha), DType(0), false, !param.transpose_b)); - (param.transpose_b ? gemm::op(dD, A, dB, DType(param.alpha), DType(0), true, param.transpose_a) - : gemm::op(A, dD, dB, DType(param.alpha), DType(0), !param.transpose_a, false)); - const int N(dC.size(0)*dC.size(1)); - for ( int i = 0; i < N; ++i ) { - dC.dptr_[i] = param.beta * dD.dptr_[i]; - } + bool tA(param.transpose_a), tB(param.transpose_b); + (tA ? gemm::op(B, dD, dA, DType(param.alpha), DType(0), tB, true, s) + : gemm::op(dD, B, dA, DType(param.alpha), DType(0), false, !tB, s)); + (tB ? gemm::op(dD, A, dB, DType(param.alpha), DType(0), true, tA, s) + : gemm::op(A, dD, dB, DType(param.alpha), DType(0), !tA, false, s)); + Copy(dC, dD, s); + using namespace mxnet_op; + Kernel<Scale, xpu>::Launch(s, dC.MSize(), DType(param.beta), dC.dptr_); } }; struct gemm2_backward { template<typename xpu, typename DType> - static void op(const Tensor<xpu, 2, DType>& dC, const Tensor<xpu, 2, DType>& A, - const Tensor<xpu, 2, DType>& B, const Tensor<xpu, 2, DType>& dA, - const Tensor<xpu, 2, DType>& dB, const nnvm::NodeAttrs& attrs) { + static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& A, + const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& dA, + const Tensor<xpu, 3, DType>& dB, + Stream<xpu>* s, const nnvm::NodeAttrs& attrs) { const LaMatrixMultParam& param = nnvm::get<LaMatrixMultParam>(attrs.parsed); - (param.transpose_a ? gemm::op(B, dC, dA, DType(param.alpha), DType(0), param.transpose_b, true) - : gemm::op(dC, B, dA, DType(param.alpha), DType(0), false, !param.transpose_b)); - (param.transpose_b ? gemm::op(dC, A, dB, DType(param.alpha), DType(0), true, param.transpose_a) - : gemm::op(A, dC, dB, DType(param.alpha), DType(0), !param.transpose_a, false)); + bool tA(param.transpose_a), tB(param.transpose_b); + (tA ? gemm::op(B, dC, dA, DType(param.alpha), DType(0), tB, true, s) + : gemm::op(dC, B, dA, DType(param.alpha), DType(0), false, !tB, s)); + (tB ? gemm::op(dC, A, dB, DType(param.alpha), DType(0), true, tA, s) + : gemm::op(A, dC, dB, DType(param.alpha), DType(0), !tA, false, s)); } }; struct potrf_backward { template<typename xpu, typename DType> - static void op(const Tensor<xpu, 2, DType>& dL, const Tensor<xpu, 2, DType>& L, - const Tensor<xpu, 2, DType>& dA, const nnvm::NodeAttrs& attrs) { + static void op(const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>& L, + const Tensor<xpu, 3, DType>& dA, + Stream<xpu>* s, const nnvm::NodeAttrs& attrs) { // Backward of L = potrf(A). // dA = 0.5 * L**T * symm(L**T * dL # E) * L**(-1) where // '#' denotes Hadamard product @@ -269,81 +211,96 @@ struct potrf_backward { // symm(X) = 0.5 * (X + X**T) // Hadamard product and symm can be realized by a single copy from lower to upper triangle. if ( dL.dptr_ != dA.dptr_ ) { - Copy(dA, dL); + Copy(dA, dL, s); } - trmm::op(L, dA, DType(1.0), false, true); - CopyLowerToUpper(dA.dptr_, dA.size(0)); - trsm::op(L, dA, DType(1.0), false, true); - trsm::op(L, dA, DType(0.5), true, false); + trmm::op(L, dA, DType(1.0), false, true, s); + using namespace mxnet_op; + Kernel<CopyLowerToUpper, xpu>::Launch + (s, dA.MSize(), dA.size(1)*dA.stride_, dA.stride_, dA.dptr_); + trsm::op(L, dA, DType(1.0), false, true, s); + trsm::op(L, dA, DType(0.5), true, false, s); } }; struct potri_backward { template<typename xpu, typename DType> - static void op(const Tensor<xpu, 2, DType>& dA, const Tensor<xpu, 2, DType>& L, - const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& dL, - const nnvm::NodeAttrs& attrs) { + static void op(const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& L, + const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& dL, + Stream<xpu>* s, const nnvm::NodeAttrs& attrs) { // Backward of A = potri(L). // dL = -2 * tril(A * dA * L**(-T)), where tril() extracts lower triangle and diagonal. - gemm::op(A, dA, dL, DType(1.0), DType(0), false, false); - trsm::op(L, dL, DType(-2.0), true, true); - ZeroUpper(dL.dptr_, dL.size(0)); + gemm::op(A, dA, dL, DType(1.0), DType(0), false, false, s); + trsm::op(L, dL, DType(-2.0), true, true, s); + using namespace mxnet_op; + Kernel<ZeroUpper, xpu>::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, dL.stride_, dL.dptr_); } }; struct trsm_backward { template<typename xpu, typename DType> - static void op(const Tensor<xpu, 2, DType>& dB, const Tensor<xpu, 2, DType>& L, - const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& B, - const Tensor<xpu, 2, DType>& dL, const Tensor<xpu, 2, DType>& dA, - const nnvm::NodeAttrs& attrs) { + static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& L, + const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B, + const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>& dA, + Stream<xpu>* s, const nnvm::NodeAttrs& attrs) { // Backward of B = trsm(L,A). const LaTriangMatrixMultParam& param = nnvm::get<LaTriangMatrixMultParam>(attrs.parsed); // Compute dA - if ( dA.dptr_ != dB.dptr_ ) Copy(dA, dB); - trsm::op(L, dA, DType(param.alpha), param.rightside, !param.transpose); + if ( dA.dptr_ != dB.dptr_ ) Copy(dA, dB, s); + trsm::op(L, dA, DType(param.alpha), param.rightside, !param.transpose, s); // Compute dL const bool da_left(param.rightside == param.transpose); - (da_left ? - gemm::op(dA, B, dL, DType(-1.0/param.alpha), DType(0), param.transpose, !param.transpose) - : gemm::op(B, dA, dL, DType(-1.0/param.alpha), DType(0), !param.transpose, param.transpose)); - ZeroUpper(dL.dptr_, dL.size(0)); + DType scale(-1.0/param.alpha); + (da_left ? gemm::op(dA, B, dL, scale, DType(0), param.transpose, !param.transpose, s) + : gemm::op(B, dA, dL, scale, DType(0), !param.transpose, param.transpose, s)); + using namespace mxnet_op; + Kernel<ZeroUpper, xpu>::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, dL.stride_, dL.dptr_); } }; struct trmm_backward { template<typename xpu, typename DType> - static void op(const Tensor<xpu, 2, DType>& dB, const Tensor<xpu, 2, DType>& L, - const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& B, - const Tensor<xpu, 2, DType>& dL, const Tensor<xpu, 2, DType>& dA, - const nnvm::NodeAttrs& attrs) { + static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& L, + const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B, + const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>& dA, + Stream<xpu>* s, const nnvm::NodeAttrs& attrs) { // Backward of B = trmm(L,A). const LaTriangMatrixMultParam& param = nnvm::get<LaTriangMatrixMultParam>(attrs.parsed); // Compute dL const bool db_left(param.rightside == param.transpose); - (db_left ? gemm::op(dB, A, dL, DType(param.alpha), DType(0), param.transpose, !param.transpose) - : gemm::op(A, dB, dL, DType(param.alpha), DType(0), !param.transpose, param.transpose)); - ZeroUpper(dL.dptr_, dL.size(0)); + DType scale(param.alpha); + (db_left ? gemm::op(dB, A, dL, scale, DType(0), param.transpose, !param.transpose, s) + : gemm::op(A, dB, dL, scale, DType(0), !param.transpose, param.transpose, s)); + using namespace mxnet_op; + Kernel<ZeroUpper, xpu>::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, dL.stride_, dL.dptr_); // Compute dA - if ( dA.dptr_ != dB.dptr_ ) Copy(dA, dB); - trmm::op(L, dA, DType(param.alpha), param.rightside, !param.transpose); + if ( dA.dptr_ != dB.dptr_ ) Copy(dA, dB, s); + trmm::op(L, dA, scale, param.rightside, !param.transpose, s); } }; +struct BackwardSumLogDiag { + template<typename DType> + MSHADOW_XINLINE static void Map(int i, int N, int stride, DType* dB, DType* A, DType* dA) { + const int offset(i * N * stride); + for ( int j = 0; j < N; ++j ) { + dA[offset+j*(stride+1)] = dB[i]/A[offset+j*(stride+1)]; + } + } +}; struct sumlogdiag_backward { template<typename xpu, typename DType> - static void op(const DType& dB, const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DType>& dA, - const nnvm::NodeAttrs& attrs, bool add) { + static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& A, + const Tensor<xpu, 3, DType>& dA, + Stream<xpu>* s, const nnvm::NodeAttrs& attrs) { // Backward of B = sumlogdiag(A). - const int N(A.size(0)); - if ( !add ) { - for ( int i = 0; i < N*N; ++i ) { - dA.dptr_[i] = 0; - } - } - for ( int i = 0; i < N; ++i ) { - dA.dptr_[i*(N+1)] += dB / A.dptr_[i*N+i]; - } + // dB is actually a 1-d tensor but we convert it to a 3-D one before calling + // this function as the LaOpCaller-adapters can only deal with a uniform + // dimension for all tensor inputs. This doesn't matter as we will interpret + // it correctly internally in this function. + using namespace mxnet_op; + Kernel<Scale, xpu>::Launch(s, dA.MSize(), DType(0), dA.dptr_); + Kernel<BackwardSumLogDiag, xpu>::Launch + (s, A.size(0), A.size(1), A.stride_, dB.dptr_, A.dptr_, dA.dptr_); } }; diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 718e3df..7d56b46 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3412,15 +3412,8 @@ def test_deformable_psroipooling(): def test_laop(): - return - - # Currently no support for GPU. Will be added soon - # so keep these tests here in this file and activate - # gpu-testing when it is ready. - dev = default_context() - if dev.device_type == 'gpu': - return + # enable numerical checking of gradients grad_check = 1 data1 = mx.symbol.Variable('data1') -- To stop receiving notification emails like this one, please contact ['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].