You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2019/03/19 19:30:15 UTC

[systemml] branch master updated: [SYSTEMML-540] Improve the performance of GPU lstm backward operator by passing the state

This is an automated email from the ASF dual-hosted git repository.

niketanpansare pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new 91467c1  [SYSTEMML-540] Improve the performance of GPU lstm backward operator by passing the state
91467c1 is described below

commit 91467c164202f70c5a85ba7e0f7f9fcd16ddca1b
Author: Niketan Pansare <np...@us.ibm.com>
AuthorDate: Tue Mar 19 12:30:01 2019 -0700

    [SYSTEMML-540] Improve the performance of GPU lstm backward operator by passing the state
    
    - The lstm builtin function extended to return state: [out, c, state] = lstm(X, W, b, out0, c0, return_sequences)
    - The lstm_backward builtin function extended to accept state: [dX, dW, db, dout0, dc0] = lstm_backward(X, W, b, out0, c0, given_sequences, dout, dc, state)
    - Updated the DML documentation to reflect this change.
    - Updated the release documentation.
    
    Closes #856.
---
 conf/SystemML-config.xml.template                  |   3 +
 docs/dml-language-reference.md                     |  21 +-
 docs/release-process.md                            |  25 +-
 scripts/nn/layers/lstm_staging.dml                 |  12 +-
 src/main/java/org/apache/sysml/conf/DMLConfig.java |   4 +-
 .../sysml/parser/BuiltinFunctionExpression.java    |  14 +-
 .../org/apache/sysml/parser/StatementBlock.java    |  13 +-
 .../controlprogram/caching/CacheableData.java      |  10 +
 .../runtime/instructions/cp/DnnCPInstruction.java  |  72 +++++-
 .../instructions/gpu/DnnGPUInstruction.java        | 278 ++++++++++++++++-----
 .../instructions/gpu/context/GPUObject.java        |   2 +-
 .../sysml/runtime/matrix/data/LibMatrixCuDNN.java  | 163 ++++++------
 .../matrix/data/LibMatrixCuDNNRnnAlgorithm.java    |  19 +-
 .../runtime/matrix/data/LibMatrixCuMatMult.java    |   3 +
 .../org/apache/sysml/test/gpu/LstmCPUTest.java     |   5 +-
 .../java/org/apache/sysml/test/gpu/LstmTest.java   |  10 +-
 16 files changed, 443 insertions(+), 211 deletions(-)

diff --git a/conf/SystemML-config.xml.template b/conf/SystemML-config.xml.template
index b9189b1..17cc2cc 100644
--- a/conf/SystemML-config.xml.template
+++ b/conf/SystemML-config.xml.template
@@ -118,4 +118,7 @@
    <!-- Should perform recomputation of activations such as ReLU to reduce memory consumption. Set this to true
    when performing inference or for training very large networks (default: false) -->
    <sysml.gpu.recompute.activations>false</sysml.gpu.recompute.activations>
+   
+   <!-- Should SystemML runtime force the lstm builtin functions to use the CuDNN kernels (default: true) -->
+   <sysml.gpu.lstm.force.cudnn>true</sysml.gpu.lstm.force.cudnn>
 </root>
\ No newline at end of file
diff --git a/docs/dml-language-reference.md b/docs/dml-language-reference.md
index 6f1c854..f64b6ea 100644
--- a/docs/dml-language-reference.md
+++ b/docs/dml-language-reference.md
@@ -1521,16 +1521,17 @@ The images are assumed to be stored NCHW format, where N = batch size, C = #chan
 Hence, the images are internally represented as a matrix with dimension (N, C * H * W).
 
 
-| Function name                               | Input matrices           | Dimension of first input matrix                           | Dimension of second input matrix (if applicable)          | Dimension of (first) output matrix                                                          | Input Parameters                                                                                                                                                                              | Notes       [...]
-|---------------------------------------------|--------------------------|-----------------------------------------------------------|-----------------------------------------------------------|---------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------ [...]
-| conv2d                                      | input, filter            | [batch_size X num_channels* height_image* width_image]    | [num_filters X num_channels* height_filter* width_filter] | [batch_size X num_channels_out* height_out* width_out]                                      | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], filter_shape=[num_filters, num_channels, height_filter, width_filter] | Performs 2D [...]
-| conv2d_backward_filter                      | input, dout              | [batch_size X num_channels* height_image* width_image]    | [batch_size X num_channels_out* height_out* width_out]    | [num_filters X num_channels* height_filter* width_filter]                                   | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], filter_shape=[num_filters, num_channels, height_filter, width_filter] | Computes th [...]
-| conv2d_backward_data                        | filter, dout             | [num_filters X num_channels* height_filter* width_filter] | [batch_size X num_channels_out* height_out* width_out]    | [batch_size X num_channels* height_image* width_image]                                      | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], filter_shape=[num_filters, num_channels, height_filter, width_filter] | Computes th [...]
-| max_pool, avg_pool                          | input                    | [batch_size X num_channels* height_image* width_image]    |                                                           | [batch_size X num_channels* height_out* width_out]                                          | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], pool_size=[height_pool, width_pool]                                   | Performs ma [...]
-| max_pool_backward, avg_pool_backward        | input, dout              | [batch_size X num_channels* height_image* width_image]    | [batch_size X num_channels* height_out* width_out]        | [batch_size X num_channels* height_image* width_image]                                      | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], pool_size=[height_pool, width_pool]                                   | Computes th [...]
-| bias_add                                    | input, bias              | [batch_size X num_channels* height_image* width_image]    | [num_channels X 1]                                        | [batch_size X num_channels* height_image* width_image]                                      |                                                                                                                                                                                               | Adds the bi [...]
-| bias_multiply                               | input, bias              | [batch_size X num_channels* height_image* width_image]    | [num_channels X 1]                                        | [batch_size X num_channels* height_image* width_image]                                      |                                                                                                                                                                                               | Multiplies  [...]
-| lstm                                        | X,  W, bias, out0, c0    | [batch_size X seq_length*num_features]                    | [num_features+hidden_size X 4*hidden_size]                | [batch_size X seq_length*hidden_size] if return_sequences else  [batch_size X hidden_size]  | return_sequences                                                                                                                                                                              | Perform com [...]
+| Function name                               | Input matrices                                      | Dimension of first input matrix                           | Dimension of second input matrix (if applicable)          | Dimension of (first) output matrix                                                          | Input Parameters                                                                                                                                                                 [...]
+|---------------------------------------------|-----------------------------------------------------|-----------------------------------------------------------|-----------------------------------------------------------|---------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- [...]
+| conv2d                                      | input, filter                                       | [batch_size X num_channels* height_image* width_image]    | [num_filters X num_channels* height_filter* width_filter] | [batch_size X num_channels_out* height_out* width_out]                                      | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], filter_shape=[num_filters, num_channels, height_filter,  [...]
+| conv2d_backward_filter                      | input, dout                                         | [batch_size X num_channels* height_image* width_image]    | [batch_size X num_channels_out* height_out* width_out]    | [num_filters X num_channels* height_filter* width_filter]                                   | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], filter_shape=[num_filters, num_channels, height_filter,  [...]
+| conv2d_backward_data                        | filter, dout                                        | [num_filters X num_channels* height_filter* width_filter] | [batch_size X num_channels_out* height_out* width_out]    | [batch_size X num_channels* height_image* width_image]                                      | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], filter_shape=[num_filters, num_channels, height_filter,  [...]
+| max_pool, avg_pool                          | input                                               | [batch_size X num_channels* height_image* width_image]    |                                                           | [batch_size X num_channels* height_out* width_out]                                          | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], pool_size=[height_pool, width_pool]                      [...]
+| max_pool_backward, avg_pool_backward        | input, dout                                         | [batch_size X num_channels* height_image* width_image]    | [batch_size X num_channels* height_out* width_out]        | [batch_size X num_channels* height_image* width_image]                                      | stride=[stride_h, stride_w], padding=[pad_h, pad_w], input_shape=[batch_size, num_channels, height_image, width_image], pool_size=[height_pool, width_pool]                      [...]
+| bias_add                                    | input, bias                                         | [batch_size X num_channels* height_image* width_image]    | [num_channels X 1]                                        | [batch_size X num_channels* height_image* width_image]                                      |                                                                                                                                                                                  [...]
+| bias_multiply                               | input, bias                                         | [batch_size X num_channels* height_image* width_image]    | [num_channels X 1]                                        | [batch_size X num_channels* height_image* width_image]                                      |                                                                                                                                                                                  [...]
+| lstm                                        | X,  W, bias, out0, c0                               | [N X T*D]                                                 | [D+M X 4M]                                                | [N X T*M] if given_sequences is true else [ N X M ]                                         | return_sequences                                                                                                                                                                 [...]
+| lstm_backward                               | X, W, b, out0, c0, given_sequences, dout, dc, state | [N X T*M] if given_sequences is true else [ N X M]        | [N X M]                                                   | [N X T*D]                                                                                   | return_sequences                                                                                                                                                                 [...]
 
 Note: the builtin functions `batch_norm2d` and `batch_norm2d_backward` are deprecated and will be removed in the next release. The `lstm` builtin function is in experimental phase and is only supported for the GPU backend. 
 
