This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 1240577 Use feature detection in base.py (#11803)
1240577 is described below
commit 1240577dd629e97a740d65742a5008a6fc48a380
Author: cclauss <[email protected]>
AuthorDate: Wed Jul 18 22:02:00 2018 +0200
Use feature detection in base.py (#11803)
---
python/mxnet/base.py | 63 ++++++++++++++++++++++++++++++++++++++--------------
1 file changed, 46 insertions(+), 17 deletions(-)
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 0fb73b3..4df794b 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -20,44 +20,59 @@
"""ctypes library of mxnet and helper functions."""
from __future__ import absolute_import
+import atexit
+import ctypes
+import inspect
import os
import sys
-import ctypes
-import atexit
import warnings
-import inspect
+
import numpy as np
+
from . import libinfo
+
warnings.filterwarnings('default', category=DeprecationWarning)
__all__ = ['MXNetError']
#----------------------------
# library loading
#----------------------------
-if sys.version_info[0] == 3:
- string_types = str,
- numeric_types = (float, int, np.generic)
- integer_types = (int, np.int32, np.int64)
+
+# pylint: disable=pointless-statement
+try:
+ basestring
+ long
+except NameError:
+ basestring = str
+ long = int
+# pylint: enable=pointless-statement
+
+integer_types = (int, long, np.int32, np.int64)
+numeric_types = (float, int, long, np.generic)
+string_types = basestring,
+
+if sys.version_info[0] > 2:
# this function is needed for python3
# to convert ctypes.char_p .value back to python str
py_str = lambda x: x.decode('utf-8')
else:
- string_types = basestring,
- numeric_types = (float, int, long, np.generic)
- integer_types = (int, long, np.int32, np.int64)
py_str = lambda x: x
+
class _NullType(object):
"""Placeholder for arguments"""
def __repr__(self):
return '_Null'
+
_Null = _NullType()
+
class MXNetError(Exception):
"""Error that will be throwed by all mxnet functions."""
pass
+
class NotImplementedForSymbol(MXNetError):
"""Error: Not implemented for symbol"""
def __init__(self, function, alias, *args):
@@ -65,6 +80,7 @@ class NotImplementedForSymbol(MXNetError):
self.function = function.__name__
self.alias = alias
self.args = [str(type(a)) for a in args]
+
def __str__(self):
msg = 'Function {}'.format(self.function)
if self.alias:
@@ -74,6 +90,7 @@ class NotImplementedForSymbol(MXNetError):
msg += ' is not implemented for Symbol and only available in NDArray.'
return msg
+
class NotSupportedForSparseNDArray(MXNetError):
"""Error: Not supported for SparseNDArray"""
def __init__(self, function, alias, *args):
@@ -81,6 +98,7 @@ class NotSupportedForSparseNDArray(MXNetError):
self.function = function.__name__
self.alias = alias
self.args = [str(type(a)) for a in args]
+
def __str__(self):
msg = 'Function {}'.format(self.function)
if self.alias:
@@ -90,6 +108,7 @@ class NotSupportedForSparseNDArray(MXNetError):
msg += ' is not supported for SparseNDArray and only available in
NDArray.'
return msg
+
class MXCallbackList(ctypes.Structure):
"""Structure that holds Callback information. Passed to CustomOpProp."""
_fields_ = [
@@ -98,6 +117,7 @@ class MXCallbackList(ctypes.Structure):
('contexts', ctypes.POINTER(ctypes.c_void_p))
]
+
# Please see:
https://stackoverflow.com/questions/5189699/how-to-make-a-class-property
class _MXClassPropertyDescriptor(object):
def __init__(self, fget, fset=None):
@@ -125,6 +145,7 @@ class _MXClassPropertyDescriptor(object):
self.fset = func
return self
+
class _MXClassPropertyMetaClass(type):
def __setattr__(cls, key, value):
if key in cls.__dict__:
@@ -134,8 +155,9 @@ class _MXClassPropertyMetaClass(type):
return super(_MXClassPropertyMetaClass, cls).__setattr__(key, value)
+
# with_metaclass function obtained from:
https://github.com/benjaminp/six/blob/master/six.py
-#pylint: disable=unused-argument
+# pylint: disable=unused-argument
def with_metaclass(meta, *bases):
"""Create a base class with a metaclass."""
# This requires a bit of explanation: the basic idea is to make a dummy
@@ -150,7 +172,8 @@ def with_metaclass(meta, *bases):
def __prepare__(cls, name, this_bases):
return meta.__prepare__(name, bases)
return type.__new__(metaclass, 'temporary_class', (), {})
-#pylint: enable=unused-argument
+# pylint: enable=unused-argument
+
def classproperty(func):
if not isinstance(func, (classmethod, staticmethod)):
@@ -159,7 +182,6 @@ def classproperty(func):
return _MXClassPropertyDescriptor(func)
-
def _load_lib():
"""Load library by searching possible path."""
lib_path = libinfo.find_lib_path()
@@ -168,6 +190,7 @@ def _load_lib():
lib.MXGetLastError.restype = ctypes.c_char_p
return lib
+
# version number
__version__ = libinfo.__version__
# library instance of mxnet
@@ -192,6 +215,8 @@ RtcHandle = ctypes.c_void_p
CudaModuleHandle = ctypes.c_void_p
CudaKernelHandle = ctypes.c_void_p
ProfileHandle = ctypes.c_void_p
+
+
#----------------------------
# helper function definition
#----------------------------
@@ -346,6 +371,7 @@ def c_array_buf(ctype, buf):
"""
return (ctype * len(buf)).from_buffer(buf)
+
def c_handle_array(objs):
"""Create ctypes const void ** from a list of MXNet objects with handles.
@@ -363,6 +389,7 @@ def c_handle_array(objs):
arr[:] = [o.handle for o in objs]
return arr
+
def ctypes2buffer(cptr, length):
"""Convert ctypes pointer to buffer type.
@@ -386,6 +413,7 @@ def ctypes2buffer(cptr, length):
raise RuntimeError('memmove failed')
return res
+
def ctypes2numpy_shared(cptr, shape):
"""Convert a ctypes pointer to a numpy array.
@@ -456,6 +484,7 @@ def _notify_shutdown():
"""Notify MXNet about a shutdown."""
check_call(_LIB.MXNotifyShutdown())
+
atexit.register(_notify_shutdown)
@@ -585,7 +614,6 @@ def _init_op_module(root_namespace, module_name,
make_op_func):
setattr(cur_module, function.__name__, function)
cur_module.__all__.append(function.__name__)
-
if op_name_prefix == '_contrib_':
hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
@@ -616,17 +644,18 @@ def _generate_op_module_signature(root_namespace,
module_name, op_code_gen_func)
"""Return the generated module file based on module name."""
path = os.path.dirname(__file__)
module_path = module_name.split('.')
- module_path[-1] = 'gen_'+module_path[-1]
+ module_path[-1] = 'gen_' + module_path[-1]
file_name = os.path.join(path, '..', *module_path) + '.py'
module_file = open(file_name, 'w')
dependencies = {'symbol': ['from ._internal import SymbolBase',
'from ..base import _Null'],
'ndarray': ['from ._internal import NDArrayBase',
'from ..base import _Null']}
- module_file.write('# File content is auto-generated. Do not
modify.'+os.linesep)
- module_file.write('# pylint: skip-file'+os.linesep)
+ module_file.write('# File content is auto-generated. Do not modify.' +
os.linesep)
+ module_file.write('# pylint: skip-file' + os.linesep)
module_file.write(os.linesep.join(dependencies[module_name.split('.')[1]]))
return module_file
+
def write_all_str(module_file, module_all_list):
"""Write the proper __all__ based on available operators."""
module_file.write(os.linesep)