You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2018/10/09 21:59:31 UTC
systemml git commit: [SYSTEMML-445] Improved the performance of
batchnorm backward
Repository: systemml
Updated Branches:
refs/heads/master 512fb9e11 -> 3702df7c1
[SYSTEMML-445] Improved the performance of batchnorm backward
- Added a custom kernel for computing dgamma in batch normalization
layer.
- Also, fixed a minor bug in GPUDenseInputPointerFetcher class.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/3702df7c
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/3702df7c
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/3702df7c
Branch: refs/heads/master
Commit: 3702df7c1890b8c87c42715260240c604a5c3c64
Parents: 512fb9e
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Tue Oct 9 14:58:09 2018 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Tue Oct 9 14:58:09 2018 -0700
----------------------------------------------------------------------
src/main/cpp/kernels/SystemML.cu | 21 +++
src/main/cpp/kernels/SystemML.ptx | 188 ++++++++++++++++---
src/main/java/org/apache/sysml/hops/DnnOp.java | 8 +-
src/main/java/org/apache/sysml/hops/Hop.java | 3 +-
.../hops/rewrite/RewriteGPUSpecificOps.java | 22 ++-
.../org/apache/sysml/lops/DnnTransform.java | 7 +-
.../instructions/GPUInstructionParser.java | 1 +
.../instructions/gpu/DnnGPUInstruction.java | 51 ++++-
.../gpu/GPUDenseInputPointerFetcher.java | 4 +-
.../runtime/matrix/data/LibMatrixCUDA.java | 19 +-
10 files changed, 285 insertions(+), 39 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index a53d07a..26d7f43 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -2385,3 +2385,24 @@ extern "C" __global__ void invVar_f(float *X, float *C, double eps, unsigned int
invVar(X, C, eps, size);
}
+template <typename T>
+__device__ void backward_dgamma_tmp(T *ema_mean, T *dout, T *X, T*ema_var, T*ret, int N, int C,
+ int HW, int CHW, unsigned int NCHW) {
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
+ int ix = tid / CHW;
+ int iy = tid % CHW;
+ if (ix < N && iy < CHW) {
+ int c = iy / HW;
+ ret[tid] = dout[tid] * ((X[tid] - ema_mean[c]) * ema_var[c]);
+ }
+}
+
+extern "C" __global__ void backward_dgamma_tmp_d(double *ema_mean, double *dout, double *X, double* ema_var, double* ret,
+ int N, int C, int HW, int CHW, unsigned int NCHW) {
+ backward_dgamma_tmp(ema_mean, dout, X, ema_var, ret, N, C, HW, CHW, NCHW);
+}
+
+extern "C" __global__ void backward_dgamma_tmp_f(double *ema_mean, double *dout, double *X, double* ema_var, double* ret,
+ int N, int C, int HW, int CHW, int NCHW) {
+ backward_dgamma_tmp(ema_mean, dout, X, ema_var, ret, N, C, HW, CHW, NCHW);
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/cpp/kernels/SystemML.ptx
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.ptx b/src/main/cpp/kernels/SystemML.ptx
index ac04967..3043373 100644
--- a/src/main/cpp/kernels/SystemML.ptx
+++ b/src/main/cpp/kernels/SystemML.ptx
@@ -15084,12 +15084,146 @@ BB123_2:
ret;
}
+ // .globl backward_dgamma_tmp_d
+.visible .entry backward_dgamma_tmp_d(
+ .param .u64 backward_dgamma_tmp_d_param_0,
+ .param .u64 backward_dgamma_tmp_d_param_1,
+ .param .u64 backward_dgamma_tmp_d_param_2,
+ .param .u64 backward_dgamma_tmp_d_param_3,
+ .param .u64 backward_dgamma_tmp_d_param_4,
+ .param .u32 backward_dgamma_tmp_d_param_5,
+ .param .u32 backward_dgamma_tmp_d_param_6,
+ .param .u32 backward_dgamma_tmp_d_param_7,
+ .param .u32 backward_dgamma_tmp_d_param_8,
+ .param .u32 backward_dgamma_tmp_d_param_9
+)
+{
+ .reg .pred %p<4>;
+ .reg .b32 %r<11>;
+ .reg .f64 %fd<8>;
+ .reg .b64 %rd<18>;
+
+
+ ld.param.u64 %rd1, [backward_dgamma_tmp_d_param_0];
+ ld.param.u64 %rd2, [backward_dgamma_tmp_d_param_1];
+ ld.param.u64 %rd3, [backward_dgamma_tmp_d_param_2];
+ ld.param.u64 %rd4, [backward_dgamma_tmp_d_param_3];
+ ld.param.u64 %rd5, [backward_dgamma_tmp_d_param_4];
+ ld.param.u32 %r4, [backward_dgamma_tmp_d_param_5];
+ ld.param.u32 %r2, [backward_dgamma_tmp_d_param_7];
+ ld.param.u32 %r3, [backward_dgamma_tmp_d_param_8];
+ mov.u32 %r5, %ctaid.x;
+ mov.u32 %r6, %ntid.x;
+ mov.u32 %r7, %tid.x;
+ mad.lo.s32 %r1, %r6, %r5, %r7;
+ div.s32 %r8, %r1, %r3;
+ setp.lt.s32 %p1, %r8, %r4;
+ setp.gt.s32 %p2, %r3, -1;
+ and.pred %p3, %p1, %p2;
+ @!%p3 bra BB124_2;
+ bra.uni BB124_1;
+
+BB124_1:
+ rem.s32 %r9, %r1, %r3;
+ cvta.to.global.u64 %rd6, %rd2;
+ mul.wide.s32 %rd7, %r1, 8;
+ add.s64 %rd8, %rd6, %rd7;
+ cvta.to.global.u64 %rd9, %rd3;
+ add.s64 %rd10, %rd9, %rd7;
+ div.s32 %r10, %r9, %r2;
+ cvta.to.global.u64 %rd11, %rd1;
+ mul.wide.s32 %rd12, %r10, 8;
+ add.s64 %rd13, %rd11, %rd12;
+ ld.global.f64 %fd1, [%rd13];
+ ld.global.f64 %fd2, [%rd10];
+ sub.f64 %fd3, %fd2, %fd1;
+ cvta.to.global.u64 %rd14, %rd4;
+ add.s64 %rd15, %rd14, %rd12;
+ ld.global.f64 %fd4, [%rd15];
+ mul.f64 %fd5, %fd3, %fd4;
+ ld.global.f64 %fd6, [%rd8];
+ mul.f64 %fd7, %fd6, %fd5;
+ cvta.to.global.u64 %rd16, %rd5;
+ add.s64 %rd17, %rd16, %rd7;
+ st.global.f64 [%rd17], %fd7;
+
+BB124_2:
+ ret;
+}
+
+ // .globl backward_dgamma_tmp_f
+.visible .entry backward_dgamma_tmp_f(
+ .param .u64 backward_dgamma_tmp_f_param_0,
+ .param .u64 backward_dgamma_tmp_f_param_1,
+ .param .u64 backward_dgamma_tmp_f_param_2,
+ .param .u64 backward_dgamma_tmp_f_param_3,
+ .param .u64 backward_dgamma_tmp_f_param_4,
+ .param .u32 backward_dgamma_tmp_f_param_5,
+ .param .u32 backward_dgamma_tmp_f_param_6,
+ .param .u32 backward_dgamma_tmp_f_param_7,
+ .param .u32 backward_dgamma_tmp_f_param_8,
+ .param .u32 backward_dgamma_tmp_f_param_9
+)
+{
+ .reg .pred %p<4>;
+ .reg .b32 %r<11>;
+ .reg .f64 %fd<8>;
+ .reg .b64 %rd<18>;
+
+
+ ld.param.u64 %rd1, [backward_dgamma_tmp_f_param_0];
+ ld.param.u64 %rd2, [backward_dgamma_tmp_f_param_1];
+ ld.param.u64 %rd3, [backward_dgamma_tmp_f_param_2];
+ ld.param.u64 %rd4, [backward_dgamma_tmp_f_param_3];
+ ld.param.u64 %rd5, [backward_dgamma_tmp_f_param_4];
+ ld.param.u32 %r4, [backward_dgamma_tmp_f_param_5];
+ ld.param.u32 %r2, [backward_dgamma_tmp_f_param_7];
+ ld.param.u32 %r3, [backward_dgamma_tmp_f_param_8];
+ mov.u32 %r5, %ctaid.x;
+ mov.u32 %r6, %ntid.x;
+ mov.u32 %r7, %tid.x;
+ mad.lo.s32 %r1, %r6, %r5, %r7;
+ div.s32 %r8, %r1, %r3;
+ setp.lt.s32 %p1, %r8, %r4;
+ setp.gt.s32 %p2, %r3, -1;
+ and.pred %p3, %p1, %p2;
+ @!%p3 bra BB125_2;
+ bra.uni BB125_1;
+
+BB125_1:
+ rem.s32 %r9, %r1, %r3;
+ cvta.to.global.u64 %rd6, %rd2;
+ mul.wide.s32 %rd7, %r1, 8;
+ add.s64 %rd8, %rd6, %rd7;
+ cvta.to.global.u64 %rd9, %rd3;
+ add.s64 %rd10, %rd9, %rd7;
+ div.s32 %r10, %r9, %r2;
+ cvta.to.global.u64 %rd11, %rd1;
+ mul.wide.s32 %rd12, %r10, 8;
+ add.s64 %rd13, %rd11, %rd12;
+ ld.global.f64 %fd1, [%rd13];
+ ld.global.f64 %fd2, [%rd10];
+ sub.f64 %fd3, %fd2, %fd1;
+ cvta.to.global.u64 %rd14, %rd4;
+ add.s64 %rd15, %rd14, %rd12;
+ ld.global.f64 %fd4, [%rd15];
+ mul.f64 %fd5, %fd3, %fd4;
+ ld.global.f64 %fd6, [%rd8];
+ mul.f64 %fd7, %fd6, %fd5;
+ cvta.to.global.u64 %rd16, %rd5;
+ add.s64 %rd17, %rd16, %rd7;
+ st.global.f64 [%rd17], %fd7;
+
+BB125_2:
+ ret;
+}
+
.func (.param .b64 func_retval0) __internal_trig_reduction_slowpathd(
.param .b64 __internal_trig_reduction_slowpathd_param_0,
.param .b64 __internal_trig_reduction_slowpathd_param_1
)
{
- .local .align 8 .b8 __local_depot124[40];
+ .local .align 8 .b8 __local_depot126[40];
.reg .b64 %SP;
.reg .b64 %SPL;
.reg .pred %p<9>;
@@ -15098,7 +15232,7 @@ BB123_2:
.reg .b64 %rd<102>;
- mov.u64 %rd101, __local_depot124;
+ mov.u64 %rd101, __local_depot126;
cvta.local.u64 %SP, %rd101;
ld.param.f64 %fd4, [__internal_trig_reduction_slowpathd_param_0];
ld.param.u64 %rd37, [__internal_trig_reduction_slowpathd_param_1];
@@ -15112,7 +15246,7 @@ BB123_2:
shr.u32 %r3, %r1, 20;
bfe.u32 %r4, %r1, 20, 11;
setp.eq.s32 %p1, %r4, 2047;
- @%p1 bra BB124_13;
+ @%p1 bra BB126_13;
add.s32 %r15, %r4, -1024;
shr.u32 %r16, %r15, 6;
@@ -15125,7 +15259,7 @@ BB123_2:
mov.u64 %rd94, 0;
setp.ge.s32 %p2, %r5, %r6;
mov.u64 %rd93, %rd1;
- @%p2 bra BB124_4;
+ @%p2 bra BB126_4;
mov.b64 %rd41, %fd4;
shl.b64 %rd42, %rd41, 11;
@@ -15142,7 +15276,7 @@ BB123_2:
mov.u64 %rd91, %rd1;
mov.u32 %r39, %r5;
-BB124_3:
+BB126_3:
.pragma "nounroll";
ld.const.u64 %rd47, [%rd89];
// inline asm
@@ -15172,15 +15306,15 @@ BB124_3:
add.s64 %rd93, %rd93, 8;
add.s64 %rd89, %rd89, 8;
setp.lt.s32 %p3, %r39, %r6;
- @%p3 bra BB124_3;
+ @%p3 bra BB126_3;
-BB124_4:
+BB126_4:
st.local.u64 [%rd93], %rd94;
ld.local.u64 %rd95, [%rd1+16];
ld.local.u64 %rd96, [%rd1+24];
and.b32 %r9, %r3, 63;
setp.eq.s32 %p4, %r9, 0;
- @%p4 bra BB124_6;
+ @%p4 bra BB126_6;
mov.u32 %r27, 64;
sub.s32 %r28, %r27, %r9;
@@ -15192,7 +15326,7 @@ BB124_4:
shr.u64 %rd55, %rd54, %r28;
or.b64 %rd95, %rd55, %rd53;
-BB124_6:
+BB126_6:
cvta.to.local.u64 %rd56, %rd37;
shr.u64 %rd57, %rd96, 62;
cvt.u32.u64 %r29, %rd57;
@@ -15209,7 +15343,7 @@ BB124_6:
selp.b32 %r34, %r32, %r33, %p5;
st.local.u32 [%rd56], %r34;
setp.eq.s32 %p6, %r31, 0;
- @%p6 bra BB124_8;
+ @%p6 bra BB126_8;
mov.u64 %rd64, 0;
// inline asm
@@ -15229,10 +15363,10 @@ BB124_6:
// inline asm
xor.b32 %r40, %r40, -2147483648;
-BB124_8:
+BB126_8:
clz.b64 %r41, %rd98;
setp.eq.s32 %p7, %r41, 0;
- @%p7 bra BB124_10;
+ @%p7 bra BB126_10;
shl.b64 %rd67, %rd98, %r41;
mov.u32 %r35, 64;
@@ -15240,7 +15374,7 @@ BB124_8:
shr.u64 %rd68, %rd97, %r36;
or.b64 %rd98, %rd68, %rd67;
-BB124_10:
+BB126_10:
mov.u64 %rd72, -3958705157555305931;
// inline asm
{
@@ -15261,7 +15395,7 @@ BB124_10:
}
// inline asm
setp.lt.s64 %p8, %rd100, 1;
- @%p8 bra BB124_12;
+ @%p8 bra BB126_12;
// inline asm
{
@@ -15280,7 +15414,7 @@ BB124_10:
// inline asm
add.s32 %r41, %r41, 1;
-BB124_12:
+BB126_12:
cvt.u64.u32 %rd79, %r40;
shl.b64 %rd80, %rd79, 32;
mov.u32 %r37, 1022;
@@ -15295,7 +15429,7 @@ BB124_12:
or.b64 %rd88, %rd87, %rd80;
mov.b64 %fd4, %rd88;
-BB124_13:
+BB126_13:
st.param.f64 [func_retval0+0], %fd4;
ret;
}
@@ -15323,7 +15457,7 @@ BB124_13:
}
shr.u32 %r51, %r50, 20;
setp.ne.s32 %p1, %r51, 0;
- @%p1 bra BB125_2;
+ @%p1 bra BB127_2;
mul.f64 %fd14, %fd12, 0d4350000000000000;
{
@@ -15337,13 +15471,13 @@ BB124_13:
shr.u32 %r16, %r50, 20;
add.s32 %r51, %r16, -54;
-BB125_2:
+BB127_2:
add.s32 %r52, %r51, -1023;
and.b32 %r17, %r50, -2146435073;
or.b32 %r18, %r17, 1072693248;
mov.b64 %fd135, {%r49, %r18};
setp.lt.u32 %p2, %r18, 1073127583;
- @%p2 bra BB125_4;
+ @%p2 bra BB127_4;
{
.reg .b32 %temp;
@@ -15357,7 +15491,7 @@ BB125_2:
mov.b64 %fd135, {%r19, %r21};
add.s32 %r52, %r51, -1022;
-BB125_4:
+BB127_4:
add.f64 %fd15, %fd135, 0d3FF0000000000000;
rcp.approx.ftz.f64 %fd16, %fd15;
neg.f64 %fd17, %fd15;
@@ -15520,13 +15654,13 @@ BB125_4:
mov.b32 %f2, %r35;
abs.f32 %f1, %f2;
setp.lt.f32 %p4, %f1, 0f4086232B;
- @%p4 bra BB125_7;
+ @%p4 bra BB127_7;
setp.lt.f64 %p5, %fd4, 0d0000000000000000;
add.f64 %fd129, %fd4, 0d7FF0000000000000;
selp.f64 %fd136, 0d0000000000000000, %fd129, %p5;
setp.geu.f32 %p6, %f1, 0f40874800;
- @%p6 bra BB125_7;
+ @%p6 bra BB127_7;
mov.f64 %fd134, 0d4338000000000000;
mov.f64 %fd133, 0d3FF71547652B82FE;
@@ -15548,26 +15682,26 @@ BB125_4:
mov.b64 %fd131, {%r44, %r43};
mul.f64 %fd136, %fd130, %fd131;
-BB125_7:
+BB127_7:
{
.reg .b32 %temp;
mov.b64 {%temp, %r45}, %fd136;
}
and.b32 %r46, %r45, 2147483647;
setp.ne.s32 %p7, %r46, 2146435072;
- @%p7 bra BB125_9;
+ @%p7 bra BB127_9;
{
.reg .b32 %temp;
mov.b64 {%r47, %temp}, %fd136;
}
setp.eq.s32 %p8, %r47, 0;
- @%p8 bra BB125_10;
+ @%p8 bra BB127_10;
-BB125_9:
+BB127_9:
fma.rn.f64 %fd136, %fd136, %fd5, %fd136;
-BB125_10:
+BB127_10:
st.param.f64 [func_retval0+0], %fd136;
ret;
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/hops/DnnOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/DnnOp.java b/src/main/java/org/apache/sysml/hops/DnnOp.java
index c4ce466..7cf5061 100644
--- a/src/main/java/org/apache/sysml/hops/DnnOp.java
+++ b/src/main/java/org/apache/sysml/hops/DnnOp.java
@@ -141,6 +141,7 @@ public class DnnOp extends MultiThreadedHop
case UPDATE_EMA:
case INV_VAR:
case BATCH_NORM2D_BACKWARD_DX:
+ case BATCH_NORM2D_BACKWARD_DGAMMA:
{
// GPU-specific operators
setLops(constructDnnLops(ExecType.GPU, inputs));
@@ -181,6 +182,7 @@ public class DnnOp extends MultiThreadedHop
case CHANNEL_SUMS:
case UPDATE_EMA:
return 3;
+ case BATCH_NORM2D_BACKWARD_DGAMMA:
case UPDATE_NESTEROV_X:
return 4;
default:
@@ -538,7 +540,7 @@ public class DnnOp extends MultiThreadedHop
if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST ||
op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
- op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
+ op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX || op == OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA) {
// Same dimension as the first input
MatrixCharacteristics[] mc = memo.getAllInputStats(getInput());
ret[0] = mc[0].rowsKnown() ? mc[0].getRows() : -1;
@@ -755,7 +757,7 @@ public class DnnOp extends MultiThreadedHop
{
if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST ||
op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
- op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
+ op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX || op == OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA) {
// Same dimension as the first input
Hop input1 = getInput().get(0);
setDim1(input1.getDim1());
@@ -873,7 +875,7 @@ public class DnnOp extends MultiThreadedHop
if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS ||
op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.RESHAPE_COLMEANS ||
op == OpOpDnn.UPDATE_EMA_VAR || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
- op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
+ op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX || op == OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA) {
throw new RuntimeException("getDim method should not be invoked for " + op.name());
}
try {
http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index c8356e0..82a6669 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1101,7 +1101,7 @@ public abstract class Hop implements ParseInfo
CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
BIASADD, BIASMULT, BATCH_NORM2D_TEST, CHANNEL_SUMS,
UPDATE_NESTEROV_X, RESHAPE_COLMEANS, UPDATE_EMA_VAR, UPDATE_EMA, INV_VAR,
- BATCH_NORM2D_BACKWARD_DX
+ BATCH_NORM2D_BACKWARD_DX, BATCH_NORM2D_BACKWARD_DGAMMA
}
public enum DataGenMethod {
@@ -1182,6 +1182,7 @@ public abstract class Hop implements ParseInfo
HopsConv2Lops.put(OpOpDnn.UPDATE_EMA, org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_EMA);
HopsConv2Lops.put(OpOpDnn.INV_VAR, org.apache.sysml.lops.DnnTransform.OperationTypes.INV_VAR);
HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_BACKWARD_DX, org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_BACKWARD_DX);
+ HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA, org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_BACKWARD_DGAMMA);
}
protected static final HashMap<Hop.Direction, org.apache.sysml.lops.PartialAggregate.DirectionTypes> HopsDirection2Lops;
http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
index 577adc3..ab40d7b 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
@@ -170,6 +170,25 @@ public class RewriteGPUSpecificOps extends HopRewriteRuleWithPatternMatcher {
return hi;
};
+ // Avoids unnecessary intermediates:
+ // mean = cache_mean
+ // centered = bias_add(X, -mean) # shape (N, C*Hin*Win)
+ // norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win)
+ // # Compute gradients during training
+ // dgamma = util::channel_sums(dout*norm, C, Hin, Win)
+ private static final HopDagPatternMatcher _batchNormDGamma;
+ static {
+ _batchNormDGamma = util_channel_sums(
+ mult( leaf("dout", MATRIX).fitsOnGPU(3),
+ bias_multiply(bias_add(leaf("X", MATRIX), unaryMinus(leaf("ema_mean", MATRIX))),
+ leaf("ema_var", MATRIX))), leaf("C", SCALAR), leaf("HW", SCALAR));
+ }
+ private static final Function<Hop, Hop> _batchNormDGammaReplacer = hi -> {
+ LOG.debug("Applied batchNormDGamma rewrite.");
+ Hop newHop = HopRewriteUtils.createDnnOp(_batchNormDGamma, OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA,
+ "ema_mean", "dout", "X", "ema_var");
+ return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+ };
// Pattern 3:
private static final HopDagPatternMatcher _batchNormTest;
@@ -282,8 +301,9 @@ public class RewriteGPUSpecificOps extends HopRewriteRuleWithPatternMatcher {
if(_rewriters == null) {
ArrayList<HopPatternRewriter> rewriters = new ArrayList<>();
rewriters.add(new HopPatternRewriter("batchNormdX", _batchNormdX, _batchNormdXReplacer));
- rewriters.add(new HopPatternRewriter("batchNormUpdatedVar", _batchNormUpdatedVar, _batchNormUpdatedVarReplacer));
rewriters.add(new HopPatternRewriter("batchNormTest", _batchNormTest, _batchNormTestReplacer));
+ rewriters.add(new HopPatternRewriter("batchNormUpdatedVar", _batchNormUpdatedVar, _batchNormUpdatedVarReplacer));
+ // rewriters.add(new HopPatternRewriter("batchNormDGamma", _batchNormDGamma, _batchNormDGammaReplacer));
rewriters.add(new HopPatternRewriter("channelSums", _channelSums, _channelSumsReplacer));
rewriters.add(new HopPatternRewriter("updateNesterovX", _updateNesterovX, _updateNesterovXReplacer));
rewriters.add(new HopPatternRewriter("reshapeColMeans", _reshapeColMeans, _reshapeColMeansReplacer));
http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/lops/DnnTransform.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/DnnTransform.java b/src/main/java/org/apache/sysml/lops/DnnTransform.java
index 2d2d5f1..3496b5b 100644
--- a/src/main/java/org/apache/sysml/lops/DnnTransform.java
+++ b/src/main/java/org/apache/sysml/lops/DnnTransform.java
@@ -33,7 +33,7 @@ public class DnnTransform extends Lop
CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS, BATCH_NORM2D_TEST,
UPDATE_NESTEROV_X, RESHAPE_COLMEANS, UPDATE_EMA_VAR, UPDATE_EMA, INV_VAR,
- BATCH_NORM2D_BACKWARD_DX
+ BATCH_NORM2D_BACKWARD_DX, BATCH_NORM2D_BACKWARD_DGAMMA
}
private OperationTypes operation;
@@ -174,6 +174,9 @@ public class DnnTransform extends Lop
case UPDATE_NESTEROV_X:
return "update_nesterov_x";
+ case BATCH_NORM2D_BACKWARD_DGAMMA:
+ return "batch_norm2d_bwd_dgamma";
+
case BATCH_NORM2D_TEST:
return "batch_norm2d_test";
@@ -254,7 +257,7 @@ public class DnnTransform extends Lop
@Override
public String getInstructions(String input1, String input2, String input3, String input4, String output) {
- if(operation == OperationTypes.UPDATE_NESTEROV_X) {
+ if(operation == OperationTypes.UPDATE_NESTEROV_X || operation == OperationTypes.BATCH_NORM2D_BACKWARD_DGAMMA) {
StringBuilder sb = new StringBuilder();
sb.append( getExecType() );
http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index 3480504..c8a0e8d 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -66,6 +66,7 @@ public class GPUInstructionParser extends InstructionParser
String2GPUInstructionType.put( "reshape_colmeans", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "inv_var", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "batch_norm2d_bwd_dx", GPUINSTRUCTION_TYPE.Dnn);
+ String2GPUInstructionType.put( "batch_norm2d_bwd_dgamma", GPUINSTRUCTION_TYPE.Dnn);
// Matrix Multiply Operators
String2GPUInstructionType.put( "ba+*", GPUINSTRUCTION_TYPE.AggregateBinary);
http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
index 6094b6c..4ad4155 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
@@ -127,7 +127,7 @@ public class DnnGPUInstruction extends GPUInstruction {
public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String istr,
double intermediateMemoryBudget) throws DMLRuntimeException {
super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
- if( !( opcode.equals("update_nesterov_x")) ) {
+ if( !( opcode.equals("update_nesterov_x") || opcode.equals("batch_norm2d_bwd_dgamma")) ) {
throw new DMLRuntimeException("Incorrect opcode: " + opcode);
}
_input1 = in1;
@@ -339,6 +339,15 @@ public class DnnGPUInstruction extends GPUInstruction {
CPOperand out = new CPOperand(parts[5]);
return new DnnGPUInstruction(in, in2, in3, in4, out, opcode, str, 0);
}
+ else if (opcode.equalsIgnoreCase("batch_norm2d_bwd_dgamma")) {
+ InstructionUtils.checkNumFields(parts, 5);
+ CPOperand in = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand in3 = new CPOperand(parts[3]);
+ CPOperand in4 = new CPOperand(parts[4]);
+ CPOperand out = new CPOperand(parts[5]);
+ return new DnnGPUInstruction(in, in2, in3, in4, out, opcode, str, 0);
+ }
else if (opcode.equalsIgnoreCase("lstm")) {
InstructionUtils.checkNumFields(parts, 8);
CPOperand in1 = new CPOperand(parts[1]);
@@ -586,6 +595,42 @@ public class DnnGPUInstruction extends GPUInstruction {
}
}
+ // "ema_mean", "dout", "X", "ema_var"
+ private void processBatchNorm2dBackwardDGammaInstruction(ExecutionContext ec) {
+ try(GPUDenseInputPointerFetcher fetcher = new GPUDenseInputPointerFetcher(ec, gCtx, instName, _output)) {
+ fetcher.add("ema_mean", _input1).add("dout", _input2).add("X", _input3)
+ .add("ema_var", _input4);
+ MatrixObject ema_mean = fetcher.getInputMatrixObject("ema_mean");
+ MatrixObject dout = fetcher.getInputMatrixObject("dout");
+ long C = ema_mean.getNumRows();
+ long N = dout.getNumRows();
+ long CHW = dout.getNumColumns();
+ fetcher.validateDimensions("ema_mean", C, 1);
+ fetcher.validateDimensions("dout", N, CHW);
+ fetcher.validateDimensions("X", N, CHW);
+ fetcher.validateDimensions("ema_var", C, 1);
+ if(CHW % C != 0) {
+ throw new DMLRuntimeException("Incorrect dimensions: C=" + C + ", CHW=" + CHW);
+ }
+ long HW = CHW / C;
+ Pointer tmp = gCtx.allocate(instName, N*CHW*LibMatrixCUDA.sizeOfDataType);
+ // jcuda.runtime.JCuda.cudaDeviceSynchronize();
+ LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("backward_dgamma_tmp",
+ ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(N*CHW)),
+ fetcher.getInputPointer("ema_mean"),
+ fetcher.getInputPointer("dout"),
+ fetcher.getInputPointer("X"),
+ fetcher.getInputPointer("ema_var"),
+ tmp,
+ // N, C, HW, CHW, NCHW
+ toInt(N), toInt(C), toInt(HW), toInt(CHW), N*CHW);
+
+ LibMatrixCUDA.channelSums(gCtx, instName,
+ tmp, fetcher.getOutputPointer(C, 1), N, C, HW);
+ gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE);
+ }
+ }
+
private static int toInt(long num) throws DMLRuntimeException {
if(num >= Integer.MAX_VALUE || num <= Integer.MIN_VALUE) {
throw new DMLRuntimeException("GPU : Exceeded supported size " + num);
@@ -734,6 +779,10 @@ public class DnnGPUInstruction extends GPUInstruction {
processNesterovUpdateInstruction(ec);
return;
}
+ else if (instOpcode.equalsIgnoreCase("batch_norm2d_bwd_dgamma")) {
+ processBatchNorm2dBackwardDGammaInstruction(ec);
+ return;
+ }
else if (instOpcode.equalsIgnoreCase("update_ema_var")) {
processUpdateEMAVarInstruction(ec);
return;
http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
index 1ab3420..06bd1df 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
@@ -94,10 +94,10 @@ public class GPUDenseInputPointerFetcher implements java.lang.AutoCloseable {
public void validateDimensions(String var, long numRows, long numCols) {
MatrixObject mo = getInputMatrixObject(var);
if(numRows > 0 && mo.getNumRows() != numRows) {
- throw new DMLRuntimeException("Expected number of rows of subgrp_means to be " + numRows + ", but found " + mo.getNumRows());
+ throw new DMLRuntimeException("Expected number of rows of " + var + " to be " + numRows + ", but found " + mo.getNumRows());
}
else if(numCols > 0 && mo.getNumColumns() != numCols) {
- throw new DMLRuntimeException("Expected number of columns of subgrp_means to be " + numCols + ", but found " + mo.getNumColumns());
+ throw new DMLRuntimeException("Expected number of columns of " + var + " to be " + numCols + ", but found " + mo.getNumColumns());
}
}
@Override
http://git-wip-us.apache.org/repos/asf/systemml/blob/3702df7c/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index 46ab3f7..00aa578 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -362,10 +362,25 @@ public class LibMatrixCUDA {
}
Pointer imagePointer = getDensePointer(gCtx, input, instName);
Pointer outputPointer = getDensePointer(gCtx, outputBlock, instName);
-
+ channelSums(gCtx, instName, imagePointer, outputPointer, N, C, HW);
+ }
+
+ /**
+ * Perform channel_sums operations: out = rowSums(matrix(colSums(A), rows=C, cols=HW))
+ *
+ * @param gCtx a valid {@link GPUContext}
+ * @param instName the invoking instruction's name for record {@link Statistics}.
+ * @param imagePointer input image pointer
+ * @param outputPointer output pointer
+ * @param N number of rows
+ * @param C number of channels
+ * @param HW height*width
+ */
+ public static void channelSums(GPUContext gCtx, String instName, Pointer imagePointer, Pointer outputPointer, long N, long C, long HW) {
+ int cols = toInt(C*HW);
// We can replace this with CuDNN tensor reduce
Pointer tmp = gCtx.allocate(instName, cols*sizeOfDataType);
- reduceCol(gCtx, instName, "reduce_col_sum", imagePointer, tmp, N, cols);
+ reduceCol(gCtx, instName, "reduce_col_sum", imagePointer, tmp, toInt(N), cols);
reduceRow(gCtx, instName, "reduce_row_sum", tmp, outputPointer, toInt(C), toInt(HW));
gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE);
}