diff --git a/docs/release-process.md b/docs/release-process.md
index 3798ec7..dec6b15 100644
--- a/docs/release-process.md
+++ b/docs/release-process.md
@@ -255,22 +255,19 @@ this OS X example.
 
 ## Python Tests
 
-For Spark 1.*, the Python tests at (`src/main/python/tests`) can be executed in the following manner:
+Compile SystemML distribution:
 
-	PYSPARK_PYTHON=python3 pyspark --driver-class-path SystemML.jar test_matrix_agg_fn.py
-	PYSPARK_PYTHON=python3 pyspark --driver-class-path SystemML.jar test_matrix_binary_op.py
-	PYSPARK_PYTHON=python3 pyspark --driver-class-path SystemML.jar test_mlcontext.py
-	PYSPARK_PYTHON=python3 pyspark --driver-class-path SystemML.jar test_mllearn_df.py
-	PYSPARK_PYTHON=python3 pyspark --driver-class-path SystemML.jar test_mllearn_numpy.py
+	mvn package -P distribution
+	cd src/main/python/tests/
 
-For Spark 2.*, pyspark can't be used to run the Python tests, so they can be executed using
-spark-submit:
+For Spark 2.*, the Python tests at (`src/main/python/tests`) can be executed in the following manner:
 
-	spark-submit --driver-class-path SystemML.jar test_matrix_agg_fn.py
-	spark-submit --driver-class-path SystemML.jar test_matrix_binary_op.py
-	spark-submit --driver-class-path SystemML.jar test_mlcontext.py
-	spark-submit --driver-class-path SystemML.jar test_mllearn_df.py
-	spark-submit --driver-class-path SystemML.jar test_mllearn_numpy.py
+	PYSPARK_PYTHON=python3 spark-submit --driver-class-path ../../../../target/SystemML.jar,../../../../target/systemml-*-SNAPSHOT-extra.jar test_matrix_agg_fn.py
+	PYSPARK_PYTHON=python3 spark-submit --driver-class-path ../../../../target/SystemML.jar,../../../../target/systemml-*-SNAPSHOT-extra.jar test_matrix_binary_op.py
+	PYSPARK_PYTHON=python3 spark-submit --driver-class-path ../../../../target/SystemML.jar,../../../../target/systemml-*-SNAPSHOT-extra.jar test_mlcontext.py
+	PYSPARK_PYTHON=python3 spark-submit --driver-class-path ../../../../target/SystemML.jar,../../../../target/systemml-*-SNAPSHOT-extra.jar test_mllearn_df.py
+	PYSPARK_PYTHON=python3 spark-submit --driver-class-path ../../../../target/SystemML.jar,../../../../target/systemml-*-SNAPSHOT-extra.jar test_mllearn_numpy.py
+	PYSPARK_PYTHON=python3 spark-submit --driver-class-path ../../../../target/SystemML.jar,../../../../target/systemml-*-SNAPSHOT-extra.jar test_nn_numpy.py
 
 
 ## Check LICENSE and NOTICE Files
@@ -385,7 +382,7 @@ file and remove all the `@Ignore` annotations from all the tests. Then run the N
 # Run other GPU Unit Tests 
 
 	rm result.txt
-	for t in AggregateUnaryOpTests  BinaryOpTests  MatrixMatrixElementWiseOpTests  RightIndexingTests AppendTest  MatrixMultiplicationOpTest ReorgOpTests ScalarMatrixElementwiseOpTests UnaryOpTests
+	for t in AggregateUnaryOpTests  BinaryOpTests  MatrixMatrixElementWiseOpTests  RightIndexingTests AppendTest  MatrixMultiplicationOpTest ReorgOpTests ScalarMatrixElementwiseOpTests UnaryOpTests LstmTest LstmCPUTest
 	do
 		mvn -Dit.test="org.apache.sysml.test.gpu."$t verify -PgpuTests &> tmp.txt
 		SUCCESS=`grep "BUILD SUCCESS" tmp.txt`
diff --git a/scripts/nn/layers/lstm_staging.dml b/scripts/nn/layers/lstm_staging.dml
index 2f71f22..f1934da 100644
--- a/scripts/nn/layers/lstm_staging.dml
+++ b/scripts/nn/layers/lstm_staging.dml
@@ -27,7 +27,7 @@ source("nn/layers/tanh.dml") as tanh
 
 forward = function(matrix[double] X, matrix[double] W, matrix[double] b, 
                    boolean return_sequences, matrix[double] out0, matrix[double] c0)
-    return (matrix[double] out, matrix[double] c) {
+    return (matrix[double] out, matrix[double] c, matrix[double] state) {
   /*
    * Computes the forward pass for an LSTM layer with M neurons.
    * The input data has N sequences of T examples, each with D features.
@@ -58,14 +58,15 @@ forward = function(matrix[double] X, matrix[double] W, matrix[double] b,
    *      of shape (N, T*M).  Else, outputs for the final timestep, of
    *      shape (N, M).
    *  - c: Cell state for final timestep, of shape (N, M). 
+   *  - state: Intermediate state of unknown dimensions used for performance.
    */
-  out = 0; c = c0;
-  [out, c] = lstm(X, W, b, out0, c0, return_sequences)
+  out = 0; c = c0; state = c0;
+  [out, c, state] = lstm(X, W, b, out0, c0, return_sequences)
 }
 
 backward = function(matrix[double] dout, matrix[double] dc,
                     matrix[double] X, matrix[double] W, matrix[double] b,
-                    boolean given_sequences, matrix[double] out0, matrix[double] c0)
+                    boolean given_sequences, matrix[double] out0, matrix[double] c0, matrix[double] state)
     return (matrix[double] dX, matrix[double] dW, matrix[double] db,
             matrix[double] dout0, matrix[double] dc0) {
   /*
@@ -92,6 +93,7 @@ backward = function(matrix[double] dout, matrix[double] dc,
    *      Note: This is *optional* and could just be an empty matrix.
    *  - c0: Initial cell state, of shape (N, M).
    *      Note: This is *optional* and could just be an empty matrix.
+   *  - state: state generated by the forward call.
    *
    * Outputs:
    *  - dX: Gradient wrt `X`, of shape (N, T*D).
@@ -101,7 +103,7 @@ backward = function(matrix[double] dout, matrix[double] dc,
    *  - dc0: Gradient wrt `c0`, of shape (N, M).
    */
   dX = X; dW = W; db = b; dout0 = out0; dc0 = c0
-  [dX, dW, db, dout0, dc0] = lstm_backward(X, W, b, out0, c0, given_sequences, dout, dc)
+  [dX, dW, db, dout0, dc0] = lstm_backward(X, W, b, out0, c0, given_sequences, dout, dc, state)
 }
 
 init = function(int N, int D, int M)
diff --git a/src/main/java/org/apache/sysml/conf/DMLConfig.java b/src/main/java/org/apache/sysml/conf/DMLConfig.java
index 8459fd4..0b5ed78 100644
--- a/src/main/java/org/apache/sysml/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysml/conf/DMLConfig.java
@@ -88,6 +88,7 @@ public class DMLConfig
 	public static final String SYNCHRONIZE_GPU      = "sysml.gpu.sync.postProcess"; // boolean: whether to synchronize GPUs after every instruction 
 	public static final String EAGER_CUDA_FREE		= "sysml.gpu.eager.cudaFree"; // boolean: whether to perform eager CUDA free on rmvar
 	public static final String GPU_EVICTION_POLICY	= "sysml.gpu.eviction.policy"; // string: can be lru, lfu, min_evict
+	public static final String FORCE_LSTM_CUDNN		= "sysml.gpu.lstm.force.cudnn"; // boolean: should we force a cudnn operator for LSTM
 	
 	// Fraction of available memory to use. The available memory is computer when the GPUContext is created
 	// to handle the tradeoff on calling cudaMemGetInfo too often.
@@ -148,6 +149,7 @@ public class DMLConfig
 		_defaultVals.put(SYNCHRONIZE_GPU,        "false" );
 		_defaultVals.put(CACHING_BUFFER_SIZE,    "0.15" );
 		_defaultVals.put(EAGER_CUDA_FREE,        "false" );
+		_defaultVals.put(FORCE_LSTM_CUDNN,		 "true" );
 		_defaultVals.put(GPU_RECOMPUTE_ACTIVATIONS, "false" );
 		_defaultVals.put(FLOATING_POINT_PRECISION,        	 "double" );
 	}
@@ -432,7 +434,7 @@ public class DMLConfig
 				CODEGEN, CODEGEN_COMPILER, CODEGEN_OPTIMIZER, CODEGEN_PLANCACHE, CODEGEN_LITERALS,
 				EXTRA_FINEGRAINED_STATS, STATS_MAX_WRAP_LEN, PRINT_GPU_MEMORY_INFO, CACHING_BUFFER_SIZE,
 				AVAILABLE_GPUS, SYNCHRONIZE_GPU, EAGER_CUDA_FREE, FLOATING_POINT_PRECISION, GPU_EVICTION_POLICY, EVICTION_SHADOW_BUFFERSIZE,
