You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2017/03/10 09:01:11 UTC
svn commit: r1786303 - in /labs/yay/trunk/core/src:
main/java/org/apache/yay/StackedRNN.java
test/java/org/apache/yay/RNNCrossValidationTest.java
Author: tommaso
Date: Fri Mar 10 09:01:11 2017
New Revision: 1786303
URL: http://svn.apache.org/viewvc?rev=1786303&view=rev
Log:
improved sRNN
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java
labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java?rev=1786303&r1=1786302&r2=1786303&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Fri Mar 10 09:01:11 2017
@@ -50,6 +50,7 @@ public class StackedRNN extends RNN {
private final INDArray whh; // hidden to hidden
private final INDArray whh2; // hidden to hidden2
private final INDArray wh2y; // hidden2 to output
+ private final INDArray wxh2;
private final INDArray bh; // hidden bias
private final INDArray bh2; // hidden2 bias
private final INDArray by; // output bias
@@ -69,6 +70,7 @@ public class StackedRNN extends RNN {
wxh = Nd4j.randn(hiddenLayerSize, vocabSize).div(Math.sqrt(hiddenLayerSize));
whh = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize));
whh2 = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize));
+ wxh2 = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize));
wh2y = Nd4j.randn(vocabSize, hiddenLayerSize).div(Math.sqrt(vocabSize));
bh = Nd4j.zeros(hiddenLayerSize, 1);
bh2 = Nd4j.zeros(hiddenLayerSize, 1);
@@ -84,6 +86,7 @@ public class StackedRNN extends RNN {
// memory variables for Adagrad
INDArray mWxh = Nd4j.zerosLike(wxh);
+ INDArray mWxh2 = Nd4j.zerosLike(wxh2);
INDArray mWhh = Nd4j.zerosLike(whh);
INDArray mWhh2 = Nd4j.zerosLike(whh2);
INDArray mWh2y = Nd4j.zerosLike(wh2y);
@@ -118,6 +121,7 @@ public class StackedRNN extends RNN {
}
INDArray dWxh = Nd4j.zerosLike(wxh);
+ INDArray dWxh2 = Nd4j.zerosLike(wxh2);
INDArray dWhh = Nd4j.zerosLike(whh);
INDArray dWhh2 = Nd4j.zerosLike(whh2);
INDArray dWh2y = Nd4j.zerosLike(wh2y);
@@ -127,7 +131,7 @@ public class StackedRNN extends RNN {
INDArray dby = Nd4j.zerosLike(by);
// forward seqLength characters through the net and fetch gradient
- double loss = lossFun(inputs, targets, dWxh, dWhh, dWhh2, dWh2y, dbh, dbh2, dby);
+ double loss = lossFun(inputs, targets, dWxh, dWhh, dWxh2, dWhh2, dWh2y, dbh, dbh2, dby);
smoothLoss = smoothLoss * 0.999 + loss * 0.001;
if (Double.isNaN(smoothLoss) || Double.isInfinite(smoothLoss)) {
System.out.println("loss is " + smoothLoss + " (over/underflow occured, try adjusting hyperparameters)");
@@ -141,6 +145,9 @@ public class StackedRNN extends RNN {
mWxh.addi(dWxh.mul(dWxh));
wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh.add(reg))));
+ mWxh2.addi(dWxh2.mul(dWxh2));
+ wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2.add(reg))));
+
mWhh.addi(dWhh.mul(dWhh));
whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg))));
@@ -169,7 +176,7 @@ public class StackedRNN extends RNN {
* hprev is Hx1 array of initial hidden state
* returns the loss, gradients on model parameters and last hidden state
*/
- private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, INDArray dWhh, INDArray dWhh2, INDArray dWh2y,
+ private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, INDArray dWhh, INDArray dWxh2, INDArray dWhh2, INDArray dWh2y,
INDArray dbh, INDArray dbh2, INDArray dby) {
INDArray xs = Nd4j.zeros(seqLength, vocabSize);
@@ -193,13 +200,13 @@ public class StackedRNN extends RNN {
}
hs.putRow(t, hPrev.dup());
- hPrev2 = Transforms.tanh((whh.mmul(hs.getRow(t)).add(whh2.mmul(hPrev2)).add(bh2))); // hidden state 2
+ hPrev2 = Transforms.tanh((wxh2.mmul(hPrev).add(whh2.mmul(hPrev2)).add(bh2))); // hidden state 2
if (hs2 == null) {
hs2 = init(seqLength, hPrev2.shape());
}
hs2.putRow(t, hPrev2.dup());
- INDArray yst = wh2y.mmul(hs2.getRow(t)).add(by); // unnormalized log probabilities for next chars
+ INDArray yst = wh2y.mmul(hPrev2).add(by); // unnormalized log probabilities for next chars
if (ys == null) {
ys = init(seqLength, yst.shape());
}
@@ -231,11 +238,11 @@ public class StackedRNN extends RNN {
INDArray dhraw2 = (Nd4j.ones(hs2t.shape()).sub(hs2t.mul(hs2t))).mul(dh2); // backprop through tanh nonlinearity
dbh2.addi(dhraw2);
INDArray hst = hs.getRow(t);
- dWhh.addi(dhraw2.mmul(hst.transpose()));
+ dWxh2.addi(dhraw2.mmul(hst.transpose()));
dWhh2.addi(dhraw2.mmul(hs2tm1.transpose()));
dh2Next = whh2.transpose().mmul(dhraw2);
- INDArray dh = dh2Next.add(dhNext); // backprop into h
+ INDArray dh = wxh2.transpose().mmul(dhraw2).add(dhNext); // backprop into h
INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst.mul(hst))).mul(dh); // backprop through tanh nonlinearity
dbh.addi(dhraw);
dWxh.addi(dhraw.mmul(xs.getRow(t)));
@@ -250,13 +257,14 @@ public class StackedRNN extends RNN {
// clip exploding gradients
int clip = 5;
- Nd4j.getExecutioner().exec(new SetRange(dWxh, -clip, clip));
- Nd4j.getExecutioner().exec(new SetRange(dWhh, -clip, clip));
- Nd4j.getExecutioner().exec(new SetRange(dWhh2, -clip, clip));
- Nd4j.getExecutioner().exec(new SetRange(dWh2y, -clip, clip));
- Nd4j.getExecutioner().exec(new SetRange(dbh, -clip, clip));
- Nd4j.getExecutioner().exec(new SetRange(dbh2, -clip, clip));
- Nd4j.getExecutioner().exec(new SetRange(dby, -clip, clip));
+ dWxh = Nd4j.getExecutioner().execAndReturn(new SetRange(dWxh, -clip, clip));
+ dWxh2 = Nd4j.getExecutioner().execAndReturn(new SetRange(dWxh2, -clip, clip));
+ dWhh = Nd4j.getExecutioner().execAndReturn(new SetRange(dWhh, -clip, clip));
+ dWhh2 = Nd4j.getExecutioner().execAndReturn(new SetRange(dWhh2, -clip, clip));
+ dWh2y = Nd4j.getExecutioner().execAndReturn(new SetRange(dWh2y, -clip, clip));
+ dbh = Nd4j.getExecutioner().execAndReturn(new SetRange(dbh, -clip, clip));
+ dbh2 = Nd4j.getExecutioner().execAndReturn(new SetRange(dbh2, -clip, clip));
+ dby = Nd4j.getExecutioner().execAndReturn(new SetRange(dby, -clip, clip));
return loss;
}
@@ -275,8 +283,8 @@ public class StackedRNN extends RNN {
INDArray h2 = hPrev2.dup();
for (int t = 0; t < sampleSize; t++) {
- h = Transforms.tanh(((wxh.mmul(x)).add((whh.mmul(h)).add(bh))));
- h2 = Transforms.tanh(((whh.mmul(h)).add((whh2.mmul(h2)).add(bh2))));
+ h = Transforms.tanh((wxh.mmul(x)).add(whh.mmul(h)).add(bh));
+ h2 = Transforms.tanh((wxh2.mmul(h)).add(whh2.mmul(h2)).add(bh2));
INDArray y = wh2y.mmul(h2).add(by);
INDArray pm = Nd4j.getExecutioner().execAndReturn(new SoftMax(y)).ravel();
Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java?rev=1786303&r1=1786302&r2=1786303&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java Fri Mar 10 09:01:11 2017
@@ -61,11 +61,7 @@ public class RNNCrossValidationTest {
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][]{
- {1e-1f, 50, 15},
- {1e-1f, 50, 25},
- {1e-1f, 50, 50},
- {1e-1f, 50, 100},
- {1e-1f, 50, 150},
+ {1e-1f, 25, 100},
});
}
@@ -83,20 +79,6 @@ public class RNNCrossValidationTest {
rnn.serialize("target/wrnn-weights-");
}
- @Test
- public void testStackedCharRNNLearn() throws Exception {
- RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text);
- evaluate(rnn, true);
- rnn.serialize("target/scrnn-weights-");
- }
-
- @Test
- public void testStackedWordRNNLearn() throws Exception {
- RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, false);
- evaluate(rnn, false);
- rnn.serialize("target/swrnn-weights-");
- }
-
private void evaluate(RNN rnn, boolean checkRatio) {
System.out.println(rnn);
rnn.learn();
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org