This is an automated email from the ASF dual-hosted git repository.

samskalicky pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 8dcc653  [1.x] Backporting TensorRT-Gluon Partition API (and TensorRT 
7 support) (#18916)
8dcc653 is described below

commit 8dcc653048e1bfe2f653d50eaf630748aba6f835
Author: Serge Panev <[email protected]>
AuthorDate: Wed Aug 19 00:22:02 2020 -0700

    [1.x] Backporting TensorRT-Gluon Partition API (and TensorRT 7 support) 
(#18916)
    
    * [1.x] Backporting TensorRT and Gluon changes
    
    Signed-off-by: Serge Panev <[email protected]>
    
    * Remove test from Jenkins
    
    Signed-off-by: Serge Panev <[email protected]>
    
    * Fix test
    
    Signed-off-by: Serge Panev <[email protected]>
---
 3rdparty/onnx-tensorrt                             |   2 +-
 CMakeLists.txt                                     |   8 +-
 ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt     |   8 +-
 ci/docker/install/tensorrt.sh                      |  15 +-
 ci/docker/runtime_functions.sh                     |  29 +-
 ci/jenkins/Jenkins_steps.groovy                    |  20 +-
 ci/jenkins/Jenkinsfile_unix_gpu                    |   1 -
 example/extensions/lib_pass/test_pass.py           |  28 +-
 example/extensions/lib_subgraph/test_subgraph.py   |  25 +-
 include/mxnet/c_api.h                              |  30 ++
 perl-package/AI-MXNetCAPI/mxnet.i                  |  11 +
 python/mxnet/gluon/block.py                        | 108 +++--
 python/mxnet/symbol/symbol.py                      | 168 +++++--
 src/c_api/c_api_symbolic.cc                        | 120 +++--
 src/operator/subgraph/build_subgraph.cc            |  20 +-
 src/operator/subgraph/tensorrt/nnvm_to_onnx.cc     |   2 +-
 src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc |   8 +-
 src/operator/subgraph/tensorrt/tensorrt-inl.h      |  48 +-
 src/operator/subgraph/tensorrt/tensorrt.cu         |   6 +-
 tests/python/tensorrt/lenet5_train.py              |  99 ----
 tests/python/tensorrt/test_cvnets.py               | 174 -------
 tests/python/tensorrt/test_ops.py                  | 517 ---------------------
 tests/python/tensorrt/test_resnet18.py             |  74 ---
 tests/python/tensorrt/test_tensorrt_lenet5.py      | 121 -----
 tests/python/unittest/test_extensions.py           |   6 +-
 tests/python/unittest/test_subgraph_op.py          |  14 +-
 26 files changed, 447 insertions(+), 1215 deletions(-)

diff --git a/3rdparty/onnx-tensorrt b/3rdparty/onnx-tensorrt
index f4745fc..2eb74d9 160000
--- a/3rdparty/onnx-tensorrt
+++ b/3rdparty/onnx-tensorrt
@@ -1 +1 @@
-Subproject commit f4745fcaff868a519834917c657f105a8eef2f53
+Subproject commit 2eb74d933f89e1590fdbfc64971a36e5f72df720
diff --git a/CMakeLists.txt b/CMakeLists.txt
index f861686..7e1ef2a 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -239,6 +239,7 @@ if(USE_TENSORRT)
   include_directories(3rdparty/onnx-tensorrt/third_party/onnx/)
   add_definitions(-DMXNET_USE_TENSORRT=1)
   add_definitions(-DONNX_NAMESPACE=onnx)
+  add_definitions(-DONNX_ML=1)
 
   find_package(Protobuf REQUIRED)
 
@@ -248,14 +249,11 @@ if(USE_TENSORRT)
   find_library(ONNX_PROTO_LIBRARY NAMES libonnx_proto.so REQUIRED
           PATHS ${ONNX_PATH}
           DOC "Path to onnx_proto library.")
-  find_library(ONNX_TRT_RUNTIME_LIBRARY NAMES libnvonnxparser_runtime.so 
REQUIRED
-          PATHS ${ONNX_TRT_PATH}
-          DOC "Path to onnx_proto library.")
   find_library(ONNX_TRT_PARSER_LIBRARY NAMES libnvonnxparser.so REQUIRED
           PATHS ${ONNX_TRT_PATH}
-          DOC "Path to onnx_proto library.")
+          DOC "Path to onnx_proto parser library.")
 
-  list(APPEND mxnet_LINKER_LIBS libnvinfer.so ${ONNX_TRT_PARSER_LIBRARY} 
${ONNX_TRT_RUNTIME_LIBRARY}
+  list(APPEND mxnet_LINKER_LIBS libnvinfer.so ${ONNX_TRT_PARSER_LIBRARY}
           ${ONNX_PROTO_LIBRARY} ${ONNX_LIBRARY} ${PROTOBUF_LIBRARY})
 endif()
 
diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt 
b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
index 90bd772..9556fee 100644
--- a/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
+++ b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
@@ -18,7 +18,7 @@
 #
 # Dockerfile to run MXNet on Ubuntu 16.04 for CPU
 
-FROM nvidia/cuda:10.0-devel
+FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04
 
 WORKDIR /work/deps
 
@@ -36,12 +36,8 @@ ARG USER_ID=0
 COPY install/ubuntu_adduser.sh /work/
 RUN /work/ubuntu_adduser.sh
 
-ENV CUDNN_VERSION=7.5.0.56
-COPY install/ubuntu_cudnn.sh /work/
-RUN /work/ubuntu_cudnn.sh
-
 COPY runtime_functions.sh /work/
 
 WORKDIR /work/mxnet
 ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib
-ENV 
CPLUS_INCLUDE_PATH=${CPLUS_INCLUDE_PATH}:/usr/local/cuda-10.0/targets/x86_64-linux/include/
+ENV 
CPLUS_INCLUDE_PATH=${CPLUS_INCLUDE_PATH}:/usr/local/cuda-10.2/targets/x86_64-linux/include/
diff --git a/ci/docker/install/tensorrt.sh b/ci/docker/install/tensorrt.sh
index e98c764..29d8ad1 100755
--- a/ci/docker/install/tensorrt.sh
+++ b/ci/docker/install/tensorrt.sh
@@ -18,7 +18,7 @@
 # under the License.
 
 # Install gluoncv since we're testing Gluon models as well
-pip3 install gluoncv==0.2.0
+pip3 install gluoncv==0.4.0
 
 # Install Protobuf
 # Install protoc 3.5 and build protobuf here (for onnx and onnx-tensorrt)
@@ -40,10 +40,11 @@ popd
 
 # Install TensorRT
 echo "TensorRT build enabled. Installing TensorRT."
-wget -qO tensorrt.deb 
https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvidia-machine-learning-repo-ubuntu1604_1.0.0-1_amd64.deb
-dpkg -i tensorrt.deb
 apt-get update
-apt-get install -y --allow-downgrades libnvinfer5=5.1.5-1+cuda10.0
-apt-get install -y --allow-downgrades libnvinfer-dev=5.1.5-1+cuda10.0
-apt-mark hold libnvinfer5 libnvinfer-dev
-rm tensorrt.deb
+TRT_VERSION="7.0.0-1+cuda10.2"
+TRT_MAJOR_VERSION=7
+apt-get install -y libnvinfer${TRT_MAJOR_VERSION}=${TRT_VERSION} \
+                   libnvinfer-dev=${TRT_VERSION} \
+                   libnvinfer-plugin${TRT_MAJOR_VERSION}=${TRT_VERSION} \
+                   libnvinfer-plugin-dev=${TRT_VERSION}
+apt-mark hold libnvinfer${TRT_MAJOR_VERSION} libnvinfer-dev
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 4523e1f..b5cbb9a 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -708,6 +708,8 @@ build_ubuntu_gpu_tensorrt() {
 
     build_ccache_wrappers
 
+    export ONNX_NAMESPACE=onnx
+
     # Build ONNX
     pushd .
     echo "Installing ONNX."
@@ -715,14 +717,11 @@ build_ubuntu_gpu_tensorrt() {
     rm -rf build
     mkdir -p build
     cd build
-    cmake \
-        -DCMAKE_CXX_FLAGS=-I/usr/include/python${PYVER}\
-        -DBUILD_SHARED_LIBS=ON ..\
-        -G Ninja
-    ninja -j 1 -v onnx/onnx.proto
-    ninja -j 1 -v
+    cmake -DCMAKE_CXX_FLAGS=-I/usr/include/python${PYVER} 
-DBUILD_SHARED_LIBS=ON ..
+    make -j$(nproc)
     export LIBRARY_PATH=`pwd`:`pwd`/onnx/:$LIBRARY_PATH
     export CPLUS_INCLUDE_PATH=`pwd`:$CPLUS_INCLUDE_PATH
+    export CXXFLAGS=-I`pwd`
     popd
 
     # Build ONNX-TensorRT
@@ -730,15 +729,14 @@ build_ubuntu_gpu_tensorrt() {
     cd 3rdparty/onnx-tensorrt/
     mkdir -p build
     cd build
-    cmake ..
+    cmake -DONNX_NAMESPACE=$ONNX_NAMESPACE ..
     make -j$(nproc)
     export LIBRARY_PATH=`pwd`:$LIBRARY_PATH
     popd
 
     mkdir -p /work/mxnet/lib/
     cp 3rdparty/onnx-tensorrt/third_party/onnx/build/*.so /work/mxnet/lib/
-    cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser_runtime.so.0 
/work/mxnet/lib/
-    cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser.so.0 /work/mxnet/lib/
+    cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser.so* /work/mxnet/lib/
 
     cd /work/build
     cmake -DUSE_CUDA=1                            \
@@ -1071,19 +1069,6 @@ unittest_ubuntu_python3_gpu_nocudnn() {
     nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS $NOSE_TIMER_ARGUMENTS --with-xunit 
--xunit-file nosetests_gpu.xml --verbose tests/python/gpu
 }
 
-unittest_ubuntu_tensorrt_gpu() {
-    set -ex
-    export PYTHONPATH=./python/
-    export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
-    export MXNET_SUBGRAPH_VERBOSE=0
-    export LD_LIBRARY_PATH=/work/mxnet/lib:$LD_LIBRARY_PATH
-    export CUDNN_VERSION=${CUDNN_VERSION:-7.0.3}
-    export MXNET_ENABLE_CYTHON=0
-    export DMLC_LOG_STACK_TRACE_DEPTH=10
-    tests/python/tensorrt/lenet5_train.py
-    nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS $NOSE_TIMER_ARGUMENTS --with-xunit 
--xunit-file nosetests_trt_gpu.xml --verbose --nocapture tests/python/tensorrt/
-}
-
 # quantization gpu currently only runs on P3 instances
 # need to separte it from unittest_ubuntu_python3_gpu()
 unittest_ubuntu_python3_quantization_gpu() {
diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy
index c4fd96e..1cc91e4 100644
--- a/ci/jenkins/Jenkins_steps.groovy
+++ b/ci/jenkins/Jenkins_steps.groovy
@@ -34,7 +34,7 @@ mx_cmake_lib_cython = 'build/libmxnet.so, build/libmxnet.a, 
build/3rdparty/tvm/l
 mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, 
build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, 
build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, 
build/libsubgraph_lib.so, build/3rdparty/dmlc-core/libdmlc.a, 
build/tests/mxnet_unit_tests'
 mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, 
build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, 
build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, 
build/3rdparty/openmp/runtime/src/libomp.so'
 mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, 
lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, 
build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 
3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
-mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, 
build/libtvmop.so, build/tvmop.conf, lib/libnvonnxparser_runtime.so.0, 
lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so'
+mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, 
build/libtvmop.so, build/tvmop.conf, lib/libnvonnxparser.so*, 
lib/libonnx_proto.so, lib/libonnx.so'
 mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, 
lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, 
build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 
3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 
3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, 
build/cpp-package/example/*, python/mxnet/_cy3/*.so, 
python/mxnet/_ffi/_cy3/*.so'
 mx_lib_cpp_capi = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, 
lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, lib/libmkldnn.so.1, 
lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 
3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, 
deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, 
python/mxnet/_cy3/*.so, python/mxnet/_ffi/_cy3/*.so, 
build/tests/cpp/mxnet_unit_tests'
 mx_lib_cpp_examples_no_tvm_op = 'lib/libmxnet.so, lib/libmxnet.a, 
build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, 
build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 
3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, 
deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, 
python/mxnet/_cy3/*.so, python/mxnet/_ffi/_cy3/*.so'
@@ -853,24 +853,6 @@ def test_unix_python3_mkldnn_nocudnn_gpu() {
     }]
 }
 
-def test_unix_python3_tensorrt_gpu() {
-    return ['Python3: TensorRT GPU': {
-      node(NODE_LINUX_GPU_P3) {
-        ws('workspace/build-tensorrt') {
-          timeout(time: max_time, unit: 'MINUTES') {
-            try {
-              utils.unpack_and_init('tensorrt', mx_tensorrt_lib)
-              utils.docker_run('ubuntu_gpu_tensorrt', 
'unittest_ubuntu_tensorrt_gpu', true)
-              utils.publish_test_coverage()
-            } finally {
-              utils.collect_test_results_unix('nosetests_tensorrt.xml', 
'nosetests_python3_tensorrt_gpu.xml')
-            }
-          }
-        }
-      }
-    }]
-}
-
 def test_unix_python3_integration_gpu() {
     return ['Python Integration GPU': {
       node(NODE_LINUX_GPU_G4) {
diff --git a/ci/jenkins/Jenkinsfile_unix_gpu b/ci/jenkins/Jenkinsfile_unix_gpu
index 5e26a9f..f219440 100644
--- a/ci/jenkins/Jenkinsfile_unix_gpu
+++ b/ci/jenkins/Jenkinsfile_unix_gpu
@@ -50,7 +50,6 @@ core_logic: {
     custom_steps.test_unix_python3_quantize_gpu(),
     custom_steps.test_unix_python3_mkldnn_gpu(),
     custom_steps.test_unix_python3_mkldnn_nocudnn_gpu(),
-    custom_steps.test_unix_python3_tensorrt_gpu(),
     custom_steps.test_unix_perl_gpu(),
     custom_steps.test_unix_r_gpu(),
     custom_steps.test_unix_cpp_gpu(),
diff --git a/example/extensions/lib_pass/test_pass.py 
b/example/extensions/lib_pass/test_pass.py
index 01d6edd..66411a6 100644
--- a/example/extensions/lib_pass/test_pass.py
+++ b/example/extensions/lib_pass/test_pass.py
@@ -48,34 +48,16 @@ d = mx.sym.exp(c)
 sym = mx.sym.log(d)
 
 def test_model(pass_name):
+    args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}
     # execute in MXNet
     print('-------------------------------')
     print('Testing regular MXNet execution')
-
-    exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 
'b':mx.nd.ones((3,2))})
-    out = exe.forward()
+    inputs = [a,b]
+    sym_block = nn.SymbolBlock(sym, inputs)
+    sym_block.initialize()
+    out = sym_block(mx.nd.ones((3,2)),mx.nd.ones((3,2)))
     print(out)
 
-    # Symbol optimize_for
-    # with propogating shapes/types
-    print('-------------------------------')
-    print('Testing pass "%s" with shapes/types' % pass_name)
-    arg_array = [mx.nd.ones((3,2),dtype='float32'), 
mx.nd.ones((3,2),dtype='float32')]
-    aux = []
-    mysym2 = sym.optimize_for(pass_name,arg_array,aux)
-    print(mysym2.tojson())
-    exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 
'b':mx.nd.ones((3,2))})
-    out2 = exe2.forward()
-    print(out2)
-
-    # without propogating shapes/types
-    print('-------------------------------')
-    print('Testing pass "%s" without shapes/types' % pass_name)
-    mysym3 = sym.optimize_for(pass_name, myOpt='yello')
-    exe3 = mysym3.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 
'b':mx.nd.ones((3,2))})
-    out3 = exe3.forward()
-    print(out3)
-
     # Gluon Hybridize
     print('-------------------------------')
     print('Testing pass "%s" Gluon Hybridize with shapes/types' % pass_name)
diff --git a/example/extensions/lib_subgraph/test_subgraph.py 
b/example/extensions/lib_subgraph/test_subgraph.py
index 267a417..5294e1c 100644
--- a/example/extensions/lib_subgraph/test_subgraph.py
+++ b/example/extensions/lib_subgraph/test_subgraph.py
@@ -49,32 +49,31 @@ d2 = mx.sym.exp(a)
 sym2 = mx.sym.log(d2)
 
 def test(backend):
+    args = {'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}
     ###############################################
     # Test with subgraph not consuming params
     ###############################################
     #execute in MXNet
     print('-------------------------------')
     print('Testing regular MXNet execution')
-    exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 
'b':mx.nd.ones((3,2))})
+    exe = sym.bind(ctx=mx.cpu(), args=args)
     out = exe.forward()
     print(out)
 
     # with propogating shapes/types
     print('-------------------------------')
     print('Testing %s partitioning with shapes/types' % backend)
-    arg_array = [mx.nd.ones((3,2),dtype='float32'), 
mx.nd.ones((3,2),dtype='float32')]
-    mysym2 = sym.optimize_for(backend,arg_array)
+    mysym2 = sym.optimize_for(backend,args)
     print(mysym2.tojson())
-    exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 
'b':mx.nd.ones((3,2))})
+    exe2 = mysym2.bind(ctx=mx.cpu(), args=args)
     out2 = exe2.forward()
     print(out2)
 
     # with propogating shapes/types, rejecting subgraph
     print('-------------------------------')
     print('Testing %s partitioning with shapes/types - rejecting subgraph' % 
backend)
-    arg_array = [mx.nd.ones((3,2),dtype='float32'), 
mx.nd.ones((3,2),dtype='float32')]
-    mysym2 = sym.optimize_for(backend, arg_array, reject=True)
-    exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 
'b':mx.nd.ones((3,2))})
+    mysym2 = sym.optimize_for(backend, args, reject=True)
+    exe2 = mysym2.bind(ctx=mx.cpu(), args=args)
     out2 = exe2.forward()
     print(out2)
 
@@ -82,7 +81,7 @@ def test(backend):
     print('-------------------------------')
     print('Testing %s partitioning without shapes/types' % backend)
     mysym3 = sym.optimize_for(backend, myOpt='yello')
-    exe3 = mysym3.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 
'b':mx.nd.ones((3,2))})
+    exe3 = mysym3.bind(ctx=mx.cpu(), args=args)
     out3 = exe3.forward()
     print(out3)
 
@@ -115,20 +114,20 @@ def test(backend):
     ###############################################
     # Test with subgraph directly consuming params
     ###############################################
+    args = {'a':mx.nd.ones((3,2))}
     #execute in MXNet
     print('-------------------------------')
     print('Testing regular MXNet execution')
-    exe5 = sym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))})
+    exe5 = sym2.bind(ctx=mx.cpu(), args=args)
     out5 = exe5.forward()
     print(out5)
 
     # with propogating shapes/types
     print('-------------------------------')
     print('Testing %s partitioning with shapes/types' % backend)
-    arg_array = [mx.nd.ones((3,2),dtype='float32')]
-    mysym6 = sym2.optimize_for(backend, arg_array, reqArgs=True)
+    mysym6 = sym2.optimize_for(backend, args, reqArgs=True)
     print(mysym6.tojson())
-    exe6 = mysym6.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))})
+    exe6 = mysym6.bind(ctx=mx.cpu(), args=args)
     out6 = exe6.forward()
     print(out6)
 
@@ -136,7 +135,7 @@ def test(backend):
     print('-------------------------------')
     print('Testing %s partitioning without shapes/types' % backend)
     mysym7 = sym2.optimize_for(backend, reqArgs=True)
-    exe7 = mysym7.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))})
+    exe7 = mysym7.bind(ctx=mx.cpu(), args=args)
     out7 = exe7.forward()
     print(out7)
 
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index cfb2400..98a7a70 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -2166,6 +2166,25 @@ MXNET_DLL int MXGenAtomicSymbolFromSymbol(SymbolHandle 
sym_handle, SymbolHandle
  * \param num_options number of key value pairs
  * \param keys keys for options
  * \param vals values corresponding to keys
+ * \param num_input_shapes number of input shapes
+ * \param input_shape_names names of the input shapes
+ * \param input_shape_data pointer to the contiguous data shapes
+ * \param input_shape_idx array of per shape starting idx, the shape length 
for the i-th input shape
+ * is calculate as input_shape_idx[i+1] - input_shape_idx[i]
+ * \param num_input_dtypes number of input data types
+ * \param input_dtype_names array of names of the input data types
+ * \param input_dtypes array of values of the input data types
+ * \param num_input_stypesnumber of input storage types
+ * \param input_stype_names array of names of the input storage types
+ * \param input_stypes array of values of input storage types
+ * \param skip_infer if the optimization should skip the attribute inferences
+ * (to use if the backend does not require shape inference)
+ * \param new_args_cnt pointer a number to store the number of new args
+ * \param new_args_handle pointer on array to store the new args handles
+ * \param new_arg_names_handle pointer on array to store the new args names
+ * \param new_aux_cnt pointer a number to store the number of new aux
+ * \param new_aux_handle pointer on array to store the new aux handles
+ * \param new_aux_names_handle pointer on array to store the new aux names
  */
 MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle,
                                    const char* backend_name,
@@ -2178,6 +2197,17 @@ MXNET_DLL int MXOptimizeForBackend(SymbolHandle 
sym_handle,
                                    const mx_uint num_options,
                                    const char** keys,
                                    const char** vals,
+                                   const uint32_t num_input_shapes,
+                                   const char** input_shape_names,
+                                   const int64_t* input_shape_data,
+                                   const uint32_t* input_shape_idx,
+                                   const uint32_t num_input_dtypes,
+                                   const char** input_dtype_names,
+                                   const int* input_dtypes,
+                                   const uint32_t num_input_stypes,
+                                   const char** input_stype_names,
+                                   const int* input_stypes,
+                                   bool skip_infer,
                                    int* new_args_cnt,
                                    NDArrayHandle** new_args_handle,
                                    char*** new_arg_names_handle,
diff --git a/perl-package/AI-MXNetCAPI/mxnet.i 
b/perl-package/AI-MXNetCAPI/mxnet.i
index 9602b08..59346ef 100644
--- a/perl-package/AI-MXNetCAPI/mxnet.i
+++ b/perl-package/AI-MXNetCAPI/mxnet.i
@@ -1637,6 +1637,17 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
                          const mx_uint in,
                          const char** keys,
                          const char** vals,
+                         const uint32_t num_input_shapes,
+                         const char** input_shape_names,
+                         const int64_t* input_shape_data,
+                         const uint32_t* input_shape_idx,
+                         const uint32_t num_input_dtypes,
+                         const char** input_dtype_names,
+                         const int* input_dtypes,
+                         const uint32_t num_input_stypes,
+                         const char** input_stype_names,
+                         const int* input_stypes,
+                         bool skip_infer,
                          int* new_args_cnt,
                          NDArrayHandle** new_args_handle,
                          char*** new_arg_names_handle,
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index d7afd8a..9772e23 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -949,41 +949,70 @@ class HybridBlock(Block):
             warnings.warn("Parameter %s is not used by any computation. "
                           "Is this intended?"%unused, stacklevel=4)
 
-        data_indices = []
-        param_indices = []
-        self._cached_op_args = []
-        for i, name in enumerate(input_names):
-            if name in data_names:
-                data_indices.append(i)
-                self._cached_op_args.append((True, data_names[name]))
-            else:
-                param_indices.append(i)
-                self._cached_op_args.append((False, params[name]))
-        flags = [('data_indices', data_indices), ('param_indices', 
param_indices)] + \
-                self._flags
-
         args, _ = _flatten(args, "input")
         try:
-            for is_arg, i in self._cached_op_args:
-                if not is_arg:
-                    i.data()
+            for name in input_names:
+                if name in params:
+                    params[name].data()
         except DeferredInitializationError:
             self._deferred_infer_shape(*args)
-            for is_arg, i in self._cached_op_args:
-                if not is_arg:
-                    i._finish_deferred_init()
+            for name in input_names:
+                if name in params:
+                    params[name]._finish_deferred_init()
 
+        arg_dict, aux_dict = dict(), dict()
         if self._backend:
             ctx = args[0].context
             # get list of params in the order of out.list_arguments
-            arg_array = [args[data_names[name]] if name in data_names.keys() 
else params[name].data()
-                         for name in out.list_arguments()]
-            aux_array = [args[data_names[name]] if name in data_names.keys() 
else params[name].data()
-                         for name in out.list_auxiliary_states()]
+            arg_dict.update({name:args[data_names[name]] if name in 
data_names.keys() else params[name].data()
+                             for name in out.list_arguments()})
+            aux_dict.update({name:args[data_names[name]] if name in 
data_names.keys() else params[name].data()
+                             for name in out.list_auxiliary_states()})
             # Partition the graph.
-            out = out.optimize_for(self._backend, arg_array, aux_array, ctx, 
**self._backend_opts)
+            out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, 
**self._backend_opts)
+
             #update cached graph with partitioned graph
             self._cached_graph = data, out
+
+        input_names = out.list_inputs()
+        data_indices = []
+        param_indices = []
+
+        # In the default case, _cached_ops_args contains all the parameters 
from params (the sets are identical)
+        # In the case of Partition API optimized graph _cached_ops_args might 
contain some parameters from params,
+        # might contain some new parameters created during optimization and 
added to `arg_dict/aux_dict`,
+        # and might not contain some parameters that were deleted during 
optimization.
+        self._cached_op_args = []
+        for i, name in enumerate(input_names):
+            pair = None
+            if name in data_names:
+                data_indices.append(i)
+                pair = (True, data_names[name])
+            else:
+                param_indices.append(i)
+                if name in params:
+                    param = params[name]
+                else:
+                    # The param is missing from the original params 
dictionary, which means the param must have
+                    # been added by the Partition API backend
+                    if name in arg_dict or name:
+                        param_data = arg_dict[name]
+                    elif name in aux_dict:
+                        param_data = aux_dict[name]
+                    else:
+                        raise RuntimeError('A parameter was added to the graph 
during optimization but it was not '
+                                           'added to the parameter dicts.\n'
+                                           'Please check the backend.')
+
+                    param = Parameter(name)
+                    param._load_init(param_data, args[0].context)
+                pair = (False, param)
+
+            self._cached_op_args.append(pair)
+
+        flags = [('data_indices', data_indices), ('param_indices', 
param_indices)] + \
+                self._flags
+
         self._cached_op = ndarray.CachedOp(out, flags)
 
 
@@ -1203,12 +1232,14 @@ class HybridBlock(Block):
         arg_names = set(sym.list_arguments())
         aux_names = set(sym.list_auxiliary_states())
         arg_dict = {}
-        for name, param in self.collect_params().items():
-            if name in arg_names:
-                arg_dict['arg:%s'%name] = param._reduce()
-            else:
-                assert name in aux_names
-                arg_dict['aux:%s'%name] = param._reduce()
+        for is_arg, param in self._cached_op_args:
+            if not is_arg:
+                name = param.name
+                if name in arg_names:
+                    arg_dict['arg:{}'.format(name)] = param._reduce()
+                else:
+                    assert name in aux_names
+                    arg_dict['aux:{}'.format(name)] = param._reduce()
         save_fn = _mx_npx.save if is_np_array() else ndarray.save
         save_fn('%s-%04d.params'%(path, epoch), arg_dict)
 
@@ -1479,6 +1510,23 @@ class SymbolBlock(HybridBlock):
     def hybrid_forward(self, F, x, *args, **kwargs):
         raise NotImplementedError
 
+    def reset_ctx(self, ctx):
+        """Re-assign all Parameters to other contexts. If the Block is 
hybridized, it will reset the _cached_op_args.
+        Parameters
+        ----------
+        ctx : Context or list of Context, default 
:py:meth:`context.current_context()`.
+            Assign Parameter to given context. If ctx is a list of Context, a
+            copy will be made for each context.
+        """
+        params = self.collect_params()
+        if self._cached_op:
+            for p in self._cached_op_args:
+                # resetting parameters creating by the partitioning backend
+                if p.name not in params:
+                    p.reset_ctx(ctx)
+        for p in params.values():
+            p.reset_ctx(ctx)
+
 def _infer_param_types(in_params, out_params, arg_params, aux_params, 
default_dtype=mx_real_t):
     """Utility function that helps in inferring DType of args and auxs params
     from given input param.
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index e90fb9b..0f8cccd 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -1446,7 +1446,8 @@ class Symbol(SymbolBase):
 
 
     # pylint: disable=too-many-locals
-    def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
+    def optimize_for(self, backend, args=None, aux=None, ctx=None,
+                     shape_dict=None, type_dict=None, stype_dict=None, 
skip_infer=False, **kwargs):
         """Partitions current symbol and optimizes it for a given backend,
         returns new partitioned symbol.
 
@@ -1455,23 +1456,35 @@ class Symbol(SymbolBase):
         backend : str
             The name of backend, as registered in `SubgraphBackendRegistry`
 
-        args : list of NDArray or dict of str to NDArray, optional
+        args : dict of str to NDArray, optional
             Input arguments to the symbol, required to infer shapes/types 
before partitioning
-
-            - If type is a list of `NDArray`, the order is the same as that of 
`list_arguments()`.
             - If type is a dict of str to `NDArray`, then it maps the name of 
arguments
-              to the corresponding `NDArray`.
+              to the corresponding `NDArray`. Non defined arguments' 
`NDArray`s don't have to be
+              specified in the dict.
 
-        aux : list of NDArray or dict of str to NDArray, optional
+        aux : dict of str to NDArray, optional
             Input auxiliary arguments to the symbol
-
-            - If type is a list of `NDArray`, the order is the same as that of 
`list_arguments()`.
             - If type is a dict of str to `NDArray`, then it maps the name of 
arguments
               to the corresponding `NDArray`.
 
         ctx : Context, optional
             Device context, used to infer stypes
 
+        shape_dict  : Dict of str->tuple, optional
+            Input shape dictionary.
+            Used iff input NDArray is not in `args`.
+
+        type_dict  : Dict of str->numpy.dtype, optional
+            Input type dictionary.
+            Used iff input NDArray is not in `args`.
+
+        stype_dict  : Dict of str->str, optional
+            Input storage type dictionary.
+            Used iff input NDArray is not in `args`.
+
+        skip_infer : bool, optional
+            If True, the optimization skips the shape, type and storage type 
inference pass.
+
         kwargs : optional arguments
             Passed on to `PrePartition` and `PostPartition` functions of 
`SubgraphProperty`
 
@@ -1482,24 +1495,86 @@ class Symbol(SymbolBase):
         """
         out = SymbolHandle()
         assert isinstance(backend, str)
+        assert isinstance(args, dict) or args is None
+        assert isinstance(aux, dict) or aux is None
 
         if args is None or len(args) == 0:
             args_ = []
             args_handle = c_array(NDArrayHandle, [])
         else:
             args_handle, args_ = self._get_ndarray_inputs('args', args,
-                                                          
self.list_arguments(), False)
+                                                          
self.list_arguments(), True)
 
         if aux is None or len(aux) == 0:
             aux_ = []
             aux_handle = c_array(NDArrayHandle, [])
         else:
             aux_handle, aux_ = self._get_ndarray_inputs('aux_states', aux,
-                                                        
self.list_auxiliary_states(), False)
+                                                        
self.list_auxiliary_states(), True)
         if ctx is None:
             ctx = current_context()
         assert isinstance(ctx, Context)
 
+
+        # parse input data shape dict
+        num_input_shapes = 0
+        input_shape_names = ctypes.POINTER(ctypes.c_char_p)()
+        input_shape_data = ctypes.POINTER(mx_int64)()
+        input_shape_idx = ctypes.POINTER(mx_uint)()
+        if shape_dict is not None:
+            input_shape_names = []
+            input_shape_data = []
+            input_shape_idx = [0]
+            for k, v in shape_dict.items():
+                if isinstance(v, (tuple, list)):
+                    input_shape_names.append(k)
+                    input_shape_data.extend(v)
+                    input_shape_idx.append(len(input_shape_data))
+                else:
+                    raise ValueError(str(v) + " has to be a tuple or list.")
+            num_input_shapes = mx_uint(len(input_shape_names))
+            input_shape_names = c_str_array(input_shape_names)
+            input_shape_data = c_array_buf(mx_int64, array('q', 
input_shape_data))
+            input_shape_idx = c_array_buf(mx_uint, array('i', input_shape_idx))
+
+        # parse input data types dict
+        num_input_types = 0
+        input_type_names = ctypes.POINTER(ctypes.c_char_p)()  # provided type 
argument names
+        input_type_data = ctypes.POINTER(mx_uint)()  # provided types
+        if type_dict is not None:
+            input_type_names = []
+            input_type_data = []
+            for k, v in type_dict.items():
+                v = _numpy.dtype(v).type
+                if v in _DTYPE_NP_TO_MX:
+                    input_type_names.append(k)
+                    input_type_data.append(_DTYPE_NP_TO_MX[v])
+                else:
+                    raise ValueError(str(v) + " is not a MXNet type.")
+
+            num_input_types = mx_uint(len(input_type_names))
+            input_type_names = c_str_array(input_type_names)
+            input_type_data = c_array_buf(ctypes.c_int, array('i', 
input_type_data))
+
+        # parse input data storage types dict
+        num_input_stypes = 0
+        # provided storage type argument names
+        input_stype_names = ctypes.POINTER(ctypes.c_char_p)()
+        input_stype_data = ctypes.POINTER(mx_uint)()  # provided storage types
+        if stype_dict is not None:
+            input_stype_names = []
+            input_stype_data = []
+            for k, v in stype_dict.items():
+                if v in _STORAGE_TYPE_STR_TO_ID:
+                    input_stype_names.append(k)
+                    input_stype_data.append(_STORAGE_TYPE_STR_TO_ID[v])
+                else:
+                    raise ValueError(str(v) + " is not a MXNet storage type.")
+
+            num_input_stypes = mx_uint(len(input_stype_names))
+            input_stype_names = c_str_array(input_stype_names)
+            input_stype_data = c_array_buf(ctypes.c_int, array('i', 
input_stype_data))
+
         new_args_size = ctypes.c_uint()
         new_arg_names = ctypes.POINTER(ctypes.c_char_p)()
         new_args_handle = ctypes.POINTER(NDArrayHandle)()
@@ -1523,37 +1598,68 @@ class Symbol(SymbolBase):
                                              mx_uint(len(key_list)),
                                              c_str_array(key_list),
                                              c_str_array(val_list),
+                                             num_input_shapes,
+                                             input_shape_names,
+                                             input_shape_data,
+                                             input_shape_idx,
+                                             num_input_types,
+                                             input_type_names,
+                                             input_type_data,
+                                             num_input_stypes,
+                                             input_stype_names,
+                                             input_stype_data,
+                                             ctypes.c_bool(skip_infer),
                                              ctypes.byref(new_args_size),
                                              ctypes.byref(new_args_handle),
                                              ctypes.byref(new_arg_names),
                                              ctypes.byref(new_aux_size),
                                              ctypes.byref(new_aux_handle),
                                              ctypes.byref(new_aux_names)))
-        arg_names = self.list_arguments()
-        if isinstance(args, dict):
+        # add new args/aux
+        if not args is None:
             for i in range(new_args_size.value):
                 args[py_str(new_arg_names[i])] = 
NDArray(NDArrayHandle(new_args_handle[i]))
-        elif isinstance(args, list):
-            for i in range(new_args_size.value):
-                name = py_str(new_arg_names[i])
-                if name in arg_names:
-                    idx = arg_names.index(name)
-                    args[idx] = NDArray(NDArrayHandle(new_args_handle[i]))
-                else:
-                    args.append(NDArray(NDArrayHandle(new_args_handle[i])))
-        aux_names = self.list_auxiliary_states()
-        if isinstance(aux, dict):
+        elif new_args_size.value > 0:
+            raise RuntimeError('Cannot add new args in optimize_for since args 
is None\n' +
+                               'Provide a dictionary to the args argument to 
optimize_for')
+
+        if not aux is None:
             for i in range(new_aux_size.value):
                 aux[py_str(new_aux_names[i])] = 
NDArray(NDArrayHandle(new_aux_handle[i]))
-        elif isinstance(aux, list):
-            for i in range(new_aux_size.value):
-                name = py_str(new_aux_names[i])
-                if name in aux_names:
-                    idx = aux_names.index(name)
-                    aux[idx] = NDArray(NDArrayHandle(new_aux_handle[i]))
-                else:
-                    aux.append(NDArray(NDArrayHandle(new_aux_handle[i])))
-        return Symbol(out)
+        elif new_aux_size.value > 0:
+            raise RuntimeError('Cannot add new aux in optimize_for since aux 
is None\n' +
+                               'Provide a dictionary to the aux argument to 
optimize_for')
+
+        new_sym = Symbol(out)
+
+        arg_names = self.list_arguments()
+        new_arg_names = new_sym.list_arguments()
+        deleted_arg_names = set([item for item in arg_names
+                                 if item not in set(new_arg_names)])
+
+        if len(deleted_arg_names) > 0:
+            if args is not None:
+                for a_n in deleted_arg_names:
+                    if a_n in args:
+                        args.pop(a_n)
+            else:
+                warnings.warn('A param was deleted during optimization, but no 
args dictionary was provided.\n' +
+                              'Please ensure that your model weights match the 
newly optimized model.')
+
+        aux_names = self.list_auxiliary_states()
+        new_aux_names = new_sym.list_auxiliary_states()
+        deleted_aux_names = set([item for item in aux_names
+                                 if item not in set(new_aux_names)])
+        if len(deleted_aux_names) > 0:
+            if aux is not None:
+                for a_n in deleted_aux_names:
+                    if a_n in aux:
+                        aux.pop(a_n)
+            else:
+                warnings.warn('A param was deleted during optimization, but no 
args dictionary was provided.\n' +
+                              'Please ensure that your model weights match the 
newly optimized model.')
+
+        return new_sym
 
 
     # pylint: disable=too-many-locals
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index 3b3d83c..29a773d 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -1350,6 +1350,17 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
                          const mx_uint num_options,
                          const char** keys,
                          const char** vals,
+                         const uint32_t num_input_shapes,
+                         const char** input_shape_names,
+                         const int64_t* input_shape_data,
+                         const uint32_t* input_shape_idx,
+                         const uint32_t num_input_dtypes,
+                         const char** input_dtype_names,
+                         const int* input_dtypes,
+                         const uint32_t num_input_stypes,
+                         const char** input_stype_names,
+                         const int* input_stypes,
+                         bool skip_infer,
                          int* new_args_cnt,
                          NDArrayHandle** new_args_handle,
                          char*** new_arg_names_handle,
@@ -1373,47 +1384,80 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
   if (args_len || aux_len) {
     NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
     NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
-    Context default_ctx = 
Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
-    mxnet::ShapeVector arg_shapes(args_len + aux_len);
-    nnvm::DTypeVector arg_dtypes(args_len + aux_len);
-    StorageTypeVector arg_stypes(args_len + aux_len);
-    size_t args_top = 0, aux_top = 0;
-    // loop over inputs to symbol in order and add to args/aux if mutable
-    for (size_t i = 0; i < num_forward_inputs; ++i) {
-      const uint32_t nid = indexed_graph.input_nodes().at(i);
-      if (mutable_nodes.count(nid)) {
-        CHECK_LT(aux_top, aux_len)
-          << "Cannot find aux '" << input_names[i] << "' in provided aux to 
optimize_for";
-        const auto &in_arg = *(in_aux_ptr[aux_top++]);
-        arg_shapes[i] = in_arg.shape();
-        arg_dtypes[i] = in_arg.dtype();
-        arg_stypes[i] = in_arg.storage_type();
-      } else {
-        CHECK_LT(args_top, args_len)
-          << "Cannot find arg '" << input_names[i] << "' in provided args to 
optimize_for";
-        const auto &in_arg = *(in_args_ptr[args_top++]);
-        arg_shapes[i] = in_arg.shape();
-        arg_dtypes[i] = in_arg.dtype();
-        arg_stypes[i] = in_arg.storage_type();
+    if (!skip_infer) {
+      Context default_ctx = 
Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
+      mxnet::ShapeVector arg_shapes(args_len + aux_len);
+      nnvm::DTypeVector arg_dtypes(args_len + aux_len);
+      StorageTypeVector arg_stypes(args_len + aux_len);
+
+      // create the input shape, dtype and stype maps
+      std::unordered_map<std::string, mxnet::TShape> 
input_shape_map(num_input_shapes);
+      for (uint32_t i = 0; i < num_input_shapes; ++i) {
+        input_shape_map.emplace(input_shape_names[i],
+                    mxnet::TShape(input_shape_data + input_shape_idx[i],
+                    input_shape_data + input_shape_idx[i+1]));
+      }
+      std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
+      for (uint32_t i = 0; i < num_input_dtypes; ++i) {
+        input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
+      }
+      std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
+      for (uint32_t i = 0; i < num_input_stypes; ++i) {
+        input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
       }
-    }
 
-    g.attrs["context"] = std::make_shared<nnvm::any>(
-        exec::ContextVector(indexed_graph.num_nodes(), default_ctx));
+      size_t args_top = 0, aux_top = 0;
+      // loop over inputs to symbol in order and add to args/aux if mutable
+      for (size_t i = 0; i < num_forward_inputs; ++i) {
+        const uint32_t nid = indexed_graph.input_nodes().at(i);
+        if (mutable_nodes.count(nid)) {
+          CHECK_LT(aux_top, aux_len)
+            << "Cannot find aux '" << input_names[i] << "' in provided aux to 
optimize_for";
+          if (in_aux_ptr[aux_top] != nullptr) {
+            const auto &in_arg = *(in_aux_ptr[aux_top]);
+            arg_shapes[i] = in_arg.shape();
+            arg_dtypes[i] = in_arg.dtype();
+            arg_stypes[i] = in_arg.storage_type();
+          }
+          aux_top++;
+        } else {
+          auto name = input_names[i];
+          CHECK_LT(args_top, args_len)
+            << "Cannot find arg '" << name << "' in provided args to 
optimize_for";
+          if (in_args_ptr[args_top] != nullptr) {
+            const auto &in_arg = *(in_args_ptr[args_top]);
+            arg_shapes[i] = in_arg.shape();
+            arg_dtypes[i] = in_arg.dtype();
+            arg_stypes[i] = in_arg.storage_type();
+          } else {
+            // input_names[i] is not in args but can be in the optional
+            // shape/type/stype attribute dicts.
+            auto it_shape = input_shape_map.find(name);
+            if (it_shape != input_shape_map.end()) {
+              arg_shapes[i] = it_shape->second;
+            }
+            auto it_type = input_dtype_map.find(name);
+            if (it_type != input_dtype_map.end()) {
+              arg_dtypes[i] = it_type->second;
+            }
+            it_type = input_stype_map.find(name);
+            if (it_type != input_stype_map.end()) {
+              arg_stypes[i] = it_type->second;
+            }
+          }
+          args_top++;
+        }
+      }
 
-    // infer shapes
-    g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
-    // infer dtypes
-    g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
-    if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
-      common::HandleInferTypeError(num_forward_inputs, indexed_graph,
-                                   g.GetAttr<nnvm::DTypeVector>("dtype"));
-    }
-    // infer stypes
-    g = exec::InferStorageType(std::move(g), std::move(arg_stypes), 
"__storage_type__");
-    if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
-      common::HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
-                                          
g.GetAttr<StorageTypeVector>("storage_type"));
+      g.attrs["context"] = std::make_shared<nnvm::any>(
+          exec::ContextVector(indexed_graph.num_nodes(), default_ctx));
+
+      // infer shapes
+      g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+      // infer dtypes
+      g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+      // infer stypes
+      g = exec::InferStorageType(std::move(g), std::move(arg_stypes), 
"__storage_type__");
     }
     // set args/aux as attributes on graph so that subgraph property can use 
them
     std::vector<std::string> arg_names = 
sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs);
diff --git a/src/operator/subgraph/build_subgraph.cc 
b/src/operator/subgraph/build_subgraph.cc
index 2d5501d..7cf9671 100644
--- a/src/operator/subgraph/build_subgraph.cc
+++ b/src/operator/subgraph/build_subgraph.cc
@@ -226,9 +226,7 @@ bool LabelSubgraph(const nnvm::Graph& g, 
SubgraphSelectorV2Ptr subgraph_selector
     std::stack<const nnvm::Node*> s;
     s.push(descendant);
     size_t count = 0;
-    while (!s.empty()) {
-      CHECK_LT(count, indexed_graph.num_nodes()) << "Finding ancestor failed. 
There is probably"
-                                                    " a loop in the graph";
+    while (!s.empty() && count < indexed_graph.num_nodes()) {
       ++count;
       const nnvm::Node* top = s.top();
       s.pop();
@@ -276,10 +274,6 @@ bool LabelSubgraph(const nnvm::Graph& g, 
SubgraphSelectorV2Ptr subgraph_selector
 
   if (excluded_node_id != -1) {
     CHECK_LT(excluded_node_id, static_cast<int>(simple_nodes.size()));
-    CHECK_NE(excluded_node_id, static_cast<int>(snid))
-      << "A cycle is found in the computational graph between nodes "
-      << simple_nodes[excluded_node_id]->node->attrs.name << " and "
-      << simple_nodes[snid]->node->attrs.name;
     excluded_nodes->insert(simple_nodes[excluded_node_id].get());
     ResetNodeLabels(g, simple_nodes, subgraph_nodes);
     return false;
@@ -306,6 +300,7 @@ void PreSelectSubgraphNodes(const nnvm::Graph& g, 
SubgraphSelectorV2Ptr subgraph
                             const std::vector<BiDirectedNodePtr>& simple_nodes,
                             std::vector<BiDirectedNode*>* subgraph_nodes) {
   std::unordered_set<const BiDirectedNode*> excluded_nodes;
+  size_t n_excluded_nodes = 0;
   const size_t max_num_retry = simple_nodes.size() * simple_nodes.size();
   size_t count = 0;
   bool success = false;
@@ -313,7 +308,14 @@ void PreSelectSubgraphNodes(const nnvm::Graph& g, 
SubgraphSelectorV2Ptr subgraph
     success = LabelSubgraph(g, subgraph_selector, label, snid, simple_nodes, 
subgraph_nodes,
                             &excluded_nodes);
     if (!success) {
-      CHECK(!excluded_nodes.empty());
+      // Failed to label subgraph due to a cycle
+      // If the number of excluded_nodes didn't change since the last 
iteration,
+      // this means that there is no possible subgraph for the current node 
snid, we break
+      // Otherwise, we keep trying (with the excluded nodes tagged)
+      if (excluded_nodes.size() == n_excluded_nodes) {
+        break;
+      }
+      n_excluded_nodes = excluded_nodes.size();
       std::string excluded_node_names;
       for (auto node : excluded_nodes) {
         excluded_node_names += node->node->attrs.name + ", ";
@@ -428,7 +430,7 @@ void SortEntries(const std::unordered_map<const 
nnvm::NodeEntry*, size_t>& entry
 }
 
 /*!
- * \brief Given a subgraph, find the output entries of a subgraph.
+ * \brief Given a subgraph, find the input entries of a subgraph.
  * \param g pointer to the whole graph
  * \param simple_nods vector of simple nodes in top sorted order
  * \param subgraph_nodes vector of pointers of simples of a subgraph.
diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc 
b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
index 19d0f26..4f80d27 100644
--- a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
+++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
@@ -31,7 +31,6 @@
 #include <mxnet/base.h>
 #include <nnvm/graph.h>
 #include <nnvm/pass_functions.h>
-#include <operator/nn/deconvolution-inl.h>
 
 #include "../../../common/utils.h"
 #include "../../../ndarray/ndarray_function.h"
@@ -39,6 +38,7 @@
 #include "../../nn/activation-inl.h"
 #include "../../nn/batch_norm-inl.h"
 #include "../../nn/convolution-inl.h"
+#include "../../nn/deconvolution-inl.h"
 #include "../../nn/fully_connected-inl.h"
 #include "../../nn/pooling-inl.h"
 #include "../../nn/concat-inl.h"
diff --git a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc 
b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
index b02d109..4f5bdcb 100644
--- a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
+++ b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
@@ -35,13 +35,9 @@
 #include <google/protobuf/io/zero_copy_stream_impl.h>
 #include <google/protobuf/text_format.h>
 #include <onnx-tensorrt/NvOnnxParser.h>
-#include <onnx-tensorrt/NvOnnxParserRuntime.h>
 #include <dmlc/logging.h>
 #include <dmlc/parameter.h>
 
-#include <onnx-tensorrt/PluginFactory.hpp>
-#include <onnx-tensorrt/plugin_common.hpp>
-
 using std::cout;
 using std::cerr;
 using std::endl;
@@ -78,7 +74,9 @@ std::tuple<unique_ptr<nvinfer1::ICudaEngine>,
 
   auto trt_logger = std::unique_ptr<TRT_Logger>(new TRT_Logger(verbosity));
   auto trt_builder = InferObject(nvinfer1::createInferBuilder(*trt_logger));
-  auto trt_network = InferObject(trt_builder->createNetwork());
+  const auto explicitBatch = 1U << static_cast<uint32_t>(
+                             
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
+  auto trt_network = InferObject(trt_builder->createNetworkV2(explicitBatch));
   auto trt_parser  = InferObject(nvonnxparser::createParser(*trt_network, 
*trt_logger));
   ::ONNX_NAMESPACE::ModelProto parsed_model;
   // We check for a valid parse, but the main effect is the side effect
diff --git a/src/operator/subgraph/tensorrt/tensorrt-inl.h 
b/src/operator/subgraph/tensorrt/tensorrt-inl.h
index dcafba5..369d7c3 100644
--- a/src/operator/subgraph/tensorrt/tensorrt-inl.h
+++ b/src/operator/subgraph/tensorrt/tensorrt-inl.h
@@ -268,6 +268,23 @@ class TensorrtProperty : public SubgraphProperty {
     return std::make_shared<TensorrtProperty>();
   }
 
+  void PrePartition(const nnvm::Graph& g,
+    const std::vector<std::pair<std::string, std::string>>& options_map) 
override {
+    auto& in_arg_names = g.GetAttr<std::vector<std::string>>("in_arg_names");
+    auto& in_aux_names = g.GetAttr<std::vector<std::string>>("in_aux_names");
+    NDArray **in_args_ptr = g.GetAttr<NDArray**>("in_args");
+    NDArray **in_aux_ptr = g.GetAttr<NDArray**>("in_aux");
+    in_args_dict.clear();
+    in_aux_dict.clear();
+    // we trust the Python API, len(in_arg_names) == len(in_args_ptr)
+    for (unsigned i = 0; i < in_arg_names.size(); ++i) {
+      in_args_dict[in_arg_names[i]] = in_args_ptr[i];
+    }
+    for (unsigned i = 0; i < in_aux_names.size(); ++i) {
+      in_aux_dict[in_aux_names[i]] = in_aux_ptr[i];
+    }
+  }
+
   nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym,
                                    const int subgraph_id) const override {
     nnvm::ObjectPtr n = nnvm::Node::Create();
@@ -281,16 +298,33 @@ class TensorrtProperty : public SubgraphProperty {
     n->attrs.op = Op::Get("_TensorRT");
     CHECK(n->attrs.op);
     n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
+
+    // Mapping subgraph params with NDArrays
+    TRTParam param;
     std::ostringstream params_oss;
-    for (auto &e : new_sym.ListInputNames(nnvm::Symbol::kAll)) {
-      params_oss << e << ";";
+    for (auto &param_name : new_sym.ListInputNames(nnvm::Symbol::kAll)) {
+      NDArray *cache = nullptr;
+      auto it_args = in_args_dict.find(param_name);
+      if (it_args != in_args_dict.end()) {
+        cache = it_args->second;
+      } else {
+        auto it_aux = in_aux_dict.find(param_name);
+        if (it_aux != in_aux_dict.end()) {
+          cache = it_aux->second;
+        }
+      }
+      if (cache != nullptr) {
+        param.params_map.emplace(param_name, cache->Copy(Context()));
+        param.params_map[param_name].WaitToRead();
+        params_oss << param_name << ";";
+      }
     }
     auto tensorrt_params_names = params_oss.str();
-    tensorrt_params_names.pop_back();
-    n->attrs.dict["subgraph_params_names"] = tensorrt_params_names;
-    TRTParam param;
+    if (!tensorrt_params_names.empty()) {
+      tensorrt_params_names.pop_back();
+    }
     n->attrs.parsed = param;
-    n->op()->attr_parser(&(n->attrs));
+    n->attrs.dict["subgraph_params_names"] = tensorrt_params_names;
     return n;
   }
 
@@ -329,6 +363,8 @@ class TensorrtProperty : public SubgraphProperty {
     }
     subgraph_node->attrs.parsed = std::move(_params);
   }
+
+  std::unordered_map<std::string, NDArray*> in_args_dict, in_aux_dict;
 };
 
 
diff --git a/src/operator/subgraph/tensorrt/tensorrt.cu 
b/src/operator/subgraph/tensorrt/tensorrt.cu
index 4a5b23b..826f9a5 100644
--- a/src/operator/subgraph/tensorrt/tensorrt.cu
+++ b/src/operator/subgraph/tensorrt/tensorrt.cu
@@ -56,12 +56,12 @@ void TRTCompute(const OpStatePtr& state, const OpContext& 
ctx,
       param.bindings->at(i) = outputs[p.first].dptr_;
     }
   }
-  const int batch_size = static_cast<int>(inputs[0].shape_[0]);
-  param.trt_executor->enqueue(batch_size, param.bindings->data(), cuda_s, 
nullptr);
+  param.trt_executor->enqueueV2(param.bindings->data(), cuda_s, nullptr);
 }
 
 NNVM_REGISTER_OP(_TensorRT)
-.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", TRTCompute);
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", TRTCompute)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/tensorrt/lenet5_train.py 
b/tests/python/tensorrt/lenet5_train.py
deleted file mode 100755
index a0ea447..0000000
--- a/tests/python/tensorrt/lenet5_train.py
+++ /dev/null
@@ -1,99 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import os
-import mxnet as mx
-import numpy as np
-
-def get_iters(mnist, batch_size):
-    """Get MNIST iterators."""
-    train_iter = mx.io.NDArrayIter(mnist['train_data'],
-                                   mnist['train_label'],
-                                   batch_size,
-                                   shuffle=True)
-    val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], 
batch_size)
-    test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], 
batch_size)
-    all_test_labels = np.array(mnist['test_label'])
-    return train_iter, val_iter, test_iter, all_test_labels
-
-def lenet5():
-    """LeNet-5 Symbol"""
-    #pylint: disable=no-member
-    data = mx.sym.Variable('data')
-    data = mx.sym.Cast(data, 'float16')
-    conv1 = mx.sym.Convolution(data=data, kernel=(5, 5), num_filter=20)
-    tanh1 = mx.sym.Activation(data=conv1, act_type="tanh")
-    pool1 = mx.sym.Pooling(data=tanh1, pool_type="max",
-                           kernel=(2, 2), stride=(2, 2))
-    # second conv
-    conv2 = mx.sym.Convolution(data=pool1, kernel=(5, 5), num_filter=50)
-    tanh2 = mx.sym.Activation(data=conv2, act_type="tanh")
-    pool2 = mx.sym.Pooling(data=tanh2, pool_type="max",
-                           kernel=(2, 2), stride=(2, 2))
-    # first fullc
-    flatten = mx.sym.Flatten(data=pool2)
-    fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=500)
-    tanh3 = mx.sym.Activation(data=fc1, act_type="tanh")
-    # second fullc
-    fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10)
-    fc2 = mx.sym.Cast(fc2, 'float32')
-    # loss
-    lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax')
-    #pylint: enable=no-member
-    return lenet
-
-
-def train_lenet5(num_epochs, batch_size, train_iter, val_iter, test_iter):
-    """train LeNet-5 model on MNIST data"""
-    ctx = mx.gpu(0)
-    lenet_model = mx.mod.Module(lenet5(), context=ctx)
-
-    lenet_model.fit(train_iter,
-                    eval_data=val_iter,
-                    optimizer='sgd',
-                    optimizer_params={'learning_rate': 0.1, 'momentum': 0.9},
-                    eval_metric='acc',
-                    batch_end_callback=mx.callback.Speedometer(batch_size, 1),
-                    num_epoch=num_epochs)
-
-    # predict accuracy for lenet
-    acc = mx.metric.Accuracy()
-    lenet_model.score(test_iter, acc)
-    accuracy = acc.get()[1]
-    assert accuracy > 0.95, "LeNet-5 training accuracy on MNIST was too low"
-    return lenet_model
-
-
-if __name__ == '__main__':
-    num_epochs = 10
-    batch_size = 128
-    model_name = 'lenet5'
-    model_dir = os.getenv("LENET_MODEL_DIR", "/tmp")
-    model_file = '%s/%s-symbol.json' % (model_dir, model_name)
-    params_file = '%s/%s-%04d.params' % (model_dir, model_name, num_epochs)
-
-    if not (os.path.exists(model_file) and os.path.exists(params_file)):
-        mnist = mx.test_utils.get_mnist()
-
-        _, _, _, all_test_labels = get_iters(mnist, batch_size)
-
-        trained_lenet = train_lenet5(num_epochs, batch_size,
-                                    *get_iters(mnist, batch_size)[:-1])
-        trained_lenet.save_checkpoint(model_name, num_epochs)
diff --git a/tests/python/tensorrt/test_cvnets.py 
b/tests/python/tensorrt/test_cvnets.py
deleted file mode 100644
index 4b8eb48..0000000
--- a/tests/python/tensorrt/test_cvnets.py
+++ /dev/null
@@ -1,174 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import gc
-import gluoncv
-import mxnet as mx
-import numpy as np
-
-from mxnet import gluon
-from time import time
-
-from mxnet.gluon.data.vision import transforms
-
-
-def get_classif_model(model_name, use_tensorrt, ctx=mx.gpu(0), batch_size=128):
-    mx.contrib.tensorrt.set_use_fp16(False)
-    h, w = 32, 32
-    net = gluoncv.model_zoo.get_model(model_name, pretrained=True)
-    net.hybridize()
-    net.forward(mx.nd.zeros((batch_size, 3, h, w)))
-    net.export(model_name)
-    _sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, 0)
-    if use_tensorrt:
-        sym = _sym.get_backend_symbol('TensorRT')
-        arg_params, aux_params = mx.contrib.tensorrt.init_tensorrt_params(sym, 
arg_params,
-                                                                          
aux_params)
-    else:
-        sym = _sym
-    executor = sym.simple_bind(ctx=ctx, data=(batch_size, 3, h, w),
-                               softmax_label=(batch_size,),
-                               grad_req='null', force_rebind=True)
-    executor.copy_params_from(arg_params, aux_params)
-    return executor
-
-
-def cifar10_infer(model_name, use_tensorrt, num_workers, ctx=mx.gpu(0), 
batch_size=128):
-    executor = get_classif_model(model_name, use_tensorrt, ctx, batch_size)
-
-    num_ex = 10000
-    all_preds = np.zeros([num_ex, 10])
-
-    all_label_test = np.zeros(num_ex)
-
-    transform_test = transforms.Compose([
-        transforms.ToTensor(),
-        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 
0.2010])
-    ])
-
-    data_loader = lambda: gluon.data.DataLoader(
-        gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
-        batch_size=batch_size, shuffle=False, num_workers=num_workers)
-
-    val_data = data_loader()
-
-    for idx, (data, label) in enumerate(val_data):
-        # Skip last batch if it's undersized.
-        if data.shape[0] < batch_size:
-            continue
-        offset = idx * batch_size
-        all_label_test[offset:offset + batch_size] = label.asnumpy()
-
-        # warm-up, but don't use result
-        executor.forward(is_train=False, data=data)
-        executor.outputs[0].wait_to_read()
-
-    gc.collect()
-    val_data = data_loader()
-    example_ct = 0
-    start = time()
-
-    # if use_tensorrt:
-    for idx, (data, label) in enumerate(val_data):
-        # Skip last batch if it's undersized.
-        if data.shape[0] < batch_size:
-            continue
-        executor.forward(is_train=False, data=data)
-        preds = executor.outputs[0].asnumpy()
-        offset = idx * batch_size
-        all_preds[offset:offset + batch_size, :] = preds[:batch_size]
-        example_ct += batch_size
-
-    all_preds = np.argmax(all_preds, axis=1)
-    matches = (all_preds[:example_ct] == all_label_test[:example_ct]).sum()
-    duration = time() - start
-
-    return duration, 100.0 * matches / example_ct
-
-
-def run_experiment_for(model_name, batch_size, num_workers):
-    print("\n===========================================")
-    print("Model: %s" % model_name)
-    print("===========================================")
-    print("*** Running inference using pure MXNet ***\n")
-    mx_duration, mx_pct = cifar10_infer(model_name=model_name, 
batch_size=batch_size,
-                                        num_workers=num_workers, 
use_tensorrt=False)
-    print("\nMXNet: time elapsed: %.3fs, accuracy: %.2f%%" % (mx_duration, 
mx_pct))
-    print("\n*** Running inference using MXNet + TensorRT ***\n")
-    trt_duration, trt_pct = cifar10_infer(model_name=model_name, 
batch_size=batch_size,
-                                          num_workers=num_workers, 
use_tensorrt=True)
-    print("TensorRT: time elapsed: %.3fs, accuracy: %.2f%%" % (trt_duration, 
trt_pct))
-    speedup = mx_duration / trt_duration
-    print("TensorRT speed-up (not counting compilation): %.2fx" % speedup)
-
-    acc_diff = abs(mx_pct - trt_pct)
-    print("Absolute accuracy difference: %f" % acc_diff)
-    return speedup, acc_diff
-
-
-def test_tensorrt_on_cifar_resnets(batch_size=32, tolerance=0.1, 
num_workers=1):
-    original_use_fp16 = mx.contrib.tensorrt.get_use_fp16()
-    try:
-        models = [
-            'cifar_resnet20_v1',
-            'cifar_resnet56_v1',
-            'cifar_resnet110_v1',
-            'cifar_resnet20_v2',
-            'cifar_resnet56_v2',
-            'cifar_resnet110_v2',
-            'cifar_wideresnet16_10',
-            'cifar_wideresnet28_10',
-            'cifar_wideresnet40_8',
-            'cifar_resnext29_16x64d'
-        ]
-
-        num_models = len(models)
-
-        speedups = np.zeros(num_models, dtype=np.float32)
-        acc_diffs = np.zeros(num_models, dtype=np.float32)
-
-        test_start = time()
-
-        for idx, model in enumerate(models):
-            speedup, acc_diff = run_experiment_for(model, batch_size, 
num_workers)
-            speedups[idx] = speedup
-            acc_diffs[idx] = acc_diff
-            assert acc_diff < tolerance, "Accuracy difference between MXNet 
and TensorRT > %.2f%% for model %s" % (
-                tolerance, model)
-
-        print("Perf and correctness checks run on the following models:")
-        print(models)
-        mean_speedup = np.mean(speedups)
-        std_speedup = np.std(speedups)
-        print("\nSpeedups:")
-        print(speedups)
-        print("Speedup range: [%.2f, %.2f]" % (np.min(speedups), 
np.max(speedups)))
-        print("Mean speedup: %.2f" % mean_speedup)
-        print("St. dev. of speedups: %.2f" % std_speedup)
-        print("\nAcc. differences: %s" % str(acc_diffs))
-
-        test_duration = time() - test_start
-
-        print("Test duration: %.2f seconds" % test_duration)
-    finally:
-        mx.contrib.tensorrt.set_use_fp16(original_use_fp16)
-
-
-if __name__ == '__main__':
-    import nose
-
-    nose.runmodule()
diff --git a/tests/python/tensorrt/test_ops.py 
b/tests/python/tensorrt/test_ops.py
deleted file mode 100644
index af1c453..0000000
--- a/tests/python/tensorrt/test_ops.py
+++ /dev/null
@@ -1,517 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import mxnet as mx
-import numpy as np
-from itertools import product
-import copy
-
-from numpy.testing import assert_allclose
-
-import sys
-import os
-curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-sys.path.insert(0, os.path.join(curr_path, '../unittest'))
-from common import setup_module, with_seed
-
-def check_unsupported_single_sym(sym):
-    wrapped_sym = mx.sym.Group([mx.sym.identity(s) for s in sym])
-    trt_sym = wrapped_sym.get_backend_symbol('TensorRT')
-    assert len(wrapped_sym.get_internals()) == len(trt_sym.get_internals())
-
-def check_single_sym(sym, data_shapes, arg_params_shapes=None, 
aux_params_shapes=None,
-                     rtol_fp32=1e-5, atol_fp32=0., rtol_fp16=1e-3, 
atol_fp16=0.):
-    if arg_params_shapes is None:
-        arg_params_shapes = {}
-    if aux_params_shapes is None:
-        aux_params_shapes = {}
-    for i in range(3):
-        data = {k: mx.nd.array(np.random.rand(*v) + 0.01, dtype='float32', 
ctx=mx.cpu())
-                for k, v in data_shapes.items()}
-        arg_params = {k: mx.nd.array(np.random.rand(*v) + 0.01, 
dtype='float32', ctx=mx.cpu())
-                      for k, v in arg_params_shapes.items()}
-        aux_params = {k: mx.nd.array(np.random.rand(*v) + 0.01, 
dtype='float32', ctx=mx.cpu())
-                      for k, v in aux_params_shapes.items()}
-        wrapped_sym = mx.sym.Group([mx.sym.identity(s) for s in sym])
-
-        # Test FP32 MXNet Native
-        shapes = {}
-        shapes.update(data_shapes)
-        shapes.update(arg_params_shapes)
-        shapes.update(aux_params_shapes)
-        orig_executor = wrapped_sym.simple_bind(ctx=mx.gpu(0), grad_req='null',
-                                                force_rebind=True, **shapes)
-        orig_executor.copy_params_from(arg_params, aux_params)
-        orig_executor.forward(is_train=False, **data)
-        orig_outputs = [arr.asnumpy() for arr in orig_executor.outputs]
-
-        # Test FP32 MXNet-TRT
-        mx.contrib.tensorrt.set_use_fp16(False)
-        trt_sym = wrapped_sym.get_backend_symbol('TensorRT')
-        assert len(trt_sym.get_internals()) < len(wrapped_sym.get_internals())
-        remaining_arg_params, remaining_aux_params = \
-            mx.contrib.tensorrt.init_tensorrt_params(trt_sym, arg_params, 
aux_params)
-        shapes = {}
-        shapes.update(data_shapes)
-        shapes.update({k: v.shape for k, v in remaining_arg_params.items()})
-        shapes.update({k: v.shape for k, v in remaining_aux_params.items()})
-        trt_fp32_executor = trt_sym.simple_bind(ctx=mx.gpu(0), grad_req='null',
-                                                force_rebind=True, **shapes)
-        trt_fp32_executor.copy_params_from(remaining_arg_params, 
remaining_aux_params)
-        trt_fp32_executor.forward(is_train=False, **data)
-        trt_fp32_outputs = [arr.asnumpy() for arr in trt_fp32_executor.outputs]
-
-        # Test FP16 MXNet-TRT
-        mx.contrib.tensorrt.set_use_fp16(True)
-        data = {k: v.astype('float16') for k, v in data.items()}
-        arg_params = {k: v.astype('float16') for k, v in arg_params.items()}
-        aux_params = {k: v.astype('float16') for k, v in aux_params.items()}
-        trt_sym = wrapped_sym.get_backend_symbol('TensorRT')
-        assert len(trt_sym.get_internals()) < len(wrapped_sym.get_internals())
-        remaining_arg_params, remaining_aux_params = \
-            mx.contrib.tensorrt.init_tensorrt_params(trt_sym, arg_params, 
aux_params)
-        shapes = {}
-        shapes.update(data_shapes)
-        shapes.update({k: v.shape for k, v in remaining_arg_params.items()})
-        shapes.update({k: v.shape for k, v in remaining_aux_params.items()})
-
-        trt_fp16_executor = trt_sym.simple_bind(ctx=mx.gpu(0),
-                                                type_dict={k: 'float16' for k 
in shapes.keys()},
-                                                grad_req='null', 
force_rebind=True, **shapes)
-        trt_fp16_executor.copy_params_from(remaining_arg_params, 
remaining_aux_params)
-        trt_fp16_executor.forward(is_train=False, **data)
-        trt_fp16_outputs = [arr.asnumpy() for arr in trt_fp16_executor.outputs]
-
-        for j, (orig, fp16, fp32) in enumerate(zip(orig_outputs, 
trt_fp16_outputs, trt_fp32_outputs)):
-            abs_orig = abs(orig)
-            diff32 = abs(fp32 - orig)
-            diff16 = abs(fp16.astype('float32') - orig)
-            _atol32 = diff32 - rtol_fp32 * abs_orig
-            _atol16 = diff16 - rtol_fp16 * abs_orig
-            print("{}: diff32({:.2E}) | diff16({:.2E}) | atol32({:.2E}) | 
atol16({:.2E}) | orig.min({:.2E})".format(
-                  j, diff32.max(), diff16.max(), _atol32.max(), _atol16.max(), 
abs_orig.min()))
-            assert_allclose(fp32, orig, rtol=rtol_fp32, atol=atol_fp32)
-            assert_allclose(fp16, orig, rtol=rtol_fp16, atol=atol_fp16)
-
-@with_seed()
-def test_noop():
-    data = mx.sym.Variable('data')
-    check_unsupported_single_sym(data)
-
-
-@with_seed()
-def test_identity():
-    data = mx.sym.Variable('data')
-    sym = mx.sym.identity(data)
-    check_single_sym(sym, data_shapes={'data': (8,3,32,32)},
-                     rtol_fp32=0., atol_fp32=0., rtol_fp16=1e-3, 
atol_fp16=1e-7)
-
-
-@with_seed()
-def test_convolution2d():
-    data = mx.sym.Variable('data')
-    weight = mx.sym.Variable('weight')
-    bias = mx.sym.Variable('bias')
-    data_shape = (8,3,16,16)
-    num_filter = 7
-    for kernel in [(3, 3), (1, 1), (3, 1)]:
-        for stride in [(1, 1), (2, 2), (2, 1)]:
-            if stride[0] > kernel[0] or stride[1] > kernel[1]: # doesn't make 
any sense
-                continue
-            if kernel == (3, 3) and stride == (1, 1):
-                atol_fp32 = 0.
-                rtol_fp32 = 1e-5
-                atol_fp16 = 0.
-                rtol_fp16 = 1e-2
-            else:
-                atol_fp32 = 0.
-                rtol_fp32 = 0.
-                atol_fp16 = 0.
-                rtol_fp16 = 1e-2
-            for pad in [(1, 1), (0, 0), (1, 0)]:
-                for group in [1, 2]:
-                    for layout in ['NCHW', 'NHWC']:
-                        weight_shape = (num_filter, data_shape[1]) + kernel
-                        bias_shape = (num_filter,)
-                        sym = mx.sym.Convolution(data, weight=weight, 
bias=bias, kernel=kernel,
-                                                 stride=stride, pad=pad, 
num_filter=num_filter,
-                                                 no_bias=False, layout=layout)
-                        if layout == 'NCHW':
-                            print("kernel: {} | stride: {} | pad: {} | group: 
{} | layout: {} | with_bias".format(
-                                  kernel, stride, pad, group, layout))
-                            check_single_sym(sym, {'data': data_shape},
-                                             {'weight': weight_shape, 'bias': 
bias_shape},
-                                             rtol_fp32=rtol_fp32, 
atol_fp32=atol_fp32,
-                                             rtol_fp16=rtol_fp16, 
atol_fp16=atol_fp16)
-                        else:
-                            check_unsupported_single_sym(sym)
-                        sym = mx.sym.Convolution(data, weight=weight, 
kernel=kernel, stride=stride,
-                                                 pad=pad, 
num_filter=num_filter, no_bias=True,
-                                                 layout=layout)
-                        if layout == 'NCHW':
-                            print("kernel: {} | stride: {} | pad: {} | group: 
{} | layout: {} | without_bias".format(
-                                  kernel, stride, pad, group, layout))
-                            check_single_sym(sym, {'data': data_shape},
-                                             {'weight': weight_shape},
-                                             rtol_fp32=rtol_fp32, 
atol_fp32=atol_fp32,
-                                             rtol_fp16=rtol_fp16, 
atol_fp16=atol_fp16)
-                        else:
-                            check_unsupported_single_sym(sym)
-
-@with_seed()
-def test_deconvolution2d():
-    data = mx.sym.Variable('data')
-    weight = mx.sym.Variable('weight')
-    bias = mx.sym.Variable('bias')
-    data_shape = (8,3,16,16)
-    num_filter = 7
-    for kernel in [(3, 3), (1, 1), (3, 1)]:
-        for stride in [(1, 1), (2, 2), (2, 1)]:
-            if stride[0] > kernel[0] or stride[1] > kernel[1]: # doesn't make 
any sense
-                continue
-            if kernel == (3, 3) and stride == (1, 1):
-                atol_fp32 = 0.
-                rtol_fp32 = 5e-5
-                atol_fp16 = 0.
-                rtol_fp16 = 1e-2
-            else:
-                atol_fp32 = 0.
-                rtol_fp32 = 1e-6
-                atol_fp16 = 0.
-                rtol_fp16 = 1e-2
-            for pad in [(1, 1), (0, 0), (1, 0)]:
-                for group in [1, 2]:
-                    for layout in ['NCHW', 'NHWC']:
-                        weight_shape = (data_shape[1], num_filter) + kernel
-                        bias_shape = (num_filter,)
-                        sym = mx.sym.Deconvolution(data, weight=weight, 
bias=bias, kernel=kernel,
-                                                 stride=stride, pad=pad, 
num_filter=num_filter,
-                                                 no_bias=False, layout=layout)
-                        if layout == 'NCHW':
-                            print("kernel: {} | stride: {} | pad: {} | group: 
{} | layout: {} | with_bias".format(
-                                  kernel, stride, pad, group, layout))
-                            check_single_sym(sym, {'data': data_shape},
-                                             {'weight': weight_shape, 'bias': 
bias_shape},
-                                             rtol_fp32=rtol_fp32, 
atol_fp32=atol_fp32,
-                                             rtol_fp16=rtol_fp16, 
atol_fp16=atol_fp16)
-                        else:
-                            check_unsupported_single_sym(sym)
-                        sym = mx.sym.Deconvolution(data, weight=weight, 
kernel=kernel, stride=stride,
-                                                 pad=pad, 
num_filter=num_filter, no_bias=True,
-                                                 layout=layout)
-                        if layout == 'NCHW':
-                            print("kernel: {} | stride: {} | pad: {} | group: 
{} | layout: {} | without_bias".format(
-                                  kernel, stride, pad, group, layout))
-                            check_single_sym(sym, {'data': data_shape},
-                                             {'weight': weight_shape},
-                                             rtol_fp32=rtol_fp32, 
atol_fp32=atol_fp32,
-                                             rtol_fp16=rtol_fp16, 
atol_fp16=atol_fp16)
-                        else:
-                            check_unsupported_single_sym(sym)
-
-@with_seed()
-def test_fully_connected(): # TODO(cfujitsang): take care of flatten option
-    data = mx.sym.Variable('data')
-    weight = mx.sym.Variable('weight')
-    bias = mx.sym.Variable('bias')
-    data_shape = (8,64)
-    num_hidden = 7
-    weight_shape = (num_hidden, data_shape[1])
-    bias_shape = (num_hidden,)
-    sym = mx.sym.FullyConnected(data, weight=weight, bias=bias, no_bias=False,
-                                num_hidden=num_hidden)
-    check_single_sym(sym, {'data': data_shape}, {'weight': weight_shape, 
'bias': bias_shape},
-                     rtol_fp16=5e-3, atol_fp16=0.)
-    sym = mx.sym.FullyConnected(data, weight=weight, no_bias=True, 
num_hidden=num_hidden)
-    check_unsupported_single_sym(sym)
-
-
-@with_seed()
-def test_relu():
-    data = mx.sym.Variable('data')
-    sym = mx.sym.relu(data)
-    for data_shape in [(10, 32), (10, 3, 32), (10, 3, 32, 32), (10, 3, 7, 32, 
32)]:
-        check_single_sym(sym, {'data': data_shape}, rtol_fp32=0., atol_fp32=0.,
-                         rtol_fp16=1e-3, atol_fp16=1e-7)
-
-
-@with_seed()
-def test_activation():
-    data = mx.sym.Variable('data')
-    for act_type in ['relu', 'sigmoid', 'tanh']:
-        sym = mx.sym.Activation(data, act_type=act_type)
-        for data_shape in [(10, 32), (10, 3, 32), (10, 3, 32, 32), 
(10,3,7,32,32)]:
-            check_single_sym(sym, {'data': data_shape}, rtol_fp32=0., 
atol_fp32=0.,
-                             rtol_fp16=1e-3, atol_fp16=1e-7)
-    for act_type in ['softrelu', 'softsign']:
-        sym = mx.sym.Activation(data, act_type=act_type)
-        check_unsupported_single_sym(sym)
-
-
-@with_seed()
-def test_pooling2d():
-    data = mx.sym.Variable('data')
-    data_shape = (4, 3, 32,32)
-    for pool_type in ['max', 'avg', 'lp', 'sum']:
-        if pool_type == 'max':
-            rtol_fp32 = 1e-6
-            atol_fp32 = 0.
-            rtol_fp16 = 1e-3
-            atol_fp16 = 0.
-        else:
-            rtol_fp32 = 5e-6
-            atol_fp32 = 0.
-            rtol_fp16 = 1e-3
-            atol_fp16 = 0.
-        for layout in ['NHWC', 'NCHW']:
-            for (stride, pad, kernel, count_include_pad, pooling_convention) \
-                 in product([(2,2), (2,1)], [(0,0), (1,1)], [(2,2), (3,2)],
-                            [True, False], ['valid', 'full']):
-                print("pool_type: {} | layout: {} | stride: {} | pad: {} | 
".format(
-                      pool_type, layout, stride, pad) +
-                      "kernel: {} | count_include_pad: {} | 
pooling_convention: {}".format(
-                      kernel, count_include_pad, pooling_convention))
-                sym = mx.sym.Pooling(data, kernel=kernel, pool_type=pool_type, 
stride=stride,
-                                     pad=pad, layout=layout, 
count_include_pad=count_include_pad,
-                                     pooling_convention=pooling_convention)
-                if (layout == 'NHWC') or \
-                    pool_type not in ('max', 'avg') or \
-                    pooling_convention != 'valid' or \
-                    (pool_type == 'avg' and count_include_pad):
-                    check_unsupported_single_sym(sym)
-                else:
-                    check_single_sym(sym, {'data': data_shape},
-                                     rtol_fp32=rtol_fp32, atol_fp32=atol_fp32,
-                                     rtol_fp16=rtol_fp16, atol_fp16=atol_fp16)
-            print("pool_type: {} | layout: {} | global_pool".format(pool_type, 
layout))
-            sym = mx.sym.Pooling(data, global_pool=True, pool_type=pool_type, 
layout=layout)
-            if layout == 'NHWC' or pool_type not in ('max', 'avg'):
-                check_unsupported_single_sym(sym)
-            else:
-                if pool_type == 'max':
-                    rtol_fp32 = 0.
-                    atol_fp32 = 0.
-                    rtol_fp16 = 1e-3
-                    atol_fp16 = 0.
-                else:
-                    rtol_fp32 = 1e-5
-                    atol_fp32 = 0.
-                    rtol_fp16 = 1e-3
-                    atol_fp16 = 0.
-                check_single_sym(sym, {'data': data_shape}, 
rtol_fp32=rtol_fp32,
-                                 atol_fp32=atol_fp32, rtol_fp16=rtol_fp16, 
atol_fp16=atol_fp16)
-
-
-@with_seed()
-def test_softmax_output():
-    data = mx.sym.Variable('data')
-    label = mx.sym.Variable('label')
-    data_shape = (8, 100)
-    label_shape = (8, 100)
-    sym = mx.sym.SoftmaxOutput(data, label)
-    check_single_sym(sym, {'data': data_shape, 'label': label_shape},
-                     rtol_fp32=1e-6, atol_fp32=0., rtol_fp16=5e-3, 
atol_fp16=0.)
-    sym = mx.sym.SoftmaxOutput(data)
-    check_single_sym(sym, {'data': data_shape},
-                     rtol_fp32=1e-6, atol_fp32=0., rtol_fp16=5e-3, 
atol_fp16=0.)
-
-
-
-def check_batch_norm(sym, data_shapes, arg_params_shapes=None, 
aux_params_shapes=None,
-                     rtol_fp32=1e-5, atol_fp32=1e-7, rtol_fp16=1e-2, 
atol_fp16=1e-3):
-    if arg_params_shapes is None:
-        arg_params_shapes = {}
-    if aux_params_shapes is None:
-        aux_params_shapes = {}
-    for i in range(3):
-        data = {
-            'data': mx.nd.array(np.random.rand(*data_shapes['data']) + 0.01,
-                                dtype='float32', ctx=mx.cpu())
-        }
-        arg_params = {
-            'gamma': mx.nd.array(np.random.rand(*arg_params_shapes['gamma']) * 
0.1 + 1.,
-                                 dtype='float32', ctx=mx.cpu()),
-            'beta': mx.nd.array(np.random.rand(*arg_params_shapes['beta']),
-                                dtype='float32', ctx=mx.cpu())
-        }
-        aux_params = {
-            'moving_mean': mx.nd.array(
-                0.45 + np.random.rand(*aux_params_shapes['moving_mean']) * 0.1 
+ 0.01,
-                                      dtype='float32', ctx=mx.cpu()),
-            'moving_var': mx.nd.array(
-                0.95 + np.random.rand(*aux_params_shapes['moving_var']) * 0.1,
-                                      dtype='float32', ctx=mx.cpu())
-        }
-        wrapped_sym = mx.sym.Group([mx.sym.identity(s) for s in sym])
-
-        # Test FP32 MXNet Native
-        shapes = {}
-        shapes.update(data_shapes)
-        shapes.update(arg_params_shapes)
-        shapes.update(aux_params_shapes)
-        orig_executor = wrapped_sym.simple_bind(ctx=mx.gpu(0), grad_req='null',
-                                                force_rebind=True, **shapes)
-        orig_executor.copy_params_from(arg_params, aux_params)
-        orig_executor.forward(is_train=False, **data)
-        orig_outputs = [arr.asnumpy() for arr in orig_executor.outputs]
-
-        # Test FP32 MXNet-TRT
-        mx.contrib.tensorrt.set_use_fp16(False)
-        trt_sym = wrapped_sym.get_backend_symbol('TensorRT')
-        assert len(trt_sym.get_internals()) < len(wrapped_sym.get_internals())
-        remaining_arg_params, remaining_aux_params = \
-            mx.contrib.tensorrt.init_tensorrt_params(trt_sym, arg_params, 
aux_params)
-        shapes = {}
-        shapes.update(data_shapes)
-        shapes.update({k: v.shape for k, v in remaining_arg_params.items()})
-        shapes.update({k: v.shape for k, v in remaining_aux_params.items()})
-        trt_fp32_executor = trt_sym.simple_bind(ctx=mx.gpu(0), grad_req='null',
-                                                force_rebind=True, **shapes)
-        trt_fp32_executor.copy_params_from(remaining_arg_params, 
remaining_aux_params)
-        trt_fp32_executor.forward(is_train=False, **data)
-        trt_fp32_outputs = [arr.asnumpy() for arr in trt_fp32_executor.outputs]
-
-        # Test FP16 MXNet-TRT
-        mx.contrib.tensorrt.set_use_fp16(True)
-        data = {k: v.astype('float16') for k, v in data.items()}
-        arg_params = {k: v.astype('float32') for k, v in arg_params.items()}
-        aux_params = {k: v.astype('float32') for k, v in aux_params.items()}
-        trt_sym = wrapped_sym.get_backend_symbol('TensorRT')
-        remaining_arg_params, remaining_aux_params = \
-            mx.contrib.tensorrt.init_tensorrt_params(trt_sym, arg_params, 
aux_params)
-        shapes = {}
-        shapes.update(data_shapes)
-        shapes.update({k: v.shape for k, v in remaining_arg_params.items()})
-        shapes.update({k: v.shape for k, v in remaining_aux_params.items()})
-
-        trt_fp16_executor = trt_sym.simple_bind(ctx=mx.gpu(0),
-                                                type_dict={k: 'float16' for k 
in shapes.keys()},
-                                                grad_req='null', 
force_rebind=True, **shapes)
-        trt_fp16_executor.copy_params_from(remaining_arg_params, 
remaining_aux_params)
-        trt_fp16_executor.forward(is_train=False, **data)
-        trt_fp16_outputs = [arr.asnumpy() for arr in trt_fp16_executor.outputs]
-
-
-        for j, (orig, fp16, fp32) in enumerate(zip(orig_outputs,
-                                                   trt_fp16_outputs,
-                                                   trt_fp32_outputs)):
-            abs_orig = abs(orig)
-            diff32 = abs(fp32 - orig)
-            diff16 = abs(fp16.astype('float32') - orig)
-            _atol32 = diff32 - rtol_fp32 * abs_orig
-            _atol16 = diff16 - rtol_fp16 * abs_orig
-            print("{}: diff32({:.2E}) | diff16({:.2E}) | atol32({:.2E}) | 
atol16({:.2E}) | orig.min({:.2E})".format(
-                  j, diff32.max(), diff16.max(), _atol32.max(), _atol16.max(), 
abs_orig.min()))
-            assert_allclose(fp32, orig, rtol=rtol_fp32, atol=atol_fp32)
-            assert_allclose(fp16.astype('float32'), orig, rtol=rtol_fp16, 
atol=atol_fp16)
-
-@with_seed()
-def test_batch_norm():
-    data = mx.sym.Variable('data')
-    gamma = mx.sym.Variable('gamma')
-    beta = mx.sym.Variable('beta')
-    moving_mean = mx.sym.Variable('moving_mean')
-    moving_var = mx.sym.Variable('moving_var')
-    data_shape = (4,3,32,32)
-    gamma_shape = (3,)
-    beta_shape = (3,)
-    moving_mean_shape = (3,)
-    moving_var_shape = (3,)
-    for fix_gamma in [True, False]:
-        for use_global_stats in [True, False]:
-            for axis in [0, 1, 2, 3]:
-                sym = mx.sym.BatchNorm(data, gamma=gamma, beta=beta, 
moving_mean=moving_mean,
-                                       fix_gamma=fix_gamma, 
moving_var=moving_var, momentum=0.9,
-                                       axis=axis, 
use_global_stats=use_global_stats, eps=1e-5)
-                if axis == 1:
-                    check_batch_norm(sym,
-                        {'data': data_shape}, {'gamma': gamma_shape, 'beta': 
beta_shape},
-                        {'moving_mean': moving_mean_shape, 'moving_var': 
moving_var_shape},
-                        atol_fp32=2e-7)
-                else:
-                    check_unsupported_single_sym(sym)
-
-
-@with_seed()
-def test_clip():
-    data = mx.sym.Variable('data')
-    sym = mx.sym.clip(data, 0.25, 0.75)
-    for data_shape in [(10, 32), (10, 3, 32), (10, 3, 32, 32), (10,3,7,32,32)]:
-        check_single_sym(sym, {'data': data_shape},
-                         rtol_fp32=0., atol_fp32=0.,
-                         rtol_fp16=1e-3, atol_fp16=0.)
-
-
-@with_seed()
-def test_concat():
-    lhs = mx.sym.Variable('lhs')
-    rhs = mx.sym.Variable('rhs')
-    shape = [3, 5, 7, 9]
-    lhs_shape = tuple(shape)
-    for axis in range(1, 4):
-        sym = mx.sym.concat(lhs, rhs, dim=axis)
-        rhs_shape = copy.copy(shape)
-        rhs_shape[axis] = 1
-        rhs_shape = tuple(rhs_shape)
-        check_single_sym(sym, {'lhs': lhs_shape, 'rhs': rhs_shape},
-                         rtol_fp32=0., atol_fp32=0., rtol_fp16=1e-3, 
atol_fp16=1e-7)
-
-
-@with_seed()
-def test_elemwise_ops():
-    lhs = mx.sym.Variable('lhs')
-    rhs = mx.sym.Variable('rhs')
-    shape = (3, 5, 7, 9)
-    lhs_shape = tuple(shape)
-    sym = mx.sym.elemwise_add(lhs, rhs)
-    check_single_sym(sym, {'lhs': shape, 'rhs': shape},
-                     rtol_fp32=0., atol_fp32=0.)
-
-    sym = mx.sym.elemwise_sub(lhs, rhs)
-    # TODO(cfujitsang): is atol_fp16 ok ?
-    check_single_sym(sym, {'lhs': shape, 'rhs': shape},
-                     rtol_fp32=0., atol_fp32=0., rtol_fp16=1e-3, 
atol_fp16=1e-3)
-
-    sym = mx.sym.elemwise_mul(lhs, rhs)
-    check_single_sym(sym, {'lhs': shape, 'rhs': shape},
-                     rtol_fp32=0., atol_fp32=0., rtol_fp16=5e-3, 
atol_fp16=1e-7)
-
-@with_seed()
-def test_flatten():
-    data = mx.sym.Variable('data')
-    sym = mx.sym.flatten(data)
-    for data_shape in [(3, 5, 7), (3, 5, 7, 9), (3, 5, 7, 9, 11)]:
-        check_single_sym(sym, {'data': data_shape},
-                         rtol_fp32=0., atol_fp32=0., atol_fp16=1e-7)
-
-@with_seed()
-def test_dropout():
-    data = mx.sym.Variable('data')
-    for data_shape in [(3, 5), (3, 5, 7), (3, 5, 7, 9)]:
-        for mode in ['training', 'always']:
-            sym = mx.sym.Dropout(data, p=0.7, mode=mode)
-            if mode == 'training':
-                check_single_sym(sym, {'data': data_shape},
-                                 rtol_fp32=0., atol_fp32=0., atol_fp16=1e-7)
-            else:
-                check_unsupported_single_sym(sym)
-            sym = mx.sym.Dropout(data, p=0.7, mode=mode, axes=(0,))
-            check_unsupported_single_sym(sym)
-
-if __name__ == "__main__":
-    import nose
-    nose.runmodule()
diff --git a/tests/python/tensorrt/test_resnet18.py 
b/tests/python/tensorrt/test_resnet18.py
deleted file mode 100644
index 9fd99ab..0000000
--- a/tests/python/tensorrt/test_resnet18.py
+++ /dev/null
@@ -1,74 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from mxnet.gluon.model_zoo import vision
-from mxnet.test_utils import assert_almost_equal
-import mxnet as mx
-import numpy as np
-import os
-
-batch_shape = (1, 3, 224, 224)
-url = 
'https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/cat.jpg?raw=true'
-model_file_name = 'resnet18_v2_trt_test'
-
-def get_image(image_url):
-    fname = mx.test_utils.download(image_url, 
fname=image_url.split('/')[-1].split('?')[0])
-    img = mx.image.imread(fname)
-    img = mx.image.imresize(img, 224, 224)  # Resize
-    img = img.transpose((2, 0, 1))  # Channel first
-    img = img.expand_dims(axis=0)  # Batchify
-    img = mx.nd.cast(img, dtype=np.float32)
-    return img / 255.0
-
-def test_tensorrt_resnet18_feature_vect():
-    print("downloading sample input")
-    input_data = get_image(url)
-    gluon_resnet18 = vision.resnet18_v2(pretrained=True)
-    gluon_resnet18.hybridize()
-    gluon_resnet18.forward(input_data)
-    gluon_resnet18.export(model_file_name)
-    sym, arg_params, aux_params = mx.model.load_checkpoint(model_file_name, 0)
-
-    executor = sym.simple_bind(ctx=mx.gpu(), data=batch_shape,
-                               grad_req='null', force_rebind=True)
-    executor.copy_params_from(arg_params, aux_params)
-    y = executor.forward(is_train=False, data=input_data)
-    trt_sym = sym.get_backend_symbol('TensorRT')
-    arg_params, aux_params = mx.contrib.tensorrt.init_tensorrt_params(trt_sym, 
arg_params, aux_params)
-    original_precision_value = mx.contrib.tensorrt.get_use_fp16()
-    try:
-        mx.contrib.tensorrt.set_use_fp16(True)
-        executor = trt_sym.simple_bind(ctx=mx.gpu(), data=batch_shape,
-                                       grad_req='null', force_rebind=True)
-        executor.copy_params_from(arg_params, aux_params)
-        y_trt = executor.forward(is_train=False, data=input_data)
-        mx.contrib.tensorrt.set_use_fp16(False)
-        executor = trt_sym.simple_bind(ctx=mx.gpu(), data=batch_shape,
-                                       grad_req='null', force_rebind=True)
-        executor.copy_params_from(arg_params, aux_params)
-        y_trt_fp32 = executor.forward(is_train=False, data=input_data)
-        no_trt_output = y[0].asnumpy()[0]
-        trt_output = y_trt[0].asnumpy()[0]
-        trt_fp32_output = y_trt_fp32[0].asnumpy()[0]
-        assert_almost_equal(no_trt_output, trt_output, 1e-1, 1e-2)
-        assert_almost_equal(no_trt_output, trt_fp32_output, 1e-4, 1e-4)
-    finally:
-        mx.contrib.tensorrt.set_use_fp16(original_precision_value)
-
-if __name__ == '__main__':
-    import nose
-    nose.runmodule()
diff --git a/tests/python/tensorrt/test_tensorrt_lenet5.py 
b/tests/python/tensorrt/test_tensorrt_lenet5.py
deleted file mode 100644
index 78f41ca..0000000
--- a/tests/python/tensorrt/test_tensorrt_lenet5.py
+++ /dev/null
@@ -1,121 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import os
-import numpy as np
-import mxnet as mx
-from ctypes.util import find_library
-
-def check_tensorrt_installation():
-    assert find_library('nvinfer') is not None, "Can't find the TensorRT 
shared library"
-
-def get_iters(mnist, batch_size):
-    """Get MNIST iterators."""
-    train_iter = mx.io.NDArrayIter(mnist['train_data'],
-                                   mnist['train_label'],
-                                   batch_size,
-                                   shuffle=True)
-    val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], 
batch_size)
-    test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], 
batch_size)
-    all_test_labels = np.array(mnist['test_label'])
-    return train_iter, val_iter, test_iter, all_test_labels
-
-def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, 
batch_size, use_tensorrt):
-    """Run inference with either MXNet or TensorRT"""
-
-    data_size = (batch_size,) + mnist['test_data'].shape[1:]
-    type_dict = {'data': 'float32', 'softmax_label': 'float32'}
-
-    if use_tensorrt:
-        _sym = sym.get_backend_symbol('TensorRT')
-        arg_params, aux_params = 
mx.contrib.tensorrt.init_tensorrt_params(_sym, arg_params,
-                                                                          
aux_params)
-    else:
-        _sym = sym
-    for k, v in arg_params.items():
-        type_dict[k] = v.dtype
-    for k, v in aux_params.items():
-        type_dict[k] = v.dtype
-    executor = _sym.simple_bind(ctx=mx.gpu(0),
-                                type_dict=type_dict,
-                                data=data_size,
-                                softmax_label=(batch_size,),
-                                grad_req='null',
-                                force_rebind=True)
-    executor.copy_params_from(arg_params, aux_params)
-
-    # Get this value from all_test_labels
-    # Also get classes from the dataset
-    num_ex = 10000
-    all_preds = np.zeros([num_ex, 10])
-    test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], 
batch_size)
-
-    example_ct = 0
-
-    for idx, dbatch in enumerate(test_iter):
-        executor.arg_dict["data"][:] = dbatch.data[0]
-        executor.forward(is_train=False)
-        offset = idx*batch_size
-        extent = batch_size if num_ex - offset > batch_size else num_ex - 
offset
-        all_preds[offset:offset+extent, :] = 
executor.outputs[0].asnumpy()[:extent]
-        example_ct += extent
-
-    all_preds = np.argmax(all_preds, axis=1)
-    matches = (all_preds[:example_ct] == all_test_labels[:example_ct]).sum()
-
-    percentage = 100.0 * matches / example_ct
-
-    return percentage
-
-
-def test_tensorrt_inference():
-    """Run LeNet-5 inference comparison between MXNet and TensorRT."""
-    check_tensorrt_installation()
-    mnist = mx.test_utils.get_mnist()
-    num_epochs = 10
-    batch_size = 128
-    model_name = 'lenet5'
-    model_dir = os.getenv("LENET_MODEL_DIR", "/tmp")
-    model_file = '%s/%s-symbol.json' % (model_dir, model_name)
-    params_file = '%s/%s-%04d.params' % (model_dir, model_name, num_epochs)
-
-    _, _, _, all_test_labels = get_iters(mnist, batch_size)
-
-    # Load serialized MXNet model (model-symbol.json + model-epoch.params)
-    sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, 
num_epochs)
-
-    print("LeNet-5 test")
-    print("Running inference in MXNet")
-    mx_pct = run_inference(sym, arg_params, aux_params, mnist, all_test_labels,
-                           batch_size=batch_size, use_tensorrt=False)
-
-    print("Running inference in MXNet-TensorRT")
-    trt_pct = run_inference(sym, arg_params, aux_params, mnist, 
all_test_labels,
-                            batch_size=batch_size, use_tensorrt=True)
-
-    print("MXNet accuracy: %f" % mx_pct)
-    print("MXNet-TensorRT accuracy: %f" % trt_pct)
-
-    absolute_accuracy_diff = abs(mx_pct - trt_pct)
-    epsilon = 3e-2
-    assert absolute_accuracy_diff < epsilon, \
-        """Absolute diff. between MXNet & TensorRT accuracy (%f) exceeds 
threshold (%f):
-           MXNet = %f, TensorRT = %f""" % (absolute_accuracy_diff, epsilon, 
mx_pct, trt_pct)
-
-if __name__ == '__main__':
-    import nose
-    nose.runmodule()
diff --git a/tests/python/unittest/test_extensions.py 
b/tests/python/unittest/test_extensions.py
index d00f149..9c62b7f 100644
--- a/tests/python/unittest/test_extensions.py
+++ b/tests/python/unittest/test_extensions.py
@@ -130,8 +130,6 @@ def test_subgraph():
     sym = mx.sym.log(d)
 
     args = {'a':mx.nd.ones((3,2),ctx=mx.cpu()), 
'b':mx.nd.ones((3,2),ctx=mx.cpu())}
-    arg_array = [mx.nd.ones((3,2),dtype='float32',ctx=mx.cpu()),
-                 mx.nd.ones((3,2),dtype='float32',ctx=mx.cpu())]
 
     # baseline - regular execution in MXNet
     exe = sym.bind(ctx=mx.cpu(), args=args)
@@ -147,14 +145,14 @@ def test_subgraph():
 
     # with propogating shapes/types, rejecting subgraph
     # this tests creating the subgraph and having the subgraph prop reject it
-    mysym2 = sym.optimize_for("myProp", arg_array, reject=True)
+    mysym2 = sym.optimize_for("myProp", args, reject=True)
     exe2 = mysym2.bind(ctx=mx.cpu(), args=args)
     out2 = exe2.forward()
     # check that result matches one executed by MXNet
     assert_almost_equal(out[0].asnumpy(), out2[0].asnumpy(), rtol=1e-3, 
atol=1e-3)
 
     # with propogating shapes/types
-    mysym3 = sym.optimize_for("myProp",arg_array)
+    mysym3 = sym.optimize_for("myProp",args)
     exe3 = mysym3.bind(ctx=mx.cpu(), args=args)
     out3 = exe3.forward()
     # check that result matches one executed by MXNet
diff --git a/tests/python/unittest/test_subgraph_op.py 
b/tests/python/unittest/test_subgraph_op.py
index e414a98..81665f2 100644
--- a/tests/python/unittest/test_subgraph_op.py
+++ b/tests/python/unittest/test_subgraph_op.py
@@ -327,18 +327,20 @@ def check_subgraph_exe8(sym, subgraph_backend, op_names):
     then bind and compare results of the partitioned sym and the original 
sym."""
     # bind
     arg_shapes, _, aux_shapes = sym.infer_shape()
-    arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes]
-    aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes]
-    exe1 = sym.bind(ctx=mx.current_context(), args=arg_array, 
aux_states=aux_array, grad_req='null')
+    arg_names = sym.list_arguments()
+    aux_names = sym.list_auxiliary_states()
+    arg_dict = {name:mx.nd.random.uniform(shape=shape) for name,shape in 
zip(arg_names,arg_shapes)}
+    aux_dict = {name:mx.nd.random.uniform(shape=shape) for name,shape in 
zip(aux_names,aux_shapes)}
+    exe1 = sym.bind(ctx=mx.current_context(), args=arg_dict, 
aux_states=aux_dict, grad_req='null')
     exe1.forward()
 
     # infer shape/type before partition before bind
     check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), 
mx_uint(len(op_names)),
-                                                 c_str_array(op_names)))
-    part_sym = sym.optimize_for(subgraph_backend, arg_array, aux_array)
+                                                   c_str_array(op_names)))
+    part_sym = sym.optimize_for(subgraph_backend, arg_dict, aux_dict)
     check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend)))
 
-    exe2 = part_sym.bind(ctx=mx.current_context(), args=arg_array, 
aux_states=aux_array, grad_req='null')
+    exe2 = part_sym.bind(ctx=mx.current_context(), args=arg_dict, 
aux_states=aux_dict, grad_req='null')
     exe2.forward()
     
     # compare outputs

Reply via email to