This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 1240577  Use feature detection in base.py (#11803)
1240577 is described below

commit 1240577dd629e97a740d65742a5008a6fc48a380
Author: cclauss <[email protected]>
AuthorDate: Wed Jul 18 22:02:00 2018 +0200

    Use feature detection in base.py (#11803)
---
 python/mxnet/base.py | 63 ++++++++++++++++++++++++++++++++++++++--------------
 1 file changed, 46 insertions(+), 17 deletions(-)

diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 0fb73b3..4df794b 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -20,44 +20,59 @@
 """ctypes library of mxnet and helper functions."""
 from __future__ import absolute_import
 
+import atexit
+import ctypes
+import inspect
 import os
 import sys
-import ctypes
-import atexit
 import warnings
-import inspect
+
 import numpy as np
+
 from . import libinfo
+
 warnings.filterwarnings('default', category=DeprecationWarning)
 
 __all__ = ['MXNetError']
 #----------------------------
 # library loading
 #----------------------------
-if sys.version_info[0] == 3:
-    string_types = str,
-    numeric_types = (float, int, np.generic)
-    integer_types = (int, np.int32, np.int64)
+
+# pylint: disable=pointless-statement
+try:
+    basestring
+    long
+except NameError:
+    basestring = str
+    long = int
+# pylint: enable=pointless-statement
+
+integer_types = (int, long, np.int32, np.int64)
+numeric_types = (float, int, long, np.generic)
+string_types = basestring,
+
+if sys.version_info[0] > 2:
     # this function is needed for python3
     # to convert ctypes.char_p .value back to python str
     py_str = lambda x: x.decode('utf-8')
 else:
-    string_types = basestring,
-    numeric_types = (float, int, long, np.generic)
-    integer_types = (int, long, np.int32, np.int64)
     py_str = lambda x: x
 
+
 class _NullType(object):
     """Placeholder for arguments"""
     def __repr__(self):
         return '_Null'
 
+
 _Null = _NullType()
 
+
 class MXNetError(Exception):
     """Error that will be throwed by all mxnet functions."""
     pass
 
+
 class NotImplementedForSymbol(MXNetError):
     """Error: Not implemented for symbol"""
     def __init__(self, function, alias, *args):
@@ -65,6 +80,7 @@ class NotImplementedForSymbol(MXNetError):
         self.function = function.__name__
         self.alias = alias
         self.args = [str(type(a)) for a in args]
+
     def __str__(self):
         msg = 'Function {}'.format(self.function)
         if self.alias:
@@ -74,6 +90,7 @@ class NotImplementedForSymbol(MXNetError):
         msg += ' is not implemented for Symbol and only available in NDArray.'
         return msg
 
+
 class NotSupportedForSparseNDArray(MXNetError):
     """Error: Not supported for SparseNDArray"""
     def __init__(self, function, alias, *args):
@@ -81,6 +98,7 @@ class NotSupportedForSparseNDArray(MXNetError):
         self.function = function.__name__
         self.alias = alias
         self.args = [str(type(a)) for a in args]
+
     def __str__(self):
         msg = 'Function {}'.format(self.function)
         if self.alias:
@@ -90,6 +108,7 @@ class NotSupportedForSparseNDArray(MXNetError):
         msg += ' is not supported for SparseNDArray and only available in 
NDArray.'
         return msg
 
+
 class MXCallbackList(ctypes.Structure):
     """Structure that holds Callback information. Passed to CustomOpProp."""
     _fields_ = [
@@ -98,6 +117,7 @@ class MXCallbackList(ctypes.Structure):
         ('contexts', ctypes.POINTER(ctypes.c_void_p))
         ]
 
+
 # Please see: 
https://stackoverflow.com/questions/5189699/how-to-make-a-class-property
 class _MXClassPropertyDescriptor(object):
     def __init__(self, fget, fset=None):
@@ -125,6 +145,7 @@ class _MXClassPropertyDescriptor(object):
         self.fset = func
         return self
 
+
 class _MXClassPropertyMetaClass(type):
     def __setattr__(cls, key, value):
         if key in cls.__dict__:
@@ -134,8 +155,9 @@ class _MXClassPropertyMetaClass(type):
 
         return super(_MXClassPropertyMetaClass, cls).__setattr__(key, value)
 
+
 # with_metaclass function obtained from: 
https://github.com/benjaminp/six/blob/master/six.py
-#pylint: disable=unused-argument
+# pylint: disable=unused-argument
 def with_metaclass(meta, *bases):
     """Create a base class with a metaclass."""
     # This requires a bit of explanation: the basic idea is to make a dummy
@@ -150,7 +172,8 @@ def with_metaclass(meta, *bases):
         def __prepare__(cls, name, this_bases):
             return meta.__prepare__(name, bases)
     return type.__new__(metaclass, 'temporary_class', (), {})
-#pylint: enable=unused-argument
+# pylint: enable=unused-argument
+
 
 def classproperty(func):
     if not isinstance(func, (classmethod, staticmethod)):
@@ -159,7 +182,6 @@ def classproperty(func):
     return _MXClassPropertyDescriptor(func)
 
 
-
 def _load_lib():
     """Load library by searching possible path."""
     lib_path = libinfo.find_lib_path()
@@ -168,6 +190,7 @@ def _load_lib():
     lib.MXGetLastError.restype = ctypes.c_char_p
     return lib
 
+
 # version number
 __version__ = libinfo.__version__
 # library instance of mxnet
@@ -192,6 +215,8 @@ RtcHandle = ctypes.c_void_p
 CudaModuleHandle = ctypes.c_void_p
 CudaKernelHandle = ctypes.c_void_p
 ProfileHandle = ctypes.c_void_p
+
+
 #----------------------------
 # helper function definition
 #----------------------------
@@ -346,6 +371,7 @@ def c_array_buf(ctype, buf):
     """
     return (ctype * len(buf)).from_buffer(buf)
 
+
 def c_handle_array(objs):
     """Create ctypes const void ** from a list of MXNet objects with handles.
 
@@ -363,6 +389,7 @@ def c_handle_array(objs):
     arr[:] = [o.handle for o in objs]
     return arr
 
+
 def ctypes2buffer(cptr, length):
     """Convert ctypes pointer to buffer type.
 
@@ -386,6 +413,7 @@ def ctypes2buffer(cptr, length):
         raise RuntimeError('memmove failed')
     return res
 
+
 def ctypes2numpy_shared(cptr, shape):
     """Convert a ctypes pointer to a numpy array.
 
@@ -456,6 +484,7 @@ def _notify_shutdown():
     """Notify MXNet about a shutdown."""
     check_call(_LIB.MXNotifyShutdown())
 
+
 atexit.register(_notify_shutdown)
 
 
@@ -585,7 +614,6 @@ def _init_op_module(root_namespace, module_name, 
make_op_func):
         setattr(cur_module, function.__name__, function)
         cur_module.__all__.append(function.__name__)
 
-
         if op_name_prefix == '_contrib_':
             hdl = OpHandle()
             check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
@@ -616,17 +644,18 @@ def _generate_op_module_signature(root_namespace, 
module_name, op_code_gen_func)
         """Return the generated module file based on module name."""
         path = os.path.dirname(__file__)
         module_path = module_name.split('.')
-        module_path[-1] = 'gen_'+module_path[-1]
+        module_path[-1] = 'gen_' + module_path[-1]
         file_name = os.path.join(path, '..', *module_path) + '.py'
         module_file = open(file_name, 'w')
         dependencies = {'symbol': ['from ._internal import SymbolBase',
                                    'from ..base import _Null'],
                         'ndarray': ['from ._internal import NDArrayBase',
                                     'from ..base import _Null']}
-        module_file.write('# File content is auto-generated. Do not 
modify.'+os.linesep)
-        module_file.write('# pylint: skip-file'+os.linesep)
+        module_file.write('# File content is auto-generated. Do not modify.' + 
os.linesep)
+        module_file.write('# pylint: skip-file' + os.linesep)
         
module_file.write(os.linesep.join(dependencies[module_name.split('.')[1]]))
         return module_file
+
     def write_all_str(module_file, module_all_list):
         """Write the proper __all__ based on available operators."""
         module_file.write(os.linesep)

Reply via email to