This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 23f973b  amalgamation library fixes for python  2.7 and 3.6 
compatibility (#9346)
23f973b is described below

commit 23f973b7523c033dd8cd18eb7a9ae2bda32d43c0
Author: Sina Afrooze <[email protected]>
AuthorDate: Mon Jan 15 12:45:11 2018 -0800

    amalgamation library fixes for python  2.7 and 3.6 compatibility (#9346)
---
 amalgamation/amalgamation.py         | 122 +++++++++++++++++++----------------
 amalgamation/python/mxnet_predict.py |   5 ++
 2 files changed, 70 insertions(+), 57 deletions(-)

diff --git a/amalgamation/amalgamation.py b/amalgamation/amalgamation.py
index f1e1e02..1a82413 100644
--- a/amalgamation/amalgamation.py
+++ b/amalgamation/amalgamation.py
@@ -15,8 +15,11 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from __future__ import print_function
+
 import sys
-import os.path, re, StringIO
+import os.path, re
+from io import BytesIO, StringIO
 import platform
 
 blacklist = [
@@ -39,15 +42,12 @@ if minimum != 0:
     blacklist.append('linalg.h')
 
 if platform.system() != 'Darwin':
-  blacklist.append('TargetConditionals.h')
+    blacklist.append('TargetConditionals.h')
 
 if platform.system() != 'Windows':
-  blacklist.append('windows.h')
-  blacklist.append('process.h')
+    blacklist.append('windows.h')
+    blacklist.append('process.h')
 
-def pprint(lst):
-    for item in lst:
-        print item
 
 def get_sources(def_file):
     sources = []
@@ -67,6 +67,7 @@ def get_sources(def_file):
             visited.add(fn)
     return sources
 
+
 sources = get_sources(sys.argv[1])
 
 
@@ -93,9 +94,7 @@ re2 = re.compile('"([./a-zA-Z0-9_-]*)"')
 
 sysheaders = []
 history = set([])
-out = StringIO.StringIO()
-
-
+out = BytesIO()
 
 
 def expand(x, pending, stage):
@@ -103,38 +102,44 @@ def expand(x, pending, stage):
         return
 
     if x in pending:
-        #print 'loop found: %s in ' % x, pending
+        #print('loop found: {} in {}'.format(x, pending))
         return
 
-    whtspace = '  '*expand.treeDepth
-    expand.fileCount+=1
-    print >>out, "//=====[%3d] STAGE:%4s %sEXPANDING: %s =====\n" % 
(expand.fileCount, stage, whtspace, x)
-    print        "//=====[%3d] STAGE:%4s %sEXPANDING: %s        " % 
(expand.fileCount, stage, whtspace, x)
-    for line in open(x):
-        if line.find('#include') < 0:
-            out.write(line)
-            continue
-        if line.strip().find('#include') > 0:
-            print line
-            continue
-        m = re1.search(line)
-        if not m: m = re2.search(line)
-        if not m:
-            print line + ' not found'
-            continue
-        h = m.groups()[0].strip('./')
-        source = find_source(h, x, stage)
-        if not source:
-            if (h not in blacklist and
-                h not in sysheaders and
-                'mkl' not in h and
-                'nnpack' not in h and
-                not h.endswith('.cuh')): sysheaders.append(h)
-        else:
-            expand.treeDepth+=1
-            expand(source, pending + [x], stage)
-            expand.treeDepth-=1
-    print >>out, "//===== EXPANDED  : %s =====\n" %x
+    whtspace = '  ' * expand.treeDepth
+    expand.fileCount += 1
+    comment = u"//=====[{:3d}] STAGE:{:>4} {}EXPANDING: {} 
=====\n\n".format(expand.fileCount, stage, whtspace, x)
+    out.write(comment.encode('ascii'))
+    print(comment)
+
+    with open(x, 'rb') as x_h:
+        for line in x_h.readlines():
+            uline = line.decode('utf-8')
+            if uline.find('#include') < 0:
+                out.write(line)
+                continue
+            if uline.strip().find('#include') > 0:
+                print(uline)
+                continue
+            m = re1.search(uline)
+            if not m:
+                m = re2.search(uline)
+            if not m:
+                print(uline + ' not found')
+                continue
+            h = m.groups()[0].strip('./')
+            source = find_source(h, x, stage)
+            if not source:
+                if (h not in blacklist and
+                    h not in sysheaders and
+                    'mkl' not in h and
+                    'nnpack' not in h and
+                    not h.endswith('.cuh')): sysheaders.append(h)
+            else:
+                expand.treeDepth += 1
+                expand(source, pending + [x], stage)
+                expand.treeDepth -= 1
+
+    out.write(u"//===== EXPANDED  : {} =====\n\n".format(x).encode('ascii'))
     history.add(x)
 
 
@@ -149,15 +154,16 @@ expand(sys.argv[3], [], "nnvm")
 expand(sys.argv[4], [], "src")
 
 # Write to amalgamation file
-f = open(sys.argv[5], 'wb')
+with open(sys.argv[5], 'wb') as f:
 
-if minimum != 0:
-    sysheaders.remove('cblas.h')
-    print >>f, "#define MSHADOW_STAND_ALONE 1"
-    print >>f, "#define MSHADOW_USE_SSE 0"
-    print >>f, "#define MSHADOW_USE_CBLAS 0"
+    if minimum != 0:
+        sysheaders.remove('cblas.h')
+        f.write(b"#define MSHADOW_STAND_ALONE 1\n")
+        f.write(b"#define MSHADOW_USE_SSE 0\n")
+        f.write(b"#define MSHADOW_USE_CBLAS 0\n")
 
-print >>f, '''
+    f.write(
+        b"""
 #if defined(__MACH__)
 #include <mach/clock.h>
 #include <mach/mach.h>
@@ -172,19 +178,21 @@ print >>f, '''
 #endif
 
 #endif
-'''
+\n"""
+    )
 
-if minimum != 0 and android != 0 and 'complex.h' not in sysheaders:
-    sysheaders.append('complex.h')
+    if minimum != 0 and android != 0 and 'complex.h' not in sysheaders:
+        sysheaders.append('complex.h')
 
-for k in sorted(sysheaders):
-    print >>f, "#include <%s>" % k
+    for k in sorted(sysheaders):
+        f.write("#include <{}>\n".format(k).encode('ascii'))
 
-print >>f, ''
-print >>f, out.getvalue()
+    f.write(b'\n')
+    f.write(out.getvalue())
+    f.write(b'\n')
 
-for x in sources:
-    if x not in history and not x.endswith('.o'):
-        print 'Not processed:', x
+for src in sources:
+    if src not in history and not src.endswith('.o'):
+        print('Not processed:', src)
 
 
diff --git a/amalgamation/python/mxnet_predict.py 
b/amalgamation/python/mxnet_predict.py
index 3dd6b38..627f375 100644
--- a/amalgamation/python/mxnet_predict.py
+++ b/amalgamation/python/mxnet_predict.py
@@ -35,14 +35,19 @@ if sys.version_info[0] == 3:
 else:
     py_str = lambda x: x
 
+
 def c_str(string):
     """"Convert a python string to C string."""
+    if not isinstance(string, str):
+        string = string.decode('ascii')
     return ctypes.c_char_p(string.encode('utf-8'))
 
+
 def c_array(ctype, values):
     """Create ctypes array from a python array."""
     return (ctype * len(values))(*values)
 
+
 def _find_lib_path():
     """Find mxnet library."""
     curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))

-- 
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].

Reply via email to