You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2018/10/09 21:59:31 UTC

systemml git commit: [SYSTEMML-445] Improved the performance of batchnorm backward

Repository: systemml
Updated Branches:
  refs/heads/master 512fb9e11 -> 3702df7c1


[SYSTEMML-445] Improved the performance of batchnorm backward

- Added a custom kernel for computing dgamma in batch normalization
layer.
- Also, fixed a minor bug in GPUDenseInputPointerFetcher class.

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/3702df7c
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/3702df7c
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/3702df7c

Branch: refs/heads/master
Commit: 3702df7c1890b8c87c42715260240c604a5c3c64
Parents: 512fb9e
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Tue Oct 9 14:58:09 2018 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Tue Oct 9 14:58:09 2018 -0700

----------------------------------------------------------------------
 src/main/cpp/kernels/SystemML.cu                |  21 +++
 src/main/cpp/kernels/SystemML.ptx               | 188 ++++++++++++++++---
 src/main/java/org/apache/sysml/hops/DnnOp.java  |   8 +-
 src/main/java/org/apache/sysml/hops/Hop.java    |   3 +-
 .../hops/rewrite/RewriteGPUSpecificOps.java     |  22 ++-
 .../org/apache/sysml/lops/DnnTransform.java     |   7 +-
 .../instructions/GPUInstructionParser.java      |   1 +
 .../instructions/gpu/DnnGPUInstruction.java     |  51 ++++-
 .../gpu/GPUDenseInputPointerFetcher.java        |   4 +-
 .../runtime/matrix/data/LibMatrixCUDA.java      |  19 +-
 10 files changed, 285 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index a53d07a..26d7f43 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -2385,3 +2385,24 @@ extern "C" __global__ void invVar_f(float *X, float *C, double eps, unsigned int
   invVar(X, C, eps, size);
 }
 
+template <typename T>
+__device__ void backward_dgamma_tmp(T *ema_mean, T *dout, T *X, T*ema_var, T*ret, int N, int C,
+                         int HW, int CHW, unsigned int NCHW) {
+  int tid = blockIdx.x * blockDim.x + threadIdx.x;
+  int ix = tid / CHW;
+  int iy = tid % CHW;
+  if (ix < N && iy < CHW) {
+    int c = iy / HW;
+    ret[tid] = dout[tid] * ((X[tid] - ema_mean[c]) * ema_var[c]);
+  }
+}
+
+extern "C" __global__ void backward_dgamma_tmp_d(double *ema_mean, double *dout, double *X, double* ema_var, double* ret, 
+	int N, int C, int HW, int CHW, unsigned int NCHW) {
+  backward_dgamma_tmp(ema_mean, dout, X, ema_var, ret, N, C, HW, CHW, NCHW);
+}
+
+extern "C" __global__ void backward_dgamma_tmp_f(double *ema_mean, double *dout, double *X, double* ema_var, double* ret, 
+	int N, int C, int HW, int CHW, int NCHW) {
+  backward_dgamma_tmp(ema_mean, dout, X, ema_var, ret, N, C, HW, CHW, NCHW);
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/cpp/kernels/SystemML.ptx
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.ptx b/src/main/cpp/kernels/SystemML.ptx
index ac04967..3043373 100644
--- a/src/main/cpp/kernels/SystemML.ptx
+++ b/src/main/cpp/kernels/SystemML.ptx
@@ -15084,12 +15084,146 @@ BB123_2:
 	ret;
 }
 
