You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2018/08/10 04:03:47 UTC

systemml git commit: [SYSTEMML-445] Added SGD Nesterov update operator via rewrite for the GPU backend

Repository: systemml
Updated Branches:
  refs/heads/master 5ca8706e9 -> 04bc667f3


[SYSTEMML-445] Added SGD Nesterov update operator via rewrite for the GPU backend

- This leads to 10-15% speedup for ResNet200 with batch size of 32.
- Also, added GPU tests for this operator.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/04bc667f
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/04bc667f
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/04bc667f

Branch: refs/heads/master
Commit: 04bc667f3650d57c0bc9de20e46e7624205cc1e6
Parents: 5ca8706
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Thu Aug 9 21:00:21 2018 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Thu Aug 9 21:00:21 2018 -0700

----------------------------------------------------------------------
 src/main/cpp/kernels/SystemML.cu                |  23 ++-
 src/main/cpp/kernels/SystemML.ptx               | 161 +++++++++++++++----
 src/main/java/org/apache/sysml/hops/DnnOp.java  |  15 +-
 src/main/java/org/apache/sysml/hops/Hop.java    |   4 +-
 .../hops/rewrite/RewriteGPUSpecificOps.java     |  61 +++++++
 .../org/apache/sysml/lops/DnnTransform.java     |  33 +++-
 .../instructions/GPUInstructionParser.java      |   1 +
 .../instructions/gpu/DnnGPUInstruction.java     |  56 +++++++
 .../gpu/context/GPUMemoryManager.java           |   5 +-
 .../org/apache/sysml/test/gpu/GPUTests.java     |  10 ++
 .../org/apache/sysml/test/gpu/SGDUpdate.java    |  91 +++++++++++
 11 files changed, 419 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index 485b7e2..9ddaaff 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -2248,12 +2248,6 @@ extern "C" __global__ void prepare_lstm_dinput_f(float* smlInput, float* cudnnIn
 }
 
 
