You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2019/03/05 20:53:28 UTC
[systemml] branch master updated: [SYSTEMML-540] Added an initial
CP operator for lstm_backward builtin function
This is an automated email from the ASF dual-hosted git repository.
niketanpansare pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push:
new b4ef84b [SYSTEMML-540] Added an initial CP operator for lstm_backward builtin function
b4ef84b is described below
commit b4ef84ba2568dc96fe20d1a45faeb4c1bd2d47b4
Author: Niketan Pansare <np...@us.ibm.com>
AuthorDate: Tue Mar 5 12:52:11 2019 -0800
[SYSTEMML-540] Added an initial CP operator for lstm_backward builtin function
---
scripts/nn/layers/lstm_staging.dml | 6 -
.../java/org/apache/sysml/hops/FunctionOp.java | 9 +-
.../runtime/instructions/cp/DnnCPInstruction.java | 104 +++++++
.../instructions/gpu/DnnGPUInstruction.java | 42 +--
.../sysml/runtime/matrix/data/LibMatrixDNN.java | 341 ++++++++++++++++++++-
.../org/apache/sysml/test/gpu/LstmCPUTest.java | 189 ++++++++----
6 files changed, 575 insertions(+), 116 deletions(-)
diff --git a/scripts/nn/layers/lstm_staging.dml b/scripts/nn/layers/lstm_staging.dml
index 886b88c..2f71f22 100644
--- a/scripts/nn/layers/lstm_staging.dml
+++ b/scripts/nn/layers/lstm_staging.dml
@@ -92,12 +92,6 @@ backward = function(matrix[double] dout, matrix[double] dc,
* Note: This is *optional* and could just be an empty matrix.
* - c0: Initial cell state, of shape (N, M).
* Note: This is *optional* and could just be an empty matrix.
- * - cache_out: Cache of outputs, of shape (T, N*M).
- * Note: This is used for performance during training.
- * - cache_c: Cache of cell state, of shape (T, N*M).
- * Note: This is used for performance during training.
- * - cache_ifog: Cache of intermediate values, of shape (T, N*4*M).
- * Note: This is used for performance during training.
*
* Outputs:
* - dX: Gradient wrt `X`, of shape (N, T*D).
diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java b/src/main/java/org/apache/sysml/hops/FunctionOp.java
index 5fdc8e7..dedbad6 100644
--- a/src/main/java/org/apache/sysml/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java
@@ -282,13 +282,6 @@ public class FunctionOp extends MultiThreadedHop
if(getFunctionType() == FunctionType.MULTIRETURN_BUILTIN && isBuiltinFunction() &&
(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward"))) {
-
- if(getFunctionName().equalsIgnoreCase("lstm_backward")) {
- if(!ConfigurationManager.isGPU())
- throw new RuntimeException("The function " + getFunctionName() + " is only supported on GPU.");
- _etype = ExecType.GPU;
- }
-
ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
if( _etypeForced != null ) {
@@ -306,7 +299,7 @@ public class FunctionOp extends MultiThreadedHop
checkAndSetInvalidCPDimsAndSize();
}
- // Since lstm builtin functions are not supported on Spark
+ // Since lstm builtin functions are not supported on Spark or MR.
_etype = _etype == REMOTE ? ExecType.CP : _etype;
//mark for recompile (forever)
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
index 93ffd4f..50a11de 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
@@ -274,6 +274,24 @@ public class DnnCPInstruction extends UnaryCPInstruction {
int numThreads = Integer.parseInt(parts[9]);
return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, null, null, null, opcode, str, 0, numThreads);
}
+ else if (opcode.equalsIgnoreCase("lstm_backward")) {
+ InstructionUtils.checkNumFields(parts, 14);
+ CPOperand in1 = new CPOperand(parts[1]); // X
+ CPOperand in2 = new CPOperand(parts[2]); // W
+ CPOperand in3 = new CPOperand(parts[3]); // b
+ CPOperand in4 = new CPOperand(parts[4]); // out0
+ CPOperand in5 = new CPOperand(parts[5]); // c0
+ CPOperand in6 = new CPOperand(parts[6]); // return_seq
+ CPOperand in7 = new CPOperand(parts[7]); // dout
+ CPOperand in8 = new CPOperand(parts[8]); // dc
+ CPOperand out = new CPOperand(parts[9]); // dX
+ CPOperand out2 = new CPOperand(parts[10]); // dW
+ CPOperand out3 = new CPOperand(parts[11]); // db
+ CPOperand out4 = new CPOperand(parts[12]); // dout0
+ CPOperand out5 = new CPOperand(parts[13]); // dc0
+ int numThreads = Integer.parseInt(parts[14]);
+ return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0, numThreads);
+ }
else {
throw new DMLRuntimeException("Unknown opcode while parsing a DnnCPInstruction: " + str);
}
@@ -329,6 +347,88 @@ public class DnnCPInstruction extends UnaryCPInstruction {
ec.setMatrixOutput(_out2.getName(), c, getExtendedOpcode());
}
+ public void processLstmBackwardInstruction(ExecutionContext ec) {
+ MatrixBlock X = ec.getMatrixInput(input1.getName(), getExtendedOpcode());
+ MatrixBlock W = ec.getMatrixInput(_in2.getName(), getExtendedOpcode());
+ MatrixBlock b = ec.getMatrixInput(_in3.getName(), getExtendedOpcode());
+ MatrixBlock out0 = ec.getMatrixInput(_in4.getName(), getExtendedOpcode());
+ MatrixBlock c0 = ec.getMatrixInput(_in5.getName(), getExtendedOpcode());
+ boolean return_seq = ec.getScalarInput(_in6.getName(), _in6.getValueType(), _in6.isLiteral()).getBooleanValue();
+ MatrixBlock dout = ec.getMatrixInput(_in7.getName(), getExtendedOpcode());
+ MatrixBlock dc = ec.getMatrixInput(_in8.getName(), getExtendedOpcode());
+
+ int N = X.getNumRows();
+ int TD = X.getNumColumns();
+ int DPlusM = W.getNumRows();
+ int M = W.getNumColumns() / 4;
+ int D = DPlusM - M;
+ int T = TD / D;
+ if(b.getNumRows() != 1 || b.getNumColumns() != M*4) {
+ throw new DMLRuntimeException("Incorrect dimensions of bias in lstm_backward instruction. Expected [1, " + (M*4) + "], "
+ + "but found [" + b.getNumRows() + "," + b.getNumColumns() + "]");
+ }
+ if(out0.getNumRows() != N) {
+ throw new DMLRuntimeException("Unsupported operation: The batch size of previous iteration " + out0.getNumRows() +
+ " is different than the batch size of current iteration " + N);
+ }
+ if(out0.getNumColumns() != M) {
+ throw new DMLRuntimeException("Incorrect dimensions of out0 in lstm_backward instruction. Expected [" + N + ", " + M + "], "
+ + "but found [" + out0.getNumRows() + "," + out0.getNumColumns() + "]");
+ }
+ if(c0.getNumRows() != N || c0.getNumColumns() != M) {
+ throw new DMLRuntimeException("Incorrect dimensions of c0 in lstm_backward instruction. Expected [" + N + ", " + M + "], "
+ + "but found [" + out0.getNumRows() + "," + out0.getNumColumns() + "]");
+ }
+ if(dout.getNumRows() != N || dout.getNumColumns() != (return_seq ? (T*M) : M)) {
+ throw new DMLRuntimeException("Incorrect dimensions of dout in lstm_backward instruction. Expected [" + N + ", " + (return_seq ? (T*M) : M) + "], "
+ + "but found [" + dout.getNumRows() + "," + dout.getNumColumns() + "]");
+ }
+ if(dc.getNumRows() != N || dc.getNumColumns() != M) {
+ throw new DMLRuntimeException("Incorrect dimensions of dc in lstm_backward instruction. Expected [" + N + ", " + M + "], "
+ + "but found [" + dc.getNumRows() + "," + dc.getNumColumns() + "]");
+ }
+
+ MatrixBlock out = new MatrixBlock(N, return_seq ? (T*M) : M, false);
+ MatrixBlock c = new MatrixBlock(N, M, false);
+ MatrixBlock cache_out = new MatrixBlock(T, N*M, false);
+ MatrixBlock cache_c = new MatrixBlock(T, N*M, false);
+ MatrixBlock cache_ifog = new MatrixBlock(T, N*4*M, false);
+
+ // In the initial implementation, invoke lstm redundantly.
+ // TODO: Optimize this later.
+ cache_out.allocateDenseBlock();
+ cache_c.allocateDenseBlock();
+ cache_ifog.allocateDenseBlock();
+ LibMatrixDNN.lstm(X, W, b, out0, c0,
+ return_seq, N, T, D, M,
+ out, c, cache_out, cache_c, cache_ifog,
+ _numThreads);
+
+ MatrixBlock dX = new MatrixBlock(N, T*D, false);
+ MatrixBlock dW = new MatrixBlock(D+M, 4*M, false);
+ MatrixBlock db = new MatrixBlock(1, 4*M, false);
+ MatrixBlock dout0 = new MatrixBlock(N, M, false);
+ MatrixBlock dc0 = new MatrixBlock(N, M, false);
+ LibMatrixDNN.lstm_backward(dout, dc, X, W, b, out0, c0, return_seq, N, T, D, M,
+ cache_out, cache_c, cache_ifog, // from forward invocation
+ dX, dW, db, dout0, dc0, // output
+ _numThreads);
+
+ // release inputs/outputs
+ ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
+ ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode());
+ ec.releaseMatrixInput(_in3.getName(), getExtendedOpcode());
+ ec.releaseMatrixInput(_in4.getName(), getExtendedOpcode());
+ ec.releaseMatrixInput(_in5.getName(), getExtendedOpcode());
+ ec.releaseMatrixInput(_in7.getName(), getExtendedOpcode());
+ ec.releaseMatrixInput(_in8.getName(), getExtendedOpcode());
+ ec.setMatrixOutput(output.getName(), dX, getExtendedOpcode());
+ ec.setMatrixOutput(_out2.getName(), dW, getExtendedOpcode());
+ ec.setMatrixOutput(_out3.getName(), db, getExtendedOpcode());
+ ec.setMatrixOutput(_out4.getName(), dout0, getExtendedOpcode());
+ ec.setMatrixOutput(_out5.getName(), dc0, getExtendedOpcode());
+ }
+
public void processReluBackwardInstruction(ExecutionContext ec) {
// (X > 0) * dout
MatrixBlock input = ec.getMatrixInput(input1.getName(), getExtendedOpcode());
@@ -498,6 +598,10 @@ public class DnnCPInstruction extends UnaryCPInstruction {
processLstmInstruction(ec);
return;
}
+ else if (instOpcode.equalsIgnoreCase("lstm_backward")) {
+ processLstmBackwardInstruction(ec);
+ return;
+ }
// acquire inputs
MatrixBlock outputBlock = null;
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
index 35c9591..fbe7c9d 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
@@ -361,31 +361,31 @@ public class DnnGPUInstruction extends GPUInstruction {
}
else if (opcode.equalsIgnoreCase("lstm")) {
InstructionUtils.checkNumFields(parts, 8);
- CPOperand in1 = new CPOperand(parts[1]);
- CPOperand in2 = new CPOperand(parts[2]);
- CPOperand in3 = new CPOperand(parts[3]);
- CPOperand in4 = new CPOperand(parts[4]);
- CPOperand in5 = new CPOperand(parts[5]);
- CPOperand in6 = new CPOperand(parts[6]);
- CPOperand out = new CPOperand(parts[7]);
- CPOperand out2 = new CPOperand(parts[8]);
+ CPOperand in1 = new CPOperand(parts[1]); // X
+ CPOperand in2 = new CPOperand(parts[2]); // W
+ CPOperand in3 = new CPOperand(parts[3]); // b
+ CPOperand in4 = new CPOperand(parts[4]); // out0
+ CPOperand in5 = new CPOperand(parts[5]); // c0
+ CPOperand in6 = new CPOperand(parts[6]); // return_seq
+ CPOperand out = new CPOperand(parts[7]); // out
+ CPOperand out2 = new CPOperand(parts[8]); // c
return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, out, out2, opcode, str, 0);
}
else if (opcode.equalsIgnoreCase("lstm_backward")) {
InstructionUtils.checkNumFields(parts, 13);
- CPOperand in1 = new CPOperand(parts[1]); // image
- CPOperand in2 = new CPOperand(parts[2]); // scale
- CPOperand in3 = new CPOperand(parts[3]); // bias
- CPOperand in4 = new CPOperand(parts[4]); // runningMean
- CPOperand in5 = new CPOperand(parts[5]); // runningVar
- CPOperand in6 = new CPOperand(parts[6]); // mode
- CPOperand in7 = new CPOperand(parts[7]); // epsilon
- CPOperand in8 = new CPOperand(parts[8]); // exponentialAverageFactor
- CPOperand out = new CPOperand(parts[9]); // ret
- CPOperand out2 = new CPOperand(parts[10]); // retRunningMean
- CPOperand out3 = new CPOperand(parts[11]); // retRunningVar
- CPOperand out4 = new CPOperand(parts[12]); // resultSaveMean
- CPOperand out5 = new CPOperand(parts[13]); // resultSaveInvVariance
+ CPOperand in1 = new CPOperand(parts[1]); // X
+ CPOperand in2 = new CPOperand(parts[2]); // W
+ CPOperand in3 = new CPOperand(parts[3]); // b
+ CPOperand in4 = new CPOperand(parts[4]); // out0
+ CPOperand in5 = new CPOperand(parts[5]); // c0
+ CPOperand in6 = new CPOperand(parts[6]); // return_seq
+ CPOperand in7 = new CPOperand(parts[7]); // dout
+ CPOperand in8 = new CPOperand(parts[8]); // dc
+ CPOperand out = new CPOperand(parts[9]); // dX
+ CPOperand out2 = new CPOperand(parts[10]); // dW
+ CPOperand out3 = new CPOperand(parts[11]); // db
+ CPOperand out4 = new CPOperand(parts[12]); // dout0
+ CPOperand out5 = new CPOperand(parts[13]); // dc0
return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0);
}
else if (opcode.equalsIgnoreCase("batch_norm2d_test")) {
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index 365d7a2..0005932 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -20,6 +20,7 @@ package org.apache.sysml.runtime.matrix.data;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.HashSet;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
@@ -35,19 +36,29 @@ import org.apache.sysml.runtime.compress.CompressedMatrixBlock;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysml.runtime.functionobjects.Builtin;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
+import org.apache.sysml.runtime.functionobjects.MinusMultiply;
import org.apache.sysml.runtime.functionobjects.Multiply;
import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.functionobjects.PlusMultiply;
+import org.apache.sysml.runtime.functionobjects.Power;
+import org.apache.sysml.runtime.functionobjects.Power2;
+import org.apache.sysml.runtime.functionobjects.SwapIndex;
import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysml.runtime.instructions.cp.KahanObject;
import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysml.runtime.matrix.operators.LeftScalarOperator;
+import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
+import org.apache.sysml.runtime.matrix.operators.RightScalarOperator;
+import org.apache.sysml.runtime.matrix.operators.ScalarOperator;
import org.apache.sysml.runtime.matrix.operators.TernaryOperator;
import org.apache.sysml.runtime.matrix.operators.UnaryOperator;
import org.apache.sysml.runtime.util.CommonThreadPool;
import org.apache.sysml.runtime.util.DnnUtils;
+import com.sun.org.apache.xpath.internal.operations.Minus;
+
/*
* This class allows users to invoke deep learning related operations
* (such as conv2d, conv2d_backward_data, conv2d_backward_filter, maxpooling, maxpooling_backward, bias_add)
@@ -297,6 +308,10 @@ public class LibMatrixDNN {
return matBlock1.ternaryOperations(new TernaryOperator(PlusMultiply.getFnObject()),
matBlock2, matBlock3, new MatrixBlock());
}
+ private static MatrixBlock minusMultiply(MatrixBlock matBlock1, MatrixBlock matBlock2, MatrixBlock matBlock3) {
+ return matBlock1.ternaryOperations(new TernaryOperator(MinusMultiply.getFnObject()),
+ matBlock2, matBlock3, new MatrixBlock());
+ }
private static MatrixBlock multiply(MatrixBlock matBlock1, MatrixBlock matBlock2, boolean inplace) {
@@ -311,6 +326,11 @@ public class LibMatrixDNN {
}
}
+ private static MatrixBlock multiply(MatrixBlock matBlock1, double scalar, boolean inplace) {
+ ScalarOperator sc_op = new LeftScalarOperator(Multiply.getMultiplyFnObject(), scalar);
+ return (MatrixBlock) matBlock1.scalarOperations(sc_op, new MatrixBlock());
+ }
+
// sigmoid(0)*c_prev + sigmoid(0)*tanh(0);
@@ -322,6 +342,175 @@ public class LibMatrixDNN {
private static MatrixBlock tanh(MatrixBlock in, int numThreads, boolean inPlace) {
return (MatrixBlock) in.unaryOperations(new UnaryOperator(tanhOp, numThreads, inPlace), new MatrixBlock());
}
+ private static MatrixBlock power(MatrixBlock in, double exponent) {
+ return (MatrixBlock) in.scalarOperations(new RightScalarOperator(Power.getPowerFnObject(), exponent), new MatrixBlock());
+ }
+ private static MatrixBlock minus(double scalar, MatrixBlock in) {
+ return (MatrixBlock) in.scalarOperations(new LeftScalarOperator(org.apache.sysml.runtime.functionobjects.Minus.getMinusFnObject(), scalar), new MatrixBlock());
+ }
+ private static MatrixBlock tanh_backward(MatrixBlock dout, MatrixBlock X, int numThreads) {
+ MatrixBlock out = tanh(X, numThreads, false);
+ return minusMultiply(dout, power(out, 2), dout);
+ }
+
+ public static void lstm_backward(MatrixBlock dout, MatrixBlock dc,
+ MatrixBlock X, MatrixBlock W, MatrixBlock b, MatrixBlock out0, MatrixBlock c0,
+ boolean given_sequences, int N, int T, int D, int M,
+ MatrixBlock cache_out, MatrixBlock cache_c, MatrixBlock cache_ifog, // from forward invocation
+ MatrixBlock dX, MatrixBlock dW, MatrixBlock db, MatrixBlock dout0, MatrixBlock dc0,
+ int numThreads) {
+ MatrixBlock dct = dc;
+ if (!given_sequences) {
+ // only given dout for output at final timestep, so prepend empty douts for all other timesteps
+ dout = new MatrixBlock(N, (T-1)*M, true).append(dout, new MatrixBlock());
+ }
+ MatrixBlock dW_ret = dW;
+ MatrixBlock db_ret = db;
+ MatrixBlock dout_t = dout.slice(0, N-1, (T-1)*M, T*M-1, new MatrixBlock());
+ for(int t = T; t > 0; t--) {
+ MatrixBlock X_t = (T == 1) ? X : X.slice(0, N-1, (t-1)*D, t*D-1, new MatrixBlock());
+ MatrixBlock ct = sliceAndReshape(cache_c, new MatrixBlock(), t-1, N, M);
+ MatrixBlock out_prev = (t == 1) ? out0 : sliceAndReshape(cache_out, new MatrixBlock(), t-2, N, M);
+ MatrixBlock c_prev = (t == 1) ? c0 : sliceAndReshape(cache_c, new MatrixBlock(), t-2, N, M);
+ MatrixBlock input = X_t.append(out_prev, new MatrixBlock());
+ MatrixBlock ifog = sliceAndReshape(cache_ifog, new MatrixBlock(), t-1, N, 4*M);
+ MatrixBlock i = ifog.slice(0, N-1, 0, M-1, new MatrixBlock());
+ MatrixBlock f = ifog.slice(0, N-1, M, 2*M-1, new MatrixBlock());
+ MatrixBlock o = ifog.slice(0, N-1, 2*M, 3*M-1, new MatrixBlock());
+ MatrixBlock g = ifog.slice(0, N-1, 3*M, 4*M-1, new MatrixBlock());
+ dct = plusMultiply(dct, o, tanh_backward(dout_t, ct, numThreads));
+ MatrixBlock dc_prev = multiply(f, dct, false);
+
+ MatrixBlock di_raw = multiply(new MatrixBlock[] {i, minus(1, i), g, dct});
+ MatrixBlock df_raw = multiply(new MatrixBlock[] {f, minus(1, f), c_prev, dct});
+ MatrixBlock do_raw = multiply(new MatrixBlock[] {o, minus(1, o), tanh(ct, numThreads, false), dout_t});
+ MatrixBlock dg_raw = multiply(new MatrixBlock[] {minus(1, power(g, 2)), i, dct});
+ MatrixBlock difog_raw = di_raw.append(new MatrixBlock[] { df_raw, do_raw, dg_raw}, new MatrixBlock(), true);
+
+ // dW = dW + t(input) %*% difog_raw
+ dW = add(matmult(transpose(input, numThreads), difog_raw, numThreads), dW, true);
+ // db = db + colSums(difog_raw)
+ db = add(colSums(difog_raw), db, true);
+ // dinput = difog_raw %*% t(W)
+ MatrixBlock dinput = matmult(difog_raw, transpose(W, numThreads), numThreads);
+ // dX[,(t-1)*D+1:t*D] = dinput[,1:D]
+ dX.leftIndexingOperations(dinput.slice(0, N-1, 0, D-1, new MatrixBlock()), 0, N-1, (t-1)*D, t*D-1, dX, UpdateType.INPLACE);
+ // dout_prev = dinput[,D+1:D+M]
+ MatrixBlock dout_prev = dinput.slice(0, N-1, D, D+M-1, new MatrixBlock());
+
+ if(t == 1) {
+ dout0.copy(dout_prev);
+ dc0.copy(dc_prev);
+ }
+ else {
+ dout_t = add(dout.slice(0, N-1, (t-2)*M, (t-1)*M-1, new MatrixBlock()), dout_prev, true);
+ dct = dc_prev;
+ }
+ }
+ dW_ret.copy(dW);
+ db_ret.copy(db);
+ }
+
+
+ private static MatrixBlock colSums(MatrixBlock in) {
+ MatrixBlock ret = new MatrixBlock(1, in.getNumColumns(), false);
+ if(in.isEmpty()) {
+ // Do nothing
+ ret.setNonZeros(0);
+ }
+ else if(in.isInSparseFormat()) {
+ ret.allocateDenseBlock();
+ double [] retArr = ret.getDenseBlockValues();
+ SparseBlock sblock = in.getSparseBlock();
+ for(int n = 0; n < in.getNumRows(); n++) {
+ if( sblock.isEmpty(n) )
+ continue;
+ int apos = sblock.pos(n);
+ int alen = sblock.size(n);
+ int[] aix = sblock.indexes(n);
+ double[] avals = sblock.values(n);
+
+ // Iterate over the sparse block
+ for(int j=apos; j<apos+alen; j++) {
+ retArr[aix[j]] += avals[j];
+ }
+ }
+ ret.recomputeNonZeros();
+ }
+ else {
+ double [] inArr = in.getDenseBlockValues();
+ if(inArr != null) {
+ int index = 0;
+ ret.allocateDenseBlock();
+ double [] retArr = ret.getDenseBlockValues();
+ for(int r = 0; r < in.getNumRows(); r++) {
+ for(int c = 0; c < in.getNumColumns(); c++, index++) {
+ retArr[c] += inArr[index];
+ }
+ }
+ ret.recomputeNonZeros();
+ }
+ else {
+ ret.setNonZeros(0);
+ }
+ }
+ return ret;
+ }
+
+ private static MatrixBlock transpose(MatrixBlock in, int numThreads) {
+ ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
+ return (MatrixBlock) (in.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0));
+ }
+
+ private static MatrixBlock multiply(MatrixBlock [] in) {
+ boolean allDense = true;
+ int rows = 0; int cols = 0;
+ for(MatrixBlock mb : in) {
+ rows = Math.max(rows, mb.getNumRows());
+ cols = Math.max(cols, mb.getNumColumns());
+ }
+ for(MatrixBlock mb : in) {
+ if(mb.isEmpty() || (!mb.isInSparseFormat() && mb.getDenseBlockValues() == null)) {
+ MatrixBlock ret = new MatrixBlock(rows, cols, true);
+ ret.setNonZeros(0);
+ return ret;
+ }
+ allDense = allDense && !mb.isInSparseFormat();
+ }
+ if(allDense) {
+ MatrixBlock ret = new MatrixBlock(rows, cols, false);
+ ret.allocateDenseBlock();
+ double [] retArr = null;
+ // Avoids (in.length-1) recomputeNonZeros calls
+ for(MatrixBlock mb : in) {
+ if(retArr == null) {
+ retArr = ret.getDenseBlockValues();
+ System.arraycopy(mb.getDenseBlockValues(), 0, retArr, 0, retArr.length);
+ }
+ else {
+ double [] inArr = mb.getDenseBlockValues();
+ for(int index = 0; index < retArr.length; index++) {
+ retArr[index] *= inArr[index];
+ }
+ }
+ }
+ ret.recomputeNonZeros();
+ return ret;
+ }
+ else {
+ Arrays.sort(in, (mb1, mb2) -> Long.compare(mb1.getNonZeros(), mb2.getNonZeros()));
+ MatrixBlock ret = new MatrixBlock(rows, cols, in[0].isInSparseFormat());
+ for(MatrixBlock mb : in) {
+ ret = multiply(ret, mb, true);
+ }
+ return ret;
+ }
+ }
+
+ // Performs the following operation: ret = matrix(in[rowIndex+1,], rows=numRows, cols=numCols)
+ public static MatrixBlock sliceAndReshape(MatrixBlock in, MatrixBlock ret, int rowIndex, int numRows, int numCols) {
+ return LibMatrixReorg.reshape(in.slice(rowIndex, rowIndex), ret, numRows, numCols, true);
+ }
public static void lstm(MatrixBlock X, MatrixBlock W, MatrixBlock b, MatrixBlock out0, MatrixBlock c0,
boolean return_seq, int N, int T, int D, int M,
@@ -367,27 +556,84 @@ public class LibMatrixDNN {
ifog_raw = add(matmult(input, W, numThreads), b, true);
}
- MatrixBlock ifo = ifog_raw.slice(0, N-1, 0, 3*M-1, new MatrixBlock());
- ifo = sigmoid(ifo, numThreads, true);
- MatrixBlock i = ifo.slice(0, N-1, 0, M-1, new MatrixBlock());
- MatrixBlock f = ifo.slice(0, N-1, M, 2*M-1, new MatrixBlock());
- MatrixBlock o = ifo.slice(0, N-1, 2*M, 3*M-1, new MatrixBlock());
- MatrixBlock g = tanh(ifog_raw.slice(0, N-1, 3*M, 4*M-1, new MatrixBlock()), numThreads, true);
-
- // c_t = f*c_prev + i*g
- c_t = plusMultiply(multiply(f, c_prev, true), i, g);
- // out_t = o*tanh(c)
- out_t = multiply(o, tanh(c_t, numThreads, false), true);
+ if(!ifog_raw.isInSparseFormat() && !c_prev.isInSparseFormat()) {
+ double [] ifog_rawArr = ifog_raw.getDenseBlockValues();
+ double [] c_prevArr = c_prev.getDenseBlockValues();
+ double [] cache_ifogArr = null;
+ if(cache_ifog != null) {
+ cache_ifogArr = cache_ifog.getDenseBlockValues();
+ if(cache_ifogArr == null)
+ throw new DMLRuntimeException("Expected cache_ifog to be allocated in the dense format");
+ }
+ if(ifog_rawArr == null && c_prevArr == null) {
+ // Both ifog_raw and c_prev are empty matrix
+ c_t = new MatrixBlock(N, M, 0);
+ out_t = new MatrixBlock(N, M, 0);
+ c_t.setNonZeros(0);
+ out_t.setNonZeros(0);
+ updateIfogCache(cache_ifogArr, t, N, M);
+ }
+ else if(ifog_rawArr == null) {
+ // ifog_raw is an empty matrix
+ // c_t = f*c_prev + i*g
+ // = 0.5*c_prev
+ c_t = multiply(c_prev, 0.5, false);
+ // out_t = o*tanh(c)
+ // = 0.5*tanh(c)
+ out_t = multiply(tanh(c_t, numThreads, false), 0.5, false);
+ updateIfogCache(cache_ifogArr, t, N, M);
+ }
+ else {
+ // ifog_raw is not an empty matrix
+ c_t = new MatrixBlock(N, M, false); c_t.allocateDenseBlock();
+ double [] c_tArr = c_t.getDenseBlockValues();
+ out_t = new MatrixBlock(N, M, false); out_t.allocateDenseBlock();
+ double [] out_tArr = out_t.getDenseBlockValues();
+ int index = 0;
+ int offset = (t-1)*N*4*M;
+ for(int n = 0; n < N; n++) {
+ for(int m = 0; m < M; m++, index++) {
+ double c_prevVal = (c_prevArr == null) ? 0 : c_prevArr[index];
+ // c_t = f*c_prev + i*g
+ double i = sigmoidOp.execute(ifog_rawArr[n*4*M + m]);
+ double f = sigmoidOp.execute(ifog_rawArr[n*4*M + M + m]);
+ double o = sigmoidOp.execute(ifog_rawArr[n*4*M + 2*M + m]);
+ double g = tanhOp.execute(ifog_rawArr[n*4*M + 3*M + m]);
+ c_tArr[index] = f*c_prevVal + i*g;
+ // out_t = o*tanh(c)
+ out_tArr[index] = o*tanhOp.execute(c_tArr[index]);
+ updateIfogCache(cache_ifogArr, i, f, o, g, offset, n, m, N, M);
+ }
+ }
+ c_t.recomputeNonZeros();
+ out_t.recomputeNonZeros();
+ }
+ }
+ else {
+ MatrixBlock ifo = ifog_raw.slice(0, N-1, 0, 3*M-1, new MatrixBlock());
+ ifo = sigmoid(ifo, numThreads, true);
+ MatrixBlock i = ifo.slice(0, N-1, 0, M-1, new MatrixBlock());
+ MatrixBlock f = ifo.slice(0, N-1, M, 2*M-1, new MatrixBlock());
+ MatrixBlock o = ifo.slice(0, N-1, 2*M, 3*M-1, new MatrixBlock());
+ MatrixBlock g = tanh(ifog_raw.slice(0, N-1, 3*M, 4*M-1, new MatrixBlock()), numThreads, true);
+
+ // c_t = f*c_prev + i*g
+ c_t = plusMultiply(multiply(f, c_prev, true), i, g);
+ // out_t = o*tanh(c)
+ out_t = multiply(o, tanh(c_t, numThreads, false), true);
+ updateIfogCache(cache_ifog, ifo, g, t, N, M);
+ }
+
if(return_seq) {
- out = out.leftIndexingOperations(out_t, 0, N-1, (t-1)*M, t*M-1, new MatrixBlock(), UpdateType.INPLACE);
+ out = out.leftIndexingOperations(out_t, 0, N-1, (t-1)*M, t*M-1, out, UpdateType.INPLACE);
}
out_prev = out_t;
c_prev = c_t;
- // TODO: Add this when implementing lstm_backward
-// cache_out[t,] = matrix(out_t, rows=1, cols=N*M) # reshape
-// cache_c[t,] = matrix(c, rows=1, cols=N*M) # reshape
-// cache_ifog[t,] = matrix(cbind(ifo, g), rows=1, cols=N*4*M) # reshape
+ if(cache_out != null) {
+ reshapeAsRowMatrixAndLeftIndex(cache_out, out_t, t-1, N*M);
+ reshapeAsRowMatrixAndLeftIndex(cache_c, c_t, t-1, N*M);
+ }
}
if(out_t != null && !return_seq)
out.copy(out_t);
@@ -395,6 +641,69 @@ public class LibMatrixDNN {
c.copy(c_t);
else
c.copy(c0);
+ if(cache_out != null) {
+ cache_out.recomputeNonZeros();
+ cache_c.recomputeNonZeros();
+ cache_ifog.recomputeNonZeros();
+ }
+ }
+
+ private static void updateIfogCache(MatrixBlock cache_ifog, MatrixBlock ifo, MatrixBlock g, int t, int N, int M) {
+ if(cache_ifog != null) {
+ reshapeAsRowMatrixAndLeftIndex(cache_ifog, ifo.append(g, new MatrixBlock()), t-1, N*M);
+ }
+ }
+
+ // ifog_raw is an empty matrix
+ private static void updateIfogCache(double[] cache_ifogArr, int t, int N, int M) {
+ if(cache_ifogArr != null) {
+ int offset = (t-1)*N*4*M;
+ for(int n = 0 ; n < N; n++) {
+ int srcIndex = offset + n*4*M;
+ Arrays.fill(cache_ifogArr, srcIndex, srcIndex + 3*M, 0.5);
+ }
+ }
+ }
+
+ private static void updateIfogCache(double[] cache_ifogArr, double i, double f, double o, double g, int offset, int n, int m, int N, int M) {
+ if(cache_ifogArr != null) {
+ cache_ifogArr[offset + n*4*M + m] = i;
+ cache_ifogArr[offset + n*4*M + M + m] = f;
+ cache_ifogArr[offset + n*4*M + 2*M + m] = o;
+ cache_ifogArr[offset + n*4*M + 3*M + m] = g;
+ }
+ }
+
+ // Performs operation: lhsMatrix[rowIndex+1, ] = matrix(rhsMatrix, rows=1, cols=numCols)
+ private static void reshapeAsRowMatrixAndLeftIndex(MatrixBlock lhsMatrix, MatrixBlock rhsMatrix, int rowIndex, int numCols) {
+ double [] lhsArr = lhsMatrix.getDenseBlockValues();
+ if(lhsArr == null)
+ throw new DMLRuntimeException("Incorrect usage: lhsMatrix needs to be allocated in dense format before invocation of this method.");
+ if(rhsMatrix.isInSparseFormat()) {
+ SparseBlock sblock = rhsMatrix.getSparseBlock();
+ for(int n = 0; n < rhsMatrix.getNumRows(); n++) {
+ if( sblock.isEmpty(n) )
+ continue;
+ int apos = sblock.pos(n);
+ int alen = sblock.size(n);
+ int[] aix = sblock.indexes(n);
+ double[] avals = sblock.values(n);
+
+ // Iterate over the sparse block
+ for(int j=apos; j<apos+alen; j++) {
+ lhsArr[n*numCols + aix[j]] = avals[j];
+ }
+ }
+ }
+ else if(!rhsMatrix.isInSparseFormat()) {
+ double [] rhsArr = rhsMatrix.getDenseBlockValues();
+ if(rhsArr != null) {
+ System.arraycopy(rhsArr, 0, lhsArr, rowIndex*numCols, numCols);
+ }
+ else {
+ // Do nothing: assumption => lhsMatrix is initialized to 0 before invocation.
+ }
+ }
}
/**
diff --git a/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
index 785c890..5c93bca 100644
--- a/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
+++ b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
@@ -131,82 +131,141 @@ public class LstmCPUTest extends GPUTests {
inputs.put("out0", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed));
inputs.put("c0", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed));
List<String> outputs = Arrays.asList("output", "c");
- List<Object> outGPUWithCuDNN = null;
- List<Object> outCPUWithNN = null;
- synchronized (DnnGPUInstruction.FORCED_LSTM_OP) {
- try {
- DnnGPUInstruction.FORCED_LSTM_OP = LstmOperator.CUDNN;
- outGPUWithCuDNN = runOnCPU(spark, scriptStr1, inputs, outputs);
- outCPUWithNN = runOnCPU(spark, scriptStr2, inputs, outputs);
- }
- finally {
- DnnGPUInstruction.FORCED_LSTM_OP = LstmOperator.NONE;
- }
- }
- assertEqualObjects(outGPUWithCuDNN.get(0), outCPUWithNN.get(0));
- assertEqualObjects(outGPUWithCuDNN.get(1), outCPUWithNN.get(1));
+ List<Object> outBuiltin = runOnCPU(spark, scriptStr1, inputs, outputs);
+ List<Object> outNNLayer = runOnCPU(spark, scriptStr2, inputs, outputs);
+ assertEqualObjects(outBuiltin.get(0), outNNLayer.get(0));
+ assertEqualObjects(outBuiltin.get(1), outNNLayer.get(1));
}
+ @Test
+ public void testLstmBackward1() {
+ testLstmBackwardCuDNNWithNNLayer(20, 1, 50, 10, "TRUE", 0.9, 0.9);
+ }
+
+ @Test
+ public void testLstmBackward2() {
+ testLstmBackwardCuDNNWithNNLayer(20, 1, 50, 10, "FALSE", 0.9, 0.9);
+ }
+
+ @Test
+ public void testLstmBackward3() {
+ testLstmBackwardCuDNNWithNNLayer(20, 13, 1, 10, "TRUE", 0.9, 0.9);
+ }
// @Test
-// public void testLstmBackward7() {
-// testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "TRUE", 0.9, 0.9);
+// public void testLstmBackward4() {
+// testLstmBackwardCuDNNWithNNLayer(20, 13, 1, 10, "FALSE", 0.9, 0.9);
// }
-//
+
+ @Test
+ public void testLstmBackward5() {
+ testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 1, "TRUE", 0.9, 0.9);
+ }
+
+ @Test
+ public void testLstmBackward6() {
+ testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 1, "FALSE", 0.9, 0.9);
+ }
+
+ @Test
+ public void testLstmBackward7() {
+ testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "TRUE", 0.9, 0.9);
+ }
+
+ @Test
+ public void testLstmBackward8() {
+ testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "FALSE", 0.9, 0.9);
+ }
+
+ @Test
+ public void testLstmBackward9() {
+ testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "TRUE", 0.9, 0.9);
+ }
+
+ @Test
+ public void testLstmBackward10() {
+ testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "FALSE", 0.9, 0.9);
+ }
+
// @Test
-// public void testLstmBackward8() {
-// testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "FALSE", 0.9, 0.9);
+// public void testLstmBackward11() {
+// testLstmBackwardCuDNNWithNNLayer(20, 1, 50, 10, "TRUE", 0.2, 0.3);
// }
-//
+
+ @Test
+ public void testLstmBackward12() {
+ testLstmBackwardCuDNNWithNNLayer(20, 1, 50, 10, "FALSE", 0.2, 0.9);
+ }
+
// @Test
-// public void testLstmBackward9() {
-// testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "TRUE", 0.9, 0.9);
+// public void testLstmBackward13() {
+// testLstmBackwardCuDNNWithNNLayer(20, 13, 1, 10, "TRUE", 0.9, 0.1);
// }
-//
+
// @Test
-// public void testLstmBackward10() {
-// testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "FALSE", 0.9, 0.9);
+// public void testLstmBackward14() {
+// testLstmBackwardCuDNNWithNNLayer(20, 13, 1, 10, "FALSE", 0.3, 0.6);
// }
-//
-//
-// public void testLstmBackwardCuDNNWithNNLayer(int N, int T, int D, int M, String returnSequences, double sparsity,
-// double weightSparsity) {
-// boolean returnSequences1 = returnSequences.equals("TRUE");
-//
-// String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
-// + "[dX, dW, db, dout0, dc0] = lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0);";
-// String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
-// + "[output, c, cache_out, cache_c, cache_ifog] = lstm::forward(x, w, b, "
-// + T + ", " + D + ", " + returnSequences + ", out0, c0); \n"
-// + "[dX, dW, db, dout0, dc0] = lstm::backward(dout, dc, x, w, b, "
-// + T + ", " + D + ", " + returnSequences + ", out0, c0, cache_out, cache_c, cache_ifog);";
-//
-// HashMap<String, Object> inputs = new HashMap<>();
-// inputs.put("dout", generateInputMatrix(spark, N, returnSequences1 ? T*M : M, 0, 10, sparsity, seed));
-// inputs.put("dc", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed));
-// inputs.put("x", generateInputMatrix(spark, N, T*D, 0, 10, sparsity, seed));
-// inputs.put("w", generateInputMatrix(spark, D+M, 4*M, 0, 10, weightSparsity, seed));
-// inputs.put("b", generateInputMatrix(spark, 1, 4*M, 0, 10, sparsity, seed));
-// inputs.put("out0", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed));
-// inputs.put("c0", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed));
-// List<String> outputs = Arrays.asList("dX", "dW", "db", "dout0", "dc0");
-// List<Object> outGPUWithCuDNN = null;
-// List<Object> outCPUWithNN = null;
-// synchronized (DnnGPUInstruction.FORCED_LSTM_OP) {
-// try {
-// DnnGPUInstruction.FORCED_LSTM_OP = LstmOperator.CUDNN;
-// outGPUWithCuDNN = runOnCPU(spark, scriptStr1, inputs, outputs);
-// }
-// finally {
-// DnnGPUInstruction.FORCED_LSTM_OP = LstmOperator.NONE;
-// }
-// outCPUWithNN = runOnCPU(spark, scriptStr2, inputs, outputs);
-// }
-// assertEqualObjects(outGPUWithCuDNN.get(0), outCPUWithNN.get(0));
-// assertEqualObjects(outGPUWithCuDNN.get(1), outCPUWithNN.get(1));
-// assertEqualObjects(outGPUWithCuDNN.get(2), outCPUWithNN.get(2));
-// assertEqualObjects(outGPUWithCuDNN.get(3), outCPUWithNN.get(3));
-// assertEqualObjects(outGPUWithCuDNN.get(4), outCPUWithNN.get(4));
+
+ @Test
+ public void testLstmBackward15() {
+ testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 1, "TRUE", 0.2, 0.9);
+ }
+
+// @Test
+// public void testLstmBackward16() {
+// testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 1, "FALSE", 0.3, 0.1);
// }
+
+ @Test
+ public void testLstmBackward17() {
+ testLstmBackwardCuDNNWithNNLayer(20, 13, 15, 25, "TRUE", 0.9, 0.9);
+ }
+
+ @Test
+ public void testLstmBackward18() {
+ testLstmBackwardCuDNNWithNNLayer(20, 13, 15, 25, "FALSE", 0.9, 0.9);
+ }
+
+ @Test
+ public void testLstmBackward19() {
+ testLstmBackwardCuDNNWithNNLayer(12, 17, 15, 26, "TRUE", 0.9, 0.9);
+ }
+
+ @Test
+ public void testLstmBackward20() {
+ testLstmBackwardCuDNNWithNNLayer(12, 17, 15, 26, "FALSE", 0.9, 0.9);
+ }
+
+
+ public void testLstmBackwardCuDNNWithNNLayer(int N, int T, int D, int M, String returnSequences, double sparsity,
+ double weightSparsity) {
+ boolean returnSequences1 = returnSequences.equals("TRUE");
+
+ String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
+ + "[dX, dW, db, dout0, dc0] = lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0);";
+ String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
+ + "[output, c, cache_out, cache_c, cache_ifog] = lstm::forward(x, w, b, "
+ + T + ", " + D + ", " + returnSequences + ", out0, c0); \n"
+ + "[dX, dW, db, dout0, dc0] = lstm::backward(dout, dc, x, w, b, "
+ + T + ", " + D + ", " + returnSequences + ", out0, c0, cache_out, cache_c, cache_ifog);";
+
+ HashMap<String, Object> inputs = new HashMap<>();
+ inputs.put("dout", generateInputMatrix(spark, N, returnSequences1 ? T*M : M, 0, 10, sparsity, seed));
+ inputs.put("dc", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed));
+ inputs.put("x", generateInputMatrix(spark, N, T*D, 0, 10, sparsity, seed));
+ inputs.put("w", generateInputMatrix(spark, D+M, 4*M, 0, 10, weightSparsity, seed));
+ inputs.put("b", generateInputMatrix(spark, 1, 4*M, 0, 10, sparsity, seed));
+ inputs.put("out0", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed));
+ inputs.put("c0", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed));
+ List<String> outputs = Arrays.asList("dX", "dW", "db", "dout0", "dc0");
+ List<Object> outBuiltin = runOnCPU(spark, scriptStr1, inputs, outputs);
+ List<Object> outNN = runOnCPU(spark, scriptStr2, inputs, outputs);
+ assertEqualObjects(outBuiltin.get(0), outNN.get(0));
+ assertEqualObjects(outBuiltin.get(1), outNN.get(1));
+ assertEqualObjects(outBuiltin.get(2), outNN.get(2));
+ assertEqualObjects(outBuiltin.get(3), outNN.get(3));
+ assertEqualObjects(outBuiltin.get(4), outNN.get(4));
+ }
}