Author: tommaso
Date: Wed Oct 12 16:39:03 2016
New Revision: 1764488

URL: http://svn.apache.org/viewvc?rev=1764488&view=rev
Log:
fixed sample generation

Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java?rev=1764488&r1=1764487&r2=1764488&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java Wed Oct 12 
16:39:03 2016
@@ -47,7 +47,7 @@ public class RNN {
   protected final int seqLength; // no. of steps to unroll the RNN for
   protected final int hiddenLayerSize;
   protected final int epochs;
-  private final boolean useChars;
+  protected final boolean useChars;
   protected final int vocabSize;
   protected final Map<String, Integer> charToIx;
   protected final Map<Integer, String> ixToChar;
@@ -307,6 +307,10 @@ public class RNN {
       ixes.putScalar(t, ix);
     }
 
+    return getSampleString(ixes);
+  }
+
+  protected String getSampleString(INDArray ixes) {
     StringBuilder txt = new StringBuilder();
 
     NdIndexIterator ndIndexIterator = new NdIndexIterator(ixes.shape());

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java?rev=1764488&r1=1764487&r2=1764488&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Wed Oct 12 
16:39:03 2016
@@ -20,7 +20,6 @@ package org.apache.yay;
 
 import org.apache.commons.math3.distribution.EnumeratedDistribution;
 import org.apache.commons.math3.util.Pair;
-import org.nd4j.linalg.api.iter.NdIndexIterator;
 import org.nd4j.linalg.api.ndarray.INDArray;
 import org.nd4j.linalg.api.ops.impl.transforms.SetRange;
 import org.nd4j.linalg.factory.Nd4j;
@@ -283,14 +282,7 @@ public class StackedRNN extends RNN {
       ixes.putScalar(t, ix);
     }
 
-    StringBuilder txt = new StringBuilder();
-
-    NdIndexIterator ndIndexIterator = new NdIndexIterator(ixes.shape());
-    while (ndIndexIterator.hasNext()) {
-      int[] next = ndIndexIterator.next();
-      txt.append(ixToChar.get(ixes.getInt(next)));
-    }
-    return txt.toString();
+    return getSampleString(ixes);
   }
 
 }



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org
For additional commands, e-mail: commits-h...@labs.apache.org

Reply via email to