-/**
- * Do an log over all the elements of a matrix
- * @param A the input matrix (of length = size)
- * @param C the pre-allocated output matrix (of length = size)
- * @param size the length of the input and output matrices
- */
 template <typename T>
 __device__ void colwise_reshape(T *A, T *C, unsigned int size, 
 	unsigned int inRows, unsigned int inCols,
@@ -2278,4 +2272,21 @@ extern "C" __global__ void colwise_reshape_f(float *A, float *C, unsigned int si
 	unsigned int inRows, unsigned int inCols,
 	unsigned int outRows, unsigned int outCols) {
   colwise_reshape(A, C, size, inRows, inCols, outRows, outCols);
+}
+
+// Performs the operation: out = X - mu*v_prev + (1+mu)*v
+template <typename T>
+__device__ void update_nesterov_x(T *X, T *v, T *v_prev, double mu, T *out, unsigned int size) {
+  int index = blockIdx.x * blockDim.x + threadIdx.x;
+  if (index < size) {
+	out[index] = X[index] - mu*v_prev[index] + (1+mu)*v[index];
+  }
+}
+
+extern "C" __global__ void update_nesterov_x_d(double *X, double *v, double *v_prev, double mu, double *out, unsigned int size) {
+  update_nesterov_x(X, v, v_prev, mu, out, size);
+}
+
+extern "C" __global__ void update_nesterov_x_f(float *X, float *v, float *v_prev, double mu, float *out, unsigned int size) {
+  update_nesterov_x(X, v, v_prev, mu, out, size);
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/cpp/kernels/SystemML.ptx
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.ptx b/src/main/cpp/kernels/SystemML.ptx
index 93e5e35..8a14876 100644
--- a/src/main/cpp/kernels/SystemML.ptx
+++ b/src/main/cpp/kernels/SystemML.ptx
@@ -12977,12 +12977,119 @@ BB113_2:
 	ret;
 }
 
+	// .globl	update_nesterov_x_d
+.visible .entry update_nesterov_x_d(
+	.param .u64 update_nesterov_x_d_param_0,
+	.param .u64 update_nesterov_x_d_param_1,
+	.param .u64 update_nesterov_x_d_param_2,
+	.param .f64 update_nesterov_x_d_param_3,
+	.param .u64 update_nesterov_x_d_param_4,
+	.param .u32 update_nesterov_x_d_param_5
+)
+{
+	.reg .pred 	%p<2>;
+	.reg .b32 	%r<6>;
+	.reg .f64 	%fd<9>;
+	.reg .b64 	%rd<14>;
+
+
+	ld.param.u64 	%rd1, [update_nesterov_x_d_param_0];
+	ld.param.u64 	%rd2, [update_nesterov_x_d_param_1];
+	ld.param.u64 	%rd3, [update_nesterov_x_d_param_2];
+	ld.param.f64 	%fd1, [update_nesterov_x_d_param_3];
+	ld.param.u64 	%rd4, [update_nesterov_x_d_param_4];
+	ld.param.u32 	%r2, [update_nesterov_x_d_param_5];
+	mov.u32 	%r3, %ctaid.x;
+	mov.u32 	%r4, %ntid.x;
+	mov.u32 	%r5, %tid.x;
+	mad.lo.s32 	%r1, %r4, %r3, %r5;
+	setp.ge.u32	%p1, %r1, %r2;
+	@%p1 bra 	BB114_2;
+
+	cvta.to.global.u64 	%rd5, %rd1;
+	mul.wide.s32 	%rd6, %r1, 8;
+	add.s64 	%rd7, %rd5, %rd6;
+	cvta.to.global.u64 	%rd8, %rd3;
+	add.s64 	%rd9, %rd8, %rd6;
+	ld.global.f64 	%fd2, [%rd9];
+	mul.f64 	%fd3, %fd2, %fd1;
+	ld.global.f64 	%fd4, [%rd7];
+	sub.f64 	%fd5, %fd4, %fd3;
+	cvta.to.global.u64 	%rd10, %rd2;
+	add.s64 	%rd11, %rd10, %rd6;
+	ld.global.f64 	%fd6, [%rd11];
+	add.f64 	%fd7, %fd1, 0d3FF0000000000000;
+	fma.rn.f64 	%fd8, %fd7, %fd6, %fd5;
+	cvta.to.global.u64 	%rd12, %rd4;
+	add.s64 	%rd13, %rd12, %rd6;
+	st.global.f64 	[%rd13], %fd8;
+
+BB114_2:
+	ret;
+}
+
+	// .globl	update_nesterov_x_f
+.visible .entry update_nesterov_x_f(
+	.param .u64 update_nesterov_x_f_param_0,
+	.param .u64 update_nesterov_x_f_param_1,
+	.param .u64 update_nesterov_x_f_param_2,
+	.param .f64 update_nesterov_x_f_param_3,
+	.param .u64 update_nesterov_x_f_param_4,
+	.param .u32 update_nesterov_x_f_param_5
+)
+{
+	.reg .pred 	%p<2>;
+	.reg .f32 	%f<5>;
+	.reg .b32 	%r<6>;
+	.reg .f64 	%fd<9>;
+	.reg .b64 	%rd<14>;
+
+
+	ld.param.u64 	%rd1, [update_nesterov_x_f_param_0];
+	ld.param.u64 	%rd2, [update_nesterov_x_f_param_1];
+	ld.param.u64 	%rd3, [update_nesterov_x_f_param_2];
+	ld.param.f64 	%fd1, [update_nesterov_x_f_param_3];
+	ld.param.u64 	%rd4, [update_nesterov_x_f_param_4];
+	ld.param.u32 	%r2, [update_nesterov_x_f_param_5];
+	mov.u32 	%r3, %ctaid.x;
+	mov.u32 	%r4, %ntid.x;
+	mov.u32 	%r5, %tid.x;
+	mad.lo.s32 	%r1, %r4, %r3, %r5;
+	setp.ge.u32	%p1, %r1, %r2;
+	@%p1 bra 	BB115_2;
+
+	cvta.to.global.u64 	%rd5, %rd1;
+	mul.wide.s32 	%rd6, %r1, 4;
+	add.s64 	%rd7, %rd5, %rd6;
+	ld.global.f32 	%f1, [%rd7];
+	cvt.f64.f32	%fd2, %f1;
+	cvta.to.global.u64 	%rd8, %rd3;
+	add.s64 	%rd9, %rd8, %rd6;
+	ld.global.f32 	%f2, [%rd9];
+	cvt.f64.f32	%fd3, %f2;
+	mul.f64 	%fd4, %fd3, %fd1;
+	sub.f64 	%fd5, %fd2, %fd4;
+	cvta.to.global.u64 	%rd10, %rd2;
+	add.s64 	%rd11, %rd10, %rd6;
+	ld.global.f32 	%f3, [%rd11];
+	cvt.f64.f32	%fd6, %f3;
+	add.f64 	%fd7, %fd1, 0d3FF0000000000000;
+	fma.rn.f64 	%fd8, %fd7, %fd6, %fd5;
+	cvt.rn.f32.f64	%f4, %fd8;
+	cvta.to.global.u64 	%rd12, %rd4;
+	add.s64 	%rd13, %rd12, %rd6;
+	st.global.f32 	[%rd13], %f4;
+
+BB115_2:
+	ret;
+}
+
 .func  (.param .b64 func_retval0) __internal_trig_reduction_slowpathd(
 	.param .b64 __internal_trig_reduction_slowpathd_param_0,
 	.param .b64 __internal_trig_reduction_slowpathd_param_1
 )
 {
-	.local .align 8 .b8 	__local_depot114[40];
+	.local .align 8 .b8 	__local_depot116[40];
 	.reg .b64 	%SP;
 	.reg .b64 	%SPL;
 	.reg .pred 	%p<9>;
@@ -12991,7 +13098,7 @@ BB113_2:
 	.reg .b64 	%rd<102>;
 
 
-	mov.u64 	%rd101, __local_depot114;
+	mov.u64 	%rd101, __local_depot116;
 	cvta.local.u64 	%SP, %rd101;
 	ld.param.f64 	%fd4, [__internal_trig_reduction_slowpathd_param_0];
 	ld.param.u64 	%rd37, [__internal_trig_reduction_slowpathd_param_1];
@@ -13005,7 +13112,7 @@ BB113_2:
 	shr.u32 	%r3, %r1, 20;
 	bfe.u32 	%r4, %r1, 20, 11;
 	setp.eq.s32	%p1, %r4, 2047;
-	@%p1 bra 	BB114_13;
+	@%p1 bra 	BB116_13;
 
 	add.s32 	%r15, %r4, -1024;
 	shr.u32 	%r16, %r15, 6;
@@ -13018,7 +13125,7 @@ BB113_2:
 	mov.u64 	%rd94, 0;
 	setp.ge.s32	%p2, %r5, %r6;
 	mov.u64 	%rd93, %rd1;
-	@%p2 bra 	BB114_4;
+	@%p2 bra 	BB116_4;
 
 	mov.b64 	 %rd41, %fd4;
 	shl.b64 	%rd42, %rd41, 11;
@@ -13035,7 +13142,7 @@ BB113_2:
 	mov.u64 	%rd91, %rd1;
 	mov.u32 	%r39, %r5;
 
-BB114_3:
+BB116_3:
 	.pragma "nounroll";
 	ld.const.u64 	%rd47, [%rd89];
 	// inline asm
@@ -13065,15 +13172,15 @@ BB114_3:
 	add.s64 	%rd93, %rd93, 8;
 	add.s64 	%rd89, %rd89, 8;
 	setp.lt.s32	%p3, %r39, %r6;
-	@%p3 bra 	BB114_3;
+	@%p3 bra 	BB116_3;
 
-BB114_4:
+BB116_4:
 	st.local.u64 	[%rd93], %rd94;
 	ld.local.u64 	%rd95, [%rd1+16];
 	ld.local.u64 	%rd96, [%rd1+24];
 	and.b32  	%r9, %r3, 63;
 	setp.eq.s32	%p4, %r9, 0;
-	@%p4 bra 	BB114_6;
+	@%p4 bra 	BB116_6;
 
 	mov.u32 	%r27, 64;
 	sub.s32 	%r28, %r27, %r9;
@@ -13085,7 +13192,7 @@ BB114_4:
 	shr.u64 	%rd55, %rd54, %r28;
 	or.b64  	%rd95, %rd55, %rd53;
 
-BB114_6:
+BB116_6:
 	cvta.to.local.u64 	%rd56, %rd37;
 	shr.u64 	%rd57, %rd96, 62;
 	cvt.u32.u64	%r29, %rd57;
@@ -13102,7 +13209,7 @@ BB114_6:
 	selp.b32	%r34, %r32, %r33, %p5;
 	st.local.u32 	[%rd56], %r34;
 	setp.eq.s32	%p6, %r31, 0;
-	@%p6 bra 	BB114_8;
+	@%p6 bra 	BB116_8;
 
 	mov.u64 	%rd64, 0;
 	// inline asm
@@ -13122,10 +13229,10 @@ BB114_6:
 	// inline asm
 	xor.b32  	%r40, %r40, -2147483648;
 
-BB114_8:
+BB116_8:
 	clz.b64 	%r41, %rd98;
 	setp.eq.s32	%p7, %r41, 0;
-	@%p7 bra 	BB114_10;
+	@%p7 bra 	BB116_10;
 
 	shl.b64 	%rd67, %rd98, %r41;
 	mov.u32 	%r35, 64;
@@ -13133,7 +13240,7 @@ BB114_8:
 	shr.u64 	%rd68, %rd97, %r36;
 	or.b64  	%rd98, %rd68, %rd67;
 
-BB114_10:
+BB116_10:
 	mov.u64 	%rd72, -3958705157555305931;
 	// inline asm
 	{
@@ -13154,7 +13261,7 @@ BB114_10:
 	}
 	// inline asm
 	setp.lt.s64	%p8, %rd100, 1;
-	@%p8 bra 	BB114_12;
+	@%p8 bra 	BB116_12;
 
 	// inline asm
 	{
@@ -13173,7 +13280,7 @@ BB114_10:
 	// inline asm
 	add.s32 	%r41, %r41, 1;
 
-BB114_12:
+BB116_12:
 	cvt.u64.u32	%rd79, %r40;
 	shl.b64 	%rd80, %rd79, 32;
 	mov.u32 	%r37, 1022;
@@ -13188,7 +13295,7 @@ BB114_12:
 	or.b64  	%rd88, %rd87, %rd80;
 	mov.b64 	 %fd4, %rd88;
 
-BB114_13:
+BB116_13:
 	st.param.f64	[func_retval0+0], %fd4;
 	ret;
 }
@@ -13216,7 +13323,7 @@ BB114_13:
 	}
 	shr.u32 	%r51, %r50, 20;
 	setp.ne.s32	%p1, %r51, 0;
-	@%p1 bra 	BB115_2;
+	@%p1 bra 	BB117_2;
 
 	mul.f64 	%fd14, %fd12, 0d4350000000000000;
 	{
@@ -13230,13 +13337,13 @@ BB114_13:
 	shr.u32 	%r16, %r50, 20;
 	add.s32 	%r51, %r16, -54;
 
-BB115_2:
+BB117_2:
 	add.s32 	%r52, %r51, -1023;
 	and.b32  	%r17, %r50, -2146435073;
 	or.b32  	%r18, %r17, 1072693248;
 	mov.b64 	%fd135, {%r49, %r18};
 	setp.lt.u32	%p2, %r18, 1073127583;
-	@%p2 bra 	BB115_4;
+	@%p2 bra 	BB117_4;
 
 	{
 	.reg .b32 %temp; 
@@ -13250,7 +13357,7 @@ BB115_2:
 	mov.b64 	%fd135, {%r19, %r21};
 	add.s32 	%r52, %r51, -1022;
 
-BB115_4:
+BB117_4:
 	add.f64 	%fd15, %fd135, 0d3FF0000000000000;
 	rcp.approx.ftz.f64 	%fd16, %fd15;
 	neg.f64 	%fd17, %fd15;
@@ -13413,13 +13520,13 @@ BB115_4:
 	mov.b32 	 %f2, %r35;
 	abs.f32 	%f1, %f2;
 	setp.lt.f32	%p4, %f1, 0f4086232B;
-	@%p4 bra 	BB115_7;
+	@%p4 bra 	BB117_7;
 
 	setp.lt.f64	%p5, %fd4, 0d0000000000000000;
 	add.f64 	%fd129, %fd4, 0d7FF0000000000000;
 	selp.f64	%fd136, 0d0000000000000000, %fd129, %p5;
 	setp.geu.f32	%p6, %f1, 0f40874800;
-	@%p6 bra 	BB115_7;
+	@%p6 bra 	BB117_7;
 
 	mov.f64 	%fd134, 0d4338000000000000;
 	mov.f64 	%fd133, 0d3FF71547652B82FE;
@@ -13441,26 +13548,26 @@ BB115_4:
 	mov.b64 	%fd131, {%r44, %r43};
 	mul.f64 	%fd136, %fd130, %fd131;
 
-BB115_7:
+BB117_7:
 	{
 	.reg .b32 %temp; 
 	mov.b64 	{%temp, %r45}, %fd136;
 	}
 	and.b32  	%r46, %r45, 2147483647;
 	setp.ne.s32	%p7, %r46, 2146435072;
-	@%p7 bra 	BB115_9;
+	@%p7 bra 	BB117_9;
 
 	{
 	.reg .b32 %temp; 
 	mov.b64 	{%r47, %temp}, %fd136;
 	}
 	setp.eq.s32	%p8, %r47, 0;
-	@%p8 bra 	BB115_10;
+	@%p8 bra 	BB117_10;
 
-BB115_9:
+BB117_9:
 	fma.rn.f64 	%fd136, %fd136, %fd5, %fd136;
 
-BB115_10:
+BB117_10:
 	st.param.f64	[func_retval0+0], %fd136;
 	ret;
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/hops/DnnOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/DnnOp.java b/src/main/java/org/apache/sysml/hops/DnnOp.java
index 4e22f59..4ca90f8 100644
--- a/src/main/java/org/apache/sysml/hops/DnnOp.java
+++ b/src/main/java/org/apache/sysml/hops/DnnOp.java
@@ -136,6 +136,7 @@ public class DnnOp extends MultiThreadedHop
 			}
 			case BATCH_NORM2D_TEST:
 			case CHANNEL_SUMS:
+			case UPDATE_NESTEROV_X:
 			{	
 				if(et == ExecType.GPU) {
 					setLops(constructDnnLops(et, inputs));
@@ -175,6 +176,8 @@ public class DnnOp extends MultiThreadedHop
 				return 6;
 			case CHANNEL_SUMS:
 				return 3;
+			case UPDATE_NESTEROV_X:
+				return 4;
 			default:
 				return 13;
 		}
@@ -528,7 +531,9 @@ public class DnnOp extends MultiThreadedHop
 		// [numRows, numCols, NNZ] 
 		long[] ret = new long[3];
 		
-		if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST) {
+		if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST ||
+			op == OpOpDnn.UPDATE_NESTEROV_X) {
+			// Same dimension as the first input
 			MatrixCharacteristics[] mc = memo.getAllInputStats(getInput());
 			ret[0] = mc[0].rowsKnown() ? mc[0].getRows() : -1;
 			ret[1] = mc[0].colsKnown() ? mc[0].getCols() : -1;
@@ -734,7 +739,8 @@ public class DnnOp extends MultiThreadedHop
 	@Override
 	public void refreshSizeInformation()
 	{
-		if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST) {
+		if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.UPDATE_NESTEROV_X) {
+			// Same dimension as the first input
 			Hop input1 = getInput().get(0);
 			setDim1(input1.getDim1());
 			setDim2(input1.getDim2());
@@ -840,8 +846,9 @@ public class DnnOp extends MultiThreadedHop
 	 * @return either -1 or value associated with the dimString
 	 */
 	private long getDim(String dimString) {
-		if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS) {
-			throw new RuntimeException("getDim method should not be invoked for batch_norm_test, channel_sums, bias_add and bias_multiply");
+		if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS ||
+			op == OpOpDnn.UPDATE_NESTEROV_X) {
+			throw new RuntimeException("getDim method should not be invoked for " + op.name());
 		}
 		try {
 			parseInput();

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index 6466575..73a58e3 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1099,7 +1099,8 @@ public abstract class Hop implements ParseInfo
 	public enum OpOpDnn {
 		MAX_POOL, MAX_POOL_BACKWARD, AVG_POOL, AVG_POOL_BACKWARD,
 		CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
-		BIASADD, BIASMULT, BATCH_NORM2D_TEST, CHANNEL_SUMS
+		BIASADD, BIASMULT, BATCH_NORM2D_TEST, CHANNEL_SUMS,
+		UPDATE_NESTEROV_X
 	}
 	
 	public enum DataGenMethod {
@@ -1174,6 +1175,7 @@ public abstract class Hop implements ParseInfo
 		HopsConv2Lops.put(OpOpDnn.CONV2D_BACKWARD_DATA, org.apache.sysml.lops.DnnTransform.OperationTypes.CONV2D_BACKWARD_DATA);
 		HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_TEST, org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_TEST);
 		HopsConv2Lops.put(OpOpDnn.CHANNEL_SUMS, org.apache.sysml.lops.DnnTransform.OperationTypes.CHANNEL_SUMS);
+		HopsConv2Lops.put(OpOpDnn.UPDATE_NESTEROV_X, org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_NESTEROV_X);
 	}
 
 	protected static final HashMap<Hop.Direction, org.apache.sysml.lops.PartialAggregate.DirectionTypes> HopsDirection2Lops;

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
index 2a1699d..b603aa7 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
@@ -124,6 +124,7 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
 			}
 			hi = batchNormTest(hop, hi, i); 
 			hi = channelSums(hop, hi, i); 
+			hi = updateNesterovX(hop, hi, i);
 	
 			if( !descendFirst )
 				rule_GPUKernels(roots, hi, descendFirst);
@@ -281,6 +282,11 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
 				&& getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX;
 	}
 	
+	private static boolean isBinaryMMMinus(Hop h) {
+		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MINUS 
+				&& getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX;
+	}
+	
 	private static boolean isBinaryMSMult(Hop h, double expectedValue) {
 		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT 
 				&& getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR
@@ -323,6 +329,16 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
 				&& getSecondInput(h).getDataType() == DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR;
 	}
 	
+	private static boolean isBinarySMMult(Hop h, double expectedVal) {
+		return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT 
+				&& getSecondInput(h).getDataType() == DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR
+				&& getValue(getFirstInput(h)) == expectedVal;
+	}
+	
+	private static double getValue(Hop h) {
+		return OptimizerUtils.rEvalSimpleDoubleExpression(h, new HashMap<>());
+	}
+	
 	/**
 	 * Checks if the "mean" hop is a moving average of mean in batch normalization layer.
 	 *  
@@ -704,6 +720,51 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
 	// ------------------------------------------------------------
 	
 	/**
+	 * Checks for the nesterov_update_x pattern (X = X - mu*v_prev + (1+mu)*v)
+	 * and returns a new DnnOp if matched
+	 * 
+	 * @param parent parent of the input
+	 * @param hi input to be matched
+	 * @param pos position
+	 * @return a new DnnOp or hi
+	 */
+	private static Hop updateNesterovX(Hop parent, Hop hi, int pos) {
+		if(fitsOnGPU(hi, 4) && isBinaryMMAdd(hi) && isBinaryMMMinus(getFirstInput(hi))
+			&& isBinarySMMult(getSecondInput(getFirstInput(hi))) 
+			&& isBinarySMMult(getSecondInput(hi))) {
+			Hop onePlusMu = getFirstInput(getSecondInput(hi));
+			Hop tmp = getSecondInput(getFirstInput(hi));
+			Hop mu = getFirstInput(tmp);
+			if(isOnePlusMu(onePlusMu, mu)) {
+				Hop v_prev = getSecondInput(tmp);
+				Hop v = getSecondInput(getSecondInput(hi));
+				Hop X = getFirstInput(getFirstInput(hi));
+				if(hasSameDimensions(X, v) && hasSameDimensions(X, v_prev)) {
+					ArrayList<Hop> inHops = new ArrayList<Hop>();
+					inHops.add(X);
+					inHops.add(v);
+					inHops.add(v_prev);
+					inHops.add(mu);
+					LOG.debug("Applied updateNesterovX rewrite.");
+					Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(),
+							OpOpDnn.UPDATE_NESTEROV_X, inHops);
+					return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+				}
+			}
+		}
+		return hi;
+	}
+	
+	private static boolean hasSameDimensions(Hop x, Hop y) {
+		return x.dimsKnown() && y.dimsKnown() && (x.getDim1() == y.getDim1()) && (x.getDim2() == y.getDim2());
+	}
+	
+	private static boolean isOnePlusMu(Hop onePlusMu, Hop mu) {
+		return (isBinarySMMult(onePlusMu, 1.0) && getSecondInput(onePlusMu) == mu) ||
+				getValue(onePlusMu) == getValue(mu) + 1;
+	}
+	
+	/**
 	 * Checks for the batch norm (mode="test") pattern using the helper isBatchNormTrainMean and isBatchNormTrainVar
 	 * and returns a new DnnOp if matched
 	 * 

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/lops/DnnTransform.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/DnnTransform.java b/src/main/java/org/apache/sysml/lops/DnnTransform.java
index 6c61d4a..3183b5f 100644
--- a/src/main/java/org/apache/sysml/lops/DnnTransform.java
+++ b/src/main/java/org/apache/sysml/lops/DnnTransform.java
@@ -31,7 +31,8 @@ public class DnnTransform extends Lop
 		MAX_POOL, MAX_POOL_BACKWARD, AVG_POOL, AVG_POOL_BACKWARD,
 		RELU_MAX_POOLING, RELU_MAX_POOLING_BACKWARD, RELU_BACKWARD,
 		CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
-		BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS, BATCH_NORM2D_TEST
+		BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS, BATCH_NORM2D_TEST, 
+		UPDATE_NESTEROV_X
 	}
 	
 	private OperationTypes operation;
@@ -165,6 +166,9 @@ public class DnnTransform extends Lop
 			
 		case CHANNEL_SUMS:
 			return "channel_sums";
+		
+		case UPDATE_NESTEROV_X:
+			return "update_nesterov_x";
 			
 		case BATCH_NORM2D_TEST:
 			return "batch_norm2d_test";
@@ -232,6 +236,33 @@ public class DnnTransform extends Lop
 	}
 	
 	@Override
+	public String getInstructions(String input1, String input2, String input3, String input4, String output) {
+		if(operation == OperationTypes.UPDATE_NESTEROV_X) {
+			StringBuilder sb = new StringBuilder();
+			sb.append( getExecType() );
+			
+			sb.append( OPERAND_DELIMITOR );
+			sb.append( getOpcode() );
+			sb.append( OPERAND_DELIMITOR );
+			sb.append( getInputs().get(0).prepInputOperand(input1));
+			sb.append( OPERAND_DELIMITOR );
+			sb.append( getInputs().get(1).prepInputOperand(input2));
+			sb.append( OPERAND_DELIMITOR );
+			sb.append( getInputs().get(2).prepInputOperand(input3));
+			sb.append( OPERAND_DELIMITOR );
+			sb.append( getInputs().get(3).prepInputOperand(input4));
+			//output
+			sb.append( OPERAND_DELIMITOR );
+			sb.append( this.prepOutputOperand(output));
+			
+			return sb.toString();
+		}
+		else {
+			throw new LopsException("The operation is not supported with three operands:" + operation.name());
+		}
+	}
+	
+	@Override
 	public String getInstructions(String[] inputs, String output) {
 		StringBuilder sb = new StringBuilder();
 		appendOpcode(sb);

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index c90f9f9..f4122d9 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -63,6 +63,7 @@ public class GPUInstructionParser  extends InstructionParser
 		String2GPUInstructionType.put( "batch_norm2d_backward",  GPUINSTRUCTION_TYPE.Dnn);
 		String2GPUInstructionType.put( "batch_norm2d_test",      GPUINSTRUCTION_TYPE.Dnn);
 		String2GPUInstructionType.put( "batch_norm2d_train",      GPUINSTRUCTION_TYPE.Dnn);
+		String2GPUInstructionType.put( "update_nesterov_x",      GPUINSTRUCTION_TYPE.Dnn);
 		
 		// Matrix Multiply Operators
 		String2GPUInstructionType.put( "ba+*",  GPUINSTRUCTION_TYPE.AggregateBinary);

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
index a36d0fc..8d89032 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
@@ -124,6 +124,21 @@ public class DnnGPUInstruction extends GPUInstruction {
 		_intermediateMemoryBudget = intermediateMemoryBudget;
 	}
 	
+	public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String istr, 
+			double intermediateMemoryBudget) throws DMLRuntimeException {
+		super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
+		if( !opcode.equals("update_nesterov_x") ) {
+			throw new DMLRuntimeException("Incorrect opcode: " + opcode);
+		}
+		_input1 = in1;
+		_input2 = in2;
+		_input3 = in3;
+		_input4 = in4;
+		_gputype = GPUINSTRUCTION_TYPE.Dnn;
+		_output = out;
+		_intermediateMemoryBudget = intermediateMemoryBudget;
+	}
+	
 	public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode,
 			String istr, ArrayList<CPOperand> stride,
 			ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape,
@@ -298,6 +313,15 @@ public class DnnGPUInstruction extends GPUInstruction {
 			CPOperand out = new CPOperand(parts[4]);
 			return new DnnGPUInstruction(in, in2, in3, out, opcode, str, 0);
 		}
+		else if (opcode.equalsIgnoreCase("update_nesterov_x")) {
+			InstructionUtils.checkNumFields(parts, 5);
+			CPOperand in = new CPOperand(parts[1]);
+			CPOperand in2 = new CPOperand(parts[2]);
+			CPOperand in3 = new CPOperand(parts[3]);
+			CPOperand in4 = new CPOperand(parts[4]);
+			CPOperand out = new CPOperand(parts[5]);
+			return new DnnGPUInstruction(in, in2, in3, in4, out, opcode, str, 0);
+		}
 		else if (opcode.equalsIgnoreCase("lstm")) {
 			InstructionUtils.checkNumFields(parts, 8);
 			CPOperand in1 = new CPOperand(parts[1]);
@@ -552,6 +576,34 @@ public class DnnGPUInstruction extends GPUInstruction {
 		ec.releaseMatrixOutputForGPUInstruction(_output.getName());
 	}
 	
+	private void processNesterovUpdateInstruction(ExecutionContext ec) {
+		GPUStatistics.incrementNoOfExecutedGPUInst();;
+		MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName());
+		MatrixObject v = getMatrixInputForGPUInstruction(ec, _input2.getName());
+		MatrixObject v_prev = getMatrixInputForGPUInstruction(ec, _input3.getName());
+		double mu = (int) ec.getScalarInput(_input4.getName(), _input4.getValueType(), _input4.isLiteral()).getDoubleValue();
+		int rows = LibMatrixCUDA.toInt(input.getNumRows());
+		int cols = LibMatrixCUDA.toInt(input.getNumColumns());
+		MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), rows, cols);
+		
+		GPUContext gCtx = ec.getGPUContext(0);
+		String instName = getExtendedOpcode();
+		LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("update_nesterov_x", 
+				ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(rows*cols)),
+				LibMatrixCUDA.getDensePointer(gCtx, input, instName), 
+				LibMatrixCUDA.getDensePointer(gCtx, v, instName),
+				LibMatrixCUDA.getDensePointer(gCtx, v_prev, instName),
+				mu, 
+				LibMatrixCUDA.getDensePointer(gCtx, out, instName),
+				rows*cols);
+		
+		// release inputs/outputs
+		ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+		ec.releaseMatrixInputForGPUInstruction(_input2.getName());
+		ec.releaseMatrixInputForGPUInstruction(_input3.getName());
+		ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+	}
+	
 	private static int toInt(long num) throws DMLRuntimeException {
 		if(num >= Integer.MAX_VALUE || num <= Integer.MIN_VALUE) {
 			throw new DMLRuntimeException("GPU : Exceeded supported size " + num);
@@ -697,6 +749,10 @@ public class DnnGPUInstruction extends GPUInstruction {
 			processChannelSumsInstruction(ec);
 			return;
 		}
+		else if (instOpcode.equalsIgnoreCase("update_nesterov_x")) {
+			processNesterovUpdateInstruction(ec);
+			return;
+		}
 		else if (instOpcode.equalsIgnoreCase("lstm")) {
 			processLstmInstruction(ec);
 			return;

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
index 3d7ab2c..5d0e4bc 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
@@ -405,7 +405,8 @@ public class GPUMemoryManager {
 			allPointers.remove(toFree);
 			lazyCudaFreeMemoryManager.removeIfPresent(size, toFree);
 			allocator.free(toFree);
-			// JCuda.cudaDeviceSynchronize(); // Force a device synchronize after free-ing the pointer for debugging
+			if(DMLScript.SYNCHRONIZE_GPU)
+				jcuda.runtime.JCuda.cudaDeviceSynchronize(); // Force a device synchronize after free-ing the pointer for debugging
 		}
 		else {
 			throw new RuntimeException("Attempting to free an unaccounted pointer:" + toFree);
@@ -447,7 +448,7 @@ public class GPUMemoryManager {
 	public void removeGPUObject(GPUObject gpuObj) {
 		if(LOG.isDebugEnabled())
 			LOG.debug("Removing the GPU object: " + gpuObj);
-		matrixMemoryManager.gpuObjects.removeIf(a -> a.equals(gpuObj));
+		matrixMemoryManager.gpuObjects.remove(gpuObj);
 	}
 
 	

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
index cae2e33..e1ae1ae 100644
--- a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
+++ b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
@@ -311,6 +311,16 @@ public abstract class GPUTests extends AutomatedTestBase {
 		Set<String> heavyHitterOpCodes = Statistics.getCPHeavyHitterOpCodes();
 		Assert.assertTrue(heavyHitterOpCodes.contains(heavyHitterOpCode));
 	}
+	
+	/**
+	 * asserts that the expected op was executed
+	 *
+	 * @param heavyHitterOpCode opcode of the heavy hitter for the unary op
+	 */
+	protected void assertHeavyHitterNotPresent(String heavyHitterOpCode) {
+		Set<String> heavyHitterOpCodes = Statistics.getCPHeavyHitterOpCodes();
+		Assert.assertTrue(!heavyHitterOpCodes.contains(heavyHitterOpCode));
+	}
 
 	/**
 	 * Runs a program on the CPU

http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/test/java/org/apache/sysml/test/gpu/SGDUpdate.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/SGDUpdate.java b/src/test/java/org/apache/sysml/test/gpu/SGDUpdate.java
new file mode 100644
index 0000000..c98a74d
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/gpu/SGDUpdate.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.gpu;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Test;
+
+/**
+ * Tests update rewrites for SGD
+ */
+public class SGDUpdate extends GPUTests {
+
+	private final static String TEST_NAME = "SGDUpdateTests";
+	private final int seed = 42;
+
+	@Override
+	public void setUp() {
+		super.setUp();
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration(TEST_DIR, TEST_NAME);
+		getAndLoadTestConfiguration(TEST_NAME);
+	}
+
+	@Test
+	public void testNesterovRewrite() {
+		String scriptStr = "mu=0.99; output = x - mu*v_prev + (1+mu)*v;" ;
+		int inRows = 10;
+		int inCols = 30;
+		HashMap<String, Object> inputs = new HashMap<>();
+		inputs.put("x", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed));
+		inputs.put("v_prev", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed));
+		inputs.put("v", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed));
+		List<String> outputs = Arrays.asList("output");
+		List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, outputs);
+		List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs);
+		assertHeavyHitterPresent("gpu_update_nesterov_x");
+		assertEqualObjects(outCPU.get(0), outGPU.get(0));
+	}
+	
+	@Test
+	public void testNoNesterovRewrite1() {
+		String scriptStr = "mu=0.99; output = x - mu*v_prev + (1+mu)*v;" ;
+		int inRows = 10;
+		int inCols = 30;
+		HashMap<String, Object> inputs = new HashMap<>();
+		inputs.put("x", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed));
+		inputs.put("v_prev", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed));
+		inputs.put("v", generateInputMatrix(spark, inRows, 1, 0, 10, 0.9, seed));
+		List<String> outputs = Arrays.asList("output");
+		List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, outputs);
+		List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs);
+		assertHeavyHitterNotPresent("gpu_update_nesterov_x");
+		assertEqualObjects(outCPU.get(0), outGPU.get(0));
+	}
+	
+	@Test
+	public void testNoNesterovRewrite2() {
+		String scriptStr = "mu=0.99; output = x - mu*v_prev + mu*v;" ;
+		int inRows = 10;
+		int inCols = 30;
+		HashMap<String, Object> inputs = new HashMap<>();
+		inputs.put("x", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed));
+		inputs.put("v_prev", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed));
+		inputs.put("v", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed));
+		List<String> outputs = Arrays.asList("output");
+		List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, outputs);
+		List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs);
+		assertHeavyHitterNotPresent("gpu_update_nesterov_x");
+		assertEqualObjects(outCPU.get(0), outGPU.get(0));
+	}
+}