You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@opennlp.apache.org by to...@apache.org on 2017/07/01 12:13:15 UTC
opennlp-sandbox git commit: removed useless state update, minor fixes
Repository: opennlp-sandbox
Updated Branches:
refs/heads/master fe2b1d920 -> 6f0659f2a
removed useless state update, minor fixes
Project: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/repo
Commit: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/commit/6f0659f2
Tree: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/tree/6f0659f2
Diff: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/diff/6f0659f2
Branch: refs/heads/master
Commit: 6f0659f2ad2f3186ff0b266203ae960659ad1d98
Parents: fe2b1d9
Author: Tommaso Teofili <te...@adobe.com>
Authored: Sat Jul 1 14:12:48 2017 +0200
Committer: Tommaso Teofili <te...@adobe.com>
Committed: Sat Jul 1 14:12:48 2017 +0200
----------------------------------------------------------------------
.../src/main/java/opennlp/tools/dl/StackedRNN.java | 17 ++++++++---------
1 file changed, 8 insertions(+), 9 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/blob/6f0659f2/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
----------------------------------------------------------------------
diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
index 889fac1..e9a5f7e 100644
--- a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
+++ b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java
@@ -29,6 +29,7 @@ import java.util.List;
import org.apache.commons.math3.distribution.EnumeratedDistribution;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.impl.transforms.ReplaceNans;
import org.nd4j.linalg.api.ops.impl.transforms.SoftMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
@@ -118,8 +119,10 @@ public class StackedRNN extends RNN {
// sample from the model every now and then
if (n % 1000 == 0 && n > 0) {
- String txt = sample(inputs.getInt(0));
- System.out.printf("\n---\n %s \n----\n", txt);
+ for (int i = 0; i < 3; i++) {
+ String txt = sample(inputs.getInt(0));
+ System.out.printf("\n---\n %s \n----\n", txt);
+ }
}
INDArray dWxh = Nd4j.zerosLike(wxh);
@@ -171,7 +174,6 @@ public class StackedRNN extends RNN {
by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(eps)));
} else {
// perform parameter update with Adagrad
-
mWxh.addi(dWxh.mul(dWxh));
wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh).add(eps)));
@@ -244,13 +246,13 @@ public class StackedRNN extends RNN {
}
ys.putRow(t, yst);
- INDArray pst = Nd4j.getExecutioner().execAndReturn(new SoftMax(yst)); // probabilities for next chars
+ INDArray pst = Nd4j.getExecutioner().execAndReturn(new ReplaceNans(Nd4j.getExecutioner().execAndReturn(new SoftMax(yst)), 0d)); // probabilities for next chars
if (ps == null) {
ps = init(seqLength, pst.shape());
}
ps.putRow(t, pst);
- loss += -Math.log(pst.getDouble(targets.getInt(t))); // softmax (cross-entropy loss)
+ loss += -Math.log(pst.getDouble(targets.getInt(t),0)); // softmax (cross-entropy loss)
}
// backward pass: compute gradients going backwards
@@ -284,9 +286,6 @@ public class StackedRNN extends RNN {
dhNext = whh.transpose().mmul(dhraw);
}
- this.hPrev = hs.getRow(seqLength - 1);
- this.hPrev2 = hs2.getRow(seqLength - 1);
-
return loss;
}
@@ -298,7 +297,7 @@ public class StackedRNN extends RNN {
INDArray x = Nd4j.zeros(vocabSize, 1);
x.putScalar(seedIx, 1);
- int sampleSize = seqLength * 2;
+ int sampleSize = 100;
INDArray ixes = Nd4j.create(sampleSize);
INDArray h = hPrev.dup();