http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b8104eff/tool/python/singa/generatepy.sh ---------------------------------------------------------------------- diff --git a/tool/python/singa/generatepy.sh b/tool/python/singa/generatepy.sh new file mode 100755 index 0000000..bb4a3af --- /dev/null +++ b/tool/python/singa/generatepy.sh @@ -0,0 +1,8 @@ + +swig -c++ -python driver.i +g++ -fPIC /home/zhongle/incubator-singa/src/driver.cc driver_wrap.cxx -shared -o _driver.so \ + -L/home/zhongle/incubator-singa/.libs/ -lsinga -DMSHADOW_USE_CUDA=0 \ + -DMSHADOW_USE_CBLAS=1 -DMSHADOW_USE_MKL=0 -DUSE_GPU -std=c++11 \ + -I/usr/cuda-7.5/include -I/home/zhongle/local/include \ + -I/home/zhongle/incubator-singa/include \ + -I/usr/include/python2.7/
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b8104eff/tool/python/utils/__init__.py ---------------------------------------------------------------------- diff --git a/tool/python/utils/__init__.py b/tool/python/utils/__init__.py new file mode 100644 index 0000000..e69de29 http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b8104eff/tool/python/utils/message.py ---------------------------------------------------------------------- diff --git a/tool/python/utils/message.py b/tool/python/utils/message.py new file mode 100644 index 0000000..10a1a18 --- /dev/null +++ b/tool/python/utils/message.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python +import sys, os +from utility import * +sys.path.append(os.path.join(os.path.dirname(__file__),'../../pb2')) + +module_list=[] + +# import all modules in dir singa_root/too/pb2, except common, singa and __init__ +for f in os.listdir(os.path.join(os.path.dirname(__file__),'../../pb2')): + if (f.endswith(".pyc")): + continue + if(f == "__init__.py" or f == "common_pb2.py" or f == "singa_pb2.py" ): + continue + module_name = f.split('.')[0] + module=__import__(module_name) + module_list.append(module) + for func_name in dir(module): + if not func_name.startswith("__"): + globals()[func_name] = getattr(module,func_name) + +class Message(object): + def __init__(self,protoname,**kwargs): + for module in module_list: + if hasattr(module,protoname+"Proto"): + class_ = getattr(module,protoname+"Proto") + self.proto = class_() + return setval(self.proto,**kwargs) + raise Exception('invalid protoname') + +enumDict_=dict() + +#get all enum type list in modules +for module in module_list: + for enumtype in module.DESCRIPTOR.enum_types_by_name: + tempDict=enumDict_[enumtype]=dict() + for name in getattr(module,enumtype).DESCRIPTOR.values_by_name: + tempDict[name[1:].lower()]=getattr(module,name) + +def make_function(enumtype): + def _function(key): + return enumDict_[enumtype][key] + return _function + +current_module = sys.modules[__name__] + +#def all the enumtypes +for module in module_list: + for enumtype in module.DESCRIPTOR.enum_types_by_name: + setattr(current_module,"enum"+enumtype,make_function(enumtype)) + http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b8104eff/tool/python/utils/utility.py ---------------------------------------------------------------------- diff --git a/tool/python/utils/utility.py b/tool/python/utils/utility.py new file mode 100644 index 0000000..42840c9 --- /dev/null +++ b/tool/python/utils/utility.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python + +layerid = 0 +paramid = 0 + +def generateName(label, op=0): + global layerid, paramid + num = layerid + if label == 'layer': + if op ==1: layerid += 1 + num = layerid + elif label == 'param': + if op ==1: paramid += 1 + num = paramid + else: + if op ==1: layerid += 1 + num = layerid + + return '{0}{1}'.format(label, num) + + +def setval(proto, **kwargs): + for k,v in kwargs.items(): + #print 'kv: ', k, ', ', v + if hasattr(proto, k): + flabel = proto.DESCRIPTOR.fields_by_name[k].label + ftype = proto.DESCRIPTOR.fields_by_name[k].type + + fattr = getattr(proto, k) + if flabel == 3: # repeated field + if ftype == 11: # message type + fattr = fattr.add() + fattr.MergeFrom(v) + else: + if type(v) == list or type(v) == tuple: + for i in range(len(v)): + fattr.append(v[i]) + else: + fattr.append(v) + else: + if ftype == 11: # message type + fattr = getattr(proto,k) + fattr.MergeFrom(v) + else: + setattr(proto, k, v)