-				GPU_MEMORY_ALLOCATOR, GPU_MEMORY_UTILIZATION_FACTOR, GPU_RECOMPUTE_ACTIVATIONS
+				GPU_MEMORY_ALLOCATOR, GPU_MEMORY_UTILIZATION_FACTOR, GPU_RECOMPUTE_ACTIVATIONS, FORCE_LSTM_CUDNN
 		}; 
 		
 		StringBuilder sb = new StringBuilder();
diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
index f27958f..325107c 100644
--- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
@@ -219,12 +219,13 @@ public class BuiltinFunctionExpression extends DataIdentifier
 			checkMatrixParam(getFifthExpr());
 			
 			// setup output properties
-			if(getOutputs() == null || getOutputs().length != 2) {
+			if(getOutputs() == null || getOutputs().length != 3) {
 				int numOutputs = getOutputs() == null ? 0 : getOutputs().length;
-				raiseValidateError("The builtin function lstm has two outputs, but instead found: " + numOutputs, conditional);
+				raiseValidateError("The builtin function lstm has three outputs, but instead found: " + numOutputs, conditional);
 			}
 			DataIdentifier out = (DataIdentifier) getOutputs()[0];
 			DataIdentifier cy = (DataIdentifier) getOutputs()[1];
+			DataIdentifier cache = (DataIdentifier) getOutputs()[2];
 			
 			// Output1 - out: If `return_sequences` is True, outputs for all timesteps, else outputs for the final timestep.
 			out.setDataType(DataType.MATRIX);
@@ -238,12 +239,17 @@ public class BuiltinFunctionExpression extends DataIdentifier
 			cy.setDimensions(getExpr(4).getOutput().getDim1(), getExpr(4).getOutput().getDim2());
 			cy.setBlockDimensions(getExpr(4).getOutput().getRowsInBlock(), getExpr(4).getOutput().getColumnsInBlock());
 			
+			cache.setDataType(DataType.MATRIX);
+			cache.setValueType(ValueType.DOUBLE);
+			cache.setDimensions(1, 1); // Use dummy dimension for now. 
+			cache.setBlockDimensions(getFirstExpr().getOutput().getRowsInBlock(), getFirstExpr().getOutput().getColumnsInBlock());
+			
 			break;
 		}
 		case LSTM_BACKWARD:
 		{
-			// Input: X, W, b, out0, c0, return_sequences, dout, cy
-			checkNumParameters(8);
+			// Input: X, W, b, out0, c0, return_sequences, dout, cy, cache
+			checkNumParameters(9);
 			checkMatrixParam(getFirstExpr());
 			checkMatrixParam(getSecondExpr());
 			checkMatrixParam(getThirdExpr());
diff --git a/src/main/java/org/apache/sysml/parser/StatementBlock.java b/src/main/java/org/apache/sysml/parser/StatementBlock.java
index 3988a7f..fdd2025 100644
--- a/src/main/java/org/apache/sysml/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysml/parser/StatementBlock.java
@@ -976,12 +976,17 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo
 			throw new LanguageException("Unexpected error.");
 		
 		if ( source instanceof FunctionCallIdentifier ) {
+			// set target properties (based on type info in function call statement return params)
+			FunctionCallIdentifier fci = (FunctionCallIdentifier)source;
+			FunctionStatement fstmt = (FunctionStatement)_dmlProg
+				.getFunctionStatementBlock(fci.getNamespace(), fci.getName()).getStatement(0);
+			if(targetList.size() != fstmt.getOutputParams().size()) {
+				// throws a controlled error if the builtin functions are used incorrectly
+				fci.raiseValidateError("Incorrect number of outputs for the function " + fci.getNamespace() + "::" +  fci.getName() 
+					+ ":" + targetList.size() + " != " + fstmt.getOutputParams().size(), conditional);
+			}
 			for (int j =0; j< targetList.size(); j++) {
 				DataIdentifier target = targetList.get(j);
-				// set target properties (based on type info in function call statement return params)
-				FunctionCallIdentifier fci = (FunctionCallIdentifier)source;
-				FunctionStatement fstmt = (FunctionStatement)_dmlProg
-					.getFunctionStatementBlock(fci.getNamespace(), fci.getName()).getStatement(0);
 				if (fstmt == null){
 					fci.raiseValidateError(" function " + fci.getName() 
 						+ " is undefined in namespace " + fci.getNamespace(), conditional);
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
index f0a44f7..2fa1274 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
@@ -190,6 +190,16 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
 	private boolean _requiresLocalWrite = false; //flag if local write for read obj
 	private boolean _isAcquireFromEmpty = false; //flag if read from status empty 
 	
+	// If the cacheable data is an intermediate cache, then this value is set to identify the type of operator that created this cache.
+	// This avoids unnecessary GPU stalling as well as supports hybrid forward/backward calls. 
+	private int     _intermediateCacheType = -1;
+	public void setIntermediateCacheType(int newValue) {
+		_intermediateCacheType = newValue;
+	}
+	public int getIntermediateCacheType() {
+		return _intermediateCacheType;
+	}
+	
 	//spark-specific handles
 	//note: we use the abstraction of LineageObjects for two reasons: (1) to keep track of cleanup
 	//for lazily evaluated RDDs, and (2) as abstraction for environments that do not necessarily have spark libraries available
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
index 35ac5b6..9167e8c 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
@@ -39,6 +39,25 @@ public class DnnCPInstruction extends UnaryCPInstruction {
 	private static final Log LOG = LogFactory.getLog(DnnCPInstruction.class.getName());
 	private static boolean warnedUnderUtilitization = false;
 	
+	public static enum LSTM_CACHE_TYPE {
+		CP_NN,
+		GPU_CUDNN,
+		GPU_NN;
+		
+		public static LSTM_CACHE_TYPE fromInteger(int x) {
+			switch(x) {
+				case 0:
+					return CP_NN;
+				case 1:
+					return GPU_CUDNN;
+				case 2:
+					return GPU_NN;
+				default:
+					throw new DMLRuntimeException("Unsupported value:" + x);
+			}
+		}
+	}
+	
 	private final CPOperand _in2;
 	private final CPOperand _in3;
 	private final CPOperand _in4;
@@ -46,6 +65,7 @@ public class DnnCPInstruction extends UnaryCPInstruction {
 	private final CPOperand _in6;
 	private final CPOperand _in7;
 	private final CPOperand _in8;
+	private final CPOperand _in9;
 	private final CPOperand _out2;
 	private final CPOperand _out3;
 	private final CPOperand _out4;
@@ -63,7 +83,7 @@ public class DnnCPInstruction extends UnaryCPInstruction {
 		super(CPType.Dnn, null, in, out, opcode, istr);
 		_in2 = in2;
 		_in3 = in3;
-		_in4 = null; _in5 = null; _in6 = null; _in7 = null; _in8 = null;
+		_in4 = null; _in5 = null; _in6 = null; _in7 = null; _in8 = null; _in9 = null;
 		_out2 = null; _out3 = null; _out4 = null; _out5 = null;
 		_stride = stride;
 		_padding = padding;
@@ -112,6 +132,32 @@ public class DnnCPInstruction extends UnaryCPInstruction {
 		_in6 = in6;
 		_in7 = in7;
 		_in8 = in8;
+		_in9 = null;
+		_out2 = out2;
+		_out3 = out3;
+		_out4 = out4;
+		_out5 = out5;
+		_stride = null;
+		_padding = null;
+		_input_shape = null;
+		_filter_shape = null;
+		_numThreads = numThreads;
+		_intermediateMemoryBudget = intermediateMemoryBudget;
+	}
+	
+	public DnnCPInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5,
+			CPOperand in6, CPOperand in7, CPOperand in8, CPOperand in9,
+			CPOperand out, CPOperand out2, CPOperand out3, CPOperand out4, CPOperand out5, String opcode, String istr, 
+			double intermediateMemoryBudget, int numThreads) throws DMLRuntimeException {
+		super(CPType.Dnn, null, in1, out, opcode, istr);
+		_in2 = in2;
+		_in3 = in3;
+		_in4 = in4;
+		_in5 = in5;
+		_in6 = in6;
+		_in7 = in7;
+		_in8 = in8;
+		_in9 = in9;
 		_out2 = out2;
 		_out3 = out3;
 		_out4 = out4;
@@ -262,7 +308,7 @@ public class DnnCPInstruction extends UnaryCPInstruction {
 			return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0, 0);
 		}
 		else if (opcode.equalsIgnoreCase("lstm")) {
-			InstructionUtils.checkNumFields(parts, 9);
+			InstructionUtils.checkNumFields(parts, 10);
 			CPOperand in1 = new CPOperand(parts[1]); // X
 			CPOperand in2 = new CPOperand(parts[2]); // W
 			CPOperand in3 = new CPOperand(parts[3]); // b
@@ -271,11 +317,12 @@ public class DnnCPInstruction extends UnaryCPInstruction {
 			CPOperand in6 = new CPOperand(parts[6]); // return_seq
 			CPOperand out = new CPOperand(parts[7]);  // out
 			CPOperand out2 = new CPOperand(parts[8]); // c
-			int numThreads = Integer.parseInt(parts[9]);
-			return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, null, null, null, opcode, str, 0, numThreads);
+			CPOperand out3 = new CPOperand(parts[9]); // cache
+			int numThreads = Integer.parseInt(parts[10]);
+			return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0, numThreads);
 		}
 		else if (opcode.equalsIgnoreCase("lstm_backward")) {
-			InstructionUtils.checkNumFields(parts, 14);
+			InstructionUtils.checkNumFields(parts, 15);
 			CPOperand in1 = new CPOperand(parts[1]); // X
 			CPOperand in2 = new CPOperand(parts[2]); // W
 			CPOperand in3 = new CPOperand(parts[3]); // b
@@ -284,13 +331,14 @@ public class DnnCPInstruction extends UnaryCPInstruction {
 			CPOperand in6 = new CPOperand(parts[6]); // return_seq
 			CPOperand in7 = new CPOperand(parts[7]); // dout
 			CPOperand in8 = new CPOperand(parts[8]); // dc
-			CPOperand out = new CPOperand(parts[9]);  // dX
-			CPOperand out2 = new CPOperand(parts[10]); // dW
-			CPOperand out3 = new CPOperand(parts[11]); // db
-			CPOperand out4 = new CPOperand(parts[12]); // dout0
-			CPOperand out5 = new CPOperand(parts[13]); // dc0
-			int numThreads = Integer.parseInt(parts[14]);
-			return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0, numThreads);
+			CPOperand in9 = new CPOperand(parts[9]); // cache
+			CPOperand out = new CPOperand(parts[10]);  // dX
+			CPOperand out2 = new CPOperand(parts[11]); // dW
+			CPOperand out3 = new CPOperand(parts[12]); // db
+			CPOperand out4 = new CPOperand(parts[13]); // dout0
+			CPOperand out5 = new CPOperand(parts[14]); // dc0
+			int numThreads = Integer.parseInt(parts[15]);
+			return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, in7, in8, in9, out, out2, out3, out4, out5, opcode, str, 0, numThreads);
 		}
 		else {
 			throw new DMLRuntimeException("Unknown opcode while parsing a DnnCPInstruction: " + str);
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
index fbe7c9d..b243b64 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
@@ -22,16 +22,22 @@ import java.util.ArrayList;
 import jcuda.Pointer;
 import jcuda.jcudnn.JCudnn;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.conf.DMLConfig;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.functionobjects.SwapIndex;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.cp.DnnCPInstruction;
 import org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
 import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
 import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNN;
+import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNNRnnAlgorithm;
 import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType;
 import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysml.runtime.util.DnnUtils;
@@ -45,7 +51,16 @@ public class DnnGPUInstruction extends GPUInstruction {
 		NONE
 	}
 	
-	public static LstmOperator FORCED_LSTM_OP = LstmOperator.NONE;
+	public static LstmOperator FORCED_LSTM_OP; 
+	static {
+		if(ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.FORCE_LSTM_CUDNN)) {
+			FORCED_LSTM_OP = LstmOperator.CUDNN;
+		}
+		else {
+			FORCED_LSTM_OP = LstmOperator.NONE;
+		}
+	}
+	private static final Log LOG = LogFactory.getLog(DnnGPUInstruction.class.getName());
 	
 	private CPOperand _input1;
 	private CPOperand _input2;
@@ -55,6 +70,7 @@ public class DnnGPUInstruction extends GPUInstruction {
 	private CPOperand _input6;
 	private CPOperand _input7;
 	private CPOperand _input8;
+	private CPOperand _input9;
 	private CPOperand _output;
 	private CPOperand _output2;
 	private CPOperand _output3;
@@ -97,6 +113,23 @@ public class DnnGPUInstruction extends GPUInstruction {
 		_intermediateMemoryBudget = intermediateMemoryBudget;
 	}
 	
+	public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5, CPOperand in6, 
+			CPOperand out, CPOperand out2, CPOperand out3, String opcode, String istr, 
+			double intermediateMemoryBudget) throws DMLRuntimeException {
+		super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
+		_input1 = in1;
+		_input2 = in2;
+		_input3 = in3;
+		_input4 = in4;
+		_input5 = in5;
+		_input6 = in6;
+		_gputype = GPUINSTRUCTION_TYPE.Dnn;
+		_output = out;
+		_output2 = out2;
+		_output3 = out3;
+		_intermediateMemoryBudget = intermediateMemoryBudget;
+	}
+	
 	public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5,
 			CPOperand in6, CPOperand in7, CPOperand in8,
 			CPOperand out, CPOperand out2, CPOperand out3, CPOperand out4, CPOperand out5, String opcode, String istr, 
@@ -119,6 +152,29 @@ public class DnnGPUInstruction extends GPUInstruction {
 		_intermediateMemoryBudget = intermediateMemoryBudget;
 	}
 	
+	public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5,
+			CPOperand in6, CPOperand in7, CPOperand in8, CPOperand in9,
+			CPOperand out, CPOperand out2, CPOperand out3, CPOperand out4, CPOperand out5, String opcode, String istr, 
+			double intermediateMemoryBudget) throws DMLRuntimeException {
+		super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
+		_input1 = in1;
+		_input2 = in2;
+		_input3 = in3;
+		_input4 = in4;
+		_input5 = in5;
+		_input6 = in6;
+		_input7 = in7;
+		_input8 = in8;
+		_input9 = in9;
+		_gputype = GPUINSTRUCTION_TYPE.Dnn;
+		_output = out;
+		_output2 = out2;
+		_output3 = out3;
+		_output4 = out4;
+		_output5 = out5;
+		_intermediateMemoryBudget = intermediateMemoryBudget;
+	}
+	
 	public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, 
 			double intermediateMemoryBudget) throws DMLRuntimeException {
 		super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
@@ -360,7 +416,7 @@ public class DnnGPUInstruction extends GPUInstruction {
 			return new DnnGPUInstruction(in, in2, in3, in4, out, opcode, str, 0);
 		}
 		else if (opcode.equalsIgnoreCase("lstm")) {
-			InstructionUtils.checkNumFields(parts, 8);
+			InstructionUtils.checkNumFields(parts, 9);
 			CPOperand in1 = new CPOperand(parts[1]); // X
 			CPOperand in2 = new CPOperand(parts[2]); // W
 			CPOperand in3 = new CPOperand(parts[3]); // b
@@ -369,10 +425,11 @@ public class DnnGPUInstruction extends GPUInstruction {
 			CPOperand in6 = new CPOperand(parts[6]); // return_seq
 			CPOperand out = new CPOperand(parts[7]); // out
 			CPOperand out2 = new CPOperand(parts[8]); // c
-			return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, out, out2, opcode, str, 0);
+			CPOperand out3 = new CPOperand(parts[9]); // cache
+			return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, out, out2, out3, opcode, str, 0);
 		}
 		else if (opcode.equalsIgnoreCase("lstm_backward")) {
-			InstructionUtils.checkNumFields(parts, 13);
+			InstructionUtils.checkNumFields(parts, 14);
 			CPOperand in1 = new CPOperand(parts[1]); // X
 			CPOperand in2 = new CPOperand(parts[2]); // W
 			CPOperand in3 = new CPOperand(parts[3]); // b
@@ -381,12 +438,13 @@ public class DnnGPUInstruction extends GPUInstruction {
 			CPOperand in6 = new CPOperand(parts[6]); // return_seq
 			CPOperand in7 = new CPOperand(parts[7]); // dout
 			CPOperand in8 = new CPOperand(parts[8]); // dc
-			CPOperand out = new CPOperand(parts[9]);  // dX
-			CPOperand out2 = new CPOperand(parts[10]); // dW
-			CPOperand out3 = new CPOperand(parts[11]); // db
-			CPOperand out4 = new CPOperand(parts[12]); // dout0
-			CPOperand out5 = new CPOperand(parts[13]); // dc0
-			return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0);
+			CPOperand cache = new CPOperand(parts[9]); // cache
+			CPOperand out = new CPOperand(parts[10]);  // dX
+			CPOperand out2 = new CPOperand(parts[11]); // dW
+			CPOperand out3 = new CPOperand(parts[12]); // db
+			CPOperand out4 = new CPOperand(parts[13]); // dout0
+			CPOperand out5 = new CPOperand(parts[14]); // dc0
+			return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, in8, cache, out, out2, out3, out4, out5, opcode, str, 0);
 		}
 		else if (opcode.equalsIgnoreCase("batch_norm2d_test")) {
 			InstructionUtils.checkNumFields(parts, 7);
@@ -661,6 +719,25 @@ public class DnnGPUInstruction extends GPUInstruction {
 		return (long)memRequired;
 	}
 	
+	private int getNumRowsLSTMTempCache(LibMatrixCuDNNRnnAlgorithm algo, long N, long T, long D, long M) {
+		return  toInt(
+				// reserve space size
+				((long)Math.ceil( ((double)algo.reserveSpaceSizeInBytes) / LibMatrixCUDA.sizeOfDataType )) + 
+				// cudnnW
+				(D+M+2)*(4*M) + 
+				// cudnnInput
+				(N*T*D));
+		
+	}
+	
+	private Pointer getCudnnWPointer(Pointer cachePointer, LibMatrixCuDNNRnnAlgorithm algo) {
+		return cachePointer.withByteOffset(algo.reserveSpaceSizeInBytes);
+	}
+	
+	private Pointer getCudnnInputPointer(Pointer cachePointer, LibMatrixCuDNNRnnAlgorithm algo, long N, long T, long D, long M) {
+		return cachePointer.withByteOffset(algo.reserveSpaceSizeInBytes + ((D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType));
+	}
+	
 	private void processLstmBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException {
 		MatrixObject out0 = getMatrixInputForGPUInstruction(ec, _input4.getName());
 		long M = out0.getNumColumns(); // hiddenSize .. since out0: (N, M)
@@ -676,8 +753,6 @@ public class DnnGPUInstruction extends GPUInstruction {
 		long numColsX = X.getNumColumns();
 		int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength
 		boolean return_sequences = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getBooleanValue();
-		
-		// long memRequired = getMemRequiredForCuDNNLSTMBackward(N, T, M, D, return_sequences);
 		 
 		String dxName = _output.getName();
 		String dwName = _output2.getName();
@@ -689,37 +764,86 @@ public class DnnGPUInstruction extends GPUInstruction {
 		
 		long memRequired = getMemRequiredForCuDNNLSTMBackward(N, T, M, D, return_sequences);
 		
-		boolean isWSparse = LibMatrixCUDA.isInSparseFormat(gCtx, W);
-		
-		
 		
 		if(FORCED_LSTM_OP == LstmOperator.CUDNN || 
 			N != N1 || // Use CuDNN operator when batch size of previous iteration is different that current iteration
-			(!isWSparse && // Don't use CuDNN kernel when w is sparse.
+			(
+			// ----------------------------------------------------------------------------------
+			// Skip sparse check
+			// !LibMatrixCUDA.isInSparseFormat(gCtx, W) && // Don't use CuDNN kernel when w is sparse.
+			// ----------------------------------------------------------------------------------
 			// When an operator is not forced, then prefer CuDNN kernel if it can fit in the GPU memory
-			FORCED_LSTM_OP == LstmOperator.NONE && gCtx.getMemoryManager().canAllocate(instName, memRequired))) {
+			FORCED_LSTM_OP == LstmOperator.NONE && gCtx.getMemoryManager().canAllocate(instName, 
+					memRequired - getSizeOnDevice(new MatrixObject[] {out0, W, bias, X})))) {
 			// Use CuDNN LSTM kernel
-			Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
-			Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
-			Pointer cudnnWPointer = gCtx.allocate(instName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
-			LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
-					ExecutionConfig.getConfigForSimpleVectorOperations(toInt((D+M+2)*(4*M))),
-					sysmlWPointer, sysmlBiasPointer, cudnnWPointer, D, M);
-			ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-			ec.releaseMatrixInputForGPUInstruction(_input3.getName());
-			Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instName); 
-			Pointer cudnnInput = gCtx.allocate(instName, (N*T*D)*LibMatrixCUDA.sizeOfDataType);
-			LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
-					ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*T*D)),
-					xPointer, cudnnInput, N, D, T*D, N*T*D);
-			ec.releaseMatrixInputForGPUInstruction(_input1.getName());
 			Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instName);
-			LibMatrixCuDNN.cuDNNLstmBackward(ec, gCtx, instName, 
-					cudnnInput, out0Pointer, c0Pointer, cudnnWPointer, doutName, dcyName,  // input
-					dxName, dwName, dbName, dhxName, dcxName, // output 
-					return_sequences, N, M, D, T);
-			gCtx.cudaFreeHelper(instName, cudnnWPointer, gCtx.EAGER_CUDA_FREE);
-			gCtx.cudaFreeHelper(instName, cudnnInput, gCtx.EAGER_CUDA_FREE);
+			try(LibMatrixCuDNNRnnAlgorithm algo = 
+					new LibMatrixCuDNNRnnAlgorithm(ec, gCtx, instName, "lstm", toInt(N), toInt(T), toInt(M), toInt(D), true)) {
+				Pointer cachePtr = null;
+				try {
+					switch(DnnCPInstruction.LSTM_CACHE_TYPE.fromInteger(ec.getMatrixObject(_input9).getIntermediateCacheType())) {
+						case GPU_CUDNN:
+							cachePtr = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, 
+									getMatrixInputForGPUInstruction(ec, _input9.getName()), instName, 
+									getNumRowsLSTMTempCache(algo, N, T, D, M), 1);
+							break;
+						case CP_NN:
+							LOG.warn("Invoking CuDNN lstm backward operator, but the intermediate state was generated by CP lstm nn operator");
+							break;
+						case GPU_NN:
+							LOG.warn("Invoking CuDNN lstm backward operator, but the intermediate state was generated by GPU lstm nn operator");
+							break;
+						default:
+							LOG.warn("Invoking CuDNN lstm forward redundantly in the backward operator. Found unknown cache type.");
+							break;
+					}
+				}
+				catch(DMLRuntimeException e) {
+					LOG.warn("Invoking CuDNN lstm forward redundantly in the backward operator");
+				}
+				if (algo.reserveSpaceSizeInBytes != 0) {
+					algo.reserveSpace = cachePtr;
+				}
+				else {
+					algo.reserveSpace = new Pointer();
+				}
+				
+				Pointer cudnnWPointer = null;
+				Pointer cudnnInput = null;
+				if(cachePtr != null) {
+					cudnnWPointer = getCudnnWPointer(cachePtr, algo);
+					cudnnInput = getCudnnInputPointer(cachePtr, algo, N, T, D, M);
+				}
+				else {
+					Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
+					Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
+					cudnnWPointer = gCtx.allocate(instName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
+					LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
+							ExecutionConfig.getConfigForSimpleVectorOperations(toInt((D+M+2)*(4*M))),
+							sysmlWPointer, sysmlBiasPointer, cudnnWPointer, D, M);
+					ec.releaseMatrixInputForGPUInstruction(_input2.getName());
+					ec.releaseMatrixInputForGPUInstruction(_input3.getName());
+					Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instName); 
+					cudnnInput = gCtx.allocate(instName, (N*T*D)*LibMatrixCUDA.sizeOfDataType);
+					LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
+							ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*T*D)),
+							xPointer, cudnnInput, N, D, T*D, N*T*D);
+					ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+				}
+				
+				LibMatrixCuDNN.cuDNNLstmBackward(ec, gCtx, instName, 
+						cudnnInput, out0Pointer, c0Pointer, cudnnWPointer, doutName, dcyName,  // input
+						dxName, dwName, dbName, dhxName, dcxName, // output 
+						return_sequences, N, M, D, T, algo);
+				if(cachePtr != null) {
+					ec.releaseMatrixInputForGPUInstruction(_input9.getName());
+				}
+				else {
+					gCtx.cudaFreeHelper(instName, cudnnWPointer, gCtx.EAGER_CUDA_FREE);
+					gCtx.cudaFreeHelper(instName, cudnnInput, gCtx.EAGER_CUDA_FREE);
+				}
+			}
+			
 		}
 		else {
 			if(N != N1) {
@@ -727,6 +851,8 @@ public class DnnGPUInstruction extends GPUInstruction {
 						" is different than the batch size of current iteration " + N);
 			}
 			
+			LOG.info("Switching to gpu lstm nn backward operator. (CuDNN memory requirement=" + String.format("%.3f", memRequired*1e-6) + " MB.");
+			
 			Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
 			Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instName); 
 			Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instName);
