You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2019/03/05 20:53:28 UTC

[systemml] branch master updated: [SYSTEMML-540] Added an initial CP operator for lstm_backward builtin function

This is an automated email from the ASF dual-hosted git repository.

niketanpansare pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new b4ef84b  [SYSTEMML-540] Added an initial CP operator for lstm_backward builtin function
b4ef84b is described below

commit b4ef84ba2568dc96fe20d1a45faeb4c1bd2d47b4
Author: Niketan Pansare <np...@us.ibm.com>
AuthorDate: Tue Mar 5 12:52:11 2019 -0800

    [SYSTEMML-540] Added an initial CP operator for lstm_backward builtin function
---
 scripts/nn/layers/lstm_staging.dml                 |   6 -
 .../java/org/apache/sysml/hops/FunctionOp.java     |   9 +-
 .../runtime/instructions/cp/DnnCPInstruction.java  | 104 +++++++
 .../instructions/gpu/DnnGPUInstruction.java        |  42 +--
 .../sysml/runtime/matrix/data/LibMatrixDNN.java    | 341 ++++++++++++++++++++-
 .../org/apache/sysml/test/gpu/LstmCPUTest.java     | 189 ++++++++----
 6 files changed, 575 insertions(+), 116 deletions(-)

diff --git a/scripts/nn/layers/lstm_staging.dml b/scripts/nn/layers/lstm_staging.dml
index 886b88c..2f71f22 100644
--- a/scripts/nn/layers/lstm_staging.dml
+++ b/scripts/nn/layers/lstm_staging.dml
@@ -92,12 +92,6 @@ backward = function(matrix[double] dout, matrix[double] dc,
    *      Note: This is *optional* and could just be an empty matrix.
    *  - c0: Initial cell state, of shape (N, M).
    *      Note: This is *optional* and could just be an empty matrix.
-   *  - cache_out: Cache of outputs, of shape (T, N*M).
-   *      Note: This is used for performance during training.
-   *  - cache_c: Cache of cell state, of shape (T, N*M).
-   *      Note: This is used for performance during training.
-   *  - cache_ifog: Cache of intermediate values, of shape (T, N*4*M).
-   *      Note: This is used for performance during training.
    *
    * Outputs:
    *  - dX: Gradient wrt `X`, of shape (N, T*D).
diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java b/src/main/java/org/apache/sysml/hops/FunctionOp.java
index 5fdc8e7..dedbad6 100644
--- a/src/main/java/org/apache/sysml/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java
@@ -282,13 +282,6 @@ public class FunctionOp extends MultiThreadedHop
 		
 		if(getFunctionType() == FunctionType.MULTIRETURN_BUILTIN && isBuiltinFunction() &&
 			(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward"))) {
-			
-			if(getFunctionName().equalsIgnoreCase("lstm_backward")) {
-				if(!ConfigurationManager.isGPU())
-					throw new RuntimeException("The function " + getFunctionName() + " is only supported on GPU.");
-				_etype = ExecType.GPU;
-			}
-			
 			ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
 			
 			if( _etypeForced != null ) {
@@ -306,7 +299,7 @@ public class FunctionOp extends MultiThreadedHop
 				checkAndSetInvalidCPDimsAndSize();
 			}
 			
-			// Since lstm builtin functions are not supported on Spark
+			// Since lstm builtin functions are not supported on Spark or MR.
 			_etype = _etype == REMOTE ?  ExecType.CP : _etype;
 			
 			//mark for recompile (forever)
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
index 93ffd4f..50a11de 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
@@ -274,6 +274,24 @@ public class DnnCPInstruction extends UnaryCPInstruction {
 			int numThreads = Integer.parseInt(parts[9]);
 			return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, null, null, null, opcode, str, 0, numThreads);
 		}
+		else if (opcode.equalsIgnoreCase("lstm_backward")) {
+			InstructionUtils.checkNumFields(parts, 14);
+			CPOperand in1 = new CPOperand(parts[1]); // X
+			CPOperand in2 = new CPOperand(parts[2]); // W
+			CPOperand in3 = new CPOperand(parts[3]); // b
+			CPOperand in4 = new CPOperand(parts[4]); // out0
+			CPOperand in5 = new CPOperand(parts[5]); // c0
+			CPOperand in6 = new CPOperand(parts[6]); // return_seq
+			CPOperand in7 = new CPOperand(parts[7]); // dout
+			CPOperand in8 = new CPOperand(parts[8]); // dc
+			CPOperand out = new CPOperand(parts[9]);  // dX
+			CPOperand out2 = new CPOperand(parts[10]); // dW
+			CPOperand out3 = new CPOperand(parts[11]); // db
+			CPOperand out4 = new CPOperand(parts[12]); // dout0
+			CPOperand out5 = new CPOperand(parts[13]); // dc0
+			int numThreads = Integer.parseInt(parts[14]);
+			return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0, numThreads);
+		}
 		else {
 			throw new DMLRuntimeException("Unknown opcode while parsing a DnnCPInstruction: " + str);
 		}
@@ -329,6 +347,88 @@ public class DnnCPInstruction extends UnaryCPInstruction {
 		ec.setMatrixOutput(_out2.getName(), c, getExtendedOpcode());
 	}
 	
+	public void processLstmBackwardInstruction(ExecutionContext ec) {
+		MatrixBlock X = ec.getMatrixInput(input1.getName(), getExtendedOpcode());
+		MatrixBlock W = ec.getMatrixInput(_in2.getName(), getExtendedOpcode());
+		MatrixBlock b = ec.getMatrixInput(_in3.getName(), getExtendedOpcode());
+		MatrixBlock out0 = ec.getMatrixInput(_in4.getName(), getExtendedOpcode());
+		MatrixBlock c0 = ec.getMatrixInput(_in5.getName(), getExtendedOpcode());
+		boolean return_seq = ec.getScalarInput(_in6.getName(), _in6.getValueType(), _in6.isLiteral()).getBooleanValue();
+		MatrixBlock dout = ec.getMatrixInput(_in7.getName(), getExtendedOpcode());
+		MatrixBlock dc = ec.getMatrixInput(_in8.getName(), getExtendedOpcode());
+		
+		int N = X.getNumRows();
+		int TD = X.getNumColumns();
+		int DPlusM = W.getNumRows();
+		int M = W.getNumColumns() / 4;
+		int D = DPlusM - M;
+		int T = TD / D;
+		if(b.getNumRows() != 1 || b.getNumColumns() != M*4) {
+			throw new DMLRuntimeException("Incorrect dimensions of bias in lstm_backward instruction. Expected [1, " + (M*4) + "], "
+					+ "but found [" + b.getNumRows() + "," + b.getNumColumns() + "]");
+		}
+		if(out0.getNumRows() != N) {
+			throw new DMLRuntimeException("Unsupported operation: The batch size of previous iteration " + out0.getNumRows() + 
+					" is different than the batch size of current iteration " + N);
+		}
+		if(out0.getNumColumns() != M) {
+			throw new DMLRuntimeException("Incorrect dimensions of out0 in lstm_backward instruction. Expected [" + N + ", " + M + "], "
+					+ "but found [" + out0.getNumRows() + "," + out0.getNumColumns() + "]");
+		}
+		if(c0.getNumRows() != N || c0.getNumColumns() != M) {
+			throw new DMLRuntimeException("Incorrect dimensions of c0 in lstm_backward instruction. Expected [" + N + ", " + M + "], "
+					+ "but found [" + out0.getNumRows() + "," + out0.getNumColumns() + "]");
+		}
+		if(dout.getNumRows() != N || dout.getNumColumns() != (return_seq ? (T*M) : M)) {
+			throw new DMLRuntimeException("Incorrect dimensions of dout in lstm_backward instruction. Expected [" + N + ", " + (return_seq ? (T*M) : M) + "], "
+					+ "but found [" + dout.getNumRows() + "," + dout.getNumColumns() + "]");
+		}
+		if(dc.getNumRows() != N || dc.getNumColumns() != M) {
+			throw new DMLRuntimeException("Incorrect dimensions of dc in lstm_backward instruction. Expected [" + N + ", " + M + "], "
+					+ "but found [" + dc.getNumRows() + "," + dc.getNumColumns() + "]");
+		}
+		
+		MatrixBlock out = new MatrixBlock(N, return_seq ? (T*M) : M, false);
+		MatrixBlock c = new MatrixBlock(N, M, false);
+		MatrixBlock cache_out = new MatrixBlock(T, N*M, false);
+		MatrixBlock cache_c = new MatrixBlock(T, N*M, false);
+		MatrixBlock cache_ifog = new MatrixBlock(T, N*4*M, false);
+		
+		// In the initial implementation, invoke lstm redundantly.
+		// TODO: Optimize this later.
+		cache_out.allocateDenseBlock();
+		cache_c.allocateDenseBlock();
+		cache_ifog.allocateDenseBlock();
+		LibMatrixDNN.lstm(X, W, b, out0, c0, 
+				return_seq, N, T, D, M,
+				out,  c, cache_out, cache_c, cache_ifog,
+				_numThreads);
+		
+		MatrixBlock dX = new MatrixBlock(N, T*D, false);
+		MatrixBlock dW = new MatrixBlock(D+M, 4*M, false);
+		MatrixBlock db = new MatrixBlock(1, 4*M, false);
+		MatrixBlock dout0 = new MatrixBlock(N, M, false);
+		MatrixBlock dc0 = new MatrixBlock(N, M, false);
+		LibMatrixDNN.lstm_backward(dout, dc, X, W, b, out0, c0, return_seq, N, T, D, M,
+				cache_out, cache_c, cache_ifog, // from forward invocation
+				dX, dW, db, dout0, dc0, // output
+				_numThreads);
+		
+		// release inputs/outputs
+		ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
+		ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode());
+		ec.releaseMatrixInput(_in3.getName(), getExtendedOpcode());
+		ec.releaseMatrixInput(_in4.getName(), getExtendedOpcode());
+		ec.releaseMatrixInput(_in5.getName(), getExtendedOpcode());
+		ec.releaseMatrixInput(_in7.getName(), getExtendedOpcode());
+		ec.releaseMatrixInput(_in8.getName(), getExtendedOpcode());
+		ec.setMatrixOutput(output.getName(), dX, getExtendedOpcode());
+		ec.setMatrixOutput(_out2.getName(), dW, getExtendedOpcode());
+		ec.setMatrixOutput(_out3.getName(), db, getExtendedOpcode());
+		ec.setMatrixOutput(_out4.getName(), dout0, getExtendedOpcode());
+		ec.setMatrixOutput(_out5.getName(), dc0, getExtendedOpcode());
+	}
+	
 	public void processReluBackwardInstruction(ExecutionContext ec) {
 		// (X > 0) * dout
 		MatrixBlock input = ec.getMatrixInput(input1.getName(), getExtendedOpcode());
@@ -498,6 +598,10 @@ public class DnnCPInstruction extends UnaryCPInstruction {
 			processLstmInstruction(ec);
 			return;
 		}
+		else if (instOpcode.equalsIgnoreCase("lstm_backward")) {
+			processLstmBackwardInstruction(ec);
+			return;
+		}
 		
 		// acquire inputs
 		MatrixBlock outputBlock = null;
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
index 35c9591..fbe7c9d 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
@@ -361,31 +361,31 @@ public class DnnGPUInstruction extends GPUInstruction {
 		}
 		else if (opcode.equalsIgnoreCase("lstm")) {
 			InstructionUtils.checkNumFields(parts, 8);
-			CPOperand in1 = new CPOperand(parts[1]);
-			CPOperand in2 = new CPOperand(parts[2]);
-			CPOperand in3 = new CPOperand(parts[3]);
-			CPOperand in4 = new CPOperand(parts[4]);
-			CPOperand in5 = new CPOperand(parts[5]);
-			CPOperand in6 = new CPOperand(parts[6]);
-			CPOperand out = new CPOperand(parts[7]);
-			CPOperand out2 = new CPOperand(parts[8]);
+			CPOperand in1 = new CPOperand(parts[1]); // X
+			CPOperand in2 = new CPOperand(parts[2]); // W
+			CPOperand in3 = new CPOperand(parts[3]); // b
+			CPOperand in4 = new CPOperand(parts[4]); // out0
+			CPOperand in5 = new CPOperand(parts[5]); // c0
+			CPOperand in6 = new CPOperand(parts[6]); // return_seq
+			CPOperand out = new CPOperand(parts[7]); // out
+			CPOperand out2 = new CPOperand(parts[8]); // c
 			return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, out, out2, opcode, str, 0);
 		}
 		else if (opcode.equalsIgnoreCase("lstm_backward")) {
 			InstructionUtils.checkNumFields(parts, 13);
-			CPOperand in1 = new CPOperand(parts[1]); // image
-			CPOperand in2 = new CPOperand(parts[2]); // scale
-			CPOperand in3 = new CPOperand(parts[3]); // bias
-			CPOperand in4 = new CPOperand(parts[4]); // runningMean
-			CPOperand in5 = new CPOperand(parts[5]); // runningVar
-			CPOperand in6 = new CPOperand(parts[6]); // mode
-			CPOperand in7 = new CPOperand(parts[7]); // epsilon
-			CPOperand in8 = new CPOperand(parts[8]); // exponentialAverageFactor
-			CPOperand out = new CPOperand(parts[9]);  // ret
-			CPOperand out2 = new CPOperand(parts[10]); // retRunningMean
-			CPOperand out3 = new CPOperand(parts[11]); // retRunningVar
-			CPOperand out4 = new CPOperand(parts[12]); // resultSaveMean
-			CPOperand out5 = new CPOperand(parts[13]); // resultSaveInvVariance
+			CPOperand in1 = new CPOperand(parts[1]); // X
+			CPOperand in2 = new CPOperand(parts[2]); // W
+			CPOperand in3 = new CPOperand(parts[3]); // b
+			CPOperand in4 = new CPOperand(parts[4]); // out0
+			CPOperand in5 = new CPOperand(parts[5]); // c0
+			CPOperand in6 = new CPOperand(parts[6]); // return_seq
+			CPOperand in7 = new CPOperand(parts[7]); // dout
+			CPOperand in8 = new CPOperand(parts[8]); // dc
+			CPOperand out = new CPOperand(parts[9]);  // dX
+			CPOperand out2 = new CPOperand(parts[10]); // dW
+			CPOperand out3 = new CPOperand(parts[11]); // db
+			CPOperand out4 = new CPOperand(parts[12]); // dout0
+			CPOperand out5 = new CPOperand(parts[13]); // dc0
 			return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0);
 		}
 		else if (opcode.equalsIgnoreCase("batch_norm2d_test")) {
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index 365d7a2..0005932 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -20,6 +20,7 @@ package org.apache.sysml.runtime.matrix.data;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashSet;
 import java.util.List;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ExecutorService;
@@ -35,19 +36,29 @@ import org.apache.sysml.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType;
 import org.apache.sysml.runtime.functionobjects.Builtin;
 import org.apache.sysml.runtime.functionobjects.KahanPlus;
+import org.apache.sysml.runtime.functionobjects.MinusMultiply;
 import org.apache.sysml.runtime.functionobjects.Multiply;
 import org.apache.sysml.runtime.functionobjects.Plus;
 import org.apache.sysml.runtime.functionobjects.PlusMultiply;
+import org.apache.sysml.runtime.functionobjects.Power;
+import org.apache.sysml.runtime.functionobjects.Power2;
+import org.apache.sysml.runtime.functionobjects.SwapIndex;
 import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode;
 import org.apache.sysml.runtime.instructions.cp.KahanObject;
 import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
 import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysml.runtime.matrix.operators.LeftScalarOperator;
+import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
+import org.apache.sysml.runtime.matrix.operators.RightScalarOperator;
+import org.apache.sysml.runtime.matrix.operators.ScalarOperator;
 import org.apache.sysml.runtime.matrix.operators.TernaryOperator;
 import org.apache.sysml.runtime.matrix.operators.UnaryOperator;
 import org.apache.sysml.runtime.util.CommonThreadPool;
 import org.apache.sysml.runtime.util.DnnUtils;
 
+import com.sun.org.apache.xpath.internal.operations.Minus;
+
 /*
  * This class allows users to invoke deep learning related operations 
  * (such as conv2d, conv2d_backward_data, conv2d_backward_filter, maxpooling, maxpooling_backward, bias_add)
@@ -297,6 +308,10 @@ public class LibMatrixDNN {
 		return matBlock1.ternaryOperations(new TernaryOperator(PlusMultiply.getFnObject()), 
 				matBlock2, matBlock3, new MatrixBlock());
 	}
+	private static MatrixBlock minusMultiply(MatrixBlock matBlock1, MatrixBlock matBlock2, MatrixBlock matBlock3) {
+		return matBlock1.ternaryOperations(new TernaryOperator(MinusMultiply.getFnObject()), 
+				matBlock2, matBlock3, new MatrixBlock());
+	}
 	
 		
 	private static MatrixBlock multiply(MatrixBlock matBlock1, MatrixBlock matBlock2, boolean inplace) {
@@ -311,6 +326,11 @@ public class LibMatrixDNN {
 		}
 	}
 	
+	private static MatrixBlock multiply(MatrixBlock matBlock1, double scalar, boolean inplace) {
+		ScalarOperator sc_op = new LeftScalarOperator(Multiply.getMultiplyFnObject(), scalar);
+		return (MatrixBlock) matBlock1.scalarOperations(sc_op, new MatrixBlock());
+	}
+	
 	
 	// sigmoid(0)*c_prev + sigmoid(0)*tanh(0);
 	
@@ -322,6 +342,175 @@ public class LibMatrixDNN {
 	private static MatrixBlock tanh(MatrixBlock in, int numThreads, boolean inPlace) {
 		return (MatrixBlock) in.unaryOperations(new UnaryOperator(tanhOp, numThreads, inPlace), new MatrixBlock());
 	}
+	private static MatrixBlock power(MatrixBlock in, double exponent) {
+		return (MatrixBlock) in.scalarOperations(new RightScalarOperator(Power.getPowerFnObject(), exponent), new MatrixBlock());
+	}
+	private static MatrixBlock minus(double scalar, MatrixBlock in) {
+		return (MatrixBlock) in.scalarOperations(new LeftScalarOperator(org.apache.sysml.runtime.functionobjects.Minus.getMinusFnObject(), scalar), new MatrixBlock());
+	}
+	private static MatrixBlock tanh_backward(MatrixBlock dout, MatrixBlock X, int numThreads) {
+		MatrixBlock out = tanh(X, numThreads, false);
+		return minusMultiply(dout, power(out, 2), dout);
+	}
+	
+	public static void lstm_backward(MatrixBlock dout, MatrixBlock dc,
+			MatrixBlock X, MatrixBlock W, MatrixBlock b, MatrixBlock out0, MatrixBlock c0, 
+			boolean given_sequences, int N, int T, int D, int M,
+			MatrixBlock cache_out, MatrixBlock cache_c, MatrixBlock cache_ifog, // from forward invocation
+			MatrixBlock dX, MatrixBlock dW, MatrixBlock db, MatrixBlock dout0, MatrixBlock dc0,
+			int numThreads) {
+		MatrixBlock dct = dc;
+		if (!given_sequences) {
+			// only given dout for output at final timestep, so prepend empty douts for all other timesteps
+			dout = new MatrixBlock(N, (T-1)*M, true).append(dout, new MatrixBlock());
+		}
+		MatrixBlock dW_ret = dW;
+		MatrixBlock db_ret = db;
+		MatrixBlock dout_t = dout.slice(0, N-1, (T-1)*M, T*M-1, new MatrixBlock());
+		for(int t = T; t > 0; t--) {
+			MatrixBlock X_t = (T == 1) ? X : X.slice(0, N-1, (t-1)*D, t*D-1, new MatrixBlock());
+			MatrixBlock ct = sliceAndReshape(cache_c, new MatrixBlock(), t-1, N, M);
+			MatrixBlock out_prev = (t == 1) ? out0 : sliceAndReshape(cache_out, new MatrixBlock(), t-2, N, M);
+			MatrixBlock c_prev = (t == 1) ? c0 : sliceAndReshape(cache_c, new MatrixBlock(), t-2, N, M);
+			MatrixBlock input = X_t.append(out_prev, new MatrixBlock());
+			MatrixBlock ifog = sliceAndReshape(cache_ifog, new MatrixBlock(), t-1, N, 4*M);
+			MatrixBlock i = ifog.slice(0, N-1, 0, M-1, new MatrixBlock());
+			MatrixBlock f = ifog.slice(0, N-1, M, 2*M-1, new MatrixBlock());
+			MatrixBlock o = ifog.slice(0, N-1, 2*M, 3*M-1, new MatrixBlock());
+			MatrixBlock g = ifog.slice(0, N-1, 3*M, 4*M-1, new MatrixBlock());
+			dct = plusMultiply(dct, o, tanh_backward(dout_t, ct, numThreads));
+			MatrixBlock dc_prev = multiply(f, dct, false);
+			
+			MatrixBlock di_raw = multiply(new MatrixBlock[] {i, minus(1, i), g, dct}); 
+			MatrixBlock df_raw = multiply(new MatrixBlock[] {f, minus(1, f), c_prev, dct});
+			MatrixBlock do_raw = multiply(new MatrixBlock[] {o, minus(1, o), tanh(ct, numThreads, false), dout_t});
+			MatrixBlock dg_raw = multiply(new MatrixBlock[] {minus(1, power(g, 2)), i, dct});
+			MatrixBlock difog_raw = di_raw.append(new MatrixBlock[] { df_raw, do_raw, dg_raw}, new MatrixBlock(), true);
+			
+			// dW = dW + t(input) %*% difog_raw
+			dW = add(matmult(transpose(input, numThreads), difog_raw, numThreads), dW, true);
+			// db = db + colSums(difog_raw)
+			db = add(colSums(difog_raw), db, true);
+			// dinput = difog_raw %*% t(W)
+			MatrixBlock dinput = matmult(difog_raw, transpose(W, numThreads), numThreads);
+			// dX[,(t-1)*D+1:t*D] = dinput[,1:D]
+			dX.leftIndexingOperations(dinput.slice(0, N-1, 0, D-1, new MatrixBlock()), 0, N-1, (t-1)*D, t*D-1, dX, UpdateType.INPLACE);
+			// dout_prev = dinput[,D+1:D+M]
+			MatrixBlock dout_prev = dinput.slice(0, N-1, D, D+M-1, new MatrixBlock());
+			
+			if(t == 1) {
+				dout0.copy(dout_prev);
+				dc0.copy(dc_prev);
+			}
+			else {
+				dout_t = add(dout.slice(0, N-1, (t-2)*M, (t-1)*M-1, new MatrixBlock()), dout_prev, true);
+				dct = dc_prev;
+			}
+		}
+		dW_ret.copy(dW);
+		db_ret.copy(db);
+	}
+	
+	
+	private static MatrixBlock colSums(MatrixBlock in) {
+		MatrixBlock ret = new MatrixBlock(1, in.getNumColumns(), false);
+		if(in.isEmpty()) {
+			// Do nothing
+			ret.setNonZeros(0);
+		}
+		else if(in.isInSparseFormat()) {
+			ret.allocateDenseBlock();
+			double [] retArr = ret.getDenseBlockValues();
+			SparseBlock sblock = in.getSparseBlock();
+			for(int n = 0; n < in.getNumRows(); n++) {
+				if( sblock.isEmpty(n) )
+					continue;
+				int apos = sblock.pos(n);
+				int alen = sblock.size(n);
+				int[] aix = sblock.indexes(n);
+				double[] avals = sblock.values(n);
+				
+				// Iterate over the sparse block
+				for(int j=apos; j<apos+alen; j++) {
+					retArr[aix[j]] += avals[j];
+				}
+			}
+			ret.recomputeNonZeros();
+		}
+		else {
+			double [] inArr = in.getDenseBlockValues();
+			if(inArr != null) {
+				int index = 0;
+				ret.allocateDenseBlock();
+				double [] retArr = ret.getDenseBlockValues();
+				for(int r = 0; r < in.getNumRows(); r++) {
+					for(int c = 0; c < in.getNumColumns(); c++, index++) {
+						retArr[c] += inArr[index];
+					}
+				}
+				ret.recomputeNonZeros();
+			}
+			else {
+				ret.setNonZeros(0);
+			}
+		}
+		return ret;
+	}
+	
+	private static MatrixBlock transpose(MatrixBlock in, int numThreads) {
+		ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), numThreads);
+		return (MatrixBlock) (in.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0));
+	}
+	
+	private static MatrixBlock multiply(MatrixBlock [] in) {
+		boolean allDense = true;
+		int rows = 0; int cols = 0;
+		for(MatrixBlock mb : in) {
+			rows = Math.max(rows, mb.getNumRows());
+			cols = Math.max(cols, mb.getNumColumns());
+		}
+		for(MatrixBlock mb : in) {
+			if(mb.isEmpty() || (!mb.isInSparseFormat() && mb.getDenseBlockValues() == null)) {
+				MatrixBlock ret = new MatrixBlock(rows, cols, true);
+				ret.setNonZeros(0);
+				return ret;
+			}
+			allDense = allDense && !mb.isInSparseFormat();
+		}
+		if(allDense) {
+			MatrixBlock ret = new MatrixBlock(rows, cols, false);
+			ret.allocateDenseBlock();
+			double [] retArr = null;
+			// Avoids (in.length-1) recomputeNonZeros calls
+			for(MatrixBlock mb : in) {
+				if(retArr == null) {
+					retArr = ret.getDenseBlockValues();
+					System.arraycopy(mb.getDenseBlockValues(), 0, retArr, 0, retArr.length);
+				}
+				else {
+					double [] inArr = mb.getDenseBlockValues();
+					for(int index = 0; index < retArr.length; index++) {
+						retArr[index] *= inArr[index];
+					}
+				}
+			}
+			ret.recomputeNonZeros();
+			return ret;
+		}
+		else {
+			Arrays.sort(in, (mb1, mb2) -> Long.compare(mb1.getNonZeros(), mb2.getNonZeros()));
+			MatrixBlock ret = new MatrixBlock(rows, cols, in[0].isInSparseFormat());
+			for(MatrixBlock mb : in) {
+				ret = multiply(ret, mb, true);
+			}
+			return ret;
+		}
+	}
+	
+	// Performs the following operation: ret = matrix(in[rowIndex+1,], rows=numRows, cols=numCols)
+	public static MatrixBlock sliceAndReshape(MatrixBlock in, MatrixBlock ret, int rowIndex, int numRows, int numCols) {
+		return LibMatrixReorg.reshape(in.slice(rowIndex, rowIndex), ret, numRows, numCols, true);
+	}
 	
 	public static void lstm(MatrixBlock X, MatrixBlock W, MatrixBlock b, MatrixBlock out0, MatrixBlock c0, 
 			boolean return_seq, int N, int T, int D, int M,
@@ -367,27 +556,84 @@ public class LibMatrixDNN {
 				ifog_raw = add(matmult(input, W, numThreads), b, true);
 			}
 			
-			MatrixBlock ifo = ifog_raw.slice(0, N-1, 0, 3*M-1, new MatrixBlock());
-			ifo = sigmoid(ifo, numThreads, true);
-			MatrixBlock i = ifo.slice(0, N-1, 0, M-1, new MatrixBlock());
-			MatrixBlock f = ifo.slice(0, N-1, M, 2*M-1, new MatrixBlock());
-			MatrixBlock o = ifo.slice(0, N-1, 2*M, 3*M-1, new MatrixBlock());
-			MatrixBlock g = tanh(ifog_raw.slice(0, N-1, 3*M, 4*M-1, new MatrixBlock()), numThreads, true);
-					
-			// c_t = f*c_prev + i*g
-			c_t = plusMultiply(multiply(f, c_prev, true), i, g);
-			// out_t = o*tanh(c)
-			out_t = multiply(o, tanh(c_t, numThreads, false), true);
+			if(!ifog_raw.isInSparseFormat() && !c_prev.isInSparseFormat()) {
+				double [] ifog_rawArr = ifog_raw.getDenseBlockValues();
+				double [] c_prevArr = c_prev.getDenseBlockValues();
+				double [] cache_ifogArr = null;
+				if(cache_ifog != null) {
+					cache_ifogArr = cache_ifog.getDenseBlockValues();
+					if(cache_ifogArr == null)
+						throw new DMLRuntimeException("Expected cache_ifog to be allocated in the dense format");
+				}
+				if(ifog_rawArr == null && c_prevArr == null) {
+					// Both ifog_raw and c_prev are empty matrix
+					c_t = new MatrixBlock(N, M, 0);
+					out_t = new MatrixBlock(N, M, 0);
+					c_t.setNonZeros(0);
+					out_t.setNonZeros(0);
+					updateIfogCache(cache_ifogArr, t, N, M);
+				}
+				else if(ifog_rawArr == null) {
+					// ifog_raw is an empty matrix
+					// c_t = f*c_prev + i*g 
+					//     = 0.5*c_prev
+					c_t = multiply(c_prev, 0.5, false);
+					// out_t = o*tanh(c)
+					//       = 0.5*tanh(c)
+					out_t = multiply(tanh(c_t, numThreads, false), 0.5, false);
+					updateIfogCache(cache_ifogArr, t, N, M);
+				}
+				else {
+					// ifog_raw is not an empty matrix
+					c_t = new MatrixBlock(N, M, false); c_t.allocateDenseBlock();
+					double [] c_tArr = c_t.getDenseBlockValues();
+					out_t = new MatrixBlock(N, M, false); out_t.allocateDenseBlock();
+					double [] out_tArr = out_t.getDenseBlockValues();
+					int index = 0;
+					int offset = (t-1)*N*4*M;
+					for(int n = 0; n < N; n++) {
+						for(int m = 0; m < M; m++, index++) {
+							double c_prevVal = (c_prevArr == null) ? 0 : c_prevArr[index];
+							// c_t = f*c_prev + i*g
+							double i = sigmoidOp.execute(ifog_rawArr[n*4*M + m]);
+							double f = sigmoidOp.execute(ifog_rawArr[n*4*M + M + m]);
+							double o = sigmoidOp.execute(ifog_rawArr[n*4*M + 2*M + m]);
+							double g = tanhOp.execute(ifog_rawArr[n*4*M + 3*M + m]);
+							c_tArr[index] = f*c_prevVal + i*g;
+							// out_t = o*tanh(c)
+							out_tArr[index] = o*tanhOp.execute(c_tArr[index]);
+							updateIfogCache(cache_ifogArr, i, f, o, g, offset, n, m, N, M);
+						}
+					}
+					c_t.recomputeNonZeros();
+					out_t.recomputeNonZeros();
+				}
+			}
+			else {
+				MatrixBlock ifo = ifog_raw.slice(0, N-1, 0, 3*M-1, new MatrixBlock());
+				ifo = sigmoid(ifo, numThreads, true);
+				MatrixBlock i = ifo.slice(0, N-1, 0, M-1, new MatrixBlock());
+				MatrixBlock f = ifo.slice(0, N-1, M, 2*M-1, new MatrixBlock());
+				MatrixBlock o = ifo.slice(0, N-1, 2*M, 3*M-1, new MatrixBlock());
+				MatrixBlock g = tanh(ifog_raw.slice(0, N-1, 3*M, 4*M-1, new MatrixBlock()), numThreads, true);
+						
+				// c_t = f*c_prev + i*g
+				c_t = plusMultiply(multiply(f, c_prev, true), i, g);
+				// out_t = o*tanh(c)
+				out_t = multiply(o, tanh(c_t, numThreads, false), true);
+				updateIfogCache(cache_ifog, ifo, g, t, N, M);
+			}
+			
 			if(return_seq) {
-				out = out.leftIndexingOperations(out_t, 0, N-1, (t-1)*M, t*M-1, new MatrixBlock(), UpdateType.INPLACE);
+				out = out.leftIndexingOperations(out_t, 0, N-1, (t-1)*M, t*M-1, out, UpdateType.INPLACE);
 			}
 			out_prev = out_t;
 			c_prev = c_t;
 			
-			// TODO: Add this when implementing lstm_backward
-//			cache_out[t,] = matrix(out_t, rows=1, cols=N*M)  # reshape
-//		    cache_c[t,] = matrix(c, rows=1, cols=N*M)  # reshape
-//		    cache_ifog[t,] = matrix(cbind(ifo, g), rows=1, cols=N*4*M)  # reshape
+			if(cache_out != null) {
+				reshapeAsRowMatrixAndLeftIndex(cache_out, out_t, t-1, N*M);
+				reshapeAsRowMatrixAndLeftIndex(cache_c, c_t, t-1, N*M);
+			}
 		}
 		if(out_t != null && !return_seq)
 			out.copy(out_t);
@@ -395,6 +641,69 @@ public class LibMatrixDNN {
 			c.copy(c_t);
 		else
 			c.copy(c0);
+		if(cache_out != null) {
+			cache_out.recomputeNonZeros();
+			cache_c.recomputeNonZeros();
+			cache_ifog.recomputeNonZeros();
+		}
+	}
+	
+	private static void updateIfogCache(MatrixBlock cache_ifog, MatrixBlock ifo, MatrixBlock g, int t, int N, int M) {
+		if(cache_ifog != null) {
+			reshapeAsRowMatrixAndLeftIndex(cache_ifog, ifo.append(g, new MatrixBlock()), t-1, N*M);
+		}
+	}
+	
+	// ifog_raw is an empty matrix
+	private static void updateIfogCache(double[] cache_ifogArr, int t, int N, int M) {
+		if(cache_ifogArr != null) {
+			int offset = (t-1)*N*4*M;
+			for(int n = 0 ; n < N; n++) {
+				int srcIndex = offset + n*4*M;
+				Arrays.fill(cache_ifogArr, srcIndex, srcIndex + 3*M, 0.5);
+			}
+		}
+	}
+	
+	private static void updateIfogCache(double[] cache_ifogArr, double i, double f, double o, double g, int offset, int n, int m, int N, int M) {
+		if(cache_ifogArr != null) {
+			cache_ifogArr[offset + n*4*M + m] = i;
+			cache_ifogArr[offset + n*4*M + M + m] = f;
+			cache_ifogArr[offset + n*4*M + 2*M + m] = o;
+			cache_ifogArr[offset + n*4*M + 3*M + m] = g;
+		}
+	}
+	
+	// Performs operation: lhsMatrix[rowIndex+1, ] =  matrix(rhsMatrix, rows=1, cols=numCols)
+	private static void reshapeAsRowMatrixAndLeftIndex(MatrixBlock lhsMatrix, MatrixBlock rhsMatrix, int rowIndex, int numCols) {
+		double [] lhsArr = lhsMatrix.getDenseBlockValues();
+		if(lhsArr == null)
+			throw new DMLRuntimeException("Incorrect usage: lhsMatrix needs to be allocated in dense format before invocation of this method.");
+		if(rhsMatrix.isInSparseFormat()) {
+			SparseBlock sblock = rhsMatrix.getSparseBlock();
+			for(int n = 0; n < rhsMatrix.getNumRows(); n++) {
+				if( sblock.isEmpty(n) )
+					continue;
+				int apos = sblock.pos(n);
+				int alen = sblock.size(n);
+				int[] aix = sblock.indexes(n);
+				double[] avals = sblock.values(n);
+				
+				// Iterate over the sparse block
+				for(int j=apos; j<apos+alen; j++) {
+					lhsArr[n*numCols + aix[j]] = avals[j];
+				}
+			}
+		}
+		else if(!rhsMatrix.isInSparseFormat()) {
+			double [] rhsArr = rhsMatrix.getDenseBlockValues();
+			if(rhsArr != null) {
+				System.arraycopy(rhsArr, 0, lhsArr, rowIndex*numCols, numCols);
+			}
+			else {
+				// Do nothing: assumption => lhsMatrix is initialized to 0 before invocation.
+			}
+		}
 	}
 	
 	/**
diff --git a/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
index 785c890..5c93bca 100644
--- a/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
+++ b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
@@ -131,82 +131,141 @@ public class LstmCPUTest extends GPUTests {
 		inputs.put("out0", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed));
 		inputs.put("c0", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed));
 		List<String> outputs = Arrays.asList("output", "c");
-		List<Object> outGPUWithCuDNN = null;
-		List<Object> outCPUWithNN = null;
-		synchronized (DnnGPUInstruction.FORCED_LSTM_OP) {
-			try {
-				DnnGPUInstruction.FORCED_LSTM_OP = LstmOperator.CUDNN;
-				outGPUWithCuDNN = runOnCPU(spark, scriptStr1, inputs, outputs);
-				outCPUWithNN = runOnCPU(spark, scriptStr2, inputs, outputs);
-			}
-			finally {
-				DnnGPUInstruction.FORCED_LSTM_OP = LstmOperator.NONE;
-			}
-		}
-		assertEqualObjects(outGPUWithCuDNN.get(0), outCPUWithNN.get(0));
-		assertEqualObjects(outGPUWithCuDNN.get(1), outCPUWithNN.get(1));
+		List<Object> outBuiltin = runOnCPU(spark, scriptStr1, inputs, outputs);
+		List<Object> outNNLayer = runOnCPU(spark, scriptStr2, inputs, outputs);
+		assertEqualObjects(outBuiltin.get(0), outNNLayer.get(0));
+		assertEqualObjects(outBuiltin.get(1), outNNLayer.get(1));
 	}
 	
 	
+	@Test
+	public void testLstmBackward1() {
+		testLstmBackwardCuDNNWithNNLayer(20, 1, 50, 10, "TRUE", 0.9, 0.9);
+	}
+	
+	@Test
+	public void testLstmBackward2() {
+		testLstmBackwardCuDNNWithNNLayer(20, 1, 50, 10, "FALSE", 0.9, 0.9);
+	}
+	
+	@Test
+	public void testLstmBackward3() {
+		testLstmBackwardCuDNNWithNNLayer(20, 13, 1, 10, "TRUE", 0.9, 0.9);
+	}
 	
 //	@Test
-//	public void testLstmBackward7() {
-//		testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "TRUE", 0.9, 0.9);
+//	public void testLstmBackward4() {
+//		testLstmBackwardCuDNNWithNNLayer(20, 13, 1, 10, "FALSE", 0.9, 0.9);
 //	}
-//	
+	
+	@Test
+	public void testLstmBackward5() {
+		testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 1, "TRUE", 0.9, 0.9);
+	}
+	
+	@Test
+	public void testLstmBackward6() {
+		testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 1, "FALSE", 0.9, 0.9);
+	}
+	
+	@Test
+	public void testLstmBackward7() {
+		testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "TRUE", 0.9, 0.9);
+	}
+	
+	@Test
+	public void testLstmBackward8() {
+		testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "FALSE", 0.9, 0.9);
+	}
+	
+	@Test
+	public void testLstmBackward9() {
+		testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "TRUE", 0.9, 0.9);
+	}
+	
+	@Test
+	public void testLstmBackward10() {
+		testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "FALSE", 0.9, 0.9);
+	}
+	
 //	@Test
-//	public void testLstmBackward8() {
-//		testLstmBackwardCuDNNWithNNLayer(1, 1, 1, 1, "FALSE", 0.9, 0.9);
+//	public void testLstmBackward11() {
+//		testLstmBackwardCuDNNWithNNLayer(20, 1, 50, 10, "TRUE", 0.2, 0.3);
 //	}
-//	
+	
+	@Test
+	public void testLstmBackward12() {
+		testLstmBackwardCuDNNWithNNLayer(20, 1, 50, 10, "FALSE", 0.2, 0.9);
+	}
+	
 //	@Test
-//	public void testLstmBackward9() {
-//		testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "TRUE", 0.9, 0.9);
+//	public void testLstmBackward13() {
+//		testLstmBackwardCuDNNWithNNLayer(20, 13, 1, 10, "TRUE", 0.9, 0.1);
 //	}
-//	
+	
 //	@Test
-//	public void testLstmBackward10() {
-//		testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 10, "FALSE", 0.9, 0.9);
+//	public void testLstmBackward14() {
+//		testLstmBackwardCuDNNWithNNLayer(20, 13, 1, 10, "FALSE", 0.3, 0.6);
 //	}
-//	
-//	
-//	public void testLstmBackwardCuDNNWithNNLayer(int N, int T, int D, int M, String returnSequences, double sparsity,
-//			double weightSparsity) {
-//		boolean returnSequences1 = returnSequences.equals("TRUE");
-//		
-//		String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
-//				+ "[dX, dW, db, dout0, dc0] = lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0);";
-//		String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
-//				+ "[output, c, cache_out, cache_c, cache_ifog] = lstm::forward(x, w, b, " 
-//				+ T + ", " + D + ", " + returnSequences + ", out0, c0); \n"
-//				+ "[dX, dW, db, dout0, dc0] = lstm::backward(dout, dc, x, w, b, " 
-//				+ T + ", " + D + ", " + returnSequences + ", out0, c0, cache_out, cache_c, cache_ifog);";
-//		
-//		HashMap<String, Object> inputs = new HashMap<>();
-//		inputs.put("dout", generateInputMatrix(spark, N, returnSequences1 ? T*M : M, 0, 10, sparsity, seed));
-//		inputs.put("dc", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed));
-//		inputs.put("x", generateInputMatrix(spark, N, T*D, 0, 10, sparsity, seed));
-//		inputs.put("w", generateInputMatrix(spark, D+M, 4*M, 0, 10, weightSparsity, seed));
-//		inputs.put("b", generateInputMatrix(spark, 1, 4*M, 0, 10, sparsity, seed));
-//		inputs.put("out0", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed));
-//		inputs.put("c0", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed));
-//		List<String> outputs = Arrays.asList("dX", "dW", "db", "dout0", "dc0");
-//		List<Object> outGPUWithCuDNN = null;
-//		List<Object> outCPUWithNN = null;
-//		synchronized (DnnGPUInstruction.FORCED_LSTM_OP) {
-//			try {
-//				DnnGPUInstruction.FORCED_LSTM_OP = LstmOperator.CUDNN;
-//				outGPUWithCuDNN = runOnCPU(spark, scriptStr1, inputs, outputs);
-//			}
-//			finally {
-//				DnnGPUInstruction.FORCED_LSTM_OP = LstmOperator.NONE;
-//			}
-//			outCPUWithNN = runOnCPU(spark, scriptStr2, inputs, outputs);
-//		}
-//		assertEqualObjects(outGPUWithCuDNN.get(0), outCPUWithNN.get(0));
-//		assertEqualObjects(outGPUWithCuDNN.get(1), outCPUWithNN.get(1));
-//		assertEqualObjects(outGPUWithCuDNN.get(2), outCPUWithNN.get(2));
-//		assertEqualObjects(outGPUWithCuDNN.get(3), outCPUWithNN.get(3));
-//		assertEqualObjects(outGPUWithCuDNN.get(4), outCPUWithNN.get(4));
+	
+	@Test
+	public void testLstmBackward15() {
+		testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 1, "TRUE", 0.2, 0.9);
+	}
+	
+//	@Test
+//	public void testLstmBackward16() {
+//		testLstmBackwardCuDNNWithNNLayer(20, 13, 50, 1, "FALSE", 0.3, 0.1);
 //	}
+	
+	@Test
+	public void testLstmBackward17() {
+		testLstmBackwardCuDNNWithNNLayer(20, 13, 15, 25, "TRUE", 0.9, 0.9);
+	}
+	
+	@Test
+	public void testLstmBackward18() {
+		testLstmBackwardCuDNNWithNNLayer(20, 13, 15, 25, "FALSE", 0.9, 0.9);
+	}
+	
+	@Test
+	public void testLstmBackward19() {
+		testLstmBackwardCuDNNWithNNLayer(12, 17, 15, 26, "TRUE", 0.9, 0.9);
+	}
+	
+	@Test
+	public void testLstmBackward20() {
+		testLstmBackwardCuDNNWithNNLayer(12, 17, 15, 26, "FALSE", 0.9, 0.9);
+	}
+	
+	
+	public void testLstmBackwardCuDNNWithNNLayer(int N, int T, int D, int M, String returnSequences, double sparsity,
+			double weightSparsity) {
+		boolean returnSequences1 = returnSequences.equals("TRUE");
+		
+		String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
+				+ "[dX, dW, db, dout0, dc0] = lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0);";
+		String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
+				+ "[output, c, cache_out, cache_c, cache_ifog] = lstm::forward(x, w, b, " 
+				+ T + ", " + D + ", " + returnSequences + ", out0, c0); \n"
+				+ "[dX, dW, db, dout0, dc0] = lstm::backward(dout, dc, x, w, b, " 
+				+ T + ", " + D + ", " + returnSequences + ", out0, c0, cache_out, cache_c, cache_ifog);";
+		
+		HashMap<String, Object> inputs = new HashMap<>();
+		inputs.put("dout", generateInputMatrix(spark, N, returnSequences1 ? T*M : M, 0, 10, sparsity, seed));
+		inputs.put("dc", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed));
+		inputs.put("x", generateInputMatrix(spark, N, T*D, 0, 10, sparsity, seed));
+		inputs.put("w", generateInputMatrix(spark, D+M, 4*M, 0, 10, weightSparsity, seed));
+		inputs.put("b", generateInputMatrix(spark, 1, 4*M, 0, 10, sparsity, seed));
+		inputs.put("out0", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed));
+		inputs.put("c0", generateInputMatrix(spark, N, M, 0, 10, sparsity, seed));
+		List<String> outputs = Arrays.asList("dX", "dW", "db", "dout0", "dc0");
+		List<Object> outBuiltin = runOnCPU(spark, scriptStr1, inputs, outputs);
+		List<Object> outNN = runOnCPU(spark, scriptStr2, inputs, outputs);
+		assertEqualObjects(outBuiltin.get(0), outNN.get(0));
+		assertEqualObjects(outBuiltin.get(1), outNN.get(1));
+		assertEqualObjects(outBuiltin.get(2), outNN.get(2));
+		assertEqualObjects(outBuiltin.get(3), outNN.get(3));
+		assertEqualObjects(outBuiltin.get(4), outNN.get(4));
+	}
 }