This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new a36bf57  Add cuda rtc module (#8017)
a36bf57 is described below

commit a36bf573ad82550dbb6692a89d7ddd1d5e4487fd
Author: Eric Junyuan Xie <[email protected]>
AuthorDate: Tue Sep 26 12:10:26 2017 -0700

    Add cuda rtc module (#8017)
    
    * Add cuda rtc module
    
    * add to docs
    
    * commit fix
    
    * fix
    
    * fix
    
    * Update Jenkinsfile
---
 CMakeLists.txt                        |   1 -
 Jenkinsfile                           |   4 +
 Makefile                              |   9 +-
 docs/api/python/index.md              |   9 ++
 docs/api/python/rtc/rtc.md            |  29 ++++
 include/mxnet/c_api.h                 |  57 +++++++
 include/mxnet/mxrtc.h                 | 107 --------------
 include/mxnet/rtc.h                   | 136 +++++++++++++++++
 make/config.mk                        |   3 -
 make/osx.mk                           |   3 -
 python/mxnet/__init__.py              |   2 +-
 python/mxnet/base.py                  |   2 +
 python/mxnet/rtc.py                   | 271 ++++++++++++++++++++++++----------
 src/c_api/c_api.cc                    | 127 +++++++++++-----
 src/common/cuda_utils.h               |  34 +++++
 src/common/mxrtc.cc                   | 159 --------------------
 src/common/rtc.cc                     | 188 +++++++++++++++++++++++
 tests/python/gpu/test_operator_gpu.py |  29 ++++
 18 files changed, 774 insertions(+), 396 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 132b0e1..c20759c 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -363,7 +363,6 @@ if(USE_CUDA)
   endif()
   list(APPEND SOURCE ${cuda_objs} ${CUDA})
   add_definitions(-DMXNET_USE_CUDA=1)
-  add_definitions(-DMXNET_USE_NVRTC=1)
   if(CUDA_LIBRARY_PATH)
     if(IS_CONTAINER_BUILD)
       # In case of building on a production-like build container which may not 
have Cuda installed
diff --git a/Jenkinsfile b/Jenkinsfile
index d8cd2f5..eb82e00 100644
--- a/Jenkinsfile
+++ b/Jenkinsfile
@@ -370,7 +370,11 @@ try {
                     init_git()
                     unpack_lib('gpu')
                     timeout(time: max_time, unit: 'MINUTES') {
+                      try {
                         sh "${docker_run} gpu ./perl-package/test.sh"
+                      } catch (exc) {
+                        error "Perl GPU test failed."
+                      }
                     }
                 }
             }
diff --git a/Makefile b/Makefile
index 54df33f..5626727 100644
--- a/Makefile
+++ b/Makefile
@@ -272,7 +272,7 @@ ALL_DEP = $(OBJ) $(EXTRA_OBJ) $(PLUGIN_OBJ) $(LIB_DEP)
 ifeq ($(USE_CUDA), 1)
        CFLAGS += -I$(ROOTDIR)/cub
        ALL_DEP += $(CUOBJ) $(EXTRA_CUOBJ) $(PLUGIN_CUOBJ)
-       LDFLAGS += -lcuda -lcufft
+       LDFLAGS += -lcuda -lcufft -lnvrtc
        SCALA_PKG_PROFILE := $(SCALA_PKG_PROFILE)-gpu
 else
        SCALA_PKG_PROFILE := $(SCALA_PKG_PROFILE)-cpu
@@ -281,13 +281,6 @@ endif
 # For quick compile test, used smaller subset
 ALLX_DEP= $(ALL_DEP)
 
-ifeq ($(USE_NVRTC), 1)
-       LDFLAGS += -lnvrtc
-       CFLAGS += -DMXNET_USE_NVRTC=1
-else
-       CFLAGS += -DMXNET_USE_NVRTC=0
-endif
-
 build/src/%.o: src/%.cc
        @mkdir -p $(@D)
        $(CXX) -std=c++11 -c $(CFLAGS) -MMD -c $< -o $@
diff --git a/docs/api/python/index.md b/docs/api/python/index.md
index 75aed07..e7f8d45 100644
--- a/docs/api/python/index.md
+++ b/docs/api/python/index.md
@@ -134,3 +134,12 @@ imported by running:
 
    metric/metric.md
 ```
+
+## Run-Time Compilation API
+
+```eval_rst
+.. toctree::
+   :maxdepth 1
+
+   rtc/rtc.md
+```
diff --git a/docs/api/python/rtc/rtc.md b/docs/api/python/rtc/rtc.md
new file mode 100644
index 0000000..bb1c314
--- /dev/null
+++ b/docs/api/python/rtc/rtc.md
@@ -0,0 +1,29 @@
+# Run-Time Compilation API
+
+```eval_rst
+.. currentmodule:: mxnet.rtc
+```
+
+## Overview
+
+The RTC package contains tools for compiling and running CUDA code from python
+frontend. The compiled kernels can be used stand-alone or combined with
+`autograd.Function` or `operator.CustomOpProp` to support differentiation.
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    mxnet.rtc
+```
+
+## API Reference
+
+<script type="text/javascript" 
src='../../_static/js/auto_module_index.js'></script>
+
+```eval_rst
+.. automodule:: mxnet.rtc
+    :members:
+```
+
+<script>auto_index("api-reference");</script>
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 1a2b82a..4f4afa3 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -84,6 +84,10 @@ typedef void *KVStoreHandle;
 typedef void *RecordIOHandle;
 /*! \brief handle to MXRtc*/
 typedef void *RtcHandle;
+/*! \brief handle to rtc cuda module*/
+typedef void *CudaModuleHandle;
+/*! \brief handle to rtc cuda kernel*/
+typedef void *CudaKernelHandle;
 
 typedef void (*ExecutorMonitorCallback)(const char*,
                                         NDArrayHandle,
@@ -1922,6 +1926,59 @@ MXNET_DLL int MXCustomOpRegister(const char* op_type, 
CustomOpPropCreator creato
 MXNET_DLL int MXCustomFunctionRecord(int num_inputs, NDArrayHandle *inputs,
                                      int num_outputs, NDArrayHandle *outputs,
                                      struct MXCallbackList *callbacks);
+/*
+ * \brief create cuda rtc module
+ * \param source cuda source code
+ * \param num_options number of compiler flags
+ * \param options compiler flags
+ * \param num_exports number of exported function names
+ * \param exported function names
+ * \param out handle to created module
+ */
+MXNET_DLL int MXRtcCudaModuleCreate(const char* source, int num_options,
+                                    const char** options, int num_exports,
+                                    const char** exports, CudaModuleHandle 
*out);
+/*
+ * \brief delete cuda rtc module
+ * \param handle handle to cuda module
+ */
+MXNET_DLL int MXRtcCudaModuleFree(CudaModuleHandle handle);
+/*
+ * \brief get kernel from module
+ * \param handle handle to cuda module
+ * \param name name of kernel function
+ * \param num_args number of arguments
+ * \param is_ndarray whether argument is ndarray
+ * \param is_const whether argument is constant
+ * \param arg_types data type of arguments
+ * \param out created kernel
+ */
+MXNET_DLL int MXRtcCudaKernelCreate(CudaModuleHandle handle, const char* name,
+                                    int num_args, int* is_ndarray, int* 
is_const,
+                                    int* arg_types, CudaKernelHandle *out);
+/*
+ * \brief delete kernel
+ * \param handle handle to previously created kernel
+ */
+MXNET_DLL int MXRtcCudaKernelFree(CudaKernelHandle handle);
+/*
+ * \brief launch cuda kernel
+ * \param handle handle to kernel
+ * \param dev_id (GPU) device id
+ * \param args pointer to arguments
+ * \param grid_dim_x grid dimension x
+ * \param grid_dim_y grid dimension y
+ * \param grid_dim_z grid dimension z
+ * \param block_dim_x block dimension x
+ * \param block_dim_y block dimension y
+ * \param block_dim_z block dimension z
+ * \param shared_mem size of dynamically allocated shared memory
+ */
+MXNET_DLL int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** 
args,
+                                  mx_uint grid_dim_x, mx_uint grid_dim_y,
+                                  mx_uint grid_dim_z, mx_uint block_dim_x,
+                                  mx_uint block_dim_y, mx_uint block_dim_z,
+                                  mx_uint shared_mem);
 
 #ifdef __cplusplus
 }
diff --git a/include/mxnet/mxrtc.h b/include/mxnet/mxrtc.h
deleted file mode 100644
index 8d7facc..0000000
--- a/include/mxnet/mxrtc.h
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file mxrtc.h
- * \brief Wrapper for NVRTC
- * \author Junyuan Xie
- */
-#ifndef MXNET_MXRTC_H_
-#define MXNET_MXRTC_H_
-#include "./base.h"
-#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
-#include <nvrtc.h>
-#include <cuda.h>
-
-#include <vector>
-#include <string>
-#include <memory>
-#include <utility>
-#include <unordered_map>
-#include "./ndarray.h"
-
-namespace mxnet {
-
-/*!
- * \brief Runtime compile of cuda kernel code with NVRTC
- */
-class MXRtc {
- public:
-  /*!
-   * \brief Build a new kernel.
-   *
-   * If the same kernel has been compiled before it will be load from
-   * cache instead of compile again.
-   * \param name name of the kernel function.
-   * \param input list of input ndarrays and their name.
-   * \param output list of output ndarrays and their name.
-   * \param kernel cuda code.
-   */
-  MXRtc(const std::string& name,
-        std::vector<std::pair<std::string, NDArray> > const& input,
-        std::vector<std::pair<std::string, NDArray> > const& output,
-        const std::string& kernel);
-  /*!
-   * \brief launch a kernel with the engine.
-   * \param input list of input ndarray.
-   * \param output list of output ndarray.
-   * \param grid_dim_X kernel grid dimensions.
-   * \param grid_dim_Y kernel grid dimensions.
-   * \param grid_dim_Z kernel grid dimensions.
-   * \param block_dim_X kernel block dimensions.
-   * \param block_dim_Y kernel block dimensions.
-   * \param block_dim_Z kernel block dimensions.
-   */
-  void push(std::vector<NDArray> const& input,
-            std::vector<NDArray> const& output,
-            unsigned int  grid_dim_X,
-            unsigned int  grid_dim_Y,
-            unsigned int  grid_dim_Z,
-            unsigned int  block_dim_X,
-            unsigned int  block_dim_Y,
-            unsigned int  block_dim_Z);
-
- private:
-  static const char str_type[];
-  static std::unordered_map<std::string, char*> kernel_registry;
-
-  std::string name_;
-  index_t num_input_, num_output_;
-  std::string code_;
-  char* ptx_;
-  std::unordered_map<int, CUmodule> module_;
-  std::unordered_map<int, CUfunction> func_;
-
-  /*!
-   * \brief add supporting code to kernel.
-   */
-  std::string decorate(const std::string& name,
-                       std::vector<std::pair<std::string, NDArray> > const& 
input,
-                       std::vector<std::pair<std::string, NDArray> > const& 
output,
-                       const std::string kernel);
-  /*!
-   * \brief compile the kernel with nvrtc.
-   */
-  char* compile(const std::string& name, const std::string& code);
-};
-
-}  // namespace mxnet
-
-#endif  // MXNET_USE_CUDA && MXNET_USE_NVRTC
-#endif  // MXNET_MXRTC_H_
diff --git a/include/mxnet/rtc.h b/include/mxnet/rtc.h
new file mode 100644
index 0000000..747c0b5
--- /dev/null
+++ b/include/mxnet/rtc.h
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_RTC_H_
+#define MXNET_RTC_H_
+#include "./base.h"
+#if MXNET_USE_CUDA
+#include <nvrtc.h>
+#include <cuda.h>
+
+#include <vector>
+#include <string>
+#include <memory>
+#include <utility>
+#include <unordered_map>
+#include <unordered_set>
+#include "./ndarray.h"
+
+namespace mxnet {
+namespace rtc {
+
+/*! \brief Cuda runtime compile module. */
+class CudaModule {
+ private:
+  /*! \brief Structure for holding internal info. */
+  struct Chunk {
+    /*!
+     * \brief Constructs cuda module.
+     * \param source cuda source code.
+     * \param exports export symbols before mangling.
+     */
+    Chunk(const char* source,
+          const std::vector<std::string>& options,
+          const std::vector<std::string>& exports);
+    /*! \brief deconstrutor */
+    ~Chunk();
+    /*!
+     * \brief Get handle to cuda kernel from loaded module
+     * \param mangled_name mangled kernel name
+     * \param ctx context to run kernel on
+     * \return loaded function handle
+     */
+    CUfunction GetFunction(const std::string& mangled_name, const Context& 
ctx);
+    /*! \brief nvrtc program handle. */
+    nvrtcProgram prog_;
+    /*! \brief compiled cuda PTX */
+    char* ptx_;
+    /*! \brief lazily loaded cuda module */
+    std::unordered_map<int, CUmodule> mod_;
+    /*! \brief exported names */
+    std::unordered_set<std::string> exports_;
+  };
+  /*! \brief pointer to Chunk */
+  std::shared_ptr<Chunk> ptr_;
+
+ public:
+  /*! \brief cuda kernel argument descriptor */
+  struct ArgType {
+    /*! \brief whether argument is NDArray */
+    bool is_ndarray;
+    /*! \brief whether argument is constant (input) */
+    bool is_const;
+    /*! \brief data type of argument */
+    mshadow::TypeFlag dtype;
+  };
+  /*! \brief Cuda kernel */
+  class Kernel {
+   public:
+    /*! \brief Launch the kernel */
+    void Launch(const Context& ctx, const std::vector<dmlc::any>& args,
+                uint32_t grid_dim_x, uint32_t grid_dim_y, uint32_t grid_dim_z,
+                uint32_t block_dim_x, uint32_t block_dim_y, uint32_t 
block_dim_z,
+                uint32_t shared_mem);
+    /*! \brief kernel interface signature */
+    const std::vector<ArgType>& signature() { return signature_; }
+
+   private:
+    friend class CudaModule;
+    /*!
+     * \brief constructor
+     * \param mod module of this kernel
+     * \param mangled_name mangled kernel name
+     * \param signature kernel argument signature
+     */
+    Kernel(const std::shared_ptr<Chunk>& mod,
+           const std::string& mangled_name,
+           const std::vector<ArgType>& signature);
+    /*! \brief mangled kernel name */
+    std::string mangled_name_;
+    /*! \brief kernel argument signature */
+    std::vector<ArgType> signature_;
+    /*! \brief module of this kernel */
+    std::shared_ptr<Chunk> mod_;
+    /*! \brief cached kernel function on each device */
+    std::unordered_map<int, CUfunction> func_;
+  };
+  /*!
+   * \brief CudaModule constructor
+   * \param source cuda source code.
+   * \param exports export symbols before mangling.
+   */
+  CudaModule(const char* source,
+             const std::vector<std::string>& options,
+             const std::vector<std::string>& exports)
+      : ptr_(std::make_shared<Chunk>(source, options, exports)) {}
+  /*!
+   * \brief Get cuda kernal from module by name
+   * \param name kernel name
+   * \param signature kernel signature
+   * \return shared pointer to cuda kernel
+   */
+  std::shared_ptr<Kernel> GetKernel(const std::string& name,
+                                    const std::vector<ArgType>& signature);
+};
+
+}  // namespace rtc
+}  // namespace mxnet
+
+#endif  // MXNET_USE_CUDA
+#endif  // MXNET_RTC_H_
diff --git a/make/config.mk b/make/config.mk
index d44898b..c5de898 100644
--- a/make/config.mk
+++ b/make/config.mk
@@ -57,9 +57,6 @@ USE_CUDA_PATH = NONE
 # whether use CuDNN R3 library
 USE_CUDNN = 0
 
-# whether use cuda runtime compiling for writing kernels in native language 
(i.e. Python)
-USE_NVRTC = 0
-
 # whether use opencv during compilation
 # you can disable it, however, you will not able to use
 # imbin iterator
diff --git a/make/osx.mk b/make/osx.mk
index 650e284..d9ce6f2 100644
--- a/make/osx.mk
+++ b/make/osx.mk
@@ -51,9 +51,6 @@ USE_CUDA_PATH = NONE
 # whether use CUDNN R3 library
 USE_CUDNN = 0
 
-# whether use cuda runtime compiling for writing kernels in native language 
(i.e. Python)
-USE_NVRTC = 0
-
 # whether use opencv during compilation
 # you can disable it, however, you will not able to use
 # imbin iterator
diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py
index 72dc2b2..cf0ba37 100644
--- a/python/mxnet/__init__.py
+++ b/python/mxnet/__init__.py
@@ -54,7 +54,7 @@ from . import lr_scheduler
 from . import kvstore as kv
 from . import kvstore_server
 # Runtime compile module
-from .rtc import Rtc as rtc
+from . import rtc
 # Attribute scope to add attributes to symbolic graphs
 from .attribute import AttrScope
 
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index e422dad..fc07853 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -125,6 +125,8 @@ DataIterHandle = ctypes.c_void_p
 KVStoreHandle = ctypes.c_void_p
 RecordIOHandle = ctypes.c_void_p
 RtcHandle = ctypes.c_void_p
+CudaModuleHandle = ctypes.c_void_p
+CudaKernelHandle = ctypes.c_void_p
 #----------------------------
 # helper function definition
 #----------------------------
diff --git a/python/mxnet/rtc.py b/python/mxnet/rtc.py
index 9da38c6..aff4588 100644
--- a/python/mxnet/rtc.py
+++ b/python/mxnet/rtc.py
@@ -18,91 +18,212 @@
 """Interface to runtime cuda kernel compile module."""
 from __future__ import absolute_import
 
+import re
 import ctypes
-from .base import _LIB, NDArrayHandle, RtcHandle, mx_uint, c_array, check_call
+import numpy as np
+
+from .base import _LIB, mx_uint, c_array, check_call
+from .base import c_str, CudaModuleHandle, CudaKernelHandle, numeric_types, 
string_types
+from .ndarray import _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, NDArray
+
+_DTYPE_CPP_TO_NP = {
+    'float': np.float32,
+    'double': np.float64,
+    '__half': np.float16,
+    'uint8_t': np.uint8,
+    'int': np.int32,
+    'int32_t': np.int32,
+    'int8_t': np.int8,
+    'char': np.int8,
+    'int64_t': np.int64,
+}
+
+class CudaModule(object):
+    r"""Compile and run CUDA code from Python.
+
+    In CUDA 7.5, you need to prepend your kernel definitions
+    with 'extern "C"' to avoid name mangling::
+
+        source = r'''
+        extern "C" __global__ void axpy(const float *x, float *y, float alpha) 
{
+            int i = threadIdx.x + blockIdx.x * blockDim.x;
+            y[i] += alpha * x[i];
+        }
+        '''
+        module = mx.rtc.CudaModule(source)
+        func = module.get_kernel("axpy", "const float *x, float *y, float 
alpha")
+        x = mx.nd.ones((10,), ctx=mx.gpu(0))
+        y = mx.nd.zeros((10,), ctx=mx.gpu(0))
+        func.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1))
+        print(y)
+
+    Starting from CUDA 8.0, you can instead export functions by name.
+    This also allows you to use templates::
+
+        source = r'''
+        template<typename DType>
+        __global__ void axpy(const DType *x, DType *y, DType alpha) {
+            int i = threadIdx.x + blockIdx.x * blockDim.x;
+            y[i] += alpha * x[i];
+        }
+        '''
+        module = mx.rtc.CudaModule(source, exports=['axpy<float>', 
'axpy<double>'])
+        func32 = module.get_kernel("axpy<float>", "const float *x, float *y, 
float alpha")
+        x = mx.nd.ones((10,), dtype='float32', ctx=mx.gpu(0))
+        y = mx.nd.zeros((10,), dtype='float32', ctx=mx.gpu(0))
+        func32.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1))
+        print(y)
+
+        func64 = module.get_kernel("axpy<double>", "const double *x, double 
*y, double alpha")
+        x = mx.nd.ones((10,), dtype='float64', ctx=mx.gpu(0))
+        y = mx.nd.zeros((10,), dtype='float64', ctx=mx.gpu(0))
+        func32.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1))
+        print(y)
 
-class Rtc(object):
-    """MXRtc object in mxnet.
-    This class allow you to write CUDA kernels in Python
-    and call them with NDArray.
 
     Parameters
     ----------
-    name : str
-        Name of the kernel.
-    inputs : tuple of (str, mxnet.ndarray)
-        List of input names and ndarray.
-    outputs : tuple of (str, mxnet.ndarray)
-        List of output names and ndarray.
-    kernel : str
-        The actual kernel code.
-        Note that this is only the body of the kernel, i.e.
-        after { and before }. Rtc will decorate the kernel.
-        For example, if ``name = "mykernel"`` and
-        inputs = [('x', mx.nd.zeros((10,)))]
-        outputs = [('y', mx.nd.zeros((10,)))]
-        kernel = "y[threadIdx.x] = x[threadIdx.x];",
-        then the compiled kernel will be:
-        extern "C" __global__ mykernel(float *x, float *y) {
-            const int x_ndim = 1;
-            const int x_dims = { 10 };
-            const int y_ndim = 1;
-            const int y_dims = { 10 };
-
-            y[threadIdx.x] = x[threadIdx.x];
-        }
+    source : str
+        Complete source code.
+    options : tuple of str
+        Compiler flags. For example, use "-I/usr/local/cuda/include" to
+        add cuda headers to include path.
+    exports : tuple of str
+        Export kernel names.
     """
-    def __init__(self, name, inputs, outputs, kernel):
-        self.handle = RtcHandle()
-        input_names = ctypes.cast(c_array(ctypes.c_char_p, [i[0] for i in 
inputs]),
-                                  ctypes.POINTER(ctypes.c_char_p))
-        output_names = ctypes.cast(c_array(ctypes.c_char_p, [i[0] for i in 
outputs]),
-                                   ctypes.POINTER(ctypes.c_char_p))
-        input_nds = ctypes.cast(c_array(NDArrayHandle, [i[1].handle for i in 
inputs]),
-                                ctypes.POINTER(NDArrayHandle))
-        output_nds = ctypes.cast(c_array(NDArrayHandle, [i[1].handle for i in 
outputs]),
-                                 ctypes.POINTER(NDArrayHandle))
-        check_call(_LIB.MXRtcCreate(ctypes.c_char_p(name),
-                                    mx_uint(len(inputs)),
-                                    mx_uint(len(outputs)),
-                                    input_names,
-                                    output_names,
-                                    input_nds,
-                                    output_nds,
-                                    ctypes.c_char_p(kernel),
-                                    ctypes.byref(self.handle)))
+    def __init__(self, source, options=(), exports=()):
+        if isinstance(options, string_types):
+            options = (options,)
+        if isinstance(exports, string_types):
+            exports = (exports,)
+        self.handle = CudaModuleHandle()
+        check_call(_LIB.MXRtcCudaModuleCreate(
+            c_str(source),
+            len(options),
+            c_array(ctypes.c_char_p, [c_str(opt) for opt in options]),
+            len(exports),
+            c_array(ctypes.c_char_p, [c_str(name) for name in exports]),
+            ctypes.byref(self.handle)))
 
     def __del__(self):
-        check_call(_LIB.MXRtcFree(self.handle))
+        check_call(_LIB.MXRtcCudaModuleFree(self.handle))
 
-    def push(self, inputs, outputs, grid_dims, block_dims):
-        """Run the kernel.
+    def get_kernel(self, name, signature):
+        r"""Get CUDA kernel from compiled module.
 
         Parameters
         ----------
-        inputs : list of NDArray
-            List of inputs. Can contain different NDArrays than those used for 
the constructor,
-            but its elements must have the same shapes and appear in the same 
order.
-        outputs : list of NDArray
-            List of outputs. Can contain different ndarrays than used for the 
constructor,
-            but must have the same shapes and appear in the same order.
-        grid_dims : tuple of 3 uint
-            Grid dimension for kernel launch.
-        block_dims : tuple of 3 uint
-            Block dimension for kernel launch.
+        name : str
+            String name of the kernel.
+        signature : str
+            Function signature for the kernel. For example, if a kernel is
+            declared as::
+
+                extern "C" __global__ void axpy(const float *x, double *y, int 
alpha)
+
+            Then its signature should be::
+
+                const float *x, double *y, int alpha
+
+            or::
+
+                const float *, double *, int
+
+            Note that `*` in signature marks an argument as array and
+            `const` marks an argument as constant (input) array.
+
+        Returns
+        -------
+        CudaKernel
+            CUDA kernels that can be launched on GPUs.
+        """
+        hdl = CudaKernelHandle()
+        is_ndarray = []
+        is_const = []
+        dtypes = []
+        pattern = 
re.compile(r"""^\s*(const)?\s*([\w_]+)\s*(\*)?\s*([\w_]+)?\s*$""")
+        args = re.sub(r"\s+", " ", signature).split(",")
+        for arg in args:
+            match = pattern.match(arg)
+            if not match or match.groups()[1] == 'const':
+                raise ValueError(
+                    'Invalid function prototype "%s". Must be in the '
+                    'form of "(const) type (*) (name)"'%arg)
+            is_const.append(bool(match.groups()[0]))
+            dtype = match.groups()[1]
+            is_ndarray.append(bool(match.groups()[2]))
+            if dtype not in _DTYPE_CPP_TO_NP:
+                raise TypeError(
+                    "Unsupported kernel argument type %s. Supported types are: 
%s."%(
+                        arg, ','.join(_DTYPE_CPP_TO_NP.keys())))
+            dtypes.append(_DTYPE_NP_TO_MX[_DTYPE_CPP_TO_NP[dtype]])
+
+        check_call(_LIB.MXRtcCudaKernelCreate(
+            self.handle,
+            c_str(name),
+            len(dtypes),
+            c_array(ctypes.c_int, [ctypes.c_int(i) for i in is_ndarray]),
+            c_array(ctypes.c_int, [ctypes.c_int(i) for i in is_const]),
+            c_array(ctypes.c_int, [ctypes.c_int(i) for i in dtypes]),
+            ctypes.byref(hdl)))
+
+        return CudaKernel(hdl, name, is_ndarray, dtypes)
+
+class CudaKernel(object):
+    """Constructs CUDA kernel. Should be created by `CudaModule.get_kernel`,
+    not intended to be used by users."""
+    def __init__(self, handle, name, is_ndarray, dtypes):
+        self.handle = handle
+        self._name = name
+        self._is_ndarray = is_ndarray
+        self._dtypes = [_DTYPE_MX_TO_NP[i] for i in dtypes]
+
+    def __del__(self):
+        check_call(_LIB.MXRtcCudaKernelFree(self.handle))
+
+    def launch(self, args, ctx, grid_dims, block_dims, shared_mem=0):
+        """Launch cuda kernel.
+
+        Parameters
+        ----------
+        args : tuple of NDArray or numbers
+            List of arguments for kernel. NDArrays are expected for pointer
+            types (e.g. `float*`, `double*`) while numbers are expected for
+            non-pointer types (e.g. `int`, `float`).
+        ctx : Context
+            The context to launch kernel on. Must be GPU context.
+        grid_dims : tuple of 3 integers
+            Grid dimensions for CUDA kernel.
+        block_dims : tuple of 3 integers
+            Block dimensions for CUDA kernel.
+        shared_mem : integer, optional
+            Size of dynamically allocated shared memory. Defaults to 0.
         """