@@ -781,6 +907,14 @@ public class DnnGPUInstruction extends GPUInstruction {
 		ec.releaseMatrixInputForGPUInstruction(_input5.getName());
 	}
 	
+	private long getSizeOnDevice(MatrixObject[] mObjects) {
+		long ret = 0;
+		for(MatrixObject mo : mObjects) {
+			ret += mo.getGPUObject(gCtx).getSizeOnDevice();
+		}
+		return ret;
+	}
+	
 	private void processLstmInstruction(ExecutionContext ec) throws DMLRuntimeException {
 		// batchSize=N, seqLength=T, numFeatures=D and hiddenSize=M
 		// input  X:(N, T*D), 	==> (T, D, N)
@@ -801,42 +935,62 @@ public class DnnGPUInstruction extends GPUInstruction {
 		long numColsX = X.getNumColumns();
 		long T = numColsX/D; // since X:(N, T*D) ... seqLength
 		boolean return_sequences = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getBooleanValue();
-		
+				
 		long memRequired = getMemRequiredForCuDNNLSTMBackward(N, T, M, D, return_sequences);
 		
-		boolean isWSparse = LibMatrixCUDA.isInSparseFormat(gCtx, W);
-		
 		if(FORCED_LSTM_OP == LstmOperator.CUDNN || 
 			N != N1 || // Use CuDNN operator when batch size of previous iteration is different that current iteration
-			(!isWSparse && // Don't use CuDNN kernel when w is sparse.
+			(
+			// ----------------------------------------------------------------------------------
+			// Skip sparse check
+			// !LibMatrixCUDA.isInSparseFormat(gCtx, W) && // Don't use CuDNN kernel when w is sparse.
+			// ----------------------------------------------------------------------------------
 			// When an operator is not forced, then prefer CuDNN kernel if it can fit in the GPU memory
-			FORCED_LSTM_OP == LstmOperator.NONE && gCtx.getMemoryManager().canAllocate(instName, memRequired))) {
-			Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
-			Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
-			Pointer cudnnWPointer = gCtx.allocate(instName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
-			LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
-					ExecutionConfig.getConfigForSimpleVectorOperations(toInt((D+M+2)*(4*M))),
-					sysmlWPointer, sysmlBiasPointer, cudnnWPointer, toInt(D), toInt(M));
-			ec.releaseMatrixInputForGPUInstruction(_input2.getName()); // W
-			ec.releaseMatrixInputForGPUInstruction(_input3.getName()); // bias
-			// Beause the matrices are released immediately, the output for transpose need not be taken into account
-			Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instName); 
-			Pointer cudnnInput = gCtx.allocate(instName, (N*T*D)*LibMatrixCUDA.sizeOfDataType);
-			LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
-					ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*T*D)),
-					xPointer, cudnnInput, toInt(N), toInt(D), toInt(T*D), toInt(N*T*D));
-			ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-			Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instName); 
-			LibMatrixCuDNN.cuDNNLstm(ec, gCtx, instName, cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), 
-					toInt(N), toInt(M), toInt(D), toInt(T));
-			gCtx.cudaFreeHelper(instName, cudnnWPointer, gCtx.EAGER_CUDA_FREE);
-			gCtx.cudaFreeHelper(instName, cudnnInput, gCtx.EAGER_CUDA_FREE);
+			FORCED_LSTM_OP == LstmOperator.NONE && gCtx.getMemoryManager().canAllocate(instName, 
+					memRequired - getSizeOnDevice(new MatrixObject[] {out0, W, bias, X})))) {
+			Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instName);
+			try(LibMatrixCuDNNRnnAlgorithm algo = new LibMatrixCuDNNRnnAlgorithm(ec, gCtx, instName, "lstm", toInt(N), toInt(T), toInt(M), toInt(D), true)) {
+				int numRows = getNumRowsLSTMTempCache(algo, N, T, D, M);
+				ec.setMetaData(_output3.getName(), numRows, 1);
+				Pointer cachePtr = LibMatrixCuDNN.getDenseOutputPointer(ec, gCtx, instName,  _output3.getName(), numRows, 1);
+				if(algo.reserveSpaceSizeInBytes != 0) {
+					algo.reserveSpace = cachePtr;
+				}
+				else {
+					algo.reserveSpace = new Pointer();
+				}
+				
+				// Compute cudnnWPointer
+				Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instName, D+M, 4*M);
+				Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
+				Pointer cudnnWPointer = getCudnnWPointer(cachePtr, algo); 
+				LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
+						ExecutionConfig.getConfigForSimpleVectorOperations(toInt((D+M+2)*(4*M))),
+						sysmlWPointer, sysmlBiasPointer, cudnnWPointer, toInt(D), toInt(M));
+				ec.releaseMatrixInputForGPUInstruction(_input2.getName()); // W
+				ec.releaseMatrixInputForGPUInstruction(_input3.getName()); // bias
+				
+				// Compute cudnnInput
+				// Because the matrices are released immediately, the output for transpose need not be taken into account
+				Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instName); 
+				Pointer cudnnInput = getCudnnInputPointer(cachePtr, algo, N, T, D, M);
+				LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
+						ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*T*D)),
+						xPointer, cudnnInput, toInt(N), toInt(D), toInt(T*D), toInt(N*T*D));
+				ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+				LibMatrixCuDNN.cuDNNLstm(ec, gCtx, instName, cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), 
+						N, M, D, T, algo);
+				ec.getMatrixObject(_output3.getName()).setIntermediateCacheType(DnnCPInstruction.LSTM_CACHE_TYPE.GPU_CUDNN.ordinal());
+				ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
+			}
+			
 		}
 		else {
 			if(N != N1) {
 				throw new DMLRuntimeException("Unsupported operation: The batch size of previous iteration " + N1 + 
 						" is different than the batch size of current iteration " + N);
 			}
+			LOG.info("Switching to gpu lstm nn operator. (CuDNN memory requirement=" + String.format("%.3f", memRequired*1e-6) + " MB.");
 			
 			Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instName, 1, 4*M);
 			Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instName); 
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
index 04af229..9d263aa 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
@@ -797,7 +797,7 @@ public class GPUObject {
 		setSparseMatrixCudaPointer(tmp);
 	}
 
