SINGA-126 Python Binding for Interactive Training

1. Replace 'x != None' with 'x is not None'
2. Fixed the bug from type mismatch: debug should be set to False before 
passing it to SINGA's loss layer.
3. Set default value for SingaProto's zookeeper endpoint. Then we can ignore 
`-singa_conf xxx'. SINGA would assume that
the zookeeper endpoint is the default one ('localhost:2181'), and glog
would use its default logging dir.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/1c8e0dc0
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/1c8e0dc0
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/1c8e0dc0

Branch: refs/heads/master
Commit: 1c8e0dc03e1fc06c2e6460892fc9a91e482e5434
Parents: 5a8fc37
Author: Wei Wang <[email protected]>
Authored: Wed Apr 6 12:26:47 2016 +0800
Committer: Wei Wang <[email protected]>
Committed: Wed Apr 6 12:26:47 2016 +0800

----------------------------------------------------------------------
 src/driver.cc                         | 2 --
 src/proto/singa.proto                 | 6 +++---
 tool/python/examples/train_cifar10.py | 8 ++++----
 tool/python/singa/layer.py            | 8 ++++----
 4 files changed, 11 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/1c8e0dc0/src/driver.cc
----------------------------------------------------------------------
diff --git a/src/driver.cc b/src/driver.cc
index 702df5e..56680d2 100644
--- a/src/driver.cc
+++ b/src/driver.cc
@@ -54,8 +54,6 @@ void Driver::Init(int argc, char **argv) {
   arg_pos = ArgPos(argc, argv, "-singa_conf");
   if (arg_pos != -1)
     ReadProtoFromTextFile(argv[arg_pos + 1], &singa_conf_);
-  else
-    ReadProtoFromTextFile("conf/singa.conf", &singa_conf_);
   // set log path
   if (singa_conf_.has_log_dir())
     SetupLog(singa_conf_.log_dir(), "driver");

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/1c8e0dc0/src/proto/singa.proto
----------------------------------------------------------------------
diff --git a/src/proto/singa.proto b/src/proto/singa.proto
index 6e12d25..2fbf2db 100644
--- a/src/proto/singa.proto
+++ b/src/proto/singa.proto
@@ -7,9 +7,9 @@
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
-* 
+*
 *   http://www.apache.org/licenses/LICENSE-2.0
-* 
+*
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -23,7 +23,7 @@ package singa;
 
 message SingaProto {
   // ip/hostname:port[,ip/hostname:port]
-  required string zookeeper_host = 1;
+  optional string zookeeper_host = 1 [default = "localhost:2181"];
   // log dir for singa binary and job information(job id, host list, pid list)
   optional string log_dir = 2 [default = "/tmp/singa-log/"];
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/1c8e0dc0/tool/python/examples/train_cifar10.py
----------------------------------------------------------------------
diff --git a/tool/python/examples/train_cifar10.py 
b/tool/python/examples/train_cifar10.py
index 6c621a3..e8ac973 100755
--- a/tool/python/examples/train_cifar10.py
+++ b/tool/python/examples/train_cifar10.py
@@ -38,7 +38,7 @@ from singa.model import *
 
 '''
 CIFAR10 dataset can be downloaded at 
[https://www.cs.toronto.edu/~kriz/cifar.html]
-- please specify dataset_dir 
+- please specify dataset_dir
 '''
 dataset_dir_ = singa_root_ + 
"/tool/python/examples/datasets/cifar-10-batches-py"
 mean_image = None
@@ -73,7 +73,7 @@ def load_dataset(did=1):
         1 row (1 image) includes 1 label & 3072 pixels
         3072 pixels are  3 channels of a 32x32 image
     '''
-    assert mean_image != None, 'mean_image is required'
+    assert mean_image is not None, 'mean_image is required'
     print '[Load CIFAR10 dataset {}]'.format(did)
     fname_train_data = dataset_dir_ + "/data_batch_{}".format(did)
     cifar10 = unpickle(fname_train_data)
@@ -81,7 +81,7 @@ def load_dataset(did=1):
     image = image - mean_image
     print '  image x:', image.shape
     label = np.asarray(cifar10['labels'], dtype=np.uint8)
-    label = label.reshape(label.size, 1) 
+    label = label.reshape(label.size, 1)
     print '  label y:', label.shape
     return image, label
 
@@ -116,7 +116,7 @@ loss = Loss('softmaxloss')
 sgd = SGD(decay=0.004, momentum=0.9, lr_type='manual', step=(0,60000,65000), 
step_lr=(0.001,0.0001,0.00001))
 
 #-------------------------------------------------------------------
-batchsize = 100 
+batchsize = 100
 disp_freq = 50
 train_step = 1000
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/1c8e0dc0/tool/python/singa/layer.py
----------------------------------------------------------------------
diff --git a/tool/python/singa/layer.py b/tool/python/singa/layer.py
index f786245..c9a992d 100644
--- a/tool/python/singa/layer.py
+++ b/tool/python/singa/layer.py
@@ -160,7 +160,7 @@ class Layer(object):
         return data
 
     def display(self):
-        debug, flag = 0, 0
+        debug, flag = False, 0
         print self.singalayer.ToString(debug, flag)
 
     def get_singalayer(self):
@@ -204,7 +204,7 @@ class Dummy(object):
         if self.singalayer == None:
             self.setup(shape)
 
-        if data != None:
+        if data is not None:
             data = data.astype(np.float)
             dataVector = floatVector(datasize)
             for i in range(batchsize):
@@ -212,7 +212,7 @@ class Dummy(object):
                     dataVector[i*hdim+j] = data[i, j]
             labelVector = intVector(0)
 
-        if aux_data != None:
+        if aux_data is not None:
             aux_data = aux_data.astype(np.int)
             labelVector = intVector(datasize)
             for i in range(batchsize):
@@ -249,7 +249,7 @@ class LabelInput(Dummy):
 
     def Feed(self, label_data):
         Dummy.Feed(self, label_data.shape, None, label_data)
-    
+
 
 class Data(Layer):
 

Reply via email to