SINGA-126 Python Binding for Interactive Training
1. Replace 'x != None' with 'x is not None'
2. Fixed the bug from type mismatch: debug should be set to False before
passing it to SINGA's loss layer.
3. Set default value for SingaProto's zookeeper endpoint. Then we can ignore
`-singa_conf xxx'. SINGA would assume that
the zookeeper endpoint is the default one ('localhost:2181'), and glog
would use its default logging dir.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/1c8e0dc0
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/1c8e0dc0
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/1c8e0dc0
Branch: refs/heads/master
Commit: 1c8e0dc03e1fc06c2e6460892fc9a91e482e5434
Parents: 5a8fc37
Author: Wei Wang <[email protected]>
Authored: Wed Apr 6 12:26:47 2016 +0800
Committer: Wei Wang <[email protected]>
Committed: Wed Apr 6 12:26:47 2016 +0800
----------------------------------------------------------------------
src/driver.cc | 2 --
src/proto/singa.proto | 6 +++---
tool/python/examples/train_cifar10.py | 8 ++++----
tool/python/singa/layer.py | 8 ++++----
4 files changed, 11 insertions(+), 13 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/1c8e0dc0/src/driver.cc
----------------------------------------------------------------------
diff --git a/src/driver.cc b/src/driver.cc
index 702df5e..56680d2 100644
--- a/src/driver.cc
+++ b/src/driver.cc
@@ -54,8 +54,6 @@ void Driver::Init(int argc, char **argv) {
arg_pos = ArgPos(argc, argv, "-singa_conf");
if (arg_pos != -1)
ReadProtoFromTextFile(argv[arg_pos + 1], &singa_conf_);
- else
- ReadProtoFromTextFile("conf/singa.conf", &singa_conf_);
// set log path
if (singa_conf_.has_log_dir())
SetupLog(singa_conf_.log_dir(), "driver");
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/1c8e0dc0/src/proto/singa.proto
----------------------------------------------------------------------
diff --git a/src/proto/singa.proto b/src/proto/singa.proto
index 6e12d25..2fbf2db 100644
--- a/src/proto/singa.proto
+++ b/src/proto/singa.proto
@@ -7,9 +7,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
-*
+*
* http://www.apache.org/licenses/LICENSE-2.0
-*
+*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -23,7 +23,7 @@ package singa;
message SingaProto {
// ip/hostname:port[,ip/hostname:port]
- required string zookeeper_host = 1;
+ optional string zookeeper_host = 1 [default = "localhost:2181"];
// log dir for singa binary and job information(job id, host list, pid list)
optional string log_dir = 2 [default = "/tmp/singa-log/"];
}
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/1c8e0dc0/tool/python/examples/train_cifar10.py
----------------------------------------------------------------------
diff --git a/tool/python/examples/train_cifar10.py
b/tool/python/examples/train_cifar10.py
index 6c621a3..e8ac973 100755
--- a/tool/python/examples/train_cifar10.py
+++ b/tool/python/examples/train_cifar10.py
@@ -38,7 +38,7 @@ from singa.model import *
'''
CIFAR10 dataset can be downloaded at
[https://www.cs.toronto.edu/~kriz/cifar.html]
-- please specify dataset_dir
+- please specify dataset_dir
'''
dataset_dir_ = singa_root_ +
"/tool/python/examples/datasets/cifar-10-batches-py"
mean_image = None
@@ -73,7 +73,7 @@ def load_dataset(did=1):
1 row (1 image) includes 1 label & 3072 pixels
3072 pixels are 3 channels of a 32x32 image
'''
- assert mean_image != None, 'mean_image is required'
+ assert mean_image is not None, 'mean_image is required'
print '[Load CIFAR10 dataset {}]'.format(did)
fname_train_data = dataset_dir_ + "/data_batch_{}".format(did)
cifar10 = unpickle(fname_train_data)
@@ -81,7 +81,7 @@ def load_dataset(did=1):
image = image - mean_image
print ' image x:', image.shape
label = np.asarray(cifar10['labels'], dtype=np.uint8)
- label = label.reshape(label.size, 1)
+ label = label.reshape(label.size, 1)
print ' label y:', label.shape
return image, label
@@ -116,7 +116,7 @@ loss = Loss('softmaxloss')
sgd = SGD(decay=0.004, momentum=0.9, lr_type='manual', step=(0,60000,65000),
step_lr=(0.001,0.0001,0.00001))
#-------------------------------------------------------------------
-batchsize = 100
+batchsize = 100
disp_freq = 50
train_step = 1000
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/1c8e0dc0/tool/python/singa/layer.py
----------------------------------------------------------------------
diff --git a/tool/python/singa/layer.py b/tool/python/singa/layer.py
index f786245..c9a992d 100644
--- a/tool/python/singa/layer.py
+++ b/tool/python/singa/layer.py
@@ -160,7 +160,7 @@ class Layer(object):
return data
def display(self):
- debug, flag = 0, 0
+ debug, flag = False, 0
print self.singalayer.ToString(debug, flag)
def get_singalayer(self):
@@ -204,7 +204,7 @@ class Dummy(object):
if self.singalayer == None:
self.setup(shape)
- if data != None:
+ if data is not None:
data = data.astype(np.float)
dataVector = floatVector(datasize)
for i in range(batchsize):
@@ -212,7 +212,7 @@ class Dummy(object):
dataVector[i*hdim+j] = data[i, j]
labelVector = intVector(0)
- if aux_data != None:
+ if aux_data is not None:
aux_data = aux_data.astype(np.int)
labelVector = intVector(datasize)
for i in range(batchsize):
@@ -249,7 +249,7 @@ class LabelInput(Dummy):
def Feed(self, label_data):
Dummy.Feed(self, label_data.shape, None, label_data)
-
+
class Data(Layer):