-	protected long getSizeOnDevice() {
+	public long getSizeOnDevice() {
 		long rlen = mat.getNumRows();
 		long clen = mat.getNumColumns();
 		long nnz = mat.getNnz();
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
index e496ddb..4151234 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
@@ -1032,16 +1032,15 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 					LibMatrixCuMatMult.denseSparseMatMult(gCtx.getCusparseHandle(), instName, dinput, difog_raw, wSparsePointer, param2);
 				}
 			}
-			else
+			else {
 				LibMatrixCuMatMult.denseDenseMatMult(gCtx.getCublasHandle(), instName, dinput, difog_raw, wDensePointer, param2);
+			}
 			
 			// db = db + colSums(difog_raw)  # shape (1, 4M)
-			reduceCol(gCtx, instName, "reduce_col_sum", difog_raw, tmpDb, 1, toInt(4*M));
+			reduceCol(gCtx, instName, "reduce_col_sum", difog_raw, tmpDb, toInt(N), toInt(4*M));
 			matrixMatrixOp(gCtx, instName, tmpDb, db, 1, toInt(4*M), VectorShape.NONE.code(), VectorShape.NONE.code(), db, 
 					new BinaryOperator(Plus.getPlusFnObject()));
 			
-			// jcuda.runtime.JCuda.cudaDeviceSynchronize();
-			
 			int size = toInt(Math.max(N*D, N*M));
 			getCudaKernels(gCtx).launchKernel("postProcessNNLstmBackward",
 					ExecutionConfig.getConfigForSimpleVectorOperations(size),
@@ -1177,18 +1176,20 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 	 * @param M hidden size
 	 * @param D number of features
 	 * @param T sequence length
+	 * @param algo rnn algorithm
 	 * @throws DMLRuntimeException if error
 	 */
 	public static void cuDNNLstm(ExecutionContext ec, GPUContext gCtx, String instName,
 			Pointer X,  Pointer wPointer, Pointer out0, Pointer c0, boolean return_sequences,
-			String outputName, String cyName, int N, int M, int D, int T) throws DMLRuntimeException {
-		cuDNNSingleLayerUnidirectionalRNNForward(ec, gCtx, instName, X, out0, c0, wPointer, outputName, cyName, "lstm", return_sequences, N, M, D, T);
+			String outputName, String cyName, long N, long M, long D, long T, LibMatrixCuDNNRnnAlgorithm algo) throws DMLRuntimeException {
+		cuDNNSingleLayerUnidirectionalRNNForward(ec, gCtx, instName, X, out0, c0, wPointer, outputName, cyName, "lstm", return_sequences, N, M, D, T, algo);
 	}
 	
 	private static void cuDNNSingleLayerUnidirectionalRNNForward(ExecutionContext ec, GPUContext gCtx, String instName,
 			Pointer x, Pointer hx, Pointer cx, Pointer wPointer,  // input
 			String outputName, String cyName,  					 // output
-			String rnnMode, boolean return_sequences, int N, int M, int D, int T) throws DMLRuntimeException {
+			String rnnMode, boolean return_sequences, long N, long M, long D, long T,
+			LibMatrixCuDNNRnnAlgorithm algo) throws DMLRuntimeException {
 		boolean hasCarry = rnnMode.equalsIgnoreCase("lstm");
 		if(LOG.isDebugEnabled()) {
 			long memRequired = (N*T*M + 2*N*M + N*T*M)*sizeOfDataType;
@@ -1201,25 +1202,23 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 		Pointer cyPointer = hasCarry ? getDenseOutputPointer(ec, gCtx, instName, cyName, N, M) : new Pointer();
 		// Pointer wPointer = getDensePointerForCuDNN(gCtx, w, instName, D+M+2, 4*M);
 		
-		try(LibMatrixCuDNNRnnAlgorithm algo = new LibMatrixCuDNNRnnAlgorithm(ec, gCtx, instName, rnnMode, N, T, M, D, true, wPointer)) {
-			JCudnn.cudnnRNNForwardTraining(gCtx.getCudnnHandle(), algo.rnnDesc, T, 
-					algo.xDesc, x, 
-					algo.hxDesc, hx, 
-					algo.cxDesc, cx, 
-					algo.wDesc, wPointer, 
-					algo.yDesc, cudnnYPointer, 
-					algo.hyDesc, hyPointer, 
-					algo.cyDesc, cyPointer, 
-					algo.workSpace, algo.sizeInBytes, 
-					algo.reserveSpace, algo.reserveSpaceSizeInBytes);
-		}
+		JCudnn.cudnnRNNForwardTraining(gCtx.getCudnnHandle(), algo.rnnDesc, toInt(T), 
+				algo.xDesc, x, 
+				algo.hxDesc, hx, 
+				algo.cxDesc, cx, 
+				algo.wDesc, wPointer, 
+				algo.yDesc, cudnnYPointer, 
+				algo.hyDesc, hyPointer, 
+				algo.cyDesc, cyPointer, 
+				algo.workSpace, algo.sizeInBytes, 
+				algo.reserveSpace, algo.reserveSpaceSizeInBytes);
 		
 		if(return_sequences) {
 			gCtx.cudaFreeHelper(instName, hyPointer, gCtx.EAGER_CUDA_FREE);
 			Pointer sysmlYPointer = getDenseOutputPointer(ec, gCtx, instName, outputName, N, T*M);
 			LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_output",
-					ExecutionConfig.getConfigForSimpleVectorOperations(N*T*M),
-					sysmlYPointer, cudnnYPointer, N, T, M, N*T*M);
+					ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*T*M)),
+					sysmlYPointer, cudnnYPointer, toInt(N), toInt(T), toInt(M), toInt(N*T*M));
 		}
 		gCtx.cudaFreeHelper(instName, cudnnYPointer, gCtx.EAGER_CUDA_FREE);
 	}
