Author: tommaso Date: Thu Mar 3 11:00:32 2016 New Revision: 1733443 URL: http://svn.apache.org/viewvc?rev=1733443&view=rev Log: fixed momentum impl, more appropriate softmax usage
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java?rev=1733443&r1=1733442&r2=1733443&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java Thu Mar 3 11:00:32 2016 @@ -112,7 +112,9 @@ public class SkipGramNetwork { for (int d = 0; d < configuration.window - 1; d++) { int startColumn = d * len / (configuration.window - 1); RealMatrix subMatrix = scores.getSubMatrix(0, scores.getRowDimension() - 1, startColumn, startColumn + input.length); - probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix).getData(), 0, startColumn); + for (int sm = 0; sm < subMatrix.getRowDimension(); sm++) { + probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix.getRowMatrix(sm)).getData(), sm, startColumn); + } } RealVector d = probs.getRowVector(0); @@ -240,7 +242,7 @@ public class SkipGramNetwork { } } - configuration.alpha = configuration.alpha * 0.5; +// configuration.alpha = configuration.alpha * 0.5; RealMatrix w0t = weights[0].transpose(); final RealMatrix w1t = weights[1].transpose(); @@ -285,7 +287,9 @@ public class SkipGramNetwork { for (int d = 0; d < configuration.window - 1; d++) { int startColumn = d * len / (configuration.window - 1); RealMatrix subMatrix = scores.getSubMatrix(0, scores.getRowDimension() - 1, startColumn, startColumn + x.getColumnDimension()); - probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix).getData(), 0, startColumn); + for (int sm = 0; sm < subMatrix.getRowDimension(); sm++) { + probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix.getRowMatrix(sm)).getData(), sm, startColumn); + } } RealMatrix correctLogProbs = MatrixUtils.createRealMatrix(x.getRowDimension(), 1); @@ -510,7 +514,7 @@ public class SkipGramNetwork { @Override public double visit(int row, int column, double value) { - return configuration.mu * value - configuration.alpha + db.getEntry(row, column); + return configuration.mu * value - configuration.alpha * db.getEntry(row, column); } @Override @@ -527,7 +531,7 @@ public class SkipGramNetwork { @Override public double visit(int row, int column, double value) { - return configuration.mu * value - configuration.alpha + db2.getEntry(row, column); + return configuration.mu * value - configuration.alpha * db2.getEntry(row, column); } @Override @@ -545,7 +549,7 @@ public class SkipGramNetwork { @Override public double visit(int row, int column, double value) { - return configuration.mu * value - configuration.alpha + dWt.getEntry(row, column); + return configuration.mu * value - configuration.alpha * dWt.getEntry(row, column); } @Override @@ -563,7 +567,7 @@ public class SkipGramNetwork { @Override public double visit(int row, int column, double value) { - return configuration.mu * value + configuration.alpha - dWt2.getEntry(row, column); + return configuration.mu * value - configuration.alpha * dWt2.getEntry(row, column); } @Override Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java?rev=1733443&r1=1733442&r2=1733443&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java Thu Mar 3 11:00:32 2016 @@ -44,11 +44,13 @@ public class SkipGramNetworkTest { public void testWordVectorsLearningOnAbstracts() throws Exception { Path path = Paths.get(getClass().getResource("/word2vec/abstracts.txt").getFile()); SkipGramNetwork network = SkipGramNetwork.newModel(). - withWindow(3). + withWindow(4). fromTextAt(path). withDimension(10). - withAlpha(1). - withLambda(0.003). + withAlpha(0.09). + withLambda(0.03). + useMomentum(true). + withMu(0.9). withMaxIterations(500). build(); RealMatrix wv = network.getWeights()[0]; @@ -62,11 +64,13 @@ public class SkipGramNetworkTest { public void testWordVectorsLearningOnSentences() throws Exception { Path path = Paths.get(getClass().getResource("/word2vec/sentences.txt").getFile()); SkipGramNetwork network = SkipGramNetwork.newModel(). - withWindow(3). + withWindow(4). fromTextAt(path). withDimension(10). - withAlpha(1). - withLambda(0.03). + withAlpha(0.001). + withLambda(0.003). + useMomentum(true). + withMu(0.9). withMaxIterations(500). build(); RealMatrix wv = network.getWeights()[0]; @@ -83,10 +87,12 @@ public class SkipGramNetworkTest { withWindow(3). fromTextAt(path). withDimension(2). - withAlpha(1). + withAlpha(0.0008). withLambda(0.03). - withThreshold(0.000003). - withMaxIterations(1000). + useMomentum(true). + withMu(0.9). + withThreshold(0.00000000003). + withMaxIterations(10000). build(); System.err.println("accuracy: " + SkipGramNetwork.evaluate(network)); RealMatrix wv = network.getWeights()[0]; --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org