http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b8104eff/tool/python/singa/generatepy.sh
----------------------------------------------------------------------
diff --git a/tool/python/singa/generatepy.sh b/tool/python/singa/generatepy.sh
new file mode 100755
index 0000000..bb4a3af
--- /dev/null
+++ b/tool/python/singa/generatepy.sh
@@ -0,0 +1,8 @@
+
+swig -c++ -python driver.i
+g++ -fPIC /home/zhongle/incubator-singa/src/driver.cc driver_wrap.cxx -shared 
-o _driver.so \
+    -L/home/zhongle/incubator-singa/.libs/ -lsinga -DMSHADOW_USE_CUDA=0 \
+    -DMSHADOW_USE_CBLAS=1 -DMSHADOW_USE_MKL=0 -DUSE_GPU -std=c++11 \
+    -I/usr/cuda-7.5/include -I/home/zhongle/local/include \
+    -I/home/zhongle/incubator-singa/include \
+    -I/usr/include/python2.7/

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b8104eff/tool/python/utils/__init__.py
----------------------------------------------------------------------
diff --git a/tool/python/utils/__init__.py b/tool/python/utils/__init__.py
new file mode 100644
index 0000000..e69de29

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b8104eff/tool/python/utils/message.py
----------------------------------------------------------------------
diff --git a/tool/python/utils/message.py b/tool/python/utils/message.py
new file mode 100644
index 0000000..10a1a18
--- /dev/null
+++ b/tool/python/utils/message.py
@@ -0,0 +1,50 @@
+#!/usr/bin/env python 
+import sys, os 
+from utility import * 
+sys.path.append(os.path.join(os.path.dirname(__file__),'../../pb2')) 
+
+module_list=[]
+
+# import all modules in dir singa_root/too/pb2, except common, singa and 
__init__
+for f in os.listdir(os.path.join(os.path.dirname(__file__),'../../pb2')):
+  if (f.endswith(".pyc")):
+    continue
+  if(f == "__init__.py" or f == "common_pb2.py" or f == "singa_pb2.py" ):
+    continue
+  module_name = f.split('.')[0]
+  module=__import__(module_name)  
+  module_list.append(module)
+  for func_name in dir(module):
+    if not func_name.startswith("__"):
+      globals()[func_name] = getattr(module,func_name)
+
+class Message(object):
+  def __init__(self,protoname,**kwargs):
+    for module in module_list:
+      if hasattr(module,protoname+"Proto"):
+        class_ = getattr(module,protoname+"Proto")
+        self.proto = class_()
+        return setval(self.proto,**kwargs)
+    raise Exception('invalid protoname')
+
+enumDict_=dict()
+
+#get all enum type list in modules
+for module in module_list:
+  for enumtype in module.DESCRIPTOR.enum_types_by_name:
+    tempDict=enumDict_[enumtype]=dict()
+    for name in getattr(module,enumtype).DESCRIPTOR.values_by_name: 
+      tempDict[name[1:].lower()]=getattr(module,name)
+
+def make_function(enumtype):
+  def _function(key):
+    return enumDict_[enumtype][key]
+  return _function
+
+current_module = sys.modules[__name__]
+
+#def all the enumtypes
+for module in module_list:
+  for enumtype in module.DESCRIPTOR.enum_types_by_name:
+    setattr(current_module,"enum"+enumtype,make_function(enumtype))
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b8104eff/tool/python/utils/utility.py
----------------------------------------------------------------------
diff --git a/tool/python/utils/utility.py b/tool/python/utils/utility.py
new file mode 100644
index 0000000..42840c9
--- /dev/null
+++ b/tool/python/utils/utility.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python
+
+layerid = 0 
+paramid = 0 
+
+def generateName(label, op=0):
+  global layerid, paramid
+  num = layerid
+  if label == 'layer':
+    if op ==1: layerid += 1
+    num = layerid
+  elif label == 'param':
+    if op ==1: paramid += 1
+    num = paramid
+  else:
+    if op ==1: layerid += 1
+    num = layerid
+
+  return '{0}{1}'.format(label, num)
+
+
+def setval(proto, **kwargs):
+  for k,v in kwargs.items():
+    #print 'kv: ', k, ', ', v
+    if hasattr(proto, k):
+      flabel = proto.DESCRIPTOR.fields_by_name[k].label
+      ftype  = proto.DESCRIPTOR.fields_by_name[k].type
+
+      fattr  = getattr(proto, k) 
+      if flabel == 3: # repeated field
+        if ftype == 11: # message type 
+          fattr = fattr.add()
+          fattr.MergeFrom(v)
+        else:
+          if type(v) == list or type(v) == tuple:
+            for i in range(len(v)):
+              fattr.append(v[i])
+          else:
+            fattr.append(v)
+      else:
+        if ftype == 11: # message type 
+          fattr = getattr(proto,k)
+          fattr.MergeFrom(v)
+        else:
+          setattr(proto, k, v)

Reply via email to