@@ -1227,7 +1226,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 	public static void cuDNNLstmBackward(ExecutionContext ec, GPUContext gCtx, String instName,
 			Pointer x, Pointer hx, Pointer cx, Pointer wPointer, String doutName, String dcyName,  // input
 			String dxName, String dwName, String dbName, String dhxName, String dcxName,  	// output
-			boolean return_sequences, long N, long M, long D, long T) throws DMLRuntimeException {
+			boolean return_sequences, long N, long M, long D, long T, LibMatrixCuDNNRnnAlgorithm algo) throws DMLRuntimeException {
 		
 		if(LOG.isDebugEnabled()) {
 			long memRequired = (D+M)*4*M // sysmlWPointer
@@ -1252,8 +1251,11 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 				
 		// Allocate intermediate pointers computed by forward
 		Pointer yPointer = gCtx.allocate(instName, N*T*M*sizeOfDataType);
-		try(LibMatrixCuDNNRnnAlgorithm algo = new LibMatrixCuDNNRnnAlgorithm(ec, gCtx, instName, "lstm", toInt(N), toInt(T), 
-				toInt(M), toInt(D), true, wPointer)) {
+		
+		boolean freeReserveSpace = false;
+		if(algo.reserveSpace == null) {
+			freeReserveSpace = true;
+			algo.reserveSpace = gCtx.allocate(instName, algo.reserveSpaceSizeInBytes);
 			JCudnn.cudnnRNNForwardTraining(gCtx.getCudnnHandle(), algo.rnnDesc, toInt(T), 
 					algo.xDesc, x, 
 					algo.hxDesc, hx, 
@@ -1264,59 +1266,68 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
 					algo.cyDesc, new Pointer(), 
 					algo.workSpace, algo.sizeInBytes, 
 					algo.reserveSpace, algo.reserveSpaceSizeInBytes);
-			
-			Pointer cudnnDx = gCtx.allocate(instName, N*T*D*LibMatrixCUDA.sizeOfDataType);
-			JCudnn.cudnnRNNBackwardData(gCtx.getCudnnHandle(), algo.rnnDesc, toInt(T), 
-					algo.yDesc, yPointer,
-					// ----------------------
-					// Additional inputs:
-					algo.dyDesc, dy, 
-					algo.dhyDesc, new Pointer(), 
-					algo.dcyDesc, getDenseInputPointer(ec, gCtx, instName, dcyName, N, M),
-					// ----------------------
-					algo.wDesc, wPointer, 
-					algo.hxDesc, hx,
-					algo.cxDesc, cx,
-					// ----------------------
-					// Output:
-					algo.dxDesc, cudnnDx, 
-					algo.dhxDesc, getDenseOutputPointer(ec, gCtx, instName, dhxName, N, M), 
-					algo.dcxDesc, getDenseOutputPointer(ec, gCtx, instName, dcxName, N, M),
-					// ----------------------
-					algo.workSpace, algo.sizeInBytes, 
-					algo.reserveSpace, algo.reserveSpaceSizeInBytes);
-			gCtx.cudaFreeHelper(instName, dy, gCtx.EAGER_CUDA_FREE);
-			ec.releaseMatrixInputForGPUInstruction(dcyName);
-			ec.releaseMatrixOutputForGPUInstruction(dhxName);
-			ec.releaseMatrixOutputForGPUInstruction(dcxName);
-			
-			Pointer smlDx = getDenseOutputPointer(ec, gCtx, instName, dxName, N, T*D);
-			LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_dinput",
-					ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*T*D)),
-					smlDx, cudnnDx, N, D, T*D, N*T*D);
-			ec.releaseMatrixOutputForGPUInstruction(dxName);
-			gCtx.cudaFreeHelper(instName, cudnnDx, gCtx.EAGER_CUDA_FREE);
-			
-			// -------------------------------------------------------------------------------------------
-			Pointer cudnnDwPointer = gCtx.allocate(instName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
-			JCudnn.cudnnRNNBackwardWeights(gCtx.getCudnnHandle(), algo.rnnDesc, toInt(T), 
-					algo.xDesc, x, 
-					algo.hxDesc, hx, 
-					algo.yDesc, yPointer, 
-					algo.workSpace, algo.sizeInBytes, 
-					algo.dwDesc, cudnnDwPointer, 
-					algo.reserveSpace, algo.reserveSpaceSizeInBytes);
-			LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_dweight",
-					ExecutionConfig.getConfigForSimpleVectorOperations(toInt((D+M+2)*(4*M))),
-					getDenseOutputPointer(ec, gCtx, instName, dwName, D+M, 4*M), 
-					getDenseOutputPointer(ec, gCtx, instName, dbName, 1, 4*M), cudnnDwPointer, D, M);
-			gCtx.cudaFreeHelper(instName, cudnnDwPointer, gCtx.EAGER_CUDA_FREE);
-			ec.releaseMatrixOutputForGPUInstruction(dwName);
-			ec.releaseMatrixOutputForGPUInstruction(dbName);
-			// -------------------------------------------------------------------------------------------
-			
-			gCtx.cudaFreeHelper(instName, yPointer, gCtx.EAGER_CUDA_FREE);
 		}