+	// .globl	backward_dgamma_tmp_d
+.visible .entry backward_dgamma_tmp_d(
+	.param .u64 backward_dgamma_tmp_d_param_0,
+	.param .u64 backward_dgamma_tmp_d_param_1,
+	.param .u64 backward_dgamma_tmp_d_param_2,
+	.param .u64 backward_dgamma_tmp_d_param_3,
+	.param .u64 backward_dgamma_tmp_d_param_4,
+	.param .u32 backward_dgamma_tmp_d_param_5,
+	.param .u32 backward_dgamma_tmp_d_param_6,
+	.param .u32 backward_dgamma_tmp_d_param_7,
+	.param .u32 backward_dgamma_tmp_d_param_8,
+	.param .u32 backward_dgamma_tmp_d_param_9
+)
+{
+	.reg .pred 	%p<4>;
+	.reg .b32 	%r<11>;
+	.reg .f64 	%fd<8>;
+	.reg .b64 	%rd<18>;
+
+
+	ld.param.u64 	%rd1, [backward_dgamma_tmp_d_param_0];
+	ld.param.u64 	%rd2, [backward_dgamma_tmp_d_param_1];
+	ld.param.u64 	%rd3, [backward_dgamma_tmp_d_param_2];
+	ld.param.u64 	%rd4, [backward_dgamma_tmp_d_param_3];
+	ld.param.u64 	%rd5, [backward_dgamma_tmp_d_param_4];
+	ld.param.u32 	%r4, [backward_dgamma_tmp_d_param_5];
+	ld.param.u32 	%r2, [backward_dgamma_tmp_d_param_7];
+	ld.param.u32 	%r3, [backward_dgamma_tmp_d_param_8];
+	mov.u32 	%r5, %ctaid.x;
+	mov.u32 	%r6, %ntid.x;
+	mov.u32 	%r7, %tid.x;
+	mad.lo.s32 	%r1, %r6, %r5, %r7;
+	div.s32 	%r8, %r1, %r3;
+	setp.lt.s32	%p1, %r8, %r4;
+	setp.gt.s32	%p2, %r3, -1;
+	and.pred  	%p3, %p1, %p2;
+	@!%p3 bra 	BB124_2;
+	bra.uni 	BB124_1;
+
+BB124_1:
+	rem.s32 	%r9, %r1, %r3;
+	cvta.to.global.u64 	%rd6, %rd2;
+	mul.wide.s32 	%rd7, %r1, 8;
+	add.s64 	%rd8, %rd6, %rd7;
+	cvta.to.global.u64 	%rd9, %rd3;
+	add.s64 	%rd10, %rd9, %rd7;
+	div.s32 	%r10, %r9, %r2;
+	cvta.to.global.u64 	%rd11, %rd1;
+	mul.wide.s32 	%rd12, %r10, 8;
+	add.s64 	%rd13, %rd11, %rd12;
+	ld.global.f64 	%fd1, [%rd13];
+	ld.global.f64 	%fd2, [%rd10];
+	sub.f64 	%fd3, %fd2, %fd1;
+	cvta.to.global.u64 	%rd14, %rd4;
+	add.s64 	%rd15, %rd14, %rd12;
+	ld.global.f64 	%fd4, [%rd15];
+	mul.f64 	%fd5, %fd3, %fd4;
+	ld.global.f64 	%fd6, [%rd8];
+	mul.f64 	%fd7, %fd6, %fd5;
+	cvta.to.global.u64 	%rd16, %rd5;
+	add.s64 	%rd17, %rd16, %rd7;
+	st.global.f64 	[%rd17], %fd7;
+
+BB124_2:
+	ret;
+}
+
+	// .globl	backward_dgamma_tmp_f
+.visible .entry backward_dgamma_tmp_f(
+	.param .u64 backward_dgamma_tmp_f_param_0,
+	.param .u64 backward_dgamma_tmp_f_param_1,
+	.param .u64 backward_dgamma_tmp_f_param_2,
+	.param .u64 backward_dgamma_tmp_f_param_3,
+	.param .u64 backward_dgamma_tmp_f_param_4,
+	.param .u32 backward_dgamma_tmp_f_param_5,
+	.param .u32 backward_dgamma_tmp_f_param_6,
+	.param .u32 backward_dgamma_tmp_f_param_7,
+	.param .u32 backward_dgamma_tmp_f_param_8,
+	.param .u32 backward_dgamma_tmp_f_param_9
+)
+{
+	.reg .pred 	%p<4>;
+	.reg .b32 	%r<11>;
+	.reg .f64 	%fd<8>;
+	.reg .b64 	%rd<18>;
+
+
+	ld.param.u64 	%rd1, [backward_dgamma_tmp_f_param_0];
+	ld.param.u64 	%rd2, [backward_dgamma_tmp_f_param_1];
+	ld.param.u64 	%rd3, [backward_dgamma_tmp_f_param_2];
+	ld.param.u64 	%rd4, [backward_dgamma_tmp_f_param_3];
+	ld.param.u64 	%rd5, [backward_dgamma_tmp_f_param_4];
+	ld.param.u32 	%r4, [backward_dgamma_tmp_f_param_5];
+	ld.param.u32 	%r2, [backward_dgamma_tmp_f_param_7];
+	ld.param.u32 	%r3, [backward_dgamma_tmp_f_param_8];
+	mov.u32 	%r5, %ctaid.x;
+	mov.u32 	%r6, %ntid.x;
+	mov.u32 	%r7, %tid.x;
+	mad.lo.s32 	%r1, %r6, %r5, %r7;
+	div.s32 	%r8, %r1, %r3;
+	setp.lt.s32	%p1, %r8, %r4;
+	setp.gt.s32	%p2, %r3, -1;
+	and.pred  	%p3, %p1, %p2;
+	@!%p3 bra 	BB125_2;
+	bra.uni 	BB125_1;
+
+BB125_1:
+	rem.s32 	%r9, %r1, %r3;
+	cvta.to.global.u64 	%rd6, %rd2;
+	mul.wide.s32 	%rd7, %r1, 8;
+	add.s64 	%rd8, %rd6, %rd7;
+	cvta.to.global.u64 	%rd9, %rd3;
+	add.s64 	%rd10, %rd9, %rd7;
+	div.s32 	%r10, %r9, %r2;
+	cvta.to.global.u64 	%rd11, %rd1;
+	mul.wide.s32 	%rd12, %r10, 8;
+	add.s64 	%rd13, %rd11, %rd12;
+	ld.global.f64 	%fd1, [%rd13];
+	ld.global.f64 	%fd2, [%rd10];
+	sub.f64 	%fd3, %fd2, %fd1;
+	cvta.to.global.u64 	%rd14, %rd4;
+	add.s64 	%rd15, %rd14, %rd12;
+	ld.global.f64 	%fd4, [%rd15];
+	mul.f64 	%fd5, %fd3, %fd4;
+	ld.global.f64 	%fd6, [%rd8];
+	mul.f64 	%fd7, %fd6, %fd5;
+	cvta.to.global.u64 	%rd16, %rd5;
+	add.s64 	%rd17, %rd16, %rd7;
+	st.global.f64 	[%rd17], %fd7;
+
+BB125_2:
+	ret;
+}
+
 .func  (.param .b64 func_retval0) __internal_trig_reduction_slowpathd(
 	.param .b64 __internal_trig_reduction_slowpathd_param_0,
 	.param .b64 __internal_trig_reduction_slowpathd_param_1
 )
 {
-	.local .align 8 .b8 	__local_depot124[40];
+	.local .align 8 .b8 	__local_depot126[40];
 	.reg .b64 	%SP;
 	.reg .b64 	%SPL;
 	.reg .pred 	%p<9>;
@@ -15098,7 +15232,7 @@ BB123_2:
 	.reg .b64 	%rd<102>;
 
 
-	mov.u64 	%rd101, __local_depot124;
+	mov.u64 	%rd101, __local_depot126;
 	cvta.local.u64 	%SP, %rd101;
 	ld.param.f64 	%fd4, [__internal_trig_reduction_slowpathd_param_0];
 	ld.param.u64 	%rd37, [__internal_trig_reduction_slowpathd_param_1];
@@ -15112,7 +15246,7 @@ BB123_2:
 	shr.u32 	%r3, %r1, 20;
 	bfe.u32 	%r4, %r1, 20, 11;
 	setp.eq.s32	%p1, %r4, 2047;
-	@%p1 bra 	BB124_13;
+	@%p1 bra 	BB126_13;
 
 	add.s32 	%r15, %r4, -1024;
 	shr.u32 	%r16, %r15, 6;
@@ -15125,7 +15259,7 @@ BB123_2:
 	mov.u64 	%rd94, 0;
 	setp.ge.s32	%p2, %r5, %r6;
 	mov.u64 	%rd93, %rd1;
-	@%p2 bra 	BB124_4;
+	@%p2 bra 	BB126_4;
 
 	mov.b64 	 %rd41, %fd4;
 	shl.b64 	%rd42, %rd41, 11;
@@ -15142,7 +15276,7 @@ BB123_2:
 	mov.u64 	%rd91, %rd1;
 	mov.u32 	%r39, %r5;
 
-BB124_3:
+BB126_3:
 	.pragma "nounroll";
 	ld.const.u64 	%rd47, [%rd89];
 	// inline asm
@@ -15172,15 +15306,15 @@ BB124_3:
 	add.s64 	%rd93, %rd93, 8;
 	add.s64 	%rd89, %rd89, 8;
 	setp.lt.s32	%p3, %r39, %r6;
-	@%p3 bra 	BB124_3;
+	@%p3 bra 	BB126_3;
 
-BB124_4:
+BB126_4:
 	st.local.u64 	[%rd93], %rd94;
 	ld.local.u64 	%rd95, [%rd1+16];
 	ld.local.u64 	%rd96, [%rd1+24];
 	and.b32  	%r9, %r3, 63;
 	setp.eq.s32	%p4, %r9, 0;
-	@%p4 bra 	BB124_6;
+	@%p4 bra 	BB126_6;
 
 	mov.u32 	%r27, 64;
 	sub.s32 	%r28, %r27, %r9;
@@ -15192,7 +15326,7 @@ BB124_4:
 	shr.u64 	%rd55, %rd54, %r28;
 	or.b64  	%rd95, %rd55, %rd53;
 
-BB124_6:
+BB126_6:
 	cvta.to.local.u64 	%rd56, %rd37;
 	shr.u64 	%rd57, %rd96, 62;
 	cvt.u32.u64	%r29, %rd57;
@@ -15209,7 +15343,7 @@ BB124_6:
 	selp.b32	%r34, %r32, %r33, %p5;
 	st.local.u32 	[%rd56], %r34;
 	setp.eq.s32	%p6, %r31, 0;
-	@%p6 bra 	BB124_8;
+	@%p6 bra 	BB126_8;
 
 	mov.u64 	%rd64, 0;
 	// inline asm
@@ -15229,10 +15363,10 @@ BB124_6:
 	// inline asm
 	xor.b32  	%r40, %r40, -2147483648;
 
-BB124_8:
+BB126_8:
 	clz.b64 	%r41, %rd98;
 	setp.eq.s32	%p7, %r41, 0;
-	@%p7 bra 	BB124_10;
+	@%p7 bra 	BB126_10;
 
 	shl.b64 	%rd67, %rd98, %r41;
 	mov.u32 	%r35, 64;
@@ -15240,7 +15374,7 @@ BB124_8:
 	shr.u64 	%rd68, %rd97, %r36;
 	or.b64  	%rd98, %rd68, %rd67;
 
-BB124_10:
+BB126_10:
 	mov.u64 	%rd72, -3958705157555305931;
 	// inline asm
 	{
@@ -15261,7 +15395,7 @@ BB124_10:
 	}
 	// inline asm
 	setp.lt.s64	%p8, %rd100, 1;
-	@%p8 bra 	BB124_12;
+	@%p8 bra 	BB126_12;
 
 	// inline asm
 	{
@@ -15280,7 +15414,7 @@ BB124_10:
 	// inline asm
 	add.s32 	%r41, %r41, 1;
 
-BB124_12:
+BB126_12:
 	cvt.u64.u32	%rd79, %r40;
 	shl.b64 	%rd80, %rd79, 32;
 	mov.u32 	%r37, 1022;
@@ -15295,7 +15429,7 @@ BB124_12:
 	or.b64  	%rd88, %rd87, %rd80;
 	mov.b64 	 %fd4, %rd88;
 
-BB124_13:
+BB126_13:
 	st.param.f64	[func_retval0+0], %fd4;
 	ret;
 }
@@ -15323,7 +15457,7 @@ BB124_13:
 	}
 	shr.u32 	%r51, %r50, 20;
 	setp.ne.s32	%p1, %r51, 0;
-	@%p1 bra 	BB125_2;
+	@%p1 bra 	BB127_2;
 
 	mul.f64 	%fd14, %fd12, 0d4350000000000000;
 	{
@@ -15337,13 +15471,13 @@ BB124_13:
 	shr.u32 	%r16, %r50, 20;
 	add.s32 	%r51, %r16, -54;
 
-BB125_2:
+BB127_2:
 	add.s32 	%r52, %r51, -1023;
 	and.b32  	%r17, %r50, -2146435073;
 	or.b32  	%r18, %r17, 1072693248;
 	mov.b64 	%fd135, {%r49, %r18};
 	setp.lt.u32	%p2, %r18, 1073127583;
-	@%p2 bra 	BB125_4;
+	@%p2 bra 	BB127_4;
 
 	{
 	.reg .b32 %temp; 
@@ -15357,7 +15491,7 @@ BB125_2:
 	mov.b64 	%fd135, {%r19, %r21};
 	add.s32 	%r52, %r51, -1022;
 
-BB125_4:
+BB127_4:
 	add.f64 	%fd15, %fd135, 0d3FF0000000000000;
 	rcp.approx.ftz.f64 	%fd16, %fd15;
 	neg.f64 	%fd17, %fd15;
@@ -15520,13 +15654,13 @@ BB125_4:
 	mov.b32 	 %f2, %r35;
 	abs.f32 	%f1, %f2;
 	setp.lt.f32	%p4, %f1, 0f4086232B;
-	@%p4 bra 	BB125_7;
+	@%p4 bra 	BB127_7;
 
 	setp.lt.f64	%p5, %fd4, 0d0000000000000000;
 	add.f64 	%fd129, %fd4, 0d7FF0000000000000;
 	selp.f64	%fd136, 0d0000000000000000, %fd129, %p5;
 	setp.geu.f32	%p6, %f1, 0f40874800;
-	@%p6 bra 	BB125_7;
+	@%p6 bra 	BB127_7;
 
 	mov.f64 	%fd134, 0d4338000000000000;
 	mov.f64 	%fd133, 0d3FF71547652B82FE;
@@ -15548,26 +15682,26 @@ BB125_4:
 	mov.b64 	%fd131, {%r44, %r43};
 	mul.f64 	%fd136, %fd130, %fd131;
 
-BB125_7:
+BB127_7:
 	{
 	.reg .b32 %temp; 
 	mov.b64 	{%temp, %r45}, %fd136;
 	}
 	and.b32  	%r46, %r45, 2147483647;
 	setp.ne.s32	%p7, %r46, 2146435072;
-	@%p7 bra 	BB125_9;
+	@%p7 bra 	BB127_9;
 
 	{
 	.reg .b32 %temp; 
 	mov.b64 	{%r47, %temp}, %fd136;
 	}
 	setp.eq.s32	%p8, %r47, 0;
-	@%p8 bra 	BB125_10;
+	@%p8 bra 	BB127_10;
 
-BB125_9:
+BB127_9:
 	fma.rn.f64 	%fd136, %fd136, %fd5, %fd136;
 
-BB125_10:
+BB127_10:
 	st.param.f64	[func_retval0+0], %fd136;
 	ret;
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/hops/DnnOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/DnnOp.java b/src/main/java/org/apache/sysml/hops/DnnOp.java
index c4ce466..7cf5061 100644
--- a/src/main/java/org/apache/sysml/hops/DnnOp.java
+++ b/src/main/java/org/apache/sysml/hops/DnnOp.java
@@ -141,6 +141,7 @@ public class DnnOp extends MultiThreadedHop
 			case UPDATE_EMA:
 			case INV_VAR:
 			case BATCH_NORM2D_BACKWARD_DX:
+			case BATCH_NORM2D_BACKWARD_DGAMMA:
 			{	
 				// GPU-specific operators
 				setLops(constructDnnLops(ExecType.GPU, inputs));
@@ -181,6 +182,7 @@ public class DnnOp extends MultiThreadedHop
 			case CHANNEL_SUMS:
 			case UPDATE_EMA:
 				return 3;
+			case BATCH_NORM2D_BACKWARD_DGAMMA:
 			case UPDATE_NESTEROV_X:
 				return 4;
 			default:
@@ -538,7 +540,7 @@ public class DnnOp extends MultiThreadedHop
 		
 		if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST ||
 			op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
-			op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
+			op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX || op == OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA) {
 			// Same dimension as the first input
 			MatrixCharacteristics[] mc = memo.getAllInputStats(getInput());
 			ret[0] = mc[0].rowsKnown() ? mc[0].getRows() : -1;
@@ -755,7 +757,7 @@ public class DnnOp extends MultiThreadedHop
 	{
 		if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || 
 			op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
-			op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
+			op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX || op == OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA) {
 			// Same dimension as the first input
 			Hop input1 = getInput().get(0);
 			setDim1(input1.getDim1());
@@ -873,7 +875,7 @@ public class DnnOp extends MultiThreadedHop
 		if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS ||
 			op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.RESHAPE_COLMEANS ||
 			op == OpOpDnn.UPDATE_EMA_VAR || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
-			op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
+			op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX || op == OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA) {
 			throw new RuntimeException("getDim method should not be invoked for " + op.name());
 		}
 		try {

http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index c8356e0..82a6669 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1101,7 +1101,7 @@ public abstract class Hop implements ParseInfo
 		CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
 		BIASADD, BIASMULT, BATCH_NORM2D_TEST, CHANNEL_SUMS,
 		UPDATE_NESTEROV_X, RESHAPE_COLMEANS, UPDATE_EMA_VAR, UPDATE_EMA, INV_VAR,
-		BATCH_NORM2D_BACKWARD_DX
+		BATCH_NORM2D_BACKWARD_DX, BATCH_NORM2D_BACKWARD_DGAMMA
 	}
 	
 	public enum DataGenMethod {
@@ -1182,6 +1182,7 @@ public abstract class Hop implements ParseInfo
 		HopsConv2Lops.put(OpOpDnn.UPDATE_EMA, org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_EMA);
 		HopsConv2Lops.put(OpOpDnn.INV_VAR, org.apache.sysml.lops.DnnTransform.OperationTypes.INV_VAR);
 		HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_BACKWARD_DX, org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_BACKWARD_DX);
+		HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA, org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_BACKWARD_DGAMMA);
 	}
 
 	protected static final HashMap<Hop.Direction, org.apache.sysml.lops.PartialAggregate.DirectionTypes> HopsDirection2Lops;

http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
index 577adc3..ab40d7b 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
@@ -170,6 +170,25 @@ public class RewriteGPUSpecificOps extends HopRewriteRuleWithPatternMatcher {
 		return hi;
 	};
 	
+	// Avoids unnecessary intermediates:
+	// mean = cache_mean
+	// centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)
+	// norm = bias_multiply(centered, cache_inv_var)  # shape (N, C*Hin*Win)
+	// # Compute gradients during training
+	// dgamma = util::channel_sums(dout*norm, C, Hin, Win)
+	private static final HopDagPatternMatcher _batchNormDGamma;
+	static {
+		_batchNormDGamma = util_channel_sums(
+				mult(	leaf("dout", MATRIX).fitsOnGPU(3),
+						bias_multiply(bias_add(leaf("X", MATRIX), unaryMinus(leaf("ema_mean", MATRIX))), 
+				leaf("ema_var", MATRIX))), leaf("C", SCALAR), leaf("HW", SCALAR));
+	}
+	private static final Function<Hop, Hop> _batchNormDGammaReplacer = hi -> {
+		LOG.debug("Applied batchNormDGamma rewrite.");
+		Hop newHop = HopRewriteUtils.createDnnOp(_batchNormDGamma, OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA, 
+				"ema_mean", "dout", "X", "ema_var");
+		return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+	};
 		
 	// Pattern 3:
 	private static final HopDagPatternMatcher _batchNormTest;
@@ -282,8 +301,9 @@ public class RewriteGPUSpecificOps extends HopRewriteRuleWithPatternMatcher {
 		if(_rewriters == null) {
 			ArrayList<HopPatternRewriter> rewriters = new ArrayList<>();
 			rewriters.add(new HopPatternRewriter("batchNormdX", _batchNormdX, _batchNormdXReplacer));
-			rewriters.add(new HopPatternRewriter("batchNormUpdatedVar", _batchNormUpdatedVar, _batchNormUpdatedVarReplacer));
 			rewriters.add(new HopPatternRewriter("batchNormTest", _batchNormTest, _batchNormTestReplacer));
+			rewriters.add(new HopPatternRewriter("batchNormUpdatedVar", _batchNormUpdatedVar, _batchNormUpdatedVarReplacer));
+			// rewriters.add(new HopPatternRewriter("batchNormDGamma", _batchNormDGamma, _batchNormDGammaReplacer));
 			rewriters.add(new HopPatternRewriter("channelSums", _channelSums, _channelSumsReplacer));
 			rewriters.add(new HopPatternRewriter("updateNesterovX", _updateNesterovX, _updateNesterovXReplacer));
 			rewriters.add(new HopPatternRewriter("reshapeColMeans", _reshapeColMeans, _reshapeColMeansReplacer));

http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/lops/DnnTransform.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/DnnTransform.java b/src/main/java/org/apache/sysml/lops/DnnTransform.java
index 2d2d5f1..3496b5b 100644
--- a/src/main/java/org/apache/sysml/lops/DnnTransform.java
+++ b/src/main/java/org/apache/sysml/lops/DnnTransform.java
@@ -33,7 +33,7 @@ public class DnnTransform extends Lop
 		CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
 		BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS, BATCH_NORM2D_TEST, 
 		UPDATE_NESTEROV_X, RESHAPE_COLMEANS, UPDATE_EMA_VAR, UPDATE_EMA, INV_VAR,
-		BATCH_NORM2D_BACKWARD_DX
+		BATCH_NORM2D_BACKWARD_DX, BATCH_NORM2D_BACKWARD_DGAMMA
 	}
 	
 	private OperationTypes operation;
@@ -174,6 +174,9 @@ public class DnnTransform extends Lop
 		case UPDATE_NESTEROV_X:
 			return "update_nesterov_x";
 			
+		case BATCH_NORM2D_BACKWARD_DGAMMA:
+			return "batch_norm2d_bwd_dgamma";
+			
 		case BATCH_NORM2D_TEST:
 			return "batch_norm2d_test";
 		
@@ -254,7 +257,7 @@ public class DnnTransform extends Lop
 	
 	@Override
 	public String getInstructions(String input1, String input2, String input3, String input4, String output) {
-		if(operation == OperationTypes.UPDATE_NESTEROV_X) {
+		if(operation == OperationTypes.UPDATE_NESTEROV_X || operation == OperationTypes.BATCH_NORM2D_BACKWARD_DGAMMA) {
 			StringBuilder sb = new StringBuilder();
 			sb.append( getExecType() );
 			

http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index 3480504..c8a0e8d 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -66,6 +66,7 @@ public class GPUInstructionParser  extends InstructionParser
 		String2GPUInstructionType.put( "reshape_colmeans",      GPUINSTRUCTION_TYPE.Dnn);
 		String2GPUInstructionType.put( "inv_var",      			GPUINSTRUCTION_TYPE.Dnn);
 		String2GPUInstructionType.put( "batch_norm2d_bwd_dx",   GPUINSTRUCTION_TYPE.Dnn);
+		String2GPUInstructionType.put( "batch_norm2d_bwd_dgamma",   GPUINSTRUCTION_TYPE.Dnn);
 		
 		// Matrix Multiply Operators
 		String2GPUInstructionType.put( "ba+*",  GPUINSTRUCTION_TYPE.AggregateBinary);

http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
index 6094b6c..4ad4155 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
@@ -127,7 +127,7 @@ public class DnnGPUInstruction extends GPUInstruction {
 	public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String istr, 
 			double intermediateMemoryBudget) throws DMLRuntimeException {
 		super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
-		if( !( opcode.equals("update_nesterov_x")) ) {
+		if( !( opcode.equals("update_nesterov_x") || opcode.equals("batch_norm2d_bwd_dgamma")) ) {
 			throw new DMLRuntimeException("Incorrect opcode: " + opcode);
 		}
 		_input1 = in1;
@@ -339,6 +339,15 @@ public class DnnGPUInstruction extends GPUInstruction {
 			CPOperand out = new CPOperand(parts[5]);
 			return new DnnGPUInstruction(in, in2, in3, in4, out, opcode, str, 0);
 		}
+		else if (opcode.equalsIgnoreCase("batch_norm2d_bwd_dgamma")) {
+			InstructionUtils.checkNumFields(parts, 5);
+			CPOperand in = new CPOperand(parts[1]);
+			CPOperand in2 = new CPOperand(parts[2]);
+			CPOperand in3 = new CPOperand(parts[3]);
+			CPOperand in4 = new CPOperand(parts[4]);
+			CPOperand out = new CPOperand(parts[5]);
+			return new DnnGPUInstruction(in, in2, in3, in4, out, opcode, str, 0);
+		}
 		else if (opcode.equalsIgnoreCase("lstm")) {
 			InstructionUtils.checkNumFields(parts, 8);
 			CPOperand in1 = new CPOperand(parts[1]);
@@ -586,6 +595,42 @@ public class DnnGPUInstruction extends GPUInstruction {
 		}
 	}
 	
+	// "ema_mean", "dout", "X", "ema_var"
+		private void processBatchNorm2dBackwardDGammaInstruction(ExecutionContext ec) {
+			try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+				fetcher.add("ema_mean", _input1).add("dout", _input2).add("X", _input3)
+				.add("ema_var", _input4);
+				MatrixObject ema_mean = fetcher.getInputMatrixObject("ema_mean");
+				MatrixObject dout = fetcher.getInputMatrixObject("dout");
+				long C = ema_mean.getNumRows();
+				long N = dout.getNumRows();
+				long CHW = dout.getNumColumns();
+				fetcher.validateDimensions("ema_mean", C, 1);
+				fetcher.validateDimensions("dout", N, CHW);
+				fetcher.validateDimensions("X", N, CHW);
+				fetcher.validateDimensions("ema_var", C, 1);
+				if(CHW % C != 0) {
+					throw new DMLRuntimeException("Incorrect dimensions: C=" + C + ", CHW=" + CHW);
+				}
+				long HW = CHW / C;
+				Pointer tmp = gCtx.allocate(instName, N*CHW*LibMatrixCUDA.sizeOfDataType);
+				// jcuda.runtime.JCuda.cudaDeviceSynchronize();
+				LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("backward_dgamma_tmp", 
+						ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(N*CHW)),
+						fetcher.getInputPointer("ema_mean"), 
+						fetcher.getInputPointer("dout"),
+						fetcher.getInputPointer("X"),
+						fetcher.getInputPointer("ema_var"),
+						tmp,
+						// N, C, HW, CHW, NCHW
+						toInt(N), toInt(C), toInt(HW), toInt(CHW), N*CHW);
+				
+				LibMatrixCUDA.channelSums(gCtx, instName, 
+						tmp, fetcher.getOutputPointer(C, 1), N, C, HW);
+				gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE);
+			}
+		}
+	
 	private static int toInt(long num) throws DMLRuntimeException {
 		if(num >= Integer.MAX_VALUE || num <= Integer.MIN_VALUE) {
 			throw new DMLRuntimeException("GPU : Exceeded supported size " + num);
@@ -734,6 +779,10 @@ public class DnnGPUInstruction extends GPUInstruction {
 			processNesterovUpdateInstruction(ec);
 			return;
 		}
+		else if (instOpcode.equalsIgnoreCase("batch_norm2d_bwd_dgamma")) {
+			processBatchNorm2dBackwardDGammaInstruction(ec);
+			return;
+		}
 		else if (instOpcode.equalsIgnoreCase("update_ema_var")) {
 			processUpdateEMAVarInstruction(ec);
 			return;

http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
index 1ab3420..06bd1df 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
@@ -94,10 +94,10 @@ public class GPUDenseInputPointerFetcher implements java.lang.AutoCloseable {
 	public void validateDimensions(String var, long numRows, long numCols) {
 		MatrixObject mo = getInputMatrixObject(var);
 		if(numRows > 0 && mo.getNumRows() != numRows) {
-			throw new DMLRuntimeException("Expected number of rows of subgrp_means to be " + numRows + ", but found " + mo.getNumRows());
+			throw new DMLRuntimeException("Expected number of rows of " + var + " to be " + numRows + ", but found " + mo.getNumRows());
 		}
 		else if(numCols > 0 && mo.getNumColumns() != numCols) {
-			throw new DMLRuntimeException("Expected number of columns of subgrp_means to be " + numCols + ", but found " + mo.getNumColumns());
+			throw new DMLRuntimeException("Expected number of columns of " + var + " to be " + numCols + ", but found " + mo.getNumColumns());
 		}
 	}
 	@Override

http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index 46ab3f7..00aa578 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -362,10 +362,25 @@ public class LibMatrixCUDA {
 		}
 		Pointer imagePointer = getDensePointer(gCtx, input, instName);
 		Pointer outputPointer = getDensePointer(gCtx, outputBlock, instName);
-		
+		channelSums(gCtx, instName, imagePointer, outputPointer, N, C, HW);
+	}
+	
+	/**
+	 * Perform channel_sums operations: out = rowSums(matrix(colSums(A), rows=C, cols=HW))
+	 * 
+	 * @param gCtx a valid {@link GPUContext}
+	 * @param instName the invoking instruction's name for record {@link Statistics}.
+	 * @param imagePointer  input image pointer
+	 * @param outputPointer output pointer
+	 * @param N number of rows
+	 * @param C number of channels
+	 * @param HW height*width
+	 */
+	public static void channelSums(GPUContext gCtx, String instName, Pointer imagePointer, Pointer outputPointer, long N, long C, long HW) {
+		int cols = toInt(C*HW);
 		// We can replace this with CuDNN tensor reduce
 		Pointer tmp = gCtx.allocate(instName, cols*sizeOfDataType);
-		reduceCol(gCtx, instName, "reduce_col_sum", imagePointer, tmp, N, cols);
+		reduceCol(gCtx, instName, "reduce_col_sum", imagePointer, tmp, toInt(N), cols);
 		reduceRow(gCtx, instName, "reduce_row_sum", tmp, outputPointer, toInt(C), toInt(HW));
 		gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE);
 	}