This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 77254f2 Infra to use tvm write op kernels (#15550)
77254f2 is described below
commit 77254f2a829d481d61e015e900923f1edb4af39f
Author: Yizhi Liu <[email protected]>
AuthorDate: Sun Jul 21 20:58:28 2019 -0700
Infra to use tvm write op kernels (#15550)
* intra to use tvm write op kernels
* add cmake support for tvm op
* fix header lint
* cleanup USE_TVM_OP logic in Makefile
* add doc, cmake def, etc.
* fix doc
* test rand shape
* add with_seed to test
* improve err msg. add #if
---
3rdparty/dlpack | 2 +-
3rdparty/dmlc-core | 2 +-
3rdparty/tvm | 2 +-
CMakeLists.txt | 23 ++++++
Makefile | 38 ++++++++++
cmake/BuildTVM.cmake | 135 ++++++++++++++++++++++++++++++++++
contrib/tvmop/__init__.py | 22 ++++++
contrib/tvmop/basic/__init__.py | 19 +++++
contrib/tvmop/basic/ufunc.py | 50 +++++++++++++
contrib/tvmop/compile.py | 59 +++++++++++++++
contrib/tvmop/opdef.py | 111 ++++++++++++++++++++++++++++
contrib/tvmop/prepare_tvm.sh | 63 ++++++++++++++++
contrib/tvmop/utils.py | 20 +++++
include/mxnet/c_api.h | 9 +++
include/mxnet/libinfo.h | 7 ++
make/config.mk | 3 +
make/osx.mk | 3 +
python/mxnet/base.py | 7 +-
python/mxnet/libinfo.py | 10 +--
src/c_api/c_api.cc | 9 +++
src/libinfo.cc | 4 +
src/operator/contrib/tvmop/ufunc.cc | 66 +++++++++++++++++
src/operator/tvmop/op_module.cc | 117 +++++++++++++++++++++++++++++
src/operator/tvmop/op_module.h | 63 ++++++++++++++++
tests/python/gpu/test_operator_gpu.py | 1 +
tests/python/unittest/test_tvm_op.py | 38 ++++++++++
26 files changed, 874 insertions(+), 9 deletions(-)
diff --git a/3rdparty/dlpack b/3rdparty/dlpack
index 10892ac..b90e939 160000
--- a/3rdparty/dlpack
+++ b/3rdparty/dlpack
@@ -1 +1 @@
-Subproject commit 10892ac964f1af7c81aae145cd3fab78bbccd297
+Subproject commit b90e939072066c160b18ea1e7156537b8d3710f6
diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core
index 3943914..f1ff6cc 160000
--- a/3rdparty/dmlc-core
+++ b/3rdparty/dmlc-core
@@ -1 +1 @@
-Subproject commit 3943914eed66470bd010df581e29e4dca4f7df6f
+Subproject commit f1ff6cc117f4e95169a9f62be549c8fe3e15c20f
diff --git a/3rdparty/tvm b/3rdparty/tvm
index 21935dc..afd4b3e 160000
--- a/3rdparty/tvm
+++ b/3rdparty/tvm
@@ -1 +1 @@
-Subproject commit 21935dcbf56ad3bd66ebff9891a6bc3865b8106d
+Subproject commit afd4b3e4450984358e9d79a7e8e578483cb7b017
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 19a93c7..7c479f7 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -42,6 +42,7 @@ mxnet_option(USE_MXNET_LIB_NAMING "Use MXNet library naming
conventions." ON)
mxnet_option(USE_GPROF "Compile with gprof (profiling) flag" OFF)
mxnet_option(USE_CXX14_IF_AVAILABLE "Build with C++14 if the compiler supports
it" OFF)
mxnet_option(USE_VTUNE "Enable use of Intel Amplifier XE (VTune)"
OFF) # one could set VTUNE_ROOT for search path
+mxnet_option(USE_TVM_OP "Enable use of TVM operator build system."
OFF)
mxnet_option(ENABLE_CUDA_RTC "Build with CUDA runtime compilation
support" ON)
mxnet_option(BUILD_CPP_EXAMPLES "Build cpp examples" ON)
mxnet_option(INSTALL_EXAMPLES "Install the example source files." OFF)
@@ -733,6 +734,28 @@ if(USE_DIST_KVSTORE)
list(APPEND mxnet_LINKER_LIBS ${pslite_LINKER_LIBS})
endif()
+if(USE_TVM_OP)
+ add_definitions(-DMXNET_USE_TVM_OP=1)
+ list(APPEND mxnet_LINKER_LIBS
${CMAKE_CURRENT_BINARY_DIR}/3rdparty/tvm/libtvm_runtime.so)
+ include(cmake/BuildTVM.cmake)
+ add_subdirectory("3rdparty/tvm")
+
+ if(NOT Python3_EXECUTABLE)
+ find_package(PythonInterp 3 REQUIRED)
+ set(Python3_EXECUTABLE ${PYTHON_EXECUTABLE} CACHE FILEPATH "Path to the
python3 executable")
+ if(NOT Python3_EXECUTABLE)
+ message(FATAL_ERROR "No python3 interpreter found to build TVM
operators")
+ endif()
+ endif()
+
+ add_custom_command(TARGET mxnet POST_BUILD
+ COMMAND ${CMAKE_COMMAND} -E env
+
PYTHONPATH="${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/python:${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/topi/python:${CMAKE_CURRENT_SOURCE_DIR}/contrib"
+ LD_LIBRARY_PATH="${CMAKE_CURRENT_BINARY_DIR}/3rdparty/tvm/build"
+ ${Python3_EXECUTABLE}
${CMAKE_CURRENT_SOURCE_DIR}/contrib/tvmop/compile.py
-o${CMAKE_CURRENT_BINARY_DIR}/libtvmop.so
+ )
+endif()
+
target_link_libraries(mxnet PUBLIC ${mxnet_LINKER_LIBS})
if(USE_PLUGINS_WARPCTC)
diff --git a/Makefile b/Makefile
index c5ac4eb..ce840a3 100644
--- a/Makefile
+++ b/Makefile
@@ -52,6 +52,14 @@ ifndef AMALGAMATION_PATH
AMALGAMATION_PATH = $(ROOTDIR)/amalgamation
endif
+ifndef TVM_PATH
+ TVM_PATH = $(TPARTYDIR)/tvm
+endif
+
+ifndef LLVM_PATH
+ LLVM_PATH = $(TVM_PATH)/build/llvm
+endif
+
ifneq ($(USE_OPENMP), 1)
export NO_OPENMP = 1
endif
@@ -589,6 +597,35 @@ $(DMLC_CORE)/libdmlc.a: DMLCCORE
DMLCCORE:
+ cd $(DMLC_CORE); $(MAKE) libdmlc.a USE_SSE=$(USE_SSE)
config=$(ROOTDIR)/$(config); cd $(ROOTDIR)
+ifeq ($(USE_TVM_OP), 1)
+LIB_DEP += lib/libtvm_runtime.so lib/libtvmop.so
+CFLAGS += -I$(TVM_PATH)/include -DMXNET_USE_TVM_OP=1
+LDFLAGS += -L$(TVM_PATH)/build -ltvm_runtime
+
+TVM_USE_CUDA := OFF
+ifeq ($(USE_CUDA), 1)
+ TVM_USE_CUDA := ON
+ ifneq ($(USE_CUDA_PATH), NONE)
+ TVM_USE_CUDA := $(USE_CUDA_PATH)
+ endif
+endif
+lib/libtvm_runtime.so:
+ echo "Compile TVM"
+ [ -e $(LLVM_PATH)/bin/llvm-config ] || sh
$(ROOTDIR)/contrib/tvmop/prepare_tvm.sh; \
+ cd $(TVM_PATH)/build; \
+ cmake -DUSE_LLVM="$(LLVM_PATH)/bin/llvm-config" \
+ -DUSE_SORT=OFF -DUSE_CUDA=$(TVM_USE_CUDA) -DUSE_CUDNN=OFF ..;
\
+ $(MAKE) VERBOSE=1; \
+ cp $(TVM_PATH)/build/libtvm_runtime.so
$(ROOTDIR)/lib/libtvm_runtime.so; \
+ cd $(ROOTDIR)
+
+lib/libtvmop.so: lib/libtvm_runtime.so $(wildcard contrib/tvmop/*/*.py
contrib/tvmop/*.py)
+ echo "Compile TVM operators"
+
PYTHONPATH=$(TVM_PATH)/python:$(TVM_PATH)/topi/python:$(ROOTDIR)/contrib:$PYTHONPATH
\
+ LD_LIBRARY_PATH=lib \
+ python3 $(ROOTDIR)/contrib/tvmop/compile.py -o
$(ROOTDIR)/lib/libtvmop.so
+endif
+
NNVM_INC = $(wildcard $(NNVM_PATH)/include/*/*.h)
NNVM_SRC = $(wildcard $(NNVM_PATH)/src/*/*/*.cc $(NNVM_PATH)/src/*/*.cc
$(NNVM_PATH)/src/*.cc)
$(NNVM_PATH)/lib/libnnvm.a: $(NNVM_INC) $(NNVM_SRC)
@@ -726,6 +763,7 @@ clean: rclean cyclean $(EXTRA_PACKAGES_CLEAN)
cd $(DMLC_CORE); $(MAKE) clean; cd -
cd $(PS_PATH); $(MAKE) clean; cd -
cd $(NNVM_PATH); $(MAKE) clean; cd -
+ cd $(TVM_PATH); $(MAKE) clean; cd -
cd $(AMALGAMATION_PATH); $(MAKE) clean; cd -
$(RM) -r $(patsubst %, %/*.d, $(EXTRA_OPERATORS)) $(patsubst %,
%/*/*.d, $(EXTRA_OPERATORS))
$(RM) -r $(patsubst %, %/*.o, $(EXTRA_OPERATORS)) $(patsubst %,
%/*/*.o, $(EXTRA_OPERATORS))
diff --git a/cmake/BuildTVM.cmake b/cmake/BuildTVM.cmake
new file mode 100644
index 0000000..ad8517c
--- /dev/null
+++ b/cmake/BuildTVM.cmake
@@ -0,0 +1,135 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+message(STATUS "Prepare external packages for TVM...")
+execute_process(COMMAND
"${CMAKE_CURRENT_SOURCE_DIR}/contrib/tvmop/prepare_tvm.sh")
+
+# Whether enable ROCM runtime
+#
+# Possible values:
+# - ON: enable ROCM with cmake's auto search
+# - OFF: disable ROCM
+# - /path/to/rocm: use specific path to rocm
+set(USE_ROCM OFF)
+
+# Whether enable SDAccel runtime
+set(USE_SDACCEL OFF)
+
+# Whether enable Intel FPGA SDK for OpenCL (AOCL) runtime
+set(USE_AOCL OFF)
+
+# Whether enable OpenCL runtime
+set(USE_OPENCL OFF)
+
+# Whether enable Metal runtime
+set(USE_METAL OFF)
+
+# Whether enable Vulkan runtime
+#
+# Possible values:
+# - ON: enable Vulkan with cmake's auto search
+# - OFF: disable vulkan
+# - /path/to/vulkan-sdk: use specific path to vulkan-sdk
+set(USE_VULKAN OFF)
+
+# Whether enable OpenGL runtime
+set(USE_OPENGL OFF)
+
+# Whether to enable SGX runtime
+#
+# Possible values for USE_SGX:
+# - /path/to/sgxsdk: path to Intel SGX SDK
+# - OFF: disable SGX
+#
+# SGX_MODE := HW|SIM
+set(USE_SGX OFF)
+set(SGX_MODE "SIM")
+set(RUST_SGX_SDK "/path/to/rust-sgx-sdk")
+
+# Whether enable RPC runtime
+set(USE_RPC ON)
+
+# Whether embed stackvm into the runtime
+set(USE_STACKVM_RUNTIME OFF)
+
+# Whether enable tiny embedded graph runtime.
+set(USE_GRAPH_RUNTIME ON)
+
+# Whether enable additional graph debug functions
+set(USE_GRAPH_RUNTIME_DEBUG OFF)
+
+# Whether build with LLVM support
+# Requires LLVM version >= 4.0
+#
+# Possible values:
+# - ON: enable llvm with cmake's find search
+# - OFF: disable llvm
+# - /path/to/llvm-config: enable specific LLVM when multiple llvm-dev is
available.
+set(USE_LLVM
"${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/build/llvm/bin/llvm-config")
+
+#---------------------------------------------
+# Contrib libraries
+#---------------------------------------------
+# Whether use BLAS, choices: openblas, mkl, atlas, apple
+set(USE_BLAS none)
+
+# /path/to/mkl: mkl root path when use mkl blas library
+# set(USE_MKL_PATH /opt/intel/mkl) for UNIX
+# set(USE_MKL_PATH ../IntelSWTools/compilers_and_libraries_2018/windows/mkl)
for WIN32
+set(USE_MKL_PATH none)
+
+# Whether use contrib.random in runtime
+set(USE_RANDOM OFF)
+
+# Whether use NNPack
+set(USE_NNPACK OFF)
+
+# Whether use CuDNN
+if(USE_CUDNN AND USE_CUDA)
+ detect_cuDNN()
+ if(HAVE_CUDNN)
+ set(USE_CUDNN ON)
+ else()
+ set(USE_CUDNN OFF)
+ endif()
+else()
+ set(USE_CUDNN OFF)
+endif()
+
+# Whether use cuBLAS
+set(USE_CUBLAS OFF)
+
+# Whether use MIOpen
+set(USE_MIOPEN OFF)
+
+# Whether use MPS
+set(USE_MPS OFF)
+
+# Whether use rocBlas
+set(USE_ROCBLAS OFF)
+
+# Whether use contrib sort
+set(USE_SORT OFF)
+
+# Build ANTLR parser for Relay text format
+set(USE_ANTLR OFF)
+
+# Build TSIM for VTA
+set(USE_VTA_TSIM OFF)
+
+# Whether use Relay debug mode
+set(USE_RELAY_DEBUG OFF)
diff --git a/contrib/tvmop/__init__.py b/contrib/tvmop/__init__.py
new file mode 100644
index 0000000..31189d4
--- /dev/null
+++ b/contrib/tvmop/__init__.py
@@ -0,0 +1,22 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+from .opdef import defop
+from .utils import AllTypes, RealTypes
+
+from . import basic
diff --git a/contrib/tvmop/basic/__init__.py b/contrib/tvmop/basic/__init__.py
new file mode 100644
index 0000000..fc0fa72
--- /dev/null
+++ b/contrib/tvmop/basic/__init__.py
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+from . import ufunc
diff --git a/contrib/tvmop/basic/ufunc.py b/contrib/tvmop/basic/ufunc.py
new file mode 100644
index 0000000..0419e5f
--- /dev/null
+++ b/contrib/tvmop/basic/ufunc.py
@@ -0,0 +1,50 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+import tvm
+from .. import defop, AllTypes
+
+def compute_add(dtype, ndim):
+ A = tvm.placeholder([tvm.var() for _ in range(ndim)], name='A',
dtype=dtype)
+ B = tvm.placeholder([tvm.var() for _ in range(ndim)], name='B',
dtype=dtype)
+ C = tvm.compute([tvm.var() for _ in range(ndim)],
+ lambda *index: A[index] + B[index], name='C')
+ s = tvm.create_schedule(C.op)
+ return s, A, B, C
+
+@defop(name="vadd", target="cpu", auto_broadcast=True,
+ dtype=AllTypes, ndim=list(range(1, 6)))
+def vadd(dtype, ndim):
+ s, A, B, C = compute_add(dtype, ndim)
+ axes = [axis for axis in C.op.axis]
+ fused = s[C].fuse(*axes)
+ s[C].parallel(fused)
+
+ return s, [A, B, C]
+
+@defop(name="cuda_vadd", target="cuda", auto_broadcast=True,
+ dtype=["float32", "float64"], ndim=list(range(1, 6)))
+def vadd_gpu(dtype, ndim):
+ s, A, B, C = compute_add(dtype, ndim)
+ s = tvm.create_schedule(C.op)
+ axes = [axis for axis in C.op.axis]
+ fused = s[C].fuse(*axes)
+ bx, tx = s[C].split(fused, factor=64)
+ s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
+ s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
+ return s, [A, B, C]
diff --git a/contrib/tvmop/compile.py b/contrib/tvmop/compile.py
new file mode 100644
index 0000000..94274fe
--- /dev/null
+++ b/contrib/tvmop/compile.py
@@ -0,0 +1,59 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""TVM Operator compile entry point"""
+import tvm
+
+import os
+import argparse
+from tvmop.opdef import __OP_DEF__
+
+def get_target(device):
+ if device == "cpu":
+ return "llvm"
+ elif device == "cuda" or device == "gpu":
+ return "cuda"
+ assert False, "Unknown device " + device
+
+
+if __name__ == "__main__":
+ import sys
+ sys.path.append(os.path.dirname(sys.path[0]))
+ parser = argparse.ArgumentParser(description="Generate tvm operators")
+ parser.add_argument("-o", action="store", required=True,
dest="target_path",
+ help="Target path which stores compiled library")
+ arguments = parser.parse_args()
+
+ func_list_llvm = []
+ func_list_cuda = []
+
+ # TODO: attach instruction features to the library, e.g., avx-512, etc.
+ for operator_def in __OP_DEF__:
+ for sch, args in operator_def.invoke_all():
+ if tvm.module.enabled(get_target(operator_def.target)):
+ func_list = func_list_llvm if operator_def.target == "cpu"
else func_list_cuda
+ func_lower = tvm.lower(sch, args,
+ name=operator_def.get_op_name(args),
+ binds=operator_def.get_binds(args))
+ func_list.append(func_lower)
+
+ lowered_funcs = {get_target("cpu") : func_list_llvm}
+ if len(func_list_cuda) > 0:
+ lowered_funcs[get_target("cuda")] = func_list_cuda
+ func_binary = tvm.build(lowered_funcs, name="tvmop")
+ func_binary.export_library(arguments.target_path)
diff --git a/contrib/tvmop/opdef.py b/contrib/tvmop/opdef.py
new file mode 100644
index 0000000..c658245
--- /dev/null
+++ b/contrib/tvmop/opdef.py
@@ -0,0 +1,111 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+import tvm
+from itertools import product
+
+__OP_DEF__ = []
+
+class OpDef:
+ """Specify the properties of an operator and
+ construct the value combination of the arguments
+ e.g., ldtype=["float32", "int32"], rdtype=["float16", "int16"],
+ then the argument combination is
+ [
+ {"ldtype": "float32", "rdtype": "float16"},
+ {"ldtype": "float32", "rdtype": "int16"},
+ {"ldtype": "int32", "rdtype": "float16"},
+ {"ldtype": "int32", "rdtype": "int16"},
+ ]
+
+ Parameters
+ ----------
+ func : function
+ The function to define the operator (in tvm compute and schedule).
+ It will get the argument combination extracted by this class.
+ name : str
+ function name.
+ target : str
+ {"cpu", "gpu", "cuda"}
+ auto_broadcast : bool
+ auto_broadcast=True allows one to implement broadcast computation
+ without considering whether dimension size equals to one.
+ TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape
equals 1.
+ """
+ def __init__(self, func, name, target, auto_broadcast, **kwargs):
+ # construct the value combination of the arguments
+ # e.g., ldtype=["float32", "int32"], rdtype=["float16", "int16"]
+ # arg_combination = [
+ # {"ldtype": "float32", "rdtype": "float16"},
+ # {"ldtype": "float32", "rdtype": "int16"},
+ # {"ldtype": "int32", "rdtype": "float16"},
+ # {"ldtype": "int32", "rdtype": "int16"},
+ # ]
+ args = [k for k in kwargs]
+ values = [kwargs[k] if isinstance(kwargs[k], (list, tuple)) else
[kwargs[k]]
+ for k in args]
+ cart_product = product(*values)
+ self.arg_combination = [{k: v for k, v in zip(args, comb_values)}
+ for comb_values in cart_product]
+ self.func = func
+ self.name = name
+ self.target = target
+ self.auto_broadcast = auto_broadcast
+
+ def __call__(self, *args, **kwargs):
+ return self.func(*args, **kwargs)
+
+ def invoke_all(self):
+ for each_kwargs in self.arg_combination:
+ yield self.func(**each_kwargs)
+
+ def get_op_name(self, args):
+ return self.name + ''.join(["%s_%d" % (arg.dtype, len(arg.shape)) for
arg in args])
+
+ def get_binds(self, args):
+ if self.auto_broadcast:
+ return {arg: tvm.decl_buffer(arg.shape, arg.dtype,
buffer_type="auto_broadcast")
+ for arg in args}
+ return None
+
+
+def defop(name, target=None, auto_broadcast=False, **kwargs):
+ """Decorator to define a tvm operator.
+ Parameters
+ ----------
+ name : str
+ function name
+ target : str
+ {"cpu", "gpu", "cuda"}
+ auto_broadcast : bool
+ auto_broadcast=True allows one to implement broadcast computation
+ without considering whether dimension size equals to one.
+ TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape
equals 1.
+ Returns
+ -------
+ fdef : function
+ A wrapped operator definition function, which returns (schedule,
[tensors])
+ """
+ assert name is not None and len(name) > 0
+ target = "cpu" if target is None else target
+ def _defop(func):
+ opdef = OpDef(func, name, target, auto_broadcast, **kwargs)
+ __OP_DEF__.append(opdef)
+ return opdef
+ return _defop
+
diff --git a/contrib/tvmop/prepare_tvm.sh b/contrib/tvmop/prepare_tvm.sh
new file mode 100644
index 0000000..7ebe256
--- /dev/null
+++ b/contrib/tvmop/prepare_tvm.sh
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#!/bin/sh
+
+LLVM_VERSION="8.0.0"
+LLVM_ROOT="http://releases.llvm.org/${LLVM_VERSION}/"
+LLVM_PKG="clang+llvm-${LLVM_VERSION}-x86_64-linux-gnu"
+
+os=`uname`
+if [ "$os" = "Linux" ] && [ "$(arch)" = "x86_64" ]; then
+ DISTRIB_ID=$(cat /etc/*-release | grep DISTRIB_ID | sed 's/DISTRIB_ID=//g'
| tr '[:upper:]' '[:lower:]')
+ DISTRIB_RELEASE=$(cat /etc/*-release | grep DISTRIB_RELEASE | sed
's/DISTRIB_RELEASE=//g' | tr '[:upper:]' '[:lower:]')
+ if [ "$DISTRIB_ID" = "ubuntu" ]; then
+ LLVM_PKG=${LLVM_PKG}-${DISTRIB_ID}-${DISTRIB_RELEASE}
+ else
+ echo "Downloading LLVM only supports Ubuntu. Please download manually."
+ exit 1
+ fi
+else
+ echo "Cannot identify operating system. Try downloading LLVM manually."
+ exit 1
+fi
+
+LLVM_URL=${LLVM_ROOT}${LLVM_PKG}.tar.xz
+
+TVM_PATH=`dirname $0`/../../3rdparty/tvm
+DST=${TVM_PATH}/build
+rm -rf $DST
+mkdir -p $DST
+DST=`cd $DST; pwd`
+
+if [ -x "$(command -v curl)" ]; then
+ curl -L -o "${DST}/${LLVM_PKG}.tar.xz" "$LLVM_URL"
+elif [ -x "$(command -v wget)" ]; then
+ wget -O "${DST}/${LLVM_PKG}.tar.xz" "$LLVM_URL"
+else
+ echo "curl or wget not available"
+ exit 1
+fi
+
+if [ \! $? ]; then
+ echo "Download from $LLVM_URL to $DST failed"
+ exit 1
+fi
+
+tar -xf "$DST/${LLVM_PKG}.tar.xz" -C $DST
+mv $DST/${LLVM_PKG} $DST/llvm
+echo "Downloaded and unpacked LLVM libraries to $DST"
diff --git a/contrib/tvmop/utils.py b/contrib/tvmop/utils.py
new file mode 100644
index 0000000..0b2416b
--- /dev/null
+++ b/contrib/tvmop/utils.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+AllTypes = ["float32", "float64", "float16", "uint8", "int8", "int32", "int64"]
+RealTypes = ["float32", "float64", "float16"]
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index bd30e44..058f859 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -506,6 +506,15 @@ MXNET_DLL int MXGetGPUMemoryInformation64(int dev,
uint64_t *free_mem, uint64_t
*/
MXNET_DLL int MXGetVersion(int *out);
+/*!
+ * \brief Load TVM operator from the binary library
+ * \param libpath TVM operators lib file
+ * \return 0 when success, -1 when failure happens
+ */
+#if MXNET_USE_TVM_OP
+MXNET_DLL int MXLoadTVMOp(const char *libpath);
+#endif // MXNET_USE_TVM_OP
+
//-------------------------------------
// Part 1: NDArray creation and deletion
diff --git a/include/mxnet/libinfo.h b/include/mxnet/libinfo.h
index 8b58a39..1972688 100644
--- a/include/mxnet/libinfo.h
+++ b/include/mxnet/libinfo.h
@@ -127,6 +127,10 @@
#define MXNET_USE_INT64_TENSOR_SIZE MSHADOW_INT64_TENSOR_SIZE
#endif
+#ifndef MXNET_USE_TVM_OP
+#define MXNET_USE_TVM_OP 0
+#endif
+
namespace mxnet {
namespace features {
// Check compile flags such as CMakeLists.txt
@@ -185,6 +189,9 @@ enum : unsigned {
SIGNAL_HANDLER,
DEBUG,
+ // TVM operator
+ TVM_OP,
+
// size indicator
MAX_FEATURES
};
diff --git a/make/config.mk b/make/config.mk
index 4bddb8b..982d15b 100644
--- a/make/config.mk
+++ b/make/config.mk
@@ -62,6 +62,9 @@ ADD_LDFLAGS =
# the additional compile flags you want to add
ADD_CFLAGS =
+# whether to build operators written in TVM
+USE_TVM_OP = 0
+
#---------------------------------------------
# matrix computation libraries for CPU/GPU
#---------------------------------------------
diff --git a/make/osx.mk b/make/osx.mk
index 0b5842e..25f3ba6 100644
--- a/make/osx.mk
+++ b/make/osx.mk
@@ -53,6 +53,9 @@ ADD_LDFLAGS =
# the additional compile flags you want to add
ADD_CFLAGS =
+# whether to build operators written in TVM
+USE_TVM_OP = 0
+
#---------------------------------------------
# matrix computation libraries for CPU/GPU
#---------------------------------------------
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 73fae48..bf80263 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -16,7 +16,7 @@
# under the License.
# coding: utf-8
-# pylint: disable=invalid-name, no-member, trailing-comma-tuple,
bad-mcs-classmethod-argument, unnecessary-pass
+# pylint: disable=invalid-name, no-member, trailing-comma-tuple,
bad-mcs-classmethod-argument, unnecessary-pass, wrong-import-position
"""ctypes library of mxnet and helper functions."""
from __future__ import absolute_import
@@ -734,3 +734,8 @@ def _generate_op_module_signature(root_namespace,
module_name, op_code_gen_func)
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
+
+from .runtime import Features
+if Features().is_enabled("TVM_OP"):
+ _LIB_TVM_OP = libinfo.find_lib_path("libtvmop")
+ check_call(_LIB.MXLoadTVMOp(c_str(_LIB_TVM_OP[0])))
diff --git a/python/mxnet/libinfo.py b/python/mxnet/libinfo.py
index ff795f9..fb0859b 100644
--- a/python/mxnet/libinfo.py
+++ b/python/mxnet/libinfo.py
@@ -23,7 +23,7 @@ import platform
import logging
-def find_lib_path():
+def find_lib_path(prefix='libmxnet'):
"""Find MXNet dynamic library files.
Returns
@@ -61,13 +61,13 @@ def find_lib_path():
dll_path[0:0] = [p.strip() for p in
os.environ['LD_LIBRARY_PATH'].split(":")]
if os.name == 'nt':
os.environ['PATH'] = os.path.dirname(__file__) + ';' +
os.environ['PATH']
- dll_path = [os.path.join(p, 'libmxnet.dll') for p in dll_path]
+ dll_path = [os.path.join(p, prefix + '.dll') for p in dll_path]
elif platform.system() == 'Darwin':
- dll_path = [os.path.join(p, 'libmxnet.dylib') for p in dll_path] + \
- [os.path.join(p, 'libmxnet.so') for p in dll_path]
+ dll_path = [os.path.join(p, prefix + '.dylib') for p in dll_path] + \
+ [os.path.join(p, prefix + '.so') for p in dll_path]
else:
dll_path.append('../../../')
- dll_path = [os.path.join(p, 'libmxnet.so') for p in dll_path]
+ dll_path = [os.path.join(p, prefix + '.so') for p in dll_path]
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
if len(lib_path) == 0:
raise RuntimeError('Cannot find the MXNet library.\n' +
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 35bd3ee..5207bdf 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -48,6 +48,7 @@
#include "./c_api_common.h"
#include "../operator/custom/custom-inl.h"
#include "../operator/tensor/matrix_op-inl.h"
+#include "../operator/tvmop/op_module.h"
#include "../common/utils.h"
using namespace mxnet;
@@ -159,6 +160,14 @@ int MXGetVersion(int *out) {
API_END();
}
+#if MXNET_USE_TVM_OP
+int MXLoadTVMOp(const char *libpath) {
+ API_BEGIN();
+ tvm::runtime::TVMOpModule::Get()->Load(libpath);
+ API_END();
+}
+#endif // MXNET_USE_TVM_OP
+
int MXNDArrayCreateNone(NDArrayHandle *out) {
API_BEGIN();
*out = new NDArray();
diff --git a/src/libinfo.cc b/src/libinfo.cc
index f67b45e..b31d7e4 100644
--- a/src/libinfo.cc
+++ b/src/libinfo.cc
@@ -89,6 +89,9 @@ class FeatureSet {
feature_bits.set(INT64_TENSOR_SIZE, MXNET_USE_INT64_TENSOR_SIZE);
feature_bits.set(SIGNAL_HANDLER, MXNET_USE_SIGNAL_HANDLER);
+ // TVM operators
+ feature_bits.set(TVM_OP, MXNET_USE_TVM_OP);
+
#ifndef NDEBUG
feature_bits.set(DEBUG);
#endif
@@ -159,6 +162,7 @@ const std::vector<std::string> EnumNames::names = {
"INT64_TENSOR_SIZE",
"SIGNAL_HANDLER",
"DEBUG",
+ "TVM_OP",
};
} // namespace features
diff --git a/src/operator/contrib/tvmop/ufunc.cc
b/src/operator/contrib/tvmop/ufunc.cc
new file mode 100644
index 0000000..faba671
--- /dev/null
+++ b/src/operator/contrib/tvmop/ufunc.cc
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file ufunc.cc
+ * \brief
+ * \author Yizhi Liu
+ */
+#ifdef MXNET_USE_TVM_OP
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <mxnet/base.h>
+#include "../../tensor/elemwise_binary_broadcast_op.h"
+#include "../../tvmop/op_module.h"
+#include "../../tensor/elemwise_binary_op.h"
+
+namespace mxnet {
+namespace op {
+
+static constexpr char func_vadd_cpu[] = "vadd";
+static constexpr char func_vadd_gpu[] = "cuda_vadd";
+
+template<const char* func>
+void TVMBroadcastCompute(const nnvm::NodeAttrs& attrs,
+ const mxnet::OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ CHECK_EQ(inputs.size(), 2U);
+ CHECK_EQ(outputs.size(), 1U);
+ tvm::runtime::TVMOpModule::Get()->Call(func, ctx, {inputs[0], inputs[1],
outputs[0]});
+}
+
+NNVM_REGISTER_OP(_contrib_tvm_vadd)
+ .set_num_inputs(2)
+ .set_num_outputs(1)
+ .add_argument("a", "NDArray-or-Symbol", "first input")
+ .add_argument("b", "NDArray-or-Symbol", "second input")
+ .set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)
+ .set_attr<nnvm::FInferType>("FInferType", mxnet::op::ElemwiseType<2, 1>)
+ .set_attr<mxnet::FCompute>("FCompute<cpu>",
mxnet::op::TVMBroadcastCompute<func_vadd_cpu>)
+#if MXNET_USE_CUDA
+ .set_attr<mxnet::FCompute>("FCompute<gpu>",
mxnet::op::TVMBroadcastCompute<func_vadd_gpu>);
+#endif // MXNET_USE_CUDA
+
+} // namespace op
+} // namespace mxnet
+#endif // MXNET_USE_TVM_OP
diff --git a/src/operator/tvmop/op_module.cc b/src/operator/tvmop/op_module.cc
new file mode 100644
index 0000000..d1d1c1d
--- /dev/null
+++ b/src/operator/tvmop/op_module.cc
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file op_module.cc
+ * \brief Invoke registered TVM operators.
+ * \author Yizhi Liu
+ */
+#if MXNET_USE_TVM_OP
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <string>
+#include <vector>
+#include "op_module.h"
+
+using namespace tvm::runtime;
+
+namespace tvm {
+namespace runtime {
+
+void TVMOpModule::Load(const std::string &filepath) {
+ static const PackedFunc *f_load = Registry::Get("module._LoadFromFile");
+ std::lock_guard<std::mutex> lock(mutex_);
+ Module module = (*f_load)(filepath, "");
+ module_ptr_ = std::make_shared<Module>();
+ *module_ptr_ = module;
+}
+
+PackedFunc GetFunction(const std::shared_ptr<Module> &module,
+ const std::string &op_name,
+ const std::vector<mxnet::TBlob> &args) {
+ std::ostringstream func_name;
+ func_name << op_name;
+ for (const auto &arg : args) {
+ switch (arg.type_flag_) {
+ case mshadow::kFloat32:
+ func_name << "float32";
+ break;
+ case mshadow::kFloat64:
+ func_name << "float64";
+ break;
+ case mshadow::kFloat16:
+ func_name << "float16";
+ break;
+ case mshadow::kUint8:
+ func_name << "uint8";
+ break;
+ case mshadow::kInt32:
+ func_name << "int32";
+ break;
+ case mshadow::kInt8:
+ func_name << "int8";
+ break;
+ case mshadow::kInt64:
+ func_name << "int64";
+ break;
+ default:
+ LOG(FATAL) << "Unknown dtype " << arg.type_flag_;
+ }
+ func_name << "_" << arg.shape_.ndim();
+ }
+ return module->GetFunction(func_name.str(), false);
+}
+
+void TVMOpModule::Call(const std::string &func_name,
+ const mxnet::OpContext& ctx,
+ const std::vector<mxnet::TBlob> &args) {
+ std::vector<int> type_codes;
+ std::vector<TVMValue> values;
+
+ type_codes.resize(args.size());
+ values.resize(args.size());
+ for (size_t i = 0; i < args.size(); ++i) {
+ type_codes[i] = kArrayHandle;
+ values[i].v_handle = const_cast<DLTensor *>(&(args[i].dltensor()));
+ }
+
+ TVMArgs tvm_args(&values[0], &type_codes[0], args.size());
+ TVMRetValue rv;
+
+#if MXNET_USE_CUDA
+ int dev_type = (ctx.run_ctx.ctx.dev_type ==
mxnet::Context::DeviceType::kGPU) ? kDLGPU : kDLCPU;
+ int dev_id = ctx.run_ctx.ctx.dev_id;
+ if (dev_type == kDLGPU) {
+ void *stream = static_cast<void
*>(ctx.run_ctx.get_stream<mxnet::gpu>()->stream_);
+ TVMSetStream(dev_type, dev_id, stream);
+ }
+#endif
+ GetFunction(module_ptr_, func_name, args).CallPacked(tvm_args, &rv);
+#if MXNET_USE_CUDA
+ if (dev_type == kDLGPU) {
+ TVMSetStream(dev_type, dev_id, nullptr);
+ }
+#endif
+}
+
+} // namespace runtime
+} // namespace tvm
+#endif // MXNET_USE_TVM_OP
diff --git a/src/operator/tvmop/op_module.h b/src/operator/tvmop/op_module.h
new file mode 100644
index 0000000..04e97ef
--- /dev/null
+++ b/src/operator/tvmop/op_module.h
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file op_module.h
+ * \brief Invoke registered TVM operators.
+ * \author Yizhi Liu
+ */
+#ifndef MXNET_OPERATOR_TVMOP_OP_MODULE_H_
+#define MXNET_OPERATOR_TVMOP_OP_MODULE_H_
+
+#if MXNET_USE_TVM_OP
+#include <mxnet/base.h>
+#include <mxnet/op_attr_types.h>
+#include <mutex>
+#include <string>
+#include <vector>
+
+namespace tvm {
+namespace runtime {
+
+class Module;
+class TVMOpModule {
+ public:
+ // Load TVM operators binary
+ void Load(const std::string& filepath);
+
+ void Call(const std::string& func_name,
+ const mxnet::OpContext& ctx,
+ const std::vector<mxnet::TBlob>& args);
+
+ static TVMOpModule *Get() {
+ static TVMOpModule inst;
+ return &inst;
+ }
+
+ private:
+ std::mutex mutex_;
+ std::shared_ptr<Module> module_ptr_;
+};
+
+} // namespace runtime
+} // namespace tvm
+
+#endif // MXNET_USE_TVM_OP
+#endif // MXNET_OPERATOR_TVMOP_OP_MODULE_H_
diff --git a/tests/python/gpu/test_operator_gpu.py
b/tests/python/gpu/test_operator_gpu.py
index f9814ab..91ba9fb 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -44,6 +44,7 @@ from test_sparse_operator import *
from test_ndarray import *
from test_subgraph_op import *
from test_contrib_operator import test_multibox_target_op
+from test_tvm_op import *
set_default_context(mx.gpu(0))
del test_support_vector_machine_l1_svm # noqa
diff --git a/tests/python/unittest/test_tvm_op.py
b/tests/python/unittest/test_tvm_op.py
new file mode 100644
index 0000000..3ab2a25
--- /dev/null
+++ b/tests/python/unittest/test_tvm_op.py
@@ -0,0 +1,38 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from mxnet.test_utils import same, rand_shape_nd
+from mxnet.runtime import Features
+from common import with_seed
+
+_features = Features()
+
+@with_seed()
+def test_tvm_broadcast_add():
+ if _features.is_enabled("TVM_OP"):
+ a_shape = rand_shape_nd(4)
+ b_shape = (1,) + a_shape[1:2] + (1, 1)
+ a = mx.nd.normal(shape=a_shape)
+ b = mx.nd.normal(shape=b_shape)
+ c = mx.nd.contrib.tvm_vadd(a, b)
+ c_np = a.asnumpy() + b.asnumpy()
+ assert same(c.asnumpy(), c_np)
+
+if __name__ == '__main__':
+ import nose
+ nose.runmodule()