szha closed pull request #11803: Use feature detection instead of version
detection
URL: https://github.com/apache/incubator-mxnet/pull/11803
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 0fb73b3c7dd..4df794bdfe3 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -20,44 +20,59 @@
"""ctypes library of mxnet and helper functions."""
from __future__ import absolute_import
+import atexit
+import ctypes
+import inspect
import os
import sys
-import ctypes
-import atexit
import warnings
-import inspect
+
import numpy as np
+
from . import libinfo
+
warnings.filterwarnings('default', category=DeprecationWarning)
__all__ = ['MXNetError']
#----------------------------
# library loading
#----------------------------
-if sys.version_info[0] == 3:
- string_types = str,
- numeric_types = (float, int, np.generic)
- integer_types = (int, np.int32, np.int64)
+
+# pylint: disable=pointless-statement
+try:
+ basestring
+ long
+except NameError:
+ basestring = str
+ long = int
+# pylint: enable=pointless-statement
+
+integer_types = (int, long, np.int32, np.int64)
+numeric_types = (float, int, long, np.generic)
+string_types = basestring,
+
+if sys.version_info[0] > 2:
# this function is needed for python3
# to convert ctypes.char_p .value back to python str
py_str = lambda x: x.decode('utf-8')
else:
- string_types = basestring,
- numeric_types = (float, int, long, np.generic)
- integer_types = (int, long, np.int32, np.int64)
py_str = lambda x: x
+
class _NullType(object):
"""Placeholder for arguments"""
def __repr__(self):
return '_Null'
+
_Null = _NullType()
+
class MXNetError(Exception):
"""Error that will be throwed by all mxnet functions."""
pass
+
class NotImplementedForSymbol(MXNetError):
"""Error: Not implemented for symbol"""
def __init__(self, function, alias, *args):
@@ -65,6 +80,7 @@ def __init__(self, function, alias, *args):
self.function = function.__name__
self.alias = alias
self.args = [str(type(a)) for a in args]
+
def __str__(self):
msg = 'Function {}'.format(self.function)
if self.alias:
@@ -74,6 +90,7 @@ def __str__(self):
msg += ' is not implemented for Symbol and only available in NDArray.'
return msg
+
class NotSupportedForSparseNDArray(MXNetError):
"""Error: Not supported for SparseNDArray"""
def __init__(self, function, alias, *args):
@@ -81,6 +98,7 @@ def __init__(self, function, alias, *args):
self.function = function.__name__
self.alias = alias
self.args = [str(type(a)) for a in args]
+
def __str__(self):
msg = 'Function {}'.format(self.function)
if self.alias:
@@ -90,6 +108,7 @@ def __str__(self):
msg += ' is not supported for SparseNDArray and only available in
NDArray.'
return msg
+
class MXCallbackList(ctypes.Structure):
"""Structure that holds Callback information. Passed to CustomOpProp."""
_fields_ = [
@@ -98,6 +117,7 @@ class MXCallbackList(ctypes.Structure):
('contexts', ctypes.POINTER(ctypes.c_void_p))
]
+
# Please see:
https://stackoverflow.com/questions/5189699/how-to-make-a-class-property
class _MXClassPropertyDescriptor(object):
def __init__(self, fget, fset=None):
@@ -125,6 +145,7 @@ def setter(self, func):
self.fset = func
return self
+
class _MXClassPropertyMetaClass(type):
def __setattr__(cls, key, value):
if key in cls.__dict__:
@@ -134,8 +155,9 @@ def __setattr__(cls, key, value):
return super(_MXClassPropertyMetaClass, cls).__setattr__(key, value)
+
# with_metaclass function obtained from:
https://github.com/benjaminp/six/blob/master/six.py
-#pylint: disable=unused-argument
+# pylint: disable=unused-argument
def with_metaclass(meta, *bases):
"""Create a base class with a metaclass."""
# This requires a bit of explanation: the basic idea is to make a dummy
@@ -150,7 +172,8 @@ def __new__(cls, name, this_bases, d):
def __prepare__(cls, name, this_bases):
return meta.__prepare__(name, bases)
return type.__new__(metaclass, 'temporary_class', (), {})
-#pylint: enable=unused-argument
+# pylint: enable=unused-argument
+
def classproperty(func):
if not isinstance(func, (classmethod, staticmethod)):
@@ -159,7 +182,6 @@ def classproperty(func):
return _MXClassPropertyDescriptor(func)
-
def _load_lib():
"""Load library by searching possible path."""
lib_path = libinfo.find_lib_path()
@@ -168,6 +190,7 @@ def _load_lib():
lib.MXGetLastError.restype = ctypes.c_char_p
return lib
+
# version number
__version__ = libinfo.__version__
# library instance of mxnet
@@ -192,6 +215,8 @@ def _load_lib():
CudaModuleHandle = ctypes.c_void_p
CudaKernelHandle = ctypes.c_void_p
ProfileHandle = ctypes.c_void_p
+
+
#----------------------------
# helper function definition
#----------------------------
@@ -346,6 +371,7 @@ def c_array_buf(ctype, buf):
"""
return (ctype * len(buf)).from_buffer(buf)
+
def c_handle_array(objs):
"""Create ctypes const void ** from a list of MXNet objects with handles.
@@ -363,6 +389,7 @@ def c_handle_array(objs):
arr[:] = [o.handle for o in objs]
return arr
+
def ctypes2buffer(cptr, length):
"""Convert ctypes pointer to buffer type.
@@ -386,6 +413,7 @@ def ctypes2buffer(cptr, length):
raise RuntimeError('memmove failed')
return res
+
def ctypes2numpy_shared(cptr, shape):
"""Convert a ctypes pointer to a numpy array.
@@ -456,6 +484,7 @@ def _notify_shutdown():
"""Notify MXNet about a shutdown."""
check_call(_LIB.MXNotifyShutdown())
+
atexit.register(_notify_shutdown)
@@ -585,7 +614,6 @@ def _init_op_module(root_namespace, module_name,
make_op_func):
setattr(cur_module, function.__name__, function)
cur_module.__all__.append(function.__name__)
-
if op_name_prefix == '_contrib_':
hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
@@ -616,17 +644,18 @@ def get_module_file(module_name):
"""Return the generated module file based on module name."""
path = os.path.dirname(__file__)
module_path = module_name.split('.')
- module_path[-1] = 'gen_'+module_path[-1]
+ module_path[-1] = 'gen_' + module_path[-1]
file_name = os.path.join(path, '..', *module_path) + '.py'
module_file = open(file_name, 'w')
dependencies = {'symbol': ['from ._internal import SymbolBase',
'from ..base import _Null'],
'ndarray': ['from ._internal import NDArrayBase',
'from ..base import _Null']}
- module_file.write('# File content is auto-generated. Do not
modify.'+os.linesep)
- module_file.write('# pylint: skip-file'+os.linesep)
+ module_file.write('# File content is auto-generated. Do not modify.' +
os.linesep)
+ module_file.write('# pylint: skip-file' + os.linesep)
module_file.write(os.linesep.join(dependencies[module_name.split('.')[1]]))
return module_file
+
def write_all_str(module_file, module_all_list):
"""Write the proper __all__ based on available operators."""
module_file.write(os.linesep)
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services