Author: tommaso Date: Wed Oct 12 16:39:03 2016 New Revision: 1764488 URL: http://svn.apache.org/viewvc?rev=1764488&view=rev Log: fixed sample generation
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java?rev=1764488&r1=1764487&r2=1764488&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java Wed Oct 12 16:39:03 2016 @@ -47,7 +47,7 @@ public class RNN { protected final int seqLength; // no. of steps to unroll the RNN for protected final int hiddenLayerSize; protected final int epochs; - private final boolean useChars; + protected final boolean useChars; protected final int vocabSize; protected final Map<String, Integer> charToIx; protected final Map<Integer, String> ixToChar; @@ -307,6 +307,10 @@ public class RNN { ixes.putScalar(t, ix); } + return getSampleString(ixes); + } + + protected String getSampleString(INDArray ixes) { StringBuilder txt = new StringBuilder(); NdIndexIterator ndIndexIterator = new NdIndexIterator(ixes.shape()); Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java?rev=1764488&r1=1764487&r2=1764488&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Wed Oct 12 16:39:03 2016 @@ -20,7 +20,6 @@ package org.apache.yay; import org.apache.commons.math3.distribution.EnumeratedDistribution; import org.apache.commons.math3.util.Pair; -import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.SetRange; import org.nd4j.linalg.factory.Nd4j; @@ -283,14 +282,7 @@ public class StackedRNN extends RNN { ixes.putScalar(t, ix); } - StringBuilder txt = new StringBuilder(); - - NdIndexIterator ndIndexIterator = new NdIndexIterator(ixes.shape()); - while (ndIndexIterator.hasNext()) { - int[] next = ndIndexIterator.next(); - txt.append(ixToChar.get(ixes.getInt(next))); - } - return txt.toString(); + return getSampleString(ixes); } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org