-        input_nds = ctypes.cast(c_array(NDArrayHandle, [i.handle for i in 
inputs]),
-                                ctypes.POINTER(NDArrayHandle))
-        output_nds = ctypes.cast(c_array(NDArrayHandle, [i.handle for i in 
outputs]),
-                                 ctypes.POINTER(NDArrayHandle))
-        check_call(_LIB.MXRtcPush(self.handle,
-                                  mx_uint(len(inputs)),
-                                  mx_uint(len(outputs)),
-                                  input_nds,
-                                  output_nds,
-                                  mx_uint(grid_dims[0]),
-                                  mx_uint(grid_dims[1]),
-                                  mx_uint(grid_dims[2]),
-                                  mx_uint(block_dims[0]),
-                                  mx_uint(block_dims[1]),
-                                  mx_uint(block_dims[2])))
+        assert ctx.device_type == 'gpu', "Cuda kernel can only be launched on 
GPU"
+        assert len(grid_dims) == 3, "grid_dims must be a tuple of 3 integers"
+        assert len(block_dims) == 3, "grid_dims must be a tuple of 3 integers"
+        assert len(args) == len(self._dtypes), \
+            "CudaKernel(%s) expects %d arguments but got %d"%(
+                self._name, len(self._dtypes), len(args))
+        void_args = []
+        ref_holder = []
+        for i, (arg, is_nd, dtype) in enumerate(zip(args, self._is_ndarray, 
self._dtypes)):
+            if is_nd:
+                assert isinstance(arg, NDArray), \
+                    "The %d-th argument is expected to be a NDArray but got 
%s"%(
+                        i, type(arg))
+                void_args.append(arg.handle)
+            else:
+                assert isinstance(arg, numeric_types), \
+                    "The %d-th argument is expected to be a number, but got 
%s"%(
+                        i, type(arg))
+                ref_holder.append(np.array(arg, dtype=dtype))
+                
void_args.append(ref_holder[-1].ctypes.data_as(ctypes.c_void_p))
+
+        check_call(_LIB.MXRtcCudaKernelCall(
+            self.handle,
+            ctx.device_id,
+            c_array(ctypes.c_void_p, void_args),
+            mx_uint(grid_dims[0]), mx_uint(grid_dims[1]), 
mx_uint(grid_dims[2]),
+            mx_uint(block_dims[0]), mx_uint(block_dims[1]), 
mx_uint(block_dims[2]),
+            mx_uint(shared_mem)))
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 29df716..8ab7f1f 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -33,7 +33,7 @@
 #include <mxnet/io.h>
 #include <mxnet/c_api.h>
 #include <mxnet/kvstore.h>
