Author: tommaso Date: Sun Oct 18 13:59:39 2015 New Revision: 1709279 URL: http://svn.apache.org/viewvc?rev=1709279&view=rev Log: update NN api
Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java?rev=1709279&r1=1709278&r2=1709279&view=diff ============================================================================== --- labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java (original) +++ labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java Sun Oct 18 13:59:39 2015 @@ -47,12 +47,4 @@ public interface Hypothesis<T, I, O> { */ O[] predict(Input<I> input) throws PredictionException; - /** - * Let this <code>Hypothesis</code> learn by experience, in the form of a {@link org.apache.yay.TrainingSet} - * - * @param trainingExamples the learning {@link org.apache.yay.TrainingSet} - * @throws LearningException if any error occurs during the learning phase - */ - void learn(TrainingSet<I, O> trainingExamples) throws LearningException; - } Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java?rev=1709279&r1=1709278&r2=1709279&view=diff ============================================================================== --- labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java (original) +++ labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java Sun Oct 18 13:59:39 2015 @@ -27,4 +27,13 @@ import org.apache.commons.math3.linear.R */ public interface NeuralNetwork extends Hypothesis<RealMatrix, Double, Double> { + /** + * Let this <code>Hypothesis</code> learn by experience, in the form of a {@link org.apache.yay.TrainingSet} + * + * @param trainingExamples the learning {@link org.apache.yay.TrainingSet} + * @return the learned weights + * @throws LearningException if any error occurs during the learning phase + */ + RealMatrix[] learn(TrainingSet<Double, Double> trainingExamples) throws LearningException; + } Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java?rev=1709279&r1=1709278&r2=1709279&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java Sun Oct 18 13:59:39 2015 @@ -52,10 +52,11 @@ public class BasicPerceptron implements } @Override - public void learn(TrainingSet<Double, Double> trainingExamples) throws LearningException { + public RealMatrix[] learn(TrainingSet<Double, Double> trainingExamples) throws LearningException { for (TrainingExample<Double, Double> example : trainingExamples) { learn(example); } + return new RealMatrix[]{new Array2DRowRealMatrix(currentWeights)}; } public void learn(TrainingExample<Double, Double> example) { Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java?rev=1709279&r1=1709278&r2=1709279&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java Sun Oct 18 13:59:39 2015 @@ -56,9 +56,10 @@ class NeuralNetworkFactory { private RealMatrix[] updatedRealMatrixSet = realMatrixSet; @Override - public void learn(TrainingSet<Double, Double> samples) throws LearningException { + public RealMatrix[] learn(TrainingSet<Double, Double> samples) throws LearningException { try { updatedRealMatrixSet = learningStrategy.learnWeights(realMatrixSet, samples); + return updatedRealMatrixSet; } catch (WeightLearningException e) { throw new LearningException(e); } Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java?rev=1709279&r1=1709278&r2=1709279&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java Sun Oct 18 13:59:39 2015 @@ -66,24 +66,29 @@ public class WordVectorsTest { assertFalse(fragments.isEmpty()); TrainingSet<Double, Double> trainingSet = createTrainingSet(vocabulary, fragments); - TrainingExample<Double, Double> next = trainingSet.iterator().next(); + int inputSize = next.getFeatures().size() ; int outputSize = next.getOutput().length; int hiddenSize = new Random().nextInt(50) + 15; + System.err.println("i:"+inputSize+",h:"+hiddenSize+",o:"+outputSize); RealMatrix[] randomWeights = createRandomWeights(inputSize, hiddenSize, outputSize); Map<Integer, ActivationFunction<Double>> activationFunctions = new HashMap<Integer, ActivationFunction<Double>>(); activationFunctions.put(0, new IdentityActivationFunction<Double>()); activationFunctions.put(1, new SoftmaxActivationFunction()); FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(activationFunctions); - BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.03d, 10, - BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LMSCostFunction(), 10); + BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.03d, 1, + BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LogisticRegressionCostFunction(), 10); NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(randomWeights, learningStrategy, predictionStrategy); - neuralNetwork.learn(trainingSet); + RealMatrix[] learnedWeights = neuralNetwork.learn(trainingSet); + + RealMatrix wordVectors = learnedWeights[learnedWeights.length - 1]; + + assertNotNull(wordVectors); - RealMatrix vectorsMatrix = MatrixUtils.createRealMatrix(next.getFeatures().size(), next.getOutput().length); + RealMatrix mappingsMatrix = MatrixUtils.createRealMatrix(next.getFeatures().size(), next.getOutput().length); BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.txt"))); int m = 0; @@ -115,7 +120,7 @@ public class WordVectorsTest { for (int x = 0; x < row.length; x++) { row[x] = predict[x]; } - vectorsMatrix.setRow(m, row); + mappingsMatrix.setRow(m, row); m++; String vectorString = Arrays.toString(predict); @@ -145,7 +150,7 @@ public class WordVectorsTest { bufferedWriter.close(); ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(new File("target/sg-vectors.bin"))); - MatrixUtils.serializeRealMatrix(vectorsMatrix, os); + MatrixUtils.serializeRealMatrix(mappingsMatrix, os); } private String hotDecode(Double[] doubles, List<String> vocabulary) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org