Hey David,
I attached a couple of classes that I am using in my program to manage
Networks and serialize/deserialize Connections. The code is a bit rough,
but I ran through and put in some comments to explain some things. These
classes are used by a GUI program to allow a user to interact with the
Network, but I didn't include any of the UI related stuff. Also, I've
used a few different kinds of data as input to the network (feeding it
in using the feed() method in my HTMNetwork class), but I've primarily
used Sine wave values for simplicity while working on the serialization
and deserialization aspects of the program.
Andrew
On 12/24/2015 2:09 PM, cogmission (David Ray) wrote:
No problem Andrew, I knew what you meant... ;-) Thanks!
On Thu, Dec 24, 2015 at 2:07 PM, Andrew Dillon
<[email protected] <mailto:[email protected]>> wrote:
I just realized I said I tagged them with the Serialized interface
- I don't know if that interface exists - I meant the to say the
Serializable interface. It's a minor point, but I don't want to
cause any confusion.
Thanks again, David, and Happy Holidays to you as well!
On Dec 24, 2015 1:41 PM, "cogmission (David Ray)"
<[email protected] <mailto:[email protected]>>
wrote:
Hi Andrew,
Thank you for sending over the files. I will take a look and
get back to you in the next couple of days.
Happy Holidays!
David
On Thu, Dec 24, 2015 at 12:46 PM, Andrew Dillon
<[email protected] <mailto:[email protected]>>
wrote:
I attached the classes I tagged with the Serialized
interface, but I'll list them below as well.
* *Connections*
* *Cell*
* *Column*
* *DistantDendrite*
* *Pool*
* *ProximalDendrite*
* *Segment*
* *Synapse*
* *FlatMatrixSupport*
* *SparseBinaryMatrixSupport*
* *SparseMatrixSupport*
* *SparseObjectMatrix*
The reason I found it necessary to tag the other eleven
classes besides *Connections *was to include all of the
*Connections*'s members**in the serialization process.
Some of the classes I tagged are not directly used as
members of *Connections*, but I needed to serialize them
because they were superclasses of certain *Connections*
members, and the only way to deserialize the subclass
without serializing the superclass would have been to add
a no-args constructor to the superclass. This would result
in missing data upon deserialization, though, due to
certain fields in the superclass not being initialized.
This quick jguru article
<http://www.jguru.com/faq/view.jsp?EID=34802> might
explain what I'm saying better.
Thanks for the assistance and no worries about a late
reply. I wasn't even expecting a response today, since it
is Christmas Eve.
On 12/24/2015 11:19 AM, cogmission (David Ray) wrote:
Hi Andrew,
Welcome aboard! :-)
First let me say thanks for using HTM.java, it's very
nice to hear about user experiences indeed. Perhaps it
would be easier if you attached the classes that you
altered so that we can mock up our own example to see
what's going on - such as the Connections.java file and
whatever else was necessary to alter? (I would be
surprised to find that you had to alter anything else,
actually?).
As this is the next thing on HTM.java's agenda, this is
very interesting indeed... Also, as it is x-mas eve, I
may not be able to get back to you as promptly as I
otherwise would - but please send over the files as soon
as you are able because I am anxious to play with them! ;-)
Cheers,
David
On Thu, Dec 24, 2015 at 10:51 AM, Andrew Dillon
<[email protected]
<mailto:[email protected]>> wrote:
Hello all. I've been learning about HTM for a couple
of months now, reading On Intelligence and Numenta's
papers, as well as watching their videos. I just
began actually working with htm.java (I'm not very
familiar with Python, so I am unable to use that
version) a week or two ago, so I'm no expert on it,
but I have been able to create some working
demonstrations of it.
I have, however, run into a bit of a problem with
saving networks. I am aware that htm.java does not
currently support this type of operation
(https://github.com/numenta/htm.java/wiki/Call-To-Arms),
so I am attempting to develop a basic method of
saving and recreating my networks myself. What I have
done thus far is modify a couple of classes to
implement Java's Serializable interface, in order to
save my Layer's (I am just working with one right
now) Connections object. I have succeeded in
serializing and deserializing the Connections object,
and putting it back into a new Layer with the
Layer.using() method.
The problem is that when I feed the network (that is
using the deserialized Connections) the same data it
had learned to recognize before I serialized it, it
no longer predicts the proper values. Its output
looks exactly like a new network; as though my saved
Connections is being overwritten or ignored somehow.
I've spent the past few days trying to figure out
what is happening, digging around the source code and
trying a few different things, but have been unable
to produce any results. Do any of you folks have any
idea how I might go about resolving this issue?
I am sure code samples would be of interest here, but
I'm not sure what, specifically, I should include as
my program is of a decent size. If anybody would like
some samples, please mention what general
functions/areas of my program you would like to see,
and I'll be happy to oblige.
Thanks in advance for any help. I am very fascinated
by this project and HTM theory in general. I really
appreciate what you all are doing and that this
project was made open source!
--
/With kind regards,/
David Ray
Java Solutions Architect
*Cortical.io <http://cortical.io/>*
Sponsor of: HTM.java <https://github.com/numenta/htm.java>
[email protected] <mailto:[email protected]>
http://cortical.io
--
/With kind regards,/
David Ray
Java Solutions Architect
*Cortical.io <http://cortical.io/>*
Sponsor of: HTM.java <https://github.com/numenta/htm.java>
[email protected] <mailto:[email protected]>
http://cortical.io <http://cortical.io/>
--
/With kind regards,/
David Ray
Java Solutions Architect
*Cortical.io <http://cortical.io/>*
Sponsor of: HTM.java <https://github.com/numenta/htm.java>
[email protected] <mailto:[email protected]>
http://cortical.io <http://cortical.io/>
package com.ajdillon.numyo.htm;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.concurrent.ConcurrentLinkedQueue;
import org.numenta.nupic.Connections;
import org.numenta.nupic.Parameters;
import org.numenta.nupic.Parameters.KEY;
import org.numenta.nupic.algorithms.Anomaly;
import org.numenta.nupic.algorithms.SpatialPooler;
import org.numenta.nupic.algorithms.TemporalMemory;
import org.numenta.nupic.network.Inference;
import org.numenta.nupic.network.Layer;
import org.numenta.nupic.network.Network;
import org.numenta.nupic.network.Region;
import org.numenta.nupic.network.sensor.ObservableSensor;
import org.numenta.nupic.network.sensor.Publisher;
import org.numenta.nupic.network.sensor.Sensor;
import org.numenta.nupic.network.sensor.SensorParams;
import org.numenta.nupic.network.sensor.SensorParams.Keys;
import rx.Subscriber;
// HTMNetwork is a sort of 'wrapper' class that simplifies interactions with a
htm.java Network
// that are common in my program.
public class HTMNetwork implements Serializable {
private boolean
networkIsRunning = false;
private double
predictedValue = 0.0;
private Publisher
manualPublisher;
private Network network;
private Region region;
private Layer layer;
private Connections
savedConnections;
private boolean
stopNetworkThread = false;
private ConcurrentLinkedQueue<Double> inputBuffer
= new ConcurrentLinkedQueue<Double>();
private ConcurrentLinkedQueue<Double> outputBuffer = new
ConcurrentLinkedQueue<Double>();
// Constructor that takes in a File object. If null is passed in for
the file, a new Connections object
// will be created for the Network, otherwise the file is assumed to
contain the serialized Connections
// from a previous run of the program, and it is deserialized and put into
the Layer.
public HTMNetwork(File readFile) throws FileNotFoundException,
ClassNotFoundException, IOException {
//Loading Connections from passed in File using the
PersistenceManager class.
if (readFile != null) {
System.out.println("LOADING FROM FILE");
PersistenceManager persMan = new
PersistenceManager(readFile, null);
savedConnections = persMan.loadConnections();
persMan.close();
System.out.println("DONE LOADING FROM FILE");
}
manualPublisher =
Publisher.builder().addHeader("Pod1EMG").addHeader("float").addHeader("B").build();
Parameters params = HTMNetworkHarness.getParameters();
params.union(HTMNetworkHarness.getEncoderParams());
final Sensor<ObservableSensor<String>> obsSensor =
Sensor.create(ObservableSensor::create,
SensorParams.create(Keys::obs, "",
manualPublisher));
layer = Network.createLayer("Layer 2/3", params)
.alterParameter(KEY.AUTO_CLASSIFY, Boolean.TRUE)
.add(Anomaly.create()).add(new
TemporalMemory()).add(new SpatialPooler())
.add(obsSensor);
//If a Connections object was obtained from the passed in file,
put it into the Layer.
if (savedConnections != null) {
System.out.println("LOADING CONNECTIONS INTO LAYER...");
layer.using(savedConnections);
System.out.println("DONE LOADING CONNS INTO LAYER.");
}
region = Network.createRegion("Region1")
.add(layer);
network = Network.create("NuMyoNetwork", params)
.add(region);
}
private Subscriber<Inference> getSubscriber() throws
FileNotFoundException {
return (new Subscriber<Inference>() {
private PrintWriter printWriter = new PrintWriter(
new
File(System.getProperty("user.home") +
File.separator + "Desktop" +
File.seperator + "HTMNetwork.txt"));
@Override
public void onCompleted() {
printWriter.close();
System.out.println("-----NETWORK
COMPLETED-----");
}
@Override
public void onError(Throwable e) {
e.printStackTrace();
}
@Override
public void onNext(Inference infer) {
String classifierField = "EMGPod1";
double newPrediction;
if
(infer.getClassification(classifierField).getMostProbableValue(1) != null) {
newPrediction = (Double)
infer.getClassification(classifierField).getMostProbableValue(1);
} else {
newPrediction = predictedValue;
}
//Add new prediction to outputBuffer (a Queue
object) so that the feed() method will
//obtain the value, stop looping, and return.
outputBuffer.add(newPrediction);
double actual = (Double)
infer.getClassifierInput().get(classifierField).get("inputValue");
System.out.printf("%d - PRED: %3.2f ACTL:
%3.2f%n", infer.getRecordNum(), predictedValue, actual);
printWriter.printf("%d - PRED: %3.2f ACTL:
%3.2f%n", infer.getRecordNum(), predictedValue, actual);
predictedValue = newPrediction;
}
});
}
// Starts the network and begins running a thread that is responsible
for polling the input
// Queue and giving values it finds to the manualPublisher's onNext().
Values are placed into
// the input buffer by the feed() method.
public void start() throws FileNotFoundException {
networkIsRunning = true;
stopNetworkThread = false;
network.observe().subscribe(getSubscriber());
network.start();
new Thread(new Runnable() {
@Override
public void run() {
while (!stopNetworkThread) {
System.out.println("Running Network
Thread");
Double nextInputVal;
if ((nextInputVal = inputBuffer.poll())
!= null) {
System.out.println("Network
found this input: " + nextInputVal);
manualPublisher.onNext(String.valueOf(nextInputVal));
}
try {
Thread.sleep(5);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}, "HTMNetwork-Input-Update-Thread").start();
}
//Self explanatory :)
public void stop() {
stopNetworkThread = true;
manualPublisher.onComplete();
networkIsRunning = false;
}
//Saves the Network's Connections to the specified file using a
PersistenceManager.
public void save(File file) throws FileNotFoundException, IOException {
if (networkIsRunning) {
stop();
}
PersistenceManager persMan = new PersistenceManager(null, file);
persMan.saveConnections(layer.getConnections());
persMan.close();
}
//Returns prediction of next value based on value fed in.
public double feed(double value) {
System.out.println("Feeding in value to network");
inputBuffer.add(value);
Double networkPrediction;
while ((networkPrediction = outputBuffer.poll()) == null &&
!stopNetworkThread) { //Loop until network generates its prediction
System.out.println("Waiting for network's prediction");
try {
Thread.sleep(5);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
return networkPrediction;
}
}
package com.ajdillon.numyo.htm;
import java.util.HashMap;
import java.util.Map;
import org.numenta.nupic.Parameters;
import org.numenta.nupic.Parameters.KEY;
class HTMNetworkHarness {
static Map<String, Map<String, Object>> setupMap(
Map<String, Map<String, Object>> map,
int n,
int w,
double min,
double max,
double radius,
double resolution,
Boolean periodic,
Boolean clip,
Boolean forced,
String fieldName,
String fieldType,
String encoderType) {
if (map == null) {
map = new HashMap<String, Map<String, Object>>();
}
Map<String, Object> inner = null;
if ((inner = map.get(fieldName)) == null) {
map.put(fieldName, inner = new HashMap<String,
Object>());
}
inner.put("n", n);
inner.put("w", w);
inner.put("minVal", min);
inner.put("maxVal", max);
inner.put("radius", radius);
inner.put("resolution", resolution);
if (periodic != null)
inner.put("periodic", periodic);
if (clip != null)
inner.put("clipInput", clip);
if (forced != null)
inner.put("forced", forced);
if (fieldName != null)
inner.put("fieldName", fieldName);
if (fieldType != null)
inner.put("fieldType", fieldType);
if (encoderType != null)
inner.put("encoderType", encoderType);
return map;
}
static Map<String, Map<String, Object>> getFieldEncodingMap() {
Map<String, Map<String, Object>> fieldEncodings = setupMap(
null, //Map to work with. None yet exists, so
new one will be created and returned
50, //Number of bits in SDR vector
21, //Sparsity of SDR vector - the number of
bits used to represent a value
-100, //Lowest value encoder will see
(anything fed in that is lower will be looked upon as this value)
100, //Highest value encoder will see (anything
fed in that is higher will be looked upon as this value)
0, //Inhibition radius for each column
0.1, //Two values will only be seen differently
if their difference is >= this value
null, //Is the input periodic - does it cycle
or loop (e.g. days of the week)
Boolean.TRUE, //Any values lower than minVal
will be clipped to (seen as being the same) as minVal
//Any values greater than maxVal will be
clipped to (seen as being the same) as maxVal
null, //Enforce a safety check (ensuring
standard ration of n and w)
"EMGPod1", //Name of field
"float", //Type of data this field is
"ScalarEncoder"); //Type of encoder to be used
for this data
return fieldEncodings;
}
static Parameters getEncoderParams() {
Map<String, Map<String, Object>> fieldEncodings =
getFieldEncodingMap();
Parameters p = Parameters.getEncoderDefaultParameters();
p.setParameterByKey(KEY.GLOBAL_INHIBITIONS, true);
p.setParameterByKey(KEY.COLUMN_DIMENSIONS, new int[] { 2048 });
p.setParameterByKey(KEY.CELLS_PER_COLUMN, 32);
p.setParameterByKey(KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 40.0);
p.setParameterByKey(KEY.POTENTIAL_PCT, 0.8);
p.setParameterByKey(KEY.SYN_PERM_CONNECTED, 0.1);
p.setParameterByKey(KEY.SYN_PERM_ACTIVE_INC, 0.0001);
p.setParameterByKey(KEY.SYN_PERM_INACTIVE_DEC, 0.0005);
p.setParameterByKey(KEY.MAX_BOOST, 1.0);
p.setParameterByKey(KEY.MAX_NEW_SYNAPSE_COUNT, 20);
p.setParameterByKey(KEY.INITIAL_PERMANENCE, 0.21);
p.setParameterByKey(KEY.PERMANENCE_INCREMENT, 0.1);
p.setParameterByKey(KEY.PERMANENCE_DECREMENT, 0.1);
p.setParameterByKey(KEY.MIN_THRESHOLD, 9);
p.setParameterByKey(KEY.ACTIVATION_THRESHOLD, 12);
p.setParameterByKey(KEY.CLIP_INPUT, true);
p.setParameterByKey(KEY.FIELD_ENCODING_MAP, fieldEncodings);
return p;
}
static Parameters getParameters() {
Parameters parameters = Parameters.getAllDefaultParameters();
parameters.setParameterByKey(KEY.INPUT_DIMENSIONS, new int[] {
2048 });
parameters.setParameterByKey(KEY.COLUMN_DIMENSIONS, new int[] {
20 });
parameters.setParameterByKey(KEY.CELLS_PER_COLUMN, 6);
// SpatialPooler specific
parameters.setParameterByKey(KEY.POTENTIAL_RADIUS, 12);// 3
parameters.setParameterByKey(KEY.POTENTIAL_PCT, 0.5);// 0.5
parameters.setParameterByKey(KEY.GLOBAL_INHIBITIONS, false);
parameters.setParameterByKey(KEY.LOCAL_AREA_DENSITY, -1.0);
parameters.setParameterByKey(KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 5.0);
parameters.setParameterByKey(KEY.STIMULUS_THRESHOLD, 1.0);
parameters.setParameterByKey(KEY.SYN_PERM_INACTIVE_DEC, 0.01);
parameters.setParameterByKey(KEY.SYN_PERM_ACTIVE_INC, 0.1);
parameters.setParameterByKey(KEY.SYN_PERM_TRIM_THRESHOLD, 0.05);
parameters.setParameterByKey(KEY.SYN_PERM_CONNECTED, 0.1);
parameters.setParameterByKey(KEY.MIN_PCT_OVERLAP_DUTY_CYCLE,
0.1);
parameters.setParameterByKey(KEY.MIN_PCT_ACTIVE_DUTY_CYCLE,
0.1);
parameters.setParameterByKey(KEY.DUTY_CYCLE_PERIOD, 10);
parameters.setParameterByKey(KEY.MAX_BOOST, 10.0);
parameters.setParameterByKey(KEY.SEED, 42);
parameters.setParameterByKey(KEY.SP_VERBOSITY, 0);
// Temporal Memory specific
parameters.setParameterByKey(KEY.INITIAL_PERMANENCE, 0.2);
parameters.setParameterByKey(KEY.CONNECTED_PERMANENCE, 0.8);
parameters.setParameterByKey(KEY.MIN_THRESHOLD, 5);
parameters.setParameterByKey(KEY.MAX_NEW_SYNAPSE_COUNT, 6);
parameters.setParameterByKey(KEY.PERMANENCE_INCREMENT, 0.05);
parameters.setParameterByKey(KEY.PERMANENCE_DECREMENT, 0.05);
parameters.setParameterByKey(KEY.ACTIVATION_THRESHOLD, 4);
return parameters;
}
}
package com.ajdillon.numyo.htm;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import org.numenta.nupic.Connections;
//PersistenceManager is responsible for the serialization and deserialization
of a
//HTMNetwork Network's Connections.
public class PersistenceManager {
static long startTime;
static long endTime;
static long durationSecs;
ObjectInputStream in;
ObjectOutputStream out;
// Constructor that takes a file from which to deserialize a
Connections and/or a
// file to serialize a Connectinos to.
public PersistenceManager(File readFile, File saveFile) throws
FileNotFoundException, IOException {
if (readFile != null)
in = new ObjectInputStream(new
FileInputStream(readFile.getAbsolutePath()));
if (saveFile != null)
out = new ObjectOutputStream(new
FileOutputStream(saveFile.getAbsolutePath()));
}
// Returns a Connections object after deserializing it from the file
specified
// in the Constructor.
public Connections loadConnections()
throws FileNotFoundException, IOException,
ClassNotFoundException {
System.out.println("LOADING CONNECTIONS FROM FILE...");
startTime = System.currentTimeMillis();
Connections connections = (Connections) in.readObject();
endTime = System.currentTimeMillis();
durationSecs = (endTime - startTime) / 1000;
System.out.println("DONE LOADING CONNECTIONS.(" + durationSecs
+ "secs)");
return connections;
}
// Takes a Connections object and serializes it to the file specified
in the
// constructor.
public void saveConnections(Connections connections) throws IOException
{
System.out.println("SAVING CONNECTIONS...");
out.writeObject(connections);
System.out.println("DONE SAVING CONNECTIONS.");
}
//Self explanatory :)
public void close() throws IOException {
if (in != null)
in.close();
if (out != null)
out.close();
}
}