You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by to...@apache.org on 2017/07/01 12:13:15 UTC

opennlp-sandbox git commit: removed useless state update, minor fixes

Repository: opennlp-sandbox
Updated Branches:
  refs/heads/master fe2b1d920 -> 6f0659f2a


removed useless state update, minor fixes


Project: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/repo
Commit: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/commit/6f0659f2
Tree: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/tree/6f0659f2
Diff: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/diff/6f0659f2

Branch: refs/heads/master
Commit: 6f0659f2ad2f3186ff0b266203ae960659ad1d98
Parents: fe2b1d9
Author: Tommaso Teofili <te...@adobe.com>
Authored: Sat Jul 1 14:12:48 2017 +0200
Committer: Tommaso Teofili <te...@adobe.com>
Committed: Sat Jul 1 14:12:48 2017 +0200

----------------------------------------------------------------------
 .../src/main/java/opennlp/tools/dl/StackedRNN.java | 17 ++++++++---------
 1 file changed, 8 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/blob/6f0659f2/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
----------------------------------------------------------------------
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
index 889fac1..e9a5f7e 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
@@ -29,6 +29,7 @@ import java.util.List;
 import org.apache.commons.math3.distribution.EnumeratedDistribution;
 import org.apache.commons.math3.util.Pair;
 import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.impl.transforms.ReplaceNans;
 import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
 import org.nd4j.linalg.factory.Nd4j;
 import org.nd4j.linalg.ops.transforms.Transforms;
@@ -118,8 +119,10 @@ public class StackedRNN extends RNN {
 
       // sample from the model every now and then
       if (n % 1000 == 0 && n > 0) {
-        String txt = sample(inputs.getInt(0));
-        System.out.printf("\n---\n %s \n----\n", txt);
+        for (int i = 0; i < 3; i++) {
+          String txt = sample(inputs.getInt(0));
+          System.out.printf("\n---\n %s \n----\n", txt);
+        }
       }
 
       INDArray dWxh = Nd4j.zerosLike(wxh);
@@ -171,7 +174,6 @@ public class StackedRNN extends RNN {
           by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(eps)));
         } else {
           // perform parameter update with Adagrad
-
           mWxh.addi(dWxh.mul(dWxh));
           wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh).add(eps)));
 
@@ -244,13 +246,13 @@ public class StackedRNN extends RNN {
       }
       ys.putRow(t, yst);
 
-      INDArray pst = Nd4j.getExecutioner().execAndReturn(new SoftMax(yst)); // probabilities for next chars
+      INDArray pst = Nd4j.getExecutioner().execAndReturn(new ReplaceNans(Nd4j.getExecutioner().execAndReturn(new SoftMax(yst)), 0d)); // probabilities for next chars
       if (ps == null) {
         ps = init(seqLength, pst.shape());
       }
       ps.putRow(t, pst);
 
-      loss += -Math.log(pst.getDouble(targets.getInt(t))); // softmax (cross-entropy loss)
+      loss += -Math.log(pst.getDouble(targets.getInt(t),0)); // softmax (cross-entropy loss)
     }
 
     // backward pass: compute gradients going backwards
@@ -284,9 +286,6 @@ public class StackedRNN extends RNN {
       dhNext = whh.transpose().mmul(dhraw);
     }
 
-    this.hPrev = hs.getRow(seqLength - 1);
-    this.hPrev2 = hs2.getRow(seqLength - 1);
-
     return loss;
   }
 
@@ -298,7 +297,7 @@ public class StackedRNN extends RNN {
 
     INDArray x = Nd4j.zeros(vocabSize, 1);
     x.putScalar(seedIx, 1);
-    int sampleSize = seqLength * 2;
+    int sampleSize = 100;
     INDArray ixes = Nd4j.create(sampleSize);
 
     INDArray h = hPrev.dup();