Author: tommaso Date: Thu Nov 12 16:11:25 2015 New Revision: 1714080 URL: http://svn.apache.org/viewvc?rev=1714080&view=rev Log: added cross-entropy cost function
Added: labs/yay/trunk/core/src/main/java/org/apache/yay/core/CrossEntropyCostFunction.java (with props) Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java Added: labs/yay/trunk/core/src/main/java/org/apache/yay/core/CrossEntropyCostFunction.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/CrossEntropyCostFunction.java?rev=1714080&view=auto ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/CrossEntropyCostFunction.java (added) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/CrossEntropyCostFunction.java Thu Nov 12 16:11:25 2015 @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.yay.core; + +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.yay.Hypothesis; +import org.apache.yay.NeuralNetworkCostFunction; +import org.apache.yay.PredictionException; +import org.apache.yay.TrainingExample; +import org.apache.yay.TrainingSet; + +/** + * This calculates the cross entropy cost function for neural networks + */ +public class CrossEntropyCostFunction implements NeuralNetworkCostFunction { + + @Override + public Double calculateAggregatedCost(TrainingSet<Double, Double> trainingSet, + Hypothesis<RealMatrix, Double, Double> hypothesis) throws Exception { + TrainingExample<Double, Double>[] samples = new TrainingExample[trainingSet.size()]; + int i = 0; + for (TrainingExample<Double, Double> sample : trainingSet) { + samples[i] = sample; + i++; + } + return calculateCost(hypothesis, samples); + } + + private Double calculateErrorTerm(Hypothesis<RealMatrix, Double, Double> hypothesis, + TrainingExample<Double, Double>... trainingExamples) throws PredictionException { + Double res = 0d; + + for (TrainingExample<Double, Double> input : trainingExamples) { + Double[] predictedOutput = hypothesis.predict(input); + Double[] sampleOutput = input.getOutput(); + for (int i = 0; i < predictedOutput.length; i++) { + Double so = sampleOutput[i]; + Double po = predictedOutput[i]; + res -= so * Math.log(po); + } + } + return res; + } + + @Override + public Double calculateCost(Hypothesis<RealMatrix, Double, Double> hypothesis, TrainingExample<Double, Double>... trainingExamples) throws Exception { + return calculateErrorTerm(hypothesis, trainingExamples); + } +} Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/core/CrossEntropyCostFunction.java ------------------------------------------------------------------------------ svn:eol-style = native Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java?rev=1714080&r1=1714079&r2=1714080&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java Thu Nov 12 16:11:25 2015 @@ -105,7 +105,7 @@ public class WordVectorsTest { activationFunctions.put(1, new SoftmaxActivationFunction()); FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(activationFunctions); BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.01d, 1, - BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LogisticRegressionCostFunction(), + BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new CrossEntropyCostFunction(), trainingSet.size()); NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(randomWeights, learningStrategy, predictionStrategy); @@ -238,9 +238,8 @@ public class WordVectorsTest { } } - private TrainingSet<Double, Double> createTrainingSet(final List<String> vocabulary, Queue<List<byte[]>> fragments, int window) { + private TrainingSet<Double, Double> createTrainingSet(final List<String> vocabulary, Queue<List<byte[]>> fragments, int window) throws IOException { long start = System.currentTimeMillis(); - Path file = Paths.get("/Users/teofili/Desktop/ts.txt"); Collection<TrainingExample<Double, Double>> samples = new LinkedList<>(); List<byte[]> fragment; while ((fragment = fragments.poll()) != null) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org