+		else {
+			if(LOG.isDebugEnabled())
+				LOG.debug("Skipping cudnnRNNForwardTraining call");
+		}
+		
+		Pointer cudnnDx = gCtx.allocate(instName, N*T*D*LibMatrixCUDA.sizeOfDataType);
+		JCudnn.cudnnRNNBackwardData(gCtx.getCudnnHandle(), algo.rnnDesc, toInt(T), 
+				algo.yDesc, yPointer,
+				// ----------------------
+				// Additional inputs:
+				algo.dyDesc, dy, 
+				algo.dhyDesc, new Pointer(), 
+				algo.dcyDesc, getDenseInputPointer(ec, gCtx, instName, dcyName, N, M),
+				// ----------------------
+				algo.wDesc, wPointer, 
+				algo.hxDesc, hx,
+				algo.cxDesc, cx,
+				// ----------------------
+				// Output:
+				algo.dxDesc, cudnnDx, 
+				algo.dhxDesc, getDenseOutputPointer(ec, gCtx, instName, dhxName, N, M), 
+				algo.dcxDesc, getDenseOutputPointer(ec, gCtx, instName, dcxName, N, M),
+				// ----------------------
+				algo.workSpace, algo.sizeInBytes, 
+				algo.reserveSpace, algo.reserveSpaceSizeInBytes);
+		gCtx.cudaFreeHelper(instName, dy, gCtx.EAGER_CUDA_FREE);
+		ec.releaseMatrixInputForGPUInstruction(dcyName);
+		ec.releaseMatrixOutputForGPUInstruction(dhxName);
+		ec.releaseMatrixOutputForGPUInstruction(dcxName);
+		
+		Pointer smlDx = getDenseOutputPointer(ec, gCtx, instName, dxName, N, T*D);
+		LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_dinput",
+				ExecutionConfig.getConfigForSimpleVectorOperations(toInt(N*T*D)),
+				smlDx, cudnnDx, N, D, T*D, N*T*D);
+		ec.releaseMatrixOutputForGPUInstruction(dxName);
+		gCtx.cudaFreeHelper(instName, cudnnDx, gCtx.EAGER_CUDA_FREE);
+		
+		// -------------------------------------------------------------------------------------------
+		Pointer cudnnDwPointer = gCtx.allocate(instName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
+		JCudnn.cudnnRNNBackwardWeights(gCtx.getCudnnHandle(), algo.rnnDesc, toInt(T), 
+				algo.xDesc, x, 
+				algo.hxDesc, hx, 
+				algo.yDesc, yPointer, 
+				algo.workSpace, algo.sizeInBytes, 
+				algo.dwDesc, cudnnDwPointer, 
+				algo.reserveSpace, algo.reserveSpaceSizeInBytes);
+		LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_dweight",
+				ExecutionConfig.getConfigForSimpleVectorOperations(toInt((D+M+2)*(4*M))),
+				getDenseOutputPointer(ec, gCtx, instName, dwName, D+M, 4*M), 
+				getDenseOutputPointer(ec, gCtx, instName, dbName, 1, 4*M), cudnnDwPointer, D, M);
+		gCtx.cudaFreeHelper(instName, cudnnDwPointer, gCtx.EAGER_CUDA_FREE);
+		ec.releaseMatrixOutputForGPUInstruction(dwName);
+		ec.releaseMatrixOutputForGPUInstruction(dbName);
+		// -------------------------------------------------------------------------------------------
+		
+		gCtx.cudaFreeHelper(instName, yPointer, gCtx.EAGER_CUDA_FREE);
+		if(freeReserveSpace) {
+			gCtx.cudaFreeHelper(instName, algo.reserveSpace, gCtx.EAGER_CUDA_FREE);
+			algo.reserveSpace = null;
+		}
+		
 	}
 	
 	
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNRnnAlgorithm.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNRnnAlgorithm.java
index a1d799d..4c2a844 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNRnnAlgorithm.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNRnnAlgorithm.java
@@ -56,10 +56,10 @@ public class LibMatrixCuDNNRnnAlgorithm implements java.lang.AutoCloseable {
 	cudnnFilterDescriptor wDesc;
 	cudnnFilterDescriptor dwDesc;
 	long sizeInBytes; Pointer workSpace;
-	long reserveSpaceSizeInBytes; Pointer reserveSpace;
+	public long reserveSpaceSizeInBytes; public Pointer reserveSpace;
 	long dropOutSizeInBytes; Pointer dropOutStateSpace;
 	public LibMatrixCuDNNRnnAlgorithm(ExecutionContext ec, GPUContext gCtx, String instName, 
-			String rnnMode, int N, int T, int M, int D, boolean isTraining, Pointer w) throws DMLRuntimeException {
+			String rnnMode, int N, int T, int M, int D, boolean isTraining) throws DMLRuntimeException {
 		this.gCtx = gCtx;
 		this.instName = instName;
 		
@@ -113,7 +113,7 @@ public class LibMatrixCuDNNRnnAlgorithm implements java.lang.AutoCloseable {
 		dwDesc = allocateFilterDescriptor(expectedNumWeights);
 		
 		// Setup workspace
-		workSpace = new Pointer(); reserveSpace = new Pointer();
+		workSpace = new Pointer();
 		sizeInBytes = getWorkspaceSize(T);
 		if(sizeInBytes != 0) {
 			if(LOG.isDebugEnabled()) 
@@ -123,11 +123,6 @@ public class LibMatrixCuDNNRnnAlgorithm implements java.lang.AutoCloseable {
 		reserveSpaceSizeInBytes = 0;
 		if(isTraining) {
 			reserveSpaceSizeInBytes = getReservespaceSize(T);
-			if (reserveSpaceSizeInBytes != 0) {
-				if(LOG.isDebugEnabled()) 
-					LOG.debug("Allocating " +  reserveSpaceSizeInBytes + " bytes for lstm reserve space.");
-				reserveSpace = gCtx.allocate(instName, reserveSpaceSizeInBytes);
-			}
 		}
 	}
 	
@@ -277,14 +272,6 @@ public class LibMatrixCuDNNRnnAlgorithm implements java.lang.AutoCloseable {
 			}
 		}
 		workSpace = null;
-		if(reserveSpaceSizeInBytes != 0) {
-			try {
-				gCtx.cudaFreeHelper(instName, reserveSpace, gCtx.EAGER_CUDA_FREE);
-			} catch (DMLRuntimeException e) {
-				throw new RuntimeException(e);
-			}
-		}	
-		reserveSpace = null;
 		if(dropOutSizeInBytes != 0) {
 			try {
 				gCtx.cudaFreeHelper(instName, dropOutStateSpace, gCtx.EAGER_CUDA_FREE);
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuMatMult.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuMatMult.java
index 6dacf28..5d3b527 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuMatMult.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuMatMult.java
@@ -369,6 +369,9 @@ public class LibMatrixCuMatMult extends LibMatrixCUDA {
 	 */
 	static void denseDenseMatMult(cublasHandle handle, String instName, Pointer C, Pointer A, Pointer B,
 			CuMatMultParameters param) {
+		if(A == null || B == null || C == null) {
+			throw new DMLRuntimeException("The input and output pointers are not allocated.");
+		}
 		long t0 = ConfigurationManager.isFinegrainedStatistics() ? System.nanoTime() : 0;
 		String kernel = null;
 		param.rowToColumnMajor();
diff --git a/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
index 4c4ab74..828a809 100644
--- a/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
+++ b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
@@ -117,7 +117,7 @@ public class LstmCPUTest extends GPUTests {
 	
 	public void testLstmCuDNNWithNNLayer(int N, int T, int D, int M, String returnSequences, double sparsity) {
 		String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
-				+ "[output, c] = lstm::forward(x, w, b, " + returnSequences + ", out0, c0)";
+				+ "[output, c, cache] = lstm::forward(x, w, b, " + returnSequences + ", out0, c0)";
 		String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
 				+ "[output, c, cache_out, cache_c, cache_ifog] = lstm::forward(x, w, b, " 
 				+ T + ", " + D + ", " + returnSequences + ", out0, c0)";
@@ -242,7 +242,8 @@ public class LstmCPUTest extends GPUTests {
 		boolean returnSequences1 = returnSequences.equals("TRUE");
 		
 		String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
-				+ "[dX, dW, db, dout0, dc0] = lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0);";
+				+ "[output, c, cache] = lstm::forward(x, w, b, " + returnSequences + ", out0, c0); \n"
+				+ "[dX, dW, db, dout0, dc0] = lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0, cache);";
 		String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
 				+ "[output, c, cache_out, cache_c, cache_ifog] = lstm::forward(x, w, b, " 
 				+ T + ", " + D + ", " + returnSequences + ", out0, c0); \n"
diff --git a/src/test/java/org/apache/sysml/test/gpu/LstmTest.java b/src/test/java/org/apache/sysml/test/gpu/LstmTest.java
index 47afe3a..996b12a 100644
--- a/src/test/java/org/apache/sysml/test/gpu/LstmTest.java
+++ b/src/test/java/org/apache/sysml/test/gpu/LstmTest.java
@@ -109,7 +109,7 @@ public class LstmTest extends GPUTests {
 	
 	public void testLstmCuDNNWithNNBuiltinOperator(int N, int T, int D, int M, String returnSequences, double sparsity) {
 		String scriptStr = "source(" + builtinDML + ") as lstm;\n "
-				+ "[output, c] = lstm::forward(x, w, b, " + returnSequences + ", out0, c0)";
+				+ "[output, c, cache] = lstm::forward(x, w, b, " + returnSequences + ", out0, c0)";
 		
 		HashMap<String, Object> inputs = new HashMap<>();
 		inputs.put("x", generateInputMatrix(spark, N, T*D, 0, 10, sparsity, seed));
@@ -143,7 +143,7 @@ public class LstmTest extends GPUTests {
 	
 	public void testLstmCuDNNWithNNLayer(int N, int T, int D, int M, String returnSequences, double sparsity) {
 		String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
-				+ "[output, c] = lstm::forward(x, w, b, " + returnSequences + ", out0, c0)";
+				+ "[output, c, cache] = lstm::forward(x, w, b, " + returnSequences + ", out0, c0)";
 		String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
 				+ "[output, c, cache_out, cache_c, cache_ifog] = lstm::forward(x, w, b, " 
 				+ T + ", " + D + ", " + returnSequences + ", out0, c0)";
@@ -237,7 +237,8 @@ public class LstmTest extends GPUTests {
 		boolean returnSequences1 = returnSequences.equals("TRUE");
 				
 		String scriptStr = "source(" + builtinDML + ") as lstm;\n "
-				+ "[dX, dW, db, dout0, dc0] = lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0);";
+				+ "[output, c, cache] = lstm::forward(x, w, b, " + returnSequences + ", out0, c0); \n"
+				+ "[dX, dW, db, dout0, dc0] = lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0, cache);";
 		
 		HashMap<String, Object> inputs = new HashMap<>();
 		inputs.put("dout", generateInputMatrix(spark, N, returnSequences1 ? T*M : M, 0, 10, sparsity, seed));
@@ -281,7 +282,8 @@ public class LstmTest extends GPUTests {
 		boolean returnSequences1 = returnSequences.equals("TRUE");
 		
 		String scriptStr1 = "source(" + builtinDML + ") as lstm;\n "
-				+ "[dX, dW, db, dout0, dc0] = lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0);";
+				+ "[output, c, cache] = lstm::forward(x, w, b, " + returnSequences + ", out0, c0); \n"
+				+ "[dX, dW, db, dout0, dc0] = lstm::backward(dout, dc, x, w, b, " + returnSequences + ", out0, c0, cache);";
 		String scriptStr2 = "source(" + nnDML + ") as lstm;\n "
 				+ "[output, c, cache_out, cache_c, cache_ifog] = lstm::forward(x, w, b, " 
 				+ T + ", " + D + ", " + returnSequences + ", out0, c0); \n"