Hi,
This is a continuation of http://emergent.brynmawr.edu/pipermail/pyro-
users/2007-April/thread.html#584, recording weight updates in conx.
After looking at conx's code, I noticed that there were some hooks,
notably, postStep, which appeared to be intended for allowing one to
add bits of code to various parts of train, presumably w/the
intention of making it thus more general purpose.
It took me quite some time to figure out how to actually use this
feature, but I was able to do so, so will include the code I
developed to do this below.
I'd also like to know if I'm using the hook feature the way the
creators of conx's train intended.
Finally, I wrote some extensive comments in the code in the area that
I had the most difficulty (having to do with how to override the
postStep method). If anyone has insight into the issues I describe
there, I'd love to hear from you.
The solution I propose is for the particular (simple) case of a
single output-2-input and function, but since I overrode the Network
class in doing so, it should easily be applicable to other problems,
and should thus allow one to generate figures like 4.8 in Tom
Mitchell's Machine Learning text. (Aside: I've used matplotlib to
generate the figures, but you can use whatever plotting facility you
like, as I provide you with lists of weights over time). I've found
these types of figures very useful for student learning, so if
there's a desire to include my work in the Pyro collection somehow,
feel free to use what's appropriate.
Hope this is helpful to someone else,
--b
========
#from IPython.Debugger import Tracer; debug_here = Tracer()
"""
demonstrates simple perceptron-like network trained to perform boolean
and function.
"""
import pyrobot.brain.conx as C
import pickle
import pylab
fname = 'network.pickle'
class myNetwork(C.Network) :
def __init__(self,name='Backprop
Network',verbosity=0,fname=fname) :
"""
Create a conx Neural net, with the added ability to write
its layers/connections data to a pickle file, allowing weight
change over time, for example, to be explored.
"""
C.Network.__init__(self,name=name,verbosity=verbosity)
self.fname = fname
if self.fname != None :
# pickling data as the net's training progresses is made
# possible courtesy of conx.py's postStep
self.fh = file(fname,'w')
def myPostStep(**args) : self.dumpLayer()
self.postStep = myPostStep
# aside: what I found bizarre is taht I couldn't merely
overwrite
# postStep via, e.g.
#
# def postStep(**args) : self.dumpLayer()
#
# because that was treated as a local variable? the
self.postStep= myPostStep
# piece seems crucial.
#
# the other tricky thing about this solution was that it
# required myPostStep, although a method, to not
# explicitly include the self argument.
def reportParams(n) :
"""
Some fluff I added to make viewing the network's meta
parameters easier.
Note: this is where I discovered that things like
n.getEpsilon(), although
documented on pyro's neural network curric. modules, doesn't
work.
"""
s = ""
s += " eta=%g " % n.epsilon
s += " tol=%g " % n.tolerance
if n.batch :
s += "batch "
else :
s += "stochastic "
if n.learning :
s += "learn:on "
else :
s += "learn:off "
s += "mom=%g " % n.momentum
if n.orderedInputs :
s += "train:deterministic "
else :
s += "train:randomized "
s += "allowedErrors=%g " % n.stopPercent
s += " resetEpoch=%d " % n.resetEpoch
s += " resetLimit=%d " % n.resetLimit
print s
def setLearning(self,val) :
"""
Assumption: when you turn off learning, you wish to close
the pickle file.
"""
if self.fname != None :
self.fh.close()
C.Network.setLearning(self,val)
def dumpLayer(self) :
"""
The actual hook into train.
"""
pickle.dump((self.layers,self.connections),self.fh)
def wtsOverTime(self,display=True,closeAll=True) :
"""
This method decodes the pickle file and returns the weights
as they
change over time. These weights are returned in a
dictionary, where
each entry is the list of values a particular weight has
taken on over time.
The display arg controls if you also get a view of these
weights over time via matplotlib.
"""
assert self.fname != None, "bogus usage"
f = file(self.fname,'r')
data = []
try :
while True :
data.append(pickle.load(f))
except EOFError :
f.close()
except :
raise
n = len(data)
assert n > 0, "bogus file read"
# the following code was basically ripped out of conx's
# write-weights-to-file functionality
(layers,connections) = data[0]
d = wtsFromLayersAndConnections(layers,connections)
keys = d.keys()
wts = {}
for k in keys :
wts[k] = [d[k]]
for i in range(n) :
(layers,connections) = data[i]
d = wtsFromLayersAndConnections(layers,connections)
if i == 0 :
for k in d.keys() :
wts[k] = [d[k]]
else :
for k in d.keys() :
wts[k].append(d[k])
# displaying
nWts = len(wts)
if display :
if closeAll :
pylab.close('all')
pylab.subplot(nWts,1,1)
ct = 1
for k in d.keys() :
pylab.subplot(nWts,1,ct)
pylab.plot(wts[k])
pylab.ylabel(k)
ct += 1
return wts
def wtsFromLayersAndConnections(layers,connections) :
"""
convert layers/connections weight data into dictionary so its
easier to understand
"""
d = {}
for layer in layers:
if layer.type != 'Input':
for i in range(layer.size):
d['%s_%d_bias'%(layer.name[0:3],i)]=layer.weight[i]
for connection in connections:
for i in range(connection.fromLayer.size):
for j in range(connection.toLayer.size):
d['%s:%d->%s:%d'%(connection.fromLayer.name[0:3],i, \
connection.toLayer.name[0:3],j)] =
connection.weight[i][j]
return d
def makeNet(fname=fname) :
"""
create, train, and return a simple perceptron-like network trained
to perform boolean and function.
"""
n = myNetwork(fname=fname)
n.add(C.Layer('input',2)) # The input layer has two nodes
n.add(C.Layer('output',1)) # The output layer has one node
n.connect('input','output') # The input layer is connected to
the output layer
# provide training patterns
n.setInputs([[0.0,0.0],[0.0,1.0],[1.0,0.0],[1.0,1.0]])
n.setOutputs([[0.0],[0.0],[0.0],[1.0]])
# specify learning parameters
n.setEpsilon(0.5) # learning rate
n.setTolerance(0.2) # when target - activation < .2, considered
correct
n.setReportRate(1) # report every update
reportParams(n)
n.train()
n.setLearning(0)
if fname != None :
wts = n.wtsOverTime()
return (n,wts)
return n
_______________________________________________
Pyro-users mailing list
[email protected]
http://emergent.brynmawr.edu/mailman/listinfo/pyro-users