You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2017/03/10 09:01:11 UTC

svn commit: r1786303 - in /labs/yay/trunk/core/src: main/java/org/apache/yay/StackedRNN.java test/java/org/apache/yay/RNNCrossValidationTest.java

Author: tommaso
Date: Fri Mar 10 09:01:11 2017
New Revision: 1786303

URL: http://svn.apache.org/viewvc?rev=1786303&view=rev
Log:
improved sRNN

Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java?rev=1786303&r1=1786302&r2=1786303&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Fri Mar 10 09:01:11 2017
@@ -50,6 +50,7 @@ public class StackedRNN extends RNN {
   private final INDArray whh; // hidden to hidden
   private final INDArray whh2; // hidden to hidden2
   private final INDArray wh2y; // hidden2 to output
+  private final INDArray wxh2;
   private final INDArray bh; // hidden bias
   private final INDArray bh2; // hidden2 bias
   private final INDArray by; // output bias
@@ -69,6 +70,7 @@ public class StackedRNN extends RNN {
     wxh = Nd4j.randn(hiddenLayerSize, vocabSize).div(Math.sqrt(hiddenLayerSize));
     whh = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize));
     whh2 = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize));
+    wxh2 = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize));
     wh2y = Nd4j.randn(vocabSize, hiddenLayerSize).div(Math.sqrt(vocabSize));
     bh = Nd4j.zeros(hiddenLayerSize, 1);
     bh2 = Nd4j.zeros(hiddenLayerSize, 1);
@@ -84,6 +86,7 @@ public class StackedRNN extends RNN {
 
     // memory variables for Adagrad
     INDArray mWxh = Nd4j.zerosLike(wxh);
+    INDArray mWxh2 = Nd4j.zerosLike(wxh2);
     INDArray mWhh = Nd4j.zerosLike(whh);
     INDArray mWhh2 = Nd4j.zerosLike(whh2);
     INDArray mWh2y = Nd4j.zerosLike(wh2y);
@@ -118,6 +121,7 @@ public class StackedRNN extends RNN {
       }
 
       INDArray dWxh = Nd4j.zerosLike(wxh);
+      INDArray dWxh2 = Nd4j.zerosLike(wxh2);
       INDArray dWhh = Nd4j.zerosLike(whh);
       INDArray dWhh2 = Nd4j.zerosLike(whh2);
       INDArray dWh2y = Nd4j.zerosLike(wh2y);
@@ -127,7 +131,7 @@ public class StackedRNN extends RNN {
       INDArray dby = Nd4j.zerosLike(by);
 
       // forward seqLength characters through the net and fetch gradient
-      double loss = lossFun(inputs, targets, dWxh, dWhh, dWhh2, dWh2y, dbh, dbh2, dby);
+      double loss = lossFun(inputs, targets, dWxh, dWhh, dWxh2, dWhh2, dWh2y, dbh, dbh2, dby);
       smoothLoss = smoothLoss * 0.999 + loss * 0.001;
       if (Double.isNaN(smoothLoss) || Double.isInfinite(smoothLoss)) {
         System.out.println("loss is " + smoothLoss + " (over/underflow occured, try adjusting hyperparameters)");
@@ -141,6 +145,9 @@ public class StackedRNN extends RNN {
       mWxh.addi(dWxh.mul(dWxh));
       wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh.add(reg))));
 
+      mWxh2.addi(dWxh2.mul(dWxh2));
+      wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2.add(reg))));
+
       mWhh.addi(dWhh.mul(dWhh));
       whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg))));
 
@@ -169,7 +176,7 @@ public class StackedRNN extends RNN {
    * hprev is Hx1 array of initial hidden state
    * returns the loss, gradients on model parameters and last hidden state
    */
-  private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, INDArray dWhh, INDArray dWhh2, INDArray dWh2y,
+  private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, INDArray dWhh,  INDArray dWxh2, INDArray dWhh2, INDArray dWh2y,
                          INDArray dbh, INDArray dbh2, INDArray dby) {
 
     INDArray xs = Nd4j.zeros(seqLength, vocabSize);
@@ -193,13 +200,13 @@ public class StackedRNN extends RNN {
       }
       hs.putRow(t, hPrev.dup());
 
-      hPrev2 = Transforms.tanh((whh.mmul(hs.getRow(t)).add(whh2.mmul(hPrev2)).add(bh2))); // hidden state 2
+      hPrev2 = Transforms.tanh((wxh2.mmul(hPrev).add(whh2.mmul(hPrev2)).add(bh2))); // hidden state 2
       if (hs2 == null) {
         hs2 = init(seqLength, hPrev2.shape());
       }
       hs2.putRow(t, hPrev2.dup());
 
-      INDArray yst = wh2y.mmul(hs2.getRow(t)).add(by); // unnormalized log probabilities for next chars
+      INDArray yst = wh2y.mmul(hPrev2).add(by); // unnormalized log probabilities for next chars
       if (ys == null) {
         ys = init(seqLength, yst.shape());
       }
@@ -231,11 +238,11 @@ public class StackedRNN extends RNN {
       INDArray dhraw2 = (Nd4j.ones(hs2t.shape()).sub(hs2t.mul(hs2t))).mul(dh2); //  backprop through tanh nonlinearity
       dbh2.addi(dhraw2);
       INDArray hst = hs.getRow(t);
-      dWhh.addi(dhraw2.mmul(hst.transpose()));
+      dWxh2.addi(dhraw2.mmul(hst.transpose()));
       dWhh2.addi(dhraw2.mmul(hs2tm1.transpose()));
       dh2Next = whh2.transpose().mmul(dhraw2);
 
-      INDArray dh = dh2Next.add(dhNext); // backprop into h
+      INDArray dh = wxh2.transpose().mmul(dhraw2).add(dhNext); // backprop into h
       INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst.mul(hst))).mul(dh); // backprop through tanh nonlinearity
       dbh.addi(dhraw);
       dWxh.addi(dhraw.mmul(xs.getRow(t)));
