Hey David,

I attached a couple of classes that I am using in my program to manage Networks and serialize/deserialize Connections. The code is a bit rough, but I ran through and put in some comments to explain some things. These classes are used by a GUI program to allow a user to interact with the Network, but I didn't include any of the UI related stuff. Also, I've used a few different kinds of data as input to the network (feeding it in using the feed() method in my HTMNetwork class), but I've primarily used Sine wave values for simplicity while working on the serialization and deserialization aspects of the program.

Andrew

On 12/24/2015 2:09 PM, cogmission (David Ray) wrote:
No problem Andrew, I knew what you meant... ;-) Thanks!

On Thu, Dec 24, 2015 at 2:07 PM, Andrew Dillon <[email protected] <mailto:[email protected]>> wrote:

    I just realized I said I tagged them with the Serialized interface
    - I don't know if that interface exists - I meant the to say the
    Serializable interface. It's a minor point, but I don't want to
    cause any confusion.

    Thanks again, David, and Happy Holidays to you as well!

    On Dec 24, 2015 1:41 PM, "cogmission (David Ray)"
    <[email protected] <mailto:[email protected]>>
    wrote:

        Hi Andrew,

        Thank you for sending over the files. I will take a look and
        get back to you in the next couple of days.

        Happy Holidays!

        David

        On Thu, Dec 24, 2015 at 12:46 PM, Andrew Dillon
        <[email protected] <mailto:[email protected]>>
        wrote:

            I attached the classes I tagged with the Serialized
            interface, but I'll list them below as well.

              * *Connections*
              * *Cell*
              * *Column*
              * *DistantDendrite*
              * *Pool*
              * *ProximalDendrite*
              * *Segment*
              * *Synapse*
              * *FlatMatrixSupport*
              * *SparseBinaryMatrixSupport*
              * *SparseMatrixSupport*
              * *SparseObjectMatrix*

            The reason I found it necessary to tag the other eleven
            classes besides *Connections *was to include all of the
            *Connections*'s members**in the serialization process.
            Some of the classes I tagged are not directly used as
            members of *Connections*, but I needed to serialize them
            because they were superclasses of certain *Connections*
            members, and the only way to deserialize the subclass
            without serializing the superclass would have been to add
            a no-args constructor to the superclass. This would result
            in missing data upon deserialization, though, due to
            certain fields in the superclass not being initialized.
            This quick jguru article
            <http://www.jguru.com/faq/view.jsp?EID=34802> might
            explain what I'm saying better.

            Thanks for the assistance and no worries about a late
            reply. I wasn't even expecting a response today, since it
            is Christmas Eve.


            On 12/24/2015 11:19 AM, cogmission (David Ray) wrote:
            Hi Andrew,

            Welcome aboard! :-)

            First let me say thanks for using HTM.java, it's very
            nice to hear about user experiences indeed. Perhaps it
            would be easier if you attached the classes that you
            altered so that we can mock up our own example to see
            what's going on - such as the Connections.java file and
            whatever else was necessary to alter? (I would be
            surprised to find that you had to alter anything else,
            actually?).

            As this is the next thing on HTM.java's agenda, this is
            very interesting indeed... Also, as it is x-mas eve, I
            may not be able to get back to you as promptly as I
            otherwise would - but please send over the files as soon
            as you are able because I am anxious to play with them! ;-)

            Cheers,
            David

            On Thu, Dec 24, 2015 at 10:51 AM, Andrew Dillon
            <[email protected]
            <mailto:[email protected]>> wrote:

                Hello all. I've been learning about HTM for a couple
                of months now, reading On Intelligence and Numenta's
                papers, as well as watching their videos. I just
                began actually working with htm.java (I'm not very
                familiar with Python, so I am unable to use that
                version) a week or two ago, so I'm no expert on it,
                but I have been able to create some working
                demonstrations of it.

                I have, however, run into a bit of a problem with
                saving networks. I am aware that htm.java does not
                currently support this type of operation
                (https://github.com/numenta/htm.java/wiki/Call-To-Arms),
                so I am attempting to develop a basic method of
                saving and recreating my networks myself. What I have
                done thus far is modify a couple of classes to
                implement Java's Serializable interface, in order to
                save my Layer's (I am just working with one right
                now) Connections object. I have succeeded in
                serializing and deserializing the Connections object,
                and putting it back into a new Layer with the
                Layer.using() method.

                The problem is that when I feed the network (that is
                using the deserialized Connections) the same data it
                had learned to recognize before I serialized it, it
                no longer predicts the proper values. Its output
                looks exactly like a new network; as though my saved
                Connections is being overwritten or ignored somehow.
                I've spent the past few days trying to figure out
                what is happening, digging around the source code and
                trying a few different things, but have been unable
                to produce any results. Do any of you folks have any
                idea how I might go about resolving this issue?

                I am sure code samples would be of interest here, but
                I'm not sure what, specifically, I should include as
                my program is of a decent size. If anybody would like
                some samples, please mention what general
                functions/areas of my program you would like to see,
                and I'll be happy to oblige.

                Thanks in advance for any help. I am very fascinated
                by this project and HTM theory in general. I really
                appreciate what you all are doing and that this
                project was made open source!