-#include <mxnet/mxrtc.h>
+#include <mxnet/rtc.h>
 #include <vector>
 #include <sstream>
 #include <string>
@@ -1102,21 +1102,7 @@ int MXRtcCreate(char* name, mx_uint num_input, mx_uint 
num_output,
                 NDArrayHandle* inputs, NDArrayHandle* outputs,
                 char* kernel, RtcHandle *out) {
   API_BEGIN();
-#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
-  std::vector<std::pair<std::string, NDArray> > input, output;
-  for (mx_uint i = 0; i < num_input; ++i) {
-    input.push_back(std::pair<std::string, NDArray>(input_names[i],
-                                                    
*reinterpret_cast<NDArray*>(inputs[i])));
-  }
-  for (mx_uint i = 0; i < num_output; ++i) {
-    output.push_back(std::pair<std::string, NDArray>(output_names[i],
-                                                     
*reinterpret_cast<NDArray*>(outputs[i])));
-  }
-  MXRtc *rtc = new MXRtc(name, input, output, kernel);
-  *out = reinterpret_cast<RtcHandle>(rtc);
-#else
-  LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc.";
-#endif  // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
+  LOG(FATAL) << "Old rtc API is deprecated. Please use CudaModule";
   API_END();
 }
 
@@ -1129,34 +1115,13 @@ int MXRtcPush(RtcHandle handle, mx_uint num_input, 
mx_uint num_output,
               mx_uint blockDimY,
               mx_uint blockDimZ) {
   API_BEGIN();
-#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
-  std::vector<NDArray> input, output;
-  for (mx_uint i = 0; i < num_input; ++i) {
-    input.push_back(*reinterpret_cast<NDArray*>(inputs[i]));
-  }
-  for (mx_uint i = 0; i < num_output; ++i) {
-    output.push_back(*reinterpret_cast<NDArray*>(outputs[i]));
-  }
-  reinterpret_cast<MXRtc*>(handle)->push(input, output,
-                                         gridDimX,
-                                         gridDimY,
-                                         gridDimZ,
-                                         blockDimX,
-                                         blockDimY,
-                                         blockDimZ);
-#else
-  LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc.";
-#endif  // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
+  LOG(FATAL) << "Old rtc API is deprecated. Please use CudaModule";
   API_END();
 }
 
 int MXRtcFree(RtcHandle handle) {
   API_BEGIN();
-#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
-  delete reinterpret_cast<MXRtc*>(handle);
-#else
-  LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc.";
-#endif  // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
+  LOG(FATAL) << "Old rtc API is deprecated. Please use CudaModule";
   API_END();
 }
 
@@ -1165,3 +1130,87 @@ int MXCustomOpRegister(const char* op_type, 
CustomOpPropCreator creator) {
   mxnet::op::custom::Registry::Get()->Register(op_type, creator);
   API_END();
 }
+
+
+int MXRtcCudaModuleCreate(const char* source, int num_options,
+                          const char** options, int num_exports,
+                          const char** exports, CudaModuleHandle *out) {
+  API_BEGIN();
+#if MXNET_USE_CUDA
+  std::vector<std::string> str_opts;
+  for (int i = 0; i < num_options; ++i) str_opts.emplace_back(options[i]);
+  std::vector<std::string> str_exports;
+  for (int i = 0; i < num_exports; ++i) str_exports.emplace_back(exports[i]);
+  *out = new rtc::CudaModule(source, str_opts, str_exports);
+#else
+  LOG(FATAL) << "Compile with USE_CUDA=1 to use GPU.";
+#endif
+  API_END();
+}
+
+int MXRtcCudaModuleFree(CudaModuleHandle handle) {
+  API_BEGIN();
+#if MXNET_USE_CUDA
+  delete reinterpret_cast<rtc::CudaModule*>(handle);
+#else
+  LOG(FATAL) << "Compile with USE_CUDA=1 to use GPU.";
+#endif
+  API_END();
+}
+
+int MXRtcCudaKernelCreate(CudaModuleHandle handle, const char* name, int 
num_args,
+                          int* is_ndarray, int* is_const, int* arg_types,
+                          CudaKernelHandle *out) {
+  API_BEGIN();
+#if MXNET_USE_CUDA
+  auto module = reinterpret_cast<rtc::CudaModule*>(handle);
+  std::vector<rtc::CudaModule::ArgType> signature;
+  for (int i = 0; i < num_args; ++i) {
+    signature.push_back(rtc::CudaModule::ArgType{
+        static_cast<bool>(is_ndarray[i]), static_cast<bool>(is_const[i]),
+        static_cast<mshadow::TypeFlag>(arg_types[i])});
+  }
+  auto kernel = module->GetKernel(name, signature);
+  *out = new std::shared_ptr<rtc::CudaModule::Kernel>(kernel);
+#else
+  LOG(FATAL) << "Compile with USE_CUDA=1 to use GPU.";
+#endif
+  API_END();
+}
+
+int MXRtcCudaKernelFree(CudaKernelHandle handle) {
+  API_BEGIN();
+#if MXNET_USE_CUDA
+  delete reinterpret_cast<std::shared_ptr<rtc::CudaModule::Kernel>*>(handle);
+#else
+  LOG(FATAL) << "Compile with USE_CUDA=1 to use GPU.";
+#endif
+  API_END();
+}
+
+int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** args,
+                        mx_uint grid_dim_x, mx_uint grid_dim_y,
+                        mx_uint grid_dim_z, mx_uint block_dim_x,
+                        mx_uint block_dim_y, mx_uint block_dim_z,
+                        mx_uint shared_mem) {
+  API_BEGIN();
+#if MXNET_USE_CUDA
+  auto kernel = 
reinterpret_cast<std::shared_ptr<rtc::CudaModule::Kernel>*>(handle);
+  const auto& signature = (*kernel)->signature();
+  std::vector<dmlc::any> any_args;
+  for (size_t i = 0; i < signature.size(); ++i) {
+    if (signature[i].is_ndarray) {
+      any_args.emplace_back(*static_cast<NDArray*>(args[i]));
+    } else {
+      MSHADOW_TYPE_SWITCH(signature[i].dtype, DType, {
+        any_args.emplace_back(*static_cast<DType*>(args[i]));
+      });
+    }
+  }
+  (*kernel)->Launch(Context::GPU(dev_id), any_args, grid_dim_x, grid_dim_y,
+                    grid_dim_z, block_dim_x, block_dim_y, block_dim_z, 
shared_mem);
+#else
+  LOG(FATAL) << "Compile with USE_CUDA=1 to use GPU.";
+#endif
+  API_END();
+}
diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h
index 0f63895..c135ff8 100644
--- a/src/common/cuda_utils.h
+++ b/src/common/cuda_utils.h
@@ -229,6 +229,40 @@ inline DType __device__ CudaMin(DType a, DType b) {
         << "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \
   }
 
+/*!
+ * \brief Protected NVRTC call.
+ * \param func Expression to call.
+ *
+ * It checks for NVRTC errors after invocation of the expression.
+ */
+#define NVRTC_CALL(x)                                   \
+  {                                                     \
+    nvrtcResult result = x;                             \
+    CHECK_EQ(result, NVRTC_SUCCESS)                     \
+      << #x " failed with error "                       \
+      << nvrtcGetErrorString(result);                   \
+  }
+
+/*!
+ * \brief Protected CUDA driver call.
+ * \param func Expression to call.
+ *
+ * It checks for CUDA driver errors after invocation of the expression.
+ */
+#define CUDA_DRIVER_CALL(func)                                          \
+  {                                                                     \
+    CUresult e = (func);                                                \
+    if (e != CUDA_SUCCESS) {                                            \
+      char const * err_msg = nullptr;                                         \
+      if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) {  \
+        LOG(FATAL) << "CUDA Driver: Unknown error " << e;               \
+      } else {                                                          \
+        LOG(FATAL) << "CUDA Driver: " << err_msg;                       \
+      }                                                                 \
+    }                                                                   \
+  }
+
+
 #if !defined(_MSC_VER)
 #define CUDA_UNROLL _Pragma("unroll")
 #define CUDA_NOUNROLL _Pragma("nounroll")
diff --git a/src/common/mxrtc.cc b/src/common/mxrtc.cc
deleted file mode 100644
index e72ac0b..0000000
--- a/src/common/mxrtc.cc
+++ /dev/null
@@ -1,159 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file mxrtc.cc
- * \brief Wrapper for NVRTC
- * \author Junyuan Xie
- */
-#include <mxnet/mxrtc.h>
-#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
-namespace mxnet {
-const char MXRtc::str_type[] = "float";
-std::unordered_map<std::string, char*> MXRtc::kernel_registry;
-
-MXRtc::MXRtc(const std::string& name,
-             std::vector<std::pair<std::string, NDArray> > const& input,
-             std::vector<std::pair<std::string, NDArray> > const& output,
-             const std::string& kernel) {
-    name_ = name;
-    num_input_ = input.size();
-    num_output_ = output.size();
-    code_ = decorate(name, input, output, kernel);
-    if (MXRtc::kernel_registry.find(code_) != MXRtc::kernel_registry.end()) {
-        ptx_ = MXRtc::kernel_registry[code_];
-    } else {
-        ptx_ = compile(name, code_);
-    }
-}
-
-void MXRtc::push(std::vector<NDArray> const& input,
-                 std::vector<NDArray> const& output,
-                 unsigned int grid_dim_X,
-                 unsigned int grid_dim_Y,
-                 unsigned int grid_dim_Z,
-                 unsigned int block_dim_X,
-                 unsigned int block_dim_Y,
-                 unsigned int block_dim_Z) {
-    CHECK_EQ(num_input_, input.size());
-    CHECK_EQ(num_output_, output.size());
-    CHECK(output.size());
-    cudaError_enum err;
-    CUfunction func;
-    int dev_id = output[0].ctx().dev_id;
-    if (func_.find(dev_id) != func_.end()) {
-        func = func_[dev_id];
-    } else {
-        CUmodule module;
-        CHECK_EQ(err = cuModuleLoadDataEx(&module, ptx_, 0, 0, 0), 
CUDA_SUCCESS)
-            << "CudaError: " << err;
-        CHECK_EQ(err = cuModuleGetFunction(&func, module, name_.c_str()), 
CUDA_SUCCESS)
-            << "CudaError: " << err;
-        module_[dev_id] = module;
-        func_[dev_id] = func;
-    }
-    auto op = [this, func, input, output,
-               grid_dim_X, grid_dim_Y, grid_dim_Z,
-               block_dim_X, block_dim_Y, block_dim_Z](RunContext rctx) {
-        std::vector<float*> float_args;
-        for (auto& i : input) 
float_args.push_back(static_cast<float*>(i.data().dptr_));
-        for (auto& i : output) 
float_args.push_back(static_cast<float*>(i.data().dptr_));
-        std::vector<void*> args;
-        for (auto& i : float_args) args.push_back(&i);
-        cudaError_enum err;
-        cudaError_t cuerr;
-        CHECK_EQ(err = cuLaunchKernel(func,
-                                grid_dim_X, grid_dim_Y, grid_dim_Z,
-                                block_dim_X, block_dim_Y, block_dim_Z,
-                                0, rctx.get_stream<mshadow::gpu>()->stream_,
-                                args.data(), 0), CUDA_SUCCESS) << "CudaError: 
" << err;
-        CHECK_EQ(cuerr = 
cudaStreamSynchronize(rctx.get_stream<mshadow::gpu>()->stream_),
-                 cudaSuccess) << "CudaError: " << cuerr;
-    };
-    std::vector<Engine::VarHandle> var_in, var_out;
-    for (auto& i : input) var_in.push_back(i.var());
-    for (auto& i : output) var_out.push_back(i.var());
-    Engine::Get()->PushSync(op, output[0].ctx(), var_in, var_out,
-            FnProperty::kNormal, 0, PROFILER_MESSAGE("MXRtc"));
-}
-
-std::string MXRtc::decorate(const std::string& name,
-                         std::vector<std::pair<std::string, NDArray> > const& 
input,
-                         std::vector<std::pair<std::string, NDArray> > const& 
output,
-                         const std::string kernel) {
-    std::string source;
-    source = source + "\nextern \"C\" __global__ void " + name + "(";
-    for (auto &i : input) {
-        source = source + "const " + str_type + "* " + i.first + ",";
-    }
-    for (auto &i : output) {
-        source = source + str_type + "* " + i.first + ",";
-    }
-    source.pop_back();
-    source = source + ") {\n";
-    for (auto &i : input) {
-        source = source + "const int " + i.first + "_ndim = " +
-                  std::to_string(i.second.shape().ndim()) + ";\n";
-        source = source + "const int " + i.first + "_dims[] = {";
-        for (index_t j = 0; j < i.second.shape().ndim(); ++j) {
-            source = source + std::to_string(i.second.shape()[j]) + ",";
-        }
-        source.pop_back();
-        source = source + "};\n";
-    }
-    for (auto &i : output) {
-        source = source + "const int " + i.first + "_ndim = " +
-                  std::to_string(i.second.shape().ndim()) + ";\n";
-        source = source + "const int " + i.first + "_dims[] = {";
-        for (index_t j = 0; j < i.second.shape().ndim(); ++j) {
-            source = source + std::to_string(i.second.shape()[j]) + ",";
-        }
-        source.pop_back();
-        source = source + "};\n";
-    }
-    source = source + kernel + "\n}\n";
-    return source;
-}
-
-char* MXRtc::compile(const std::string& name, const std::string& code) {
-    nvrtcProgram prog;
-    CHECK_EQ(nvrtcCreateProgram(&prog,
-                                code.c_str(),
-                                (name+".cu").c_str(),
-                                0,
-                                NULL,
-                                NULL), NVRTC_SUCCESS);
-    nvrtcResult compile_res = nvrtcCompileProgram(prog, 0, NULL);
-    size_t log_size;
-    CHECK_EQ(nvrtcGetProgramLogSize(prog, &log_size), NVRTC_SUCCESS);
-    char *log = new char[log_size];
-    CHECK_EQ(nvrtcGetProgramLog(prog, log), NVRTC_SUCCESS);
-    CHECK_EQ(compile_res, NVRTC_SUCCESS) << log;
-
-    size_t ptx_size;
-    CHECK_EQ(nvrtcGetPTXSize(prog, &ptx_size), NVRTC_SUCCESS);
-    char *ptx = new char[ptx_size];
-    CHECK_EQ(nvrtcGetPTX(prog, ptx), NVRTC_SUCCESS);
-    CHECK_EQ(nvrtcDestroyProgram(&prog), NVRTC_SUCCESS);
-    return ptx;
-}
-
-}  // namespace mxnet
-
-#endif  // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
diff --git a/src/common/rtc.cc b/src/common/rtc.cc
new file mode 100644
index 0000000..cd26f0e
--- /dev/null
+++ b/src/common/rtc.cc
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <mxnet/rtc.h>
+#include <typeinfo>
+
+#include "../common/cuda_utils.h"
+#include "../operator/operator_common.h"
+
+#if MXNET_USE_CUDA
+
+namespace mxnet {
+namespace rtc {
+
+CudaModule::Chunk::Chunk(
+    const char* source,
+    const std::vector<std::string>& options,
+    const std::vector<std::string>& exports) {
+  NVRTC_CALL(nvrtcCreateProgram(&prog_, source, "source.cu", 0, NULL, NULL));
+  for (const auto& i : exports) exports_.insert(i);
+#if CUDA_VERSION >= 8000
+  for (const auto& func : exports) {
+    NVRTC_CALL(nvrtcAddNameExpression(prog_, func.c_str()));
+  }
+#else
+  CHECK_EQ(exports.size(), 0)
+      << "Exporting is only supported with CUDA 8.0 and above. "
+      << "For lower version of CUDA, please prepend your kernel defintiions "
+      << "with extern \"C\" instead.";
+#endif
+  std::vector<const char*> c_options;
+  for (const auto& i : options) c_options.push_back(i.c_str());
+  nvrtcResult compile_res = nvrtcCompileProgram(prog_, c_options.size(), 
c_options.data());
+  if (compile_res != NVRTC_SUCCESS) {
+    size_t err_size;
+    NVRTC_CALL(nvrtcGetProgramLogSize(prog_, &err_size));
+    std::vector<char> err(err_size);
+    NVRTC_CALL(nvrtcGetProgramLog(prog_, err.data()));
+    LOG(FATAL) << err.data();
+  }
+
+  size_t ptx_size;
+  NVRTC_CALL(nvrtcGetPTXSize(prog_, &ptx_size));
+  ptx_ = new char[ptx_size];
+  NVRTC_CALL(nvrtcGetPTX(prog_, ptx_));
+}
+
+
+CudaModule::Chunk::~Chunk() {
+  for (const auto& kv : mod_) {
+    CUDA_DRIVER_CALL(cuModuleUnload(kv.second));
+  }
+  NVRTC_CALL(nvrtcDestroyProgram(&prog_));
+  delete ptx_;
+}
+
+
+CUfunction CudaModule::Chunk::GetFunction(
+    const std::string& mangled_name,
+    const Context& ctx) {
+  CHECK_EQ(ctx.dev_mask(), gpu::kDevMask)
+      << "CUDA Runtime compilation only supports Nvidia GPU.";
+  auto iter = mod_.find(ctx.dev_id);
+  CUmodule module;
+  if (iter != mod_.end()) {
+    module = iter->second;
+  } else {
+    CUDA_CALL(cudaSetDevice(ctx.dev_id));
+    CUDA_DRIVER_CALL(cuModuleLoadDataEx(&module, ptx_, 0, 0, 0));
+    mod_[ctx.dev_id] = module;
+  }
+  CUfunction function;
+  auto err = cuModuleGetFunction(&function, module, mangled_name.c_str());
+  if (err == CUDA_ERROR_NOT_FOUND) {
+    LOG(FATAL) << "Cannot find cuda kernel with name '" << mangled_name
+               << "'. Please either prepend kernel definition "
+               << "with 'extern \"C\"' or add its name to exports "
+               << "when creating CudaModule.";
+  }
+  CUDA_DRIVER_CALL(err);
+  return function;
+}
+
+
+std::shared_ptr<CudaModule::Kernel> CudaModule::GetKernel(
+    const std::string& name, const std::vector<ArgType>& signature) {
+  std::string mangled_name = name;
+#if CUDA_VERSION >= 8000
+  if (ptr_->exports_.count(name)) {
+    const char * c_mangled_name;
+    NVRTC_CALL(nvrtcGetLoweredName(ptr_->prog_, name.c_str(), 
&c_mangled_name));
+    mangled_name = c_mangled_name;
+  }
+#endif
+  return std::shared_ptr<Kernel>(new Kernel(ptr_, mangled_name, signature));
+}
+
+
+CudaModule::Kernel::Kernel(
+    const std::shared_ptr<CudaModule::Chunk>& mod,
+    const std::string& mangled_name,
+    const std::vector<ArgType>& signature)
+      : mangled_name_(mangled_name), signature_(signature), mod_(mod) {
+}
+
+void CudaModule::Kernel::Launch(
+    const Context& ctx, const std::vector<dmlc::any>& args,
+    uint32_t grid_dim_x, uint32_t grid_dim_y, uint32_t grid_dim_z,
+    uint32_t block_dim_x, uint32_t block_dim_y, uint32_t block_dim_z,
+    uint32_t shared_mem) {
+  CHECK_EQ(ctx.dev_mask(), gpu::kDevMask)
+      << "CUDA Runtime compilation only supports Nvidia GPU.";
+
+  auto mod = mod_;
+  auto arg_types = signature();
+
+  CUfunction function;
+  auto iter = func_.find(ctx.dev_id);
+  if (iter != func_.end()) {
+    function = iter->second;
+  } else {
+    function = mod_->GetFunction(mangled_name_, ctx);
+    func_[ctx.dev_id] = function;
+  }
+
+  std::vector<Engine::VarHandle> read_vars, write_vars;
+  for (size_t i = 0; i < arg_types.size(); ++i) {
+    if (!arg_types[i].is_ndarray) continue;
+    const auto& array = dmlc::get<NDArray>(args[i]);
+    CHECK_EQ(array.dtype(), arg_types[i].dtype)
+        << "The i-th argument is expected to be an NDArray of "
+        << op::type_string(arg_types[i].dtype) << " type, but got "
+        << op::type_string(array.dtype()) << " instead.";
+    if (arg_types[i].is_const) {
+      read_vars.emplace_back(array.var());
+    } else {
+      write_vars.emplace_back(array.var());
+    }
+  }
+
+  Engine::Get()->PushSync(
+    [function, mod, args, arg_types, grid_dim_x, grid_dim_y, grid_dim_z,
+     block_dim_x, block_dim_y, block_dim_z, shared_mem](RunContext rctx) {
+    std::vector<void*> p_args;
+    for (size_t i = 0; i < arg_types.size(); ++i) {
+      if (arg_types[i].is_ndarray) {
+        const auto& array = dmlc::get<NDArray>(args[i]);
+        
p_args.push_back(reinterpret_cast<void*>(const_cast<void**>(&array.data().dptr_)));
+      } else {
+        MSHADOW_TYPE_SWITCH(arg_types[i].dtype, DType, {
+          const auto& number = dmlc::get<DType>(args[i]);
+          p_args.push_back(const_cast<DType*>(&number));
+        });
+      }
+    }
+
+    mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
+    CUDA_DRIVER_CALL(cuLaunchKernel(
+        function, grid_dim_x, grid_dim_y, grid_dim_z,
+        block_dim_x, block_dim_y, block_dim_z,
+        shared_mem, s->stream_,
+        p_args.data(), 0));
+    CUDA_CALL(cudaStreamSynchronize(s->stream_));
+  }, ctx, read_vars, write_vars, FnProperty::kNormal, 0,
+  PROFILER_MESSAGE(mangled_name_.c_str()));
+}
+
+
+}  // namespace rtc
+}  // namespace mxnet
+
+#endif  // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
diff --git a/tests/python/gpu/test_operator_gpu.py 
b/tests/python/gpu/test_operator_gpu.py
index 745974f..ec65844 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -1390,6 +1390,35 @@ def test_gluon_ctc_consistency():
     assert_almost_equal(cpu_data.grad.asnumpy(), gpu_data.grad.asnumpy(), 
atol=1e-3, rtol=1e-3)
 
 
+def test_cuda_rtc():
+    source = r'''
+    extern "C" __global__ void axpy(const float *x, float *y, float alpha) {
+        int i = threadIdx.x + blockIdx.x * blockDim.x;
+        y[i] += alpha * x[i];
+    }
+
+    extern "C" __global__ void saxpy(const float *x, float *y, float alpha) {
+        extern __shared__ float smem[];
+        int i = threadIdx.x + blockIdx.x * blockDim.x;
+        smem[threadIdx.x] = x[i];
+        y[i] += alpha * smem[threadIdx.x];
+    }
+    '''
+    module = mx.rtc.CudaModule(source)
+    axpy = module.get_kernel("axpy", "const float *x, float *y, float alpha")
+    x = mx.nd.ones((10,), ctx=mx.gpu(0))
+    y = mx.nd.zeros((10,), ctx=mx.gpu(0))
+    axpy.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1))
+    assert (y.asnumpy() == 3).all()
+
+    saxpy = module.get_kernel("saxpy", "const float *x, float *y, float alpha")
+    saxpy.launch([x, y, 4.0], mx.gpu(0), (1, 1, 1), (10, 1, 1), 10)
+    assert (y.asnumpy() == 7).all()
+
+    saxpy.launch([x, y, 5.0], mx.gpu(0), (2, 1, 1), (5, 1, 1), 5)
+    assert (y.asnumpy() == 12).all()
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].

Reply via email to