@@ -250,13 +257,14 @@ public class StackedRNN extends RNN {
 
     // clip exploding gradients
     int clip = 5;
-    Nd4j.getExecutioner().exec(new SetRange(dWxh, -clip, clip));
-    Nd4j.getExecutioner().exec(new SetRange(dWhh, -clip, clip));
-    Nd4j.getExecutioner().exec(new SetRange(dWhh2, -clip, clip));
-    Nd4j.getExecutioner().exec(new SetRange(dWh2y, -clip, clip));
-    Nd4j.getExecutioner().exec(new SetRange(dbh, -clip, clip));
-    Nd4j.getExecutioner().exec(new SetRange(dbh2, -clip, clip));
-    Nd4j.getExecutioner().exec(new SetRange(dby, -clip, clip));
+    dWxh = Nd4j.getExecutioner().execAndReturn(new SetRange(dWxh, -clip, clip));
+    dWxh2 = Nd4j.getExecutioner().execAndReturn(new SetRange(dWxh2, -clip, clip));
+    dWhh = Nd4j.getExecutioner().execAndReturn(new SetRange(dWhh, -clip, clip));
+    dWhh2 = Nd4j.getExecutioner().execAndReturn(new SetRange(dWhh2, -clip, clip));
+    dWh2y = Nd4j.getExecutioner().execAndReturn(new SetRange(dWh2y, -clip, clip));
+    dbh = Nd4j.getExecutioner().execAndReturn(new SetRange(dbh, -clip, clip));
+    dbh2 = Nd4j.getExecutioner().execAndReturn(new SetRange(dbh2, -clip, clip));
+    dby = Nd4j.getExecutioner().execAndReturn(new SetRange(dby, -clip, clip));
     return loss;
   }
 
@@ -275,8 +283,8 @@ public class StackedRNN extends RNN {
     INDArray h2 = hPrev2.dup();
 
     for (int t = 0; t < sampleSize; t++) {
-      h = Transforms.tanh(((wxh.mmul(x)).add((whh.mmul(h)).add(bh))));
-      h2 = Transforms.tanh(((whh.mmul(h)).add((whh2.mmul(h2)).add(bh2))));
+      h = Transforms.tanh((wxh.mmul(x)).add(whh.mmul(h)).add(bh));
+      h2 = Transforms.tanh((wxh2.mmul(h)).add(whh2.mmul(h2)).add(bh2));
       INDArray y = wh2y.mmul(h2).add(by);
       INDArray pm = Nd4j.getExecutioner().execAndReturn(new SoftMax(y)).ravel();
 

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java?rev=1786303&r1=1786302&r2=1786303&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java Fri Mar 10 09:01:11 2017
@@ -61,11 +61,7 @@ public class RNNCrossValidationTest {
   @Parameterized.Parameters
   public static Collection<Object[]> data() {
     return Arrays.asList(new Object[][]{
-            {1e-1f, 50, 15},
-            {1e-1f, 50, 25},
-            {1e-1f, 50, 50},
-            {1e-1f, 50, 100},
-            {1e-1f, 50, 150},
+            {1e-1f, 25, 100},
     });
   }
 
@@ -83,20 +79,6 @@ public class RNNCrossValidationTest {
     rnn.serialize("target/wrnn-weights-");
   }
 
-  @Test
-  public void testStackedCharRNNLearn() throws Exception {
-    RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text);
-    evaluate(rnn, true);
-    rnn.serialize("target/scrnn-weights-");
-  }
-
-  @Test
-  public void testStackedWordRNNLearn() throws Exception {
-    RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, false);
-    evaluate(rnn, false);
-    rnn.serialize("target/swrnn-weights-");
-  }
-
   private void evaluate(RNN rnn, boolean checkRatio) {
     System.out.println(rnn);
     rnn.learn();



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org