-- /With kind regards,/
            David Ray
            Java Solutions Architect
            *Cortical.io <http://cortical.io/>*
            Sponsor of: HTM.java <https://github.com/numenta/htm.java>
            [email protected] <mailto:[email protected]>
            http://cortical.io




-- /With kind regards,/
        David Ray
        Java Solutions Architect
        *Cortical.io <http://cortical.io/>*
        Sponsor of: HTM.java <https://github.com/numenta/htm.java>
        [email protected] <mailto:[email protected]>
        http://cortical.io <http://cortical.io/>




--
/With kind regards,/
David Ray
Java Solutions Architect
*Cortical.io <http://cortical.io/>*
Sponsor of: HTM.java <https://github.com/numenta/htm.java>
[email protected] <mailto:[email protected]>
http://cortical.io <http://cortical.io/>

package com.ajdillon.numyo.htm;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.concurrent.ConcurrentLinkedQueue;

import org.numenta.nupic.Connections;
import org.numenta.nupic.Parameters;
import org.numenta.nupic.Parameters.KEY;
import org.numenta.nupic.algorithms.Anomaly;
import org.numenta.nupic.algorithms.SpatialPooler;
import org.numenta.nupic.algorithms.TemporalMemory;
import org.numenta.nupic.network.Inference;
import org.numenta.nupic.network.Layer;
import org.numenta.nupic.network.Network;
import org.numenta.nupic.network.Region;
import org.numenta.nupic.network.sensor.ObservableSensor;
import org.numenta.nupic.network.sensor.Publisher;
import org.numenta.nupic.network.sensor.Sensor;
import org.numenta.nupic.network.sensor.SensorParams;
import org.numenta.nupic.network.sensor.SensorParams.Keys;

import rx.Subscriber;

// HTMNetwork is a sort of 'wrapper' class that simplifies interactions with a 
htm.java Network
// that are common in my program.
public class HTMNetwork implements Serializable {
        private boolean                                                 
networkIsRunning        = false;
        private double                                                  
predictedValue          = 0.0;
        private Publisher                                               
manualPublisher;
        private Network                                                 network;
        private Region                                                  region;
        private Layer                                                   layer;
        private Connections                                             
savedConnections;
        private boolean                                                 
stopNetworkThread       = false;
        private ConcurrentLinkedQueue<Double>   inputBuffer                     
= new ConcurrentLinkedQueue<Double>();
        private ConcurrentLinkedQueue<Double>   outputBuffer            = new 
ConcurrentLinkedQueue<Double>();

        // Constructor that takes in a File object. If null is passed in for 
the file, a new Connections object
        // will be created for the Network, otherwise the file is assumed to 
contain the serialized Connections
  // from a previous run of the program, and it is deserialized and put into 
the Layer.
        public HTMNetwork(File readFile) throws FileNotFoundException, 
ClassNotFoundException, IOException {

                //Loading Connections from passed in File using the 
PersistenceManager class.
                if (readFile != null) {
                        System.out.println("LOADING FROM FILE");
                        PersistenceManager persMan = new 
PersistenceManager(readFile, null);
                        savedConnections = persMan.loadConnections();
                        persMan.close();
                        System.out.println("DONE LOADING FROM FILE");
                }

                manualPublisher = 
Publisher.builder().addHeader("Pod1EMG").addHeader("float").addHeader("B").build();

                Parameters params = HTMNetworkHarness.getParameters();
                params.union(HTMNetworkHarness.getEncoderParams());

                final Sensor<ObservableSensor<String>> obsSensor = 
Sensor.create(ObservableSensor::create,
                                SensorParams.create(Keys::obs, "", 
manualPublisher));

                layer = Network.createLayer("Layer 2/3", params)
                                .alterParameter(KEY.AUTO_CLASSIFY, Boolean.TRUE)
                                .add(Anomaly.create()).add(new 
TemporalMemory()).add(new SpatialPooler())
                                .add(obsSensor);

                //If a Connections object was obtained from the passed in file, 
put it into the Layer.
                if (savedConnections != null) {
                        System.out.println("LOADING CONNECTIONS INTO LAYER...");
                        layer.using(savedConnections);
                        System.out.println("DONE LOADING CONNS INTO LAYER.");
                }

                region = Network.createRegion("Region1")
                                .add(layer);

                network = Network.create("NuMyoNetwork", params)
                                .add(region);
        }

        private Subscriber<Inference> getSubscriber() throws 
FileNotFoundException {
                return (new Subscriber<Inference>() {
                        private PrintWriter printWriter = new PrintWriter(
                                        new 
File(System.getProperty("user.home") +
                                        File.separator + "Desktop" +
                                        File.seperator + "HTMNetwork.txt"));

                        @Override
                        public void onCompleted() {
                                printWriter.close();
                                System.out.println("-----NETWORK 
COMPLETED-----");
                        }

                        @Override
                        public void onError(Throwable e) {
                                e.printStackTrace();

                        }

                        @Override
                        public void onNext(Inference infer) {
                                String classifierField = "EMGPod1";
                                double newPrediction;
                                if 
(infer.getClassification(classifierField).getMostProbableValue(1) != null) {
                                        newPrediction = (Double) 
infer.getClassification(classifierField).getMostProbableValue(1);
                                } else {
                                        newPrediction = predictedValue;
                                }
                                //Add new prediction to outputBuffer (a Queue 
object) so that the feed() method will
                                //obtain the value, stop looping, and return.
                                outputBuffer.add(newPrediction);

                                double actual = (Double) 
infer.getClassifierInput().get(classifierField).get("inputValue");
                                System.out.printf("%d - PRED: %3.2f ACTL: 
%3.2f%n", infer.getRecordNum(), predictedValue, actual);
                                printWriter.printf("%d - PRED: %3.2f ACTL: 
%3.2f%n", infer.getRecordNum(), predictedValue, actual);
                                predictedValue = newPrediction;
                        }
                });
        }

        // Starts the network and begins running a thread that is responsible 
for polling the input
        // Queue and giving values it finds to the manualPublisher's onNext(). 
Values are placed into
        // the input buffer by the feed() method.
        public void start() throws FileNotFoundException {

                networkIsRunning = true;
                stopNetworkThread = false;
                network.observe().subscribe(getSubscriber());
                network.start();
                new Thread(new Runnable() {

                        @Override
                        public void run() {
                                while (!stopNetworkThread) {
                                        System.out.println("Running Network 
Thread");
                                        Double nextInputVal;
                                        if ((nextInputVal = inputBuffer.poll()) 
!= null) {
                                                System.out.println("Network 
found this input: " + nextInputVal);
                                                
manualPublisher.onNext(String.valueOf(nextInputVal));
                                        }
                                        try {
                                                Thread.sleep(5);
                                        } catch (InterruptedException e) {
                                                e.printStackTrace();
                                        }
                                }
                        }
                }, "HTMNetwork-Input-Update-Thread").start();
        }

        //Self explanatory :)
        public void stop() {
                stopNetworkThread = true;
                manualPublisher.onComplete();
                networkIsRunning = false;
        }

        //Saves the Network's Connections to the specified file using a 
PersistenceManager.
        public void save(File file) throws FileNotFoundException, IOException {
                if (networkIsRunning) {
                        stop();
                }
                PersistenceManager persMan = new PersistenceManager(null, file);
                persMan.saveConnections(layer.getConnections());
                persMan.close();
        }

        //Returns prediction of next value based on value fed in.
        public double feed(double value) {
                System.out.println("Feeding in value to network");
                inputBuffer.add(value);
                Double networkPrediction;
                while ((networkPrediction = outputBuffer.poll()) == null && 
!stopNetworkThread) { //Loop until network generates its prediction
                        System.out.println("Waiting for network's prediction");
                        try {
                                Thread.sleep(5);
                        } catch (InterruptedException e) {
                                e.printStackTrace();
                        }
                }
                return networkPrediction;

        }
}
package com.ajdillon.numyo.htm;

import java.util.HashMap;
import java.util.Map;

import org.numenta.nupic.Parameters;
import org.numenta.nupic.Parameters.KEY;

class HTMNetworkHarness {
        static Map<String, Map<String, Object>> setupMap(
                        Map<String, Map<String, Object>> map,
                        int n,
                        int w,
                        double min,
                        double max,
                        double radius,
                        double resolution,
                        Boolean periodic,
                        Boolean clip,
                        Boolean forced,
                        String fieldName,
                        String fieldType,
                        String encoderType) {
                        
                if (map == null) {
                        map = new HashMap<String, Map<String, Object>>();
                }
                Map<String, Object> inner = null;
                if ((inner = map.get(fieldName)) == null) {
                        map.put(fieldName, inner = new HashMap<String, 
Object>());
                }
                
                inner.put("n", n);
                inner.put("w", w);
                inner.put("minVal", min);
                inner.put("maxVal", max);
                inner.put("radius", radius);
                inner.put("resolution", resolution);
                
                if (periodic != null)
                        inner.put("periodic", periodic);
                if (clip != null)
                        inner.put("clipInput", clip);
                if (forced != null)
                        inner.put("forced", forced);
                if (fieldName != null)
                        inner.put("fieldName", fieldName);
                if (fieldType != null)
                        inner.put("fieldType", fieldType);
                if (encoderType != null)
                        inner.put("encoderType", encoderType);
                        
                return map;
        }
        
        static Map<String, Map<String, Object>> getFieldEncodingMap() {
                Map<String, Map<String, Object>> fieldEncodings = setupMap(
                                null, //Map to work with. None yet exists, so 
new one will be created and returned
                                50, //Number of bits in SDR vector
                                21, //Sparsity of SDR vector - the number of 
bits used to represent a value
                                -100,  //Lowest value encoder will see 
(anything fed in that is lower will be looked upon as this value)
                                100, //Highest value encoder will see (anything 
fed in that is higher will be looked upon as this value)
                                0, //Inhibition radius for each column
                                0.1, //Two values will only be seen differently 
if their difference is >= this value
                                null, //Is the input periodic - does it cycle 
or loop (e.g. days of the week)
                                Boolean.TRUE, //Any values lower than minVal 
will be clipped to (seen as being the same) as minVal
                                //Any values greater than maxVal will be 
clipped to (seen as being the same) as maxVal
                                null, //Enforce a safety check (ensuring 
standard ration of n and w)
                                "EMGPod1", //Name of field
                                "float", //Type of data this field is
                                "ScalarEncoder"); //Type of encoder to be used 
for this data
                                
                return fieldEncodings;
        }
        
        static Parameters getEncoderParams() {
                Map<String, Map<String, Object>> fieldEncodings = 
getFieldEncodingMap();
                
                Parameters p = Parameters.getEncoderDefaultParameters();
                p.setParameterByKey(KEY.GLOBAL_INHIBITIONS, true);
                p.setParameterByKey(KEY.COLUMN_DIMENSIONS, new int[] { 2048 });
                p.setParameterByKey(KEY.CELLS_PER_COLUMN, 32);
                p.setParameterByKey(KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 40.0);
                p.setParameterByKey(KEY.POTENTIAL_PCT, 0.8);
                p.setParameterByKey(KEY.SYN_PERM_CONNECTED, 0.1);
                p.setParameterByKey(KEY.SYN_PERM_ACTIVE_INC, 0.0001);
                p.setParameterByKey(KEY.SYN_PERM_INACTIVE_DEC, 0.0005);
                p.setParameterByKey(KEY.MAX_BOOST, 1.0);
                
                p.setParameterByKey(KEY.MAX_NEW_SYNAPSE_COUNT, 20);
                p.setParameterByKey(KEY.INITIAL_PERMANENCE, 0.21);
                p.setParameterByKey(KEY.PERMANENCE_INCREMENT, 0.1);
                p.setParameterByKey(KEY.PERMANENCE_DECREMENT, 0.1);
                p.setParameterByKey(KEY.MIN_THRESHOLD, 9);
                p.setParameterByKey(KEY.ACTIVATION_THRESHOLD, 12);
                
                p.setParameterByKey(KEY.CLIP_INPUT, true);
                p.setParameterByKey(KEY.FIELD_ENCODING_MAP, fieldEncodings);
                
                return p;
        }
        
        static Parameters getParameters() {
                Parameters parameters = Parameters.getAllDefaultParameters();
                parameters.setParameterByKey(KEY.INPUT_DIMENSIONS, new int[] { 
2048 });
                parameters.setParameterByKey(KEY.COLUMN_DIMENSIONS, new int[] { 
20 });
                parameters.setParameterByKey(KEY.CELLS_PER_COLUMN, 6);
                
                // SpatialPooler specific
                parameters.setParameterByKey(KEY.POTENTIAL_RADIUS, 12);// 3
                parameters.setParameterByKey(KEY.POTENTIAL_PCT, 0.5);// 0.5
                parameters.setParameterByKey(KEY.GLOBAL_INHIBITIONS, false);
                parameters.setParameterByKey(KEY.LOCAL_AREA_DENSITY, -1.0);
                
parameters.setParameterByKey(KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 5.0);
                parameters.setParameterByKey(KEY.STIMULUS_THRESHOLD, 1.0);
                parameters.setParameterByKey(KEY.SYN_PERM_INACTIVE_DEC, 0.01);
                parameters.setParameterByKey(KEY.SYN_PERM_ACTIVE_INC, 0.1);
                parameters.setParameterByKey(KEY.SYN_PERM_TRIM_THRESHOLD, 0.05);
                parameters.setParameterByKey(KEY.SYN_PERM_CONNECTED, 0.1);
                parameters.setParameterByKey(KEY.MIN_PCT_OVERLAP_DUTY_CYCLE, 
0.1);
                parameters.setParameterByKey(KEY.MIN_PCT_ACTIVE_DUTY_CYCLE, 
0.1);
                parameters.setParameterByKey(KEY.DUTY_CYCLE_PERIOD, 10);
                parameters.setParameterByKey(KEY.MAX_BOOST, 10.0);
                parameters.setParameterByKey(KEY.SEED, 42);
                parameters.setParameterByKey(KEY.SP_VERBOSITY, 0);
                
                // Temporal Memory specific
                parameters.setParameterByKey(KEY.INITIAL_PERMANENCE, 0.2);
                parameters.setParameterByKey(KEY.CONNECTED_PERMANENCE, 0.8);
                parameters.setParameterByKey(KEY.MIN_THRESHOLD, 5);
                parameters.setParameterByKey(KEY.MAX_NEW_SYNAPSE_COUNT, 6);
                parameters.setParameterByKey(KEY.PERMANENCE_INCREMENT, 0.05);
                parameters.setParameterByKey(KEY.PERMANENCE_DECREMENT, 0.05);
                parameters.setParameterByKey(KEY.ACTIVATION_THRESHOLD, 4);
                
                return parameters;
        }
}
package com.ajdillon.numyo.htm;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;

import org.numenta.nupic.Connections;

//PersistenceManager is responsible for the serialization and deserialization 
of a
//HTMNetwork Network's Connections.
public class PersistenceManager {

        static long                     startTime;
        static long                     endTime;
        static long                     durationSecs;

        ObjectInputStream       in;
        ObjectOutputStream      out;

        // Constructor that takes a file from which to deserialize a 
Connections and/or a
        // file to serialize a Connectinos to.
        public PersistenceManager(File readFile, File saveFile) throws 
FileNotFoundException, IOException {
                if (readFile != null)
                        in = new ObjectInputStream(new 
FileInputStream(readFile.getAbsolutePath()));
                if (saveFile != null)
                        out = new ObjectOutputStream(new 
FileOutputStream(saveFile.getAbsolutePath()));
        }

        // Returns a Connections object after deserializing it from the file 
specified
        // in the Constructor.
        public Connections loadConnections()
                        throws FileNotFoundException, IOException, 
ClassNotFoundException {

                System.out.println("LOADING CONNECTIONS FROM FILE...");
                startTime = System.currentTimeMillis();
                Connections connections = (Connections) in.readObject();
                endTime = System.currentTimeMillis();
                durationSecs = (endTime - startTime) / 1000;
                System.out.println("DONE LOADING CONNECTIONS.(" + durationSecs 
+ "secs)");
                return connections;
        }

        // Takes a Connections object and serializes it to the file specified 
in the
        // constructor.
        public void saveConnections(Connections connections) throws IOException 
{
                System.out.println("SAVING CONNECTIONS...");
                out.writeObject(connections);
                System.out.println("DONE SAVING CONNECTIONS.");
        }

        //Self explanatory :)
        public void close() throws IOException {
                if (in != null)
                        in.close();
                if (out != null)
                        out.close();
        }

}

Reply via email to