You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2018/08/10 04:03:47 UTC
systemml git commit: [SYSTEMML-445] Added SGD Nesterov update
operator via rewrite for the GPU backend
Repository: systemml
Updated Branches:
refs/heads/master 5ca8706e9 -> 04bc667f3
[SYSTEMML-445] Added SGD Nesterov update operator via rewrite for the GPU backend
- This leads to 10-15% speedup for ResNet200 with batch size of 32.
- Also, added GPU tests for this operator.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/04bc667f
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/04bc667f
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/04bc667f
Branch: refs/heads/master
Commit: 04bc667f3650d57c0bc9de20e46e7624205cc1e6
Parents: 5ca8706
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Thu Aug 9 21:00:21 2018 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Thu Aug 9 21:00:21 2018 -0700
----------------------------------------------------------------------
src/main/cpp/kernels/SystemML.cu | 23 ++-
src/main/cpp/kernels/SystemML.ptx | 161 +++++++++++++++----
src/main/java/org/apache/sysml/hops/DnnOp.java | 15 +-
src/main/java/org/apache/sysml/hops/Hop.java | 4 +-
.../hops/rewrite/RewriteGPUSpecificOps.java | 61 +++++++
.../org/apache/sysml/lops/DnnTransform.java | 33 +++-
.../instructions/GPUInstructionParser.java | 1 +
.../instructions/gpu/DnnGPUInstruction.java | 56 +++++++
.../gpu/context/GPUMemoryManager.java | 5 +-
.../org/apache/sysml/test/gpu/GPUTests.java | 10 ++
.../org/apache/sysml/test/gpu/SGDUpdate.java | 91 +++++++++++
11 files changed, 419 insertions(+), 41 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index 485b7e2..9ddaaff 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -2248,12 +2248,6 @@ extern "C" __global__ void prepare_lstm_dinput_f(float* smlInput, float* cudnnIn
}
-/**
- * Do an log over all the elements of a matrix
- * @param A the input matrix (of length = size)
- * @param C the pre-allocated output matrix (of length = size)
- * @param size the length of the input and output matrices
- */
template <typename T>
__device__ void colwise_reshape(T *A, T *C, unsigned int size,
unsigned int inRows, unsigned int inCols,
@@ -2278,4 +2272,21 @@ extern "C" __global__ void colwise_reshape_f(float *A, float *C, unsigned int si
unsigned int inRows, unsigned int inCols,
unsigned int outRows, unsigned int outCols) {
colwise_reshape(A, C, size, inRows, inCols, outRows, outCols);
+}
+
+// Performs the operation: out = X - mu*v_prev + (1+mu)*v
+template <typename T>
+__device__ void update_nesterov_x(T *X, T *v, T *v_prev, double mu, T *out, unsigned int size) {
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ if (index < size) {
+ out[index] = X[index] - mu*v_prev[index] + (1+mu)*v[index];
+ }
+}
+
+extern "C" __global__ void update_nesterov_x_d(double *X, double *v, double *v_prev, double mu, double *out, unsigned int size) {
+ update_nesterov_x(X, v, v_prev, mu, out, size);
+}
+
+extern "C" __global__ void update_nesterov_x_f(float *X, float *v, float *v_prev, double mu, float *out, unsigned int size) {
+ update_nesterov_x(X, v, v_prev, mu, out, size);
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/cpp/kernels/SystemML.ptx
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.ptx b/src/main/cpp/kernels/SystemML.ptx
index 93e5e35..8a14876 100644
--- a/src/main/cpp/kernels/SystemML.ptx
+++ b/src/main/cpp/kernels/SystemML.ptx
@@ -12977,12 +12977,119 @@ BB113_2:
ret;
}
+ // .globl update_nesterov_x_d
+.visible .entry update_nesterov_x_d(
+ .param .u64 update_nesterov_x_d_param_0,
+ .param .u64 update_nesterov_x_d_param_1,
+ .param .u64 update_nesterov_x_d_param_2,
+ .param .f64 update_nesterov_x_d_param_3,
+ .param .u64 update_nesterov_x_d_param_4,
+ .param .u32 update_nesterov_x_d_param_5
+)
+{
+ .reg .pred %p<2>;
+ .reg .b32 %r<6>;
+ .reg .f64 %fd<9>;
+ .reg .b64 %rd<14>;
+
+
+ ld.param.u64 %rd1, [update_nesterov_x_d_param_0];
+ ld.param.u64 %rd2, [update_nesterov_x_d_param_1];
+ ld.param.u64 %rd3, [update_nesterov_x_d_param_2];
+ ld.param.f64 %fd1, [update_nesterov_x_d_param_3];
+ ld.param.u64 %rd4, [update_nesterov_x_d_param_4];
+ ld.param.u32 %r2, [update_nesterov_x_d_param_5];
+ mov.u32 %r3, %ctaid.x;
+ mov.u32 %r4, %ntid.x;
+ mov.u32 %r5, %tid.x;
+ mad.lo.s32 %r1, %r4, %r3, %r5;
+ setp.ge.u32 %p1, %r1, %r2;
+ @%p1 bra BB114_2;
+
+ cvta.to.global.u64 %rd5, %rd1;
+ mul.wide.s32 %rd6, %r1, 8;
+ add.s64 %rd7, %rd5, %rd6;
+ cvta.to.global.u64 %rd8, %rd3;
+ add.s64 %rd9, %rd8, %rd6;
+ ld.global.f64 %fd2, [%rd9];
+ mul.f64 %fd3, %fd2, %fd1;
+ ld.global.f64 %fd4, [%rd7];
+ sub.f64 %fd5, %fd4, %fd3;
+ cvta.to.global.u64 %rd10, %rd2;
+ add.s64 %rd11, %rd10, %rd6;
+ ld.global.f64 %fd6, [%rd11];
+ add.f64 %fd7, %fd1, 0d3FF0000000000000;
+ fma.rn.f64 %fd8, %fd7, %fd6, %fd5;
+ cvta.to.global.u64 %rd12, %rd4;
+ add.s64 %rd13, %rd12, %rd6;
+ st.global.f64 [%rd13], %fd8;
+
+BB114_2:
+ ret;
+}
+
+ // .globl update_nesterov_x_f
+.visible .entry update_nesterov_x_f(
+ .param .u64 update_nesterov_x_f_param_0,
+ .param .u64 update_nesterov_x_f_param_1,
+ .param .u64 update_nesterov_x_f_param_2,
+ .param .f64 update_nesterov_x_f_param_3,
+ .param .u64 update_nesterov_x_f_param_4,
+ .param .u32 update_nesterov_x_f_param_5
+)
+{
+ .reg .pred %p<2>;
+ .reg .f32 %f<5>;
+ .reg .b32 %r<6>;
+ .reg .f64 %fd<9>;
+ .reg .b64 %rd<14>;
+
+
+ ld.param.u64 %rd1, [update_nesterov_x_f_param_0];
+ ld.param.u64 %rd2, [update_nesterov_x_f_param_1];
+ ld.param.u64 %rd3, [update_nesterov_x_f_param_2];
+ ld.param.f64 %fd1, [update_nesterov_x_f_param_3];
+ ld.param.u64 %rd4, [update_nesterov_x_f_param_4];
+ ld.param.u32 %r2, [update_nesterov_x_f_param_5];
+ mov.u32 %r3, %ctaid.x;
+ mov.u32 %r4, %ntid.x;
+ mov.u32 %r5, %tid.x;
+ mad.lo.s32 %r1, %r4, %r3, %r5;
+ setp.ge.u32 %p1, %r1, %r2;
+ @%p1 bra BB115_2;
+
+ cvta.to.global.u64 %rd5, %rd1;
+ mul.wide.s32 %rd6, %r1, 4;
+ add.s64 %rd7, %rd5, %rd6;
+ ld.global.f32 %f1, [%rd7];
+ cvt.f64.f32 %fd2, %f1;
+ cvta.to.global.u64 %rd8, %rd3;
+ add.s64 %rd9, %rd8, %rd6;
+ ld.global.f32 %f2, [%rd9];
+ cvt.f64.f32 %fd3, %f2;
+ mul.f64 %fd4, %fd3, %fd1;
+ sub.f64 %fd5, %fd2, %fd4;
+ cvta.to.global.u64 %rd10, %rd2;
+ add.s64 %rd11, %rd10, %rd6;
+ ld.global.f32 %f3, [%rd11];
+ cvt.f64.f32 %fd6, %f3;
+ add.f64 %fd7, %fd1, 0d3FF0000000000000;
+ fma.rn.f64 %fd8, %fd7, %fd6, %fd5;
+ cvt.rn.f32.f64 %f4, %fd8;
+ cvta.to.global.u64 %rd12, %rd4;
+ add.s64 %rd13, %rd12, %rd6;
+ st.global.f32 [%rd13], %f4;
+
+BB115_2:
+ ret;
+}
+
.func (.param .b64 func_retval0) __internal_trig_reduction_slowpathd(
.param .b64 __internal_trig_reduction_slowpathd_param_0,
.param .b64 __internal_trig_reduction_slowpathd_param_1
)
{
- .local .align 8 .b8 __local_depot114[40];
+ .local .align 8 .b8 __local_depot116[40];
.reg .b64 %SP;
.reg .b64 %SPL;
.reg .pred %p<9>;
@@ -12991,7 +13098,7 @@ BB113_2:
.reg .b64 %rd<102>;
- mov.u64 %rd101, __local_depot114;
+ mov.u64 %rd101, __local_depot116;
cvta.local.u64 %SP, %rd101;
ld.param.f64 %fd4, [__internal_trig_reduction_slowpathd_param_0];
ld.param.u64 %rd37, [__internal_trig_reduction_slowpathd_param_1];
@@ -13005,7 +13112,7 @@ BB113_2:
shr.u32 %r3, %r1, 20;
bfe.u32 %r4, %r1, 20, 11;
setp.eq.s32 %p1, %r4, 2047;
- @%p1 bra BB114_13;
+ @%p1 bra BB116_13;
add.s32 %r15, %r4, -1024;
shr.u32 %r16, %r15, 6;
@@ -13018,7 +13125,7 @@ BB113_2:
mov.u64 %rd94, 0;
setp.ge.s32 %p2, %r5, %r6;
mov.u64 %rd93, %rd1;
- @%p2 bra BB114_4;
+ @%p2 bra BB116_4;
mov.b64 %rd41, %fd4;
shl.b64 %rd42, %rd41, 11;
@@ -13035,7 +13142,7 @@ BB113_2:
mov.u64 %rd91, %rd1;
mov.u32 %r39, %r5;
-BB114_3:
+BB116_3:
.pragma "nounroll";
ld.const.u64 %rd47, [%rd89];
// inline asm
@@ -13065,15 +13172,15 @@ BB114_3:
add.s64 %rd93, %rd93, 8;
add.s64 %rd89, %rd89, 8;
setp.lt.s32 %p3, %r39, %r6;
- @%p3 bra BB114_3;
+ @%p3 bra BB116_3;
-BB114_4:
+BB116_4:
st.local.u64 [%rd93], %rd94;
ld.local.u64 %rd95, [%rd1+16];
ld.local.u64 %rd96, [%rd1+24];
and.b32 %r9, %r3, 63;
setp.eq.s32 %p4, %r9, 0;
- @%p4 bra BB114_6;
+ @%p4 bra BB116_6;
mov.u32 %r27, 64;
sub.s32 %r28, %r27, %r9;
@@ -13085,7 +13192,7 @@ BB114_4:
shr.u64 %rd55, %rd54, %r28;
or.b64 %rd95, %rd55, %rd53;
-BB114_6:
+BB116_6:
cvta.to.local.u64 %rd56, %rd37;
shr.u64 %rd57, %rd96, 62;
cvt.u32.u64 %r29, %rd57;
@@ -13102,7 +13209,7 @@ BB114_6:
selp.b32 %r34, %r32, %r33, %p5;
st.local.u32 [%rd56], %r34;
setp.eq.s32 %p6, %r31, 0;
- @%p6 bra BB114_8;
+ @%p6 bra BB116_8;
mov.u64 %rd64, 0;
// inline asm
@@ -13122,10 +13229,10 @@ BB114_6:
// inline asm
xor.b32 %r40, %r40, -2147483648;
-BB114_8:
+BB116_8:
clz.b64 %r41, %rd98;
setp.eq.s32 %p7, %r41, 0;
- @%p7 bra BB114_10;
+ @%p7 bra BB116_10;
shl.b64 %rd67, %rd98, %r41;
mov.u32 %r35, 64;
@@ -13133,7 +13240,7 @@ BB114_8:
shr.u64 %rd68, %rd97, %r36;
or.b64 %rd98, %rd68, %rd67;
-BB114_10:
+BB116_10:
mov.u64 %rd72, -3958705157555305931;
// inline asm
{
@@ -13154,7 +13261,7 @@ BB114_10:
}
// inline asm
setp.lt.s64 %p8, %rd100, 1;
- @%p8 bra BB114_12;
+ @%p8 bra BB116_12;
// inline asm
{
@@ -13173,7 +13280,7 @@ BB114_10:
// inline asm
add.s32 %r41, %r41, 1;
-BB114_12:
+BB116_12:
cvt.u64.u32 %rd79, %r40;
shl.b64 %rd80, %rd79, 32;
mov.u32 %r37, 1022;
@@ -13188,7 +13295,7 @@ BB114_12:
or.b64 %rd88, %rd87, %rd80;
mov.b64 %fd4, %rd88;
-BB114_13:
+BB116_13:
st.param.f64 [func_retval0+0], %fd4;
ret;
}
@@ -13216,7 +13323,7 @@ BB114_13:
}
shr.u32 %r51, %r50, 20;
setp.ne.s32 %p1, %r51, 0;
- @%p1 bra BB115_2;
+ @%p1 bra BB117_2;
mul.f64 %fd14, %fd12, 0d4350000000000000;
{
@@ -13230,13 +13337,13 @@ BB114_13:
shr.u32 %r16, %r50, 20;
add.s32 %r51, %r16, -54;
-BB115_2:
+BB117_2:
add.s32 %r52, %r51, -1023;
and.b32 %r17, %r50, -2146435073;
or.b32 %r18, %r17, 1072693248;
mov.b64 %fd135, {%r49, %r18};
setp.lt.u32 %p2, %r18, 1073127583;
- @%p2 bra BB115_4;
+ @%p2 bra BB117_4;
{
.reg .b32 %temp;
@@ -13250,7 +13357,7 @@ BB115_2:
mov.b64 %fd135, {%r19, %r21};
add.s32 %r52, %r51, -1022;
-BB115_4:
+BB117_4:
add.f64 %fd15, %fd135, 0d3FF0000000000000;
rcp.approx.ftz.f64 %fd16, %fd15;
neg.f64 %fd17, %fd15;
@@ -13413,13 +13520,13 @@ BB115_4:
mov.b32 %f2, %r35;
abs.f32 %f1, %f2;
setp.lt.f32 %p4, %f1, 0f4086232B;
- @%p4 bra BB115_7;
+ @%p4 bra BB117_7;
setp.lt.f64 %p5, %fd4, 0d0000000000000000;
add.f64 %fd129, %fd4, 0d7FF0000000000000;
selp.f64 %fd136, 0d0000000000000000, %fd129, %p5;
setp.geu.f32 %p6, %f1, 0f40874800;
- @%p6 bra BB115_7;
+ @%p6 bra BB117_7;
mov.f64 %fd134, 0d4338000000000000;
mov.f64 %fd133, 0d3FF71547652B82FE;
@@ -13441,26 +13548,26 @@ BB115_4:
mov.b64 %fd131, {%r44, %r43};
mul.f64 %fd136, %fd130, %fd131;
-BB115_7:
+BB117_7:
{
.reg .b32 %temp;
mov.b64 {%temp, %r45}, %fd136;
}
and.b32 %r46, %r45, 2147483647;
setp.ne.s32 %p7, %r46, 2146435072;
- @%p7 bra BB115_9;
+ @%p7 bra BB117_9;
{
.reg .b32 %temp;
mov.b64 {%r47, %temp}, %fd136;
}
setp.eq.s32 %p8, %r47, 0;
- @%p8 bra BB115_10;
+ @%p8 bra BB117_10;
-BB115_9:
+BB117_9:
fma.rn.f64 %fd136, %fd136, %fd5, %fd136;
-BB115_10:
+BB117_10:
st.param.f64 [func_retval0+0], %fd136;
ret;
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/hops/DnnOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/DnnOp.java b/src/main/java/org/apache/sysml/hops/DnnOp.java
index 4e22f59..4ca90f8 100644
--- a/src/main/java/org/apache/sysml/hops/DnnOp.java
+++ b/src/main/java/org/apache/sysml/hops/DnnOp.java
@@ -136,6 +136,7 @@ public class DnnOp extends MultiThreadedHop
}
case BATCH_NORM2D_TEST:
case CHANNEL_SUMS:
+ case UPDATE_NESTEROV_X:
{
if(et == ExecType.GPU) {
setLops(constructDnnLops(et, inputs));
@@ -175,6 +176,8 @@ public class DnnOp extends MultiThreadedHop
return 6;
case CHANNEL_SUMS:
return 3;
+ case UPDATE_NESTEROV_X:
+ return 4;
default:
return 13;
}
@@ -528,7 +531,9 @@ public class DnnOp extends MultiThreadedHop
// [numRows, numCols, NNZ]
long[] ret = new long[3];
- if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST) {
+ if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST ||
+ op == OpOpDnn.UPDATE_NESTEROV_X) {
+ // Same dimension as the first input
MatrixCharacteristics[] mc = memo.getAllInputStats(getInput());
ret[0] = mc[0].rowsKnown() ? mc[0].getRows() : -1;
ret[1] = mc[0].colsKnown() ? mc[0].getCols() : -1;
@@ -734,7 +739,8 @@ public class DnnOp extends MultiThreadedHop
@Override
public void refreshSizeInformation()
{
- if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST) {
+ if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.UPDATE_NESTEROV_X) {
+ // Same dimension as the first input
Hop input1 = getInput().get(0);
setDim1(input1.getDim1());
setDim2(input1.getDim2());
@@ -840,8 +846,9 @@ public class DnnOp extends MultiThreadedHop
* @return either -1 or value associated with the dimString
*/
private long getDim(String dimString) {
- if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS) {
- throw new RuntimeException("getDim method should not be invoked for batch_norm_test, channel_sums, bias_add and bias_multiply");
+ if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS ||
+ op == OpOpDnn.UPDATE_NESTEROV_X) {
+ throw new RuntimeException("getDim method should not be invoked for " + op.name());
}
try {
parseInput();
http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index 6466575..73a58e3 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1099,7 +1099,8 @@ public abstract class Hop implements ParseInfo
public enum OpOpDnn {
MAX_POOL, MAX_POOL_BACKWARD, AVG_POOL, AVG_POOL_BACKWARD,
CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
- BIASADD, BIASMULT, BATCH_NORM2D_TEST, CHANNEL_SUMS
+ BIASADD, BIASMULT, BATCH_NORM2D_TEST, CHANNEL_SUMS,
+ UPDATE_NESTEROV_X
}
public enum DataGenMethod {
@@ -1174,6 +1175,7 @@ public abstract class Hop implements ParseInfo
HopsConv2Lops.put(OpOpDnn.CONV2D_BACKWARD_DATA, org.apache.sysml.lops.DnnTransform.OperationTypes.CONV2D_BACKWARD_DATA);
HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_TEST, org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_TEST);
HopsConv2Lops.put(OpOpDnn.CHANNEL_SUMS, org.apache.sysml.lops.DnnTransform.OperationTypes.CHANNEL_SUMS);
+ HopsConv2Lops.put(OpOpDnn.UPDATE_NESTEROV_X, org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_NESTEROV_X);
}
protected static final HashMap<Hop.Direction, org.apache.sysml.lops.PartialAggregate.DirectionTypes> HopsDirection2Lops;
http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
index 2a1699d..b603aa7 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
@@ -124,6 +124,7 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
}
hi = batchNormTest(hop, hi, i);
hi = channelSums(hop, hi, i);
+ hi = updateNesterovX(hop, hi, i);
if( !descendFirst )
rule_GPUKernels(roots, hi, descendFirst);
@@ -281,6 +282,11 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
&& getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX;
}
+ private static boolean isBinaryMMMinus(Hop h) {
+ return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MINUS
+ && getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.MATRIX;
+ }
+
private static boolean isBinaryMSMult(Hop h, double expectedValue) {
return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT
&& getFirstInput(h).getDataType() == DataType.MATRIX && getSecondInput(h).getDataType() == DataType.SCALAR
@@ -323,6 +329,16 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
&& getSecondInput(h).getDataType() == DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR;
}
+ private static boolean isBinarySMMult(Hop h, double expectedVal) {
+ return h instanceof BinaryOp && ((BinaryOp)h).getOp() == OpOp2.MULT
+ && getSecondInput(h).getDataType() == DataType.MATRIX && getFirstInput(h).getDataType() == DataType.SCALAR
+ && getValue(getFirstInput(h)) == expectedVal;
+ }
+
+ private static double getValue(Hop h) {
+ return OptimizerUtils.rEvalSimpleDoubleExpression(h, new HashMap<>());
+ }
+
/**
* Checks if the "mean" hop is a moving average of mean in batch normalization layer.
*
@@ -704,6 +720,51 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
// ------------------------------------------------------------
/**
+ * Checks for the nesterov_update_x pattern (X = X - mu*v_prev + (1+mu)*v)
+ * and returns a new DnnOp if matched
+ *
+ * @param parent parent of the input
+ * @param hi input to be matched
+ * @param pos position
+ * @return a new DnnOp or hi
+ */
+ private static Hop updateNesterovX(Hop parent, Hop hi, int pos) {
+ if(fitsOnGPU(hi, 4) && isBinaryMMAdd(hi) && isBinaryMMMinus(getFirstInput(hi))
+ && isBinarySMMult(getSecondInput(getFirstInput(hi)))
+ && isBinarySMMult(getSecondInput(hi))) {
+ Hop onePlusMu = getFirstInput(getSecondInput(hi));
+ Hop tmp = getSecondInput(getFirstInput(hi));
+ Hop mu = getFirstInput(tmp);
+ if(isOnePlusMu(onePlusMu, mu)) {
+ Hop v_prev = getSecondInput(tmp);
+ Hop v = getSecondInput(getSecondInput(hi));
+ Hop X = getFirstInput(getFirstInput(hi));
+ if(hasSameDimensions(X, v) && hasSameDimensions(X, v_prev)) {
+ ArrayList<Hop> inHops = new ArrayList<Hop>();
+ inHops.add(X);
+ inHops.add(v);
+ inHops.add(v_prev);
+ inHops.add(mu);
+ LOG.debug("Applied updateNesterovX rewrite.");
+ Hop newHop = new DnnOp(hi.getName(), hi.getDataType(), hi.getValueType(),
+ OpOpDnn.UPDATE_NESTEROV_X, inHops);
+ return HopRewriteUtils.rewireAllParentChildReferences(hi, newHop);
+ }
+ }
+ }
+ return hi;
+ }
+
+ private static boolean hasSameDimensions(Hop x, Hop y) {
+ return x.dimsKnown() && y.dimsKnown() && (x.getDim1() == y.getDim1()) && (x.getDim2() == y.getDim2());
+ }
+
+ private static boolean isOnePlusMu(Hop onePlusMu, Hop mu) {
+ return (isBinarySMMult(onePlusMu, 1.0) && getSecondInput(onePlusMu) == mu) ||
+ getValue(onePlusMu) == getValue(mu) + 1;
+ }
+
+ /**
* Checks for the batch norm (mode="test") pattern using the helper isBatchNormTrainMean and isBatchNormTrainVar
* and returns a new DnnOp if matched
*
http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/lops/DnnTransform.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/DnnTransform.java b/src/main/java/org/apache/sysml/lops/DnnTransform.java
index 6c61d4a..3183b5f 100644
--- a/src/main/java/org/apache/sysml/lops/DnnTransform.java
+++ b/src/main/java/org/apache/sysml/lops/DnnTransform.java
@@ -31,7 +31,8 @@ public class DnnTransform extends Lop
MAX_POOL, MAX_POOL_BACKWARD, AVG_POOL, AVG_POOL_BACKWARD,
RELU_MAX_POOLING, RELU_MAX_POOLING_BACKWARD, RELU_BACKWARD,
CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
- BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS, BATCH_NORM2D_TEST
+ BIAS_ADD, CONV2D_BIAS_ADD, BIAS_MULTIPLY, CHANNEL_SUMS, BATCH_NORM2D_TEST,
+ UPDATE_NESTEROV_X
}
private OperationTypes operation;
@@ -165,6 +166,9 @@ public class DnnTransform extends Lop
case CHANNEL_SUMS:
return "channel_sums";
+
+ case UPDATE_NESTEROV_X:
+ return "update_nesterov_x";
case BATCH_NORM2D_TEST:
return "batch_norm2d_test";
@@ -232,6 +236,33 @@ public class DnnTransform extends Lop
}
@Override
+ public String getInstructions(String input1, String input2, String input3, String input4, String output) {
+ if(operation == OperationTypes.UPDATE_NESTEROV_X) {
+ StringBuilder sb = new StringBuilder();
+ sb.append( getExecType() );
+
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( getOpcode() );
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( getInputs().get(0).prepInputOperand(input1));
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( getInputs().get(1).prepInputOperand(input2));
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( getInputs().get(2).prepInputOperand(input3));
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( getInputs().get(3).prepInputOperand(input4));
+ //output
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( this.prepOutputOperand(output));
+
+ return sb.toString();
+ }
+ else {
+ throw new LopsException("The operation is not supported with three operands:" + operation.name());
+ }
+ }
+
+ @Override
public String getInstructions(String[] inputs, String output) {
StringBuilder sb = new StringBuilder();
appendOpcode(sb);
http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index c90f9f9..f4122d9 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -63,6 +63,7 @@ public class GPUInstructionParser extends InstructionParser
String2GPUInstructionType.put( "batch_norm2d_backward", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "batch_norm2d_test", GPUINSTRUCTION_TYPE.Dnn);
String2GPUInstructionType.put( "batch_norm2d_train", GPUINSTRUCTION_TYPE.Dnn);
+ String2GPUInstructionType.put( "update_nesterov_x", GPUINSTRUCTION_TYPE.Dnn);
// Matrix Multiply Operators
String2GPUInstructionType.put( "ba+*", GPUINSTRUCTION_TYPE.AggregateBinary);
http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
index a36d0fc..8d89032 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java
@@ -124,6 +124,21 @@ public class DnnGPUInstruction extends GPUInstruction {
_intermediateMemoryBudget = intermediateMemoryBudget;
}
+ public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String istr,
+ double intermediateMemoryBudget) throws DMLRuntimeException {
+ super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
+ if( !opcode.equals("update_nesterov_x") ) {
+ throw new DMLRuntimeException("Incorrect opcode: " + opcode);
+ }
+ _input1 = in1;
+ _input2 = in2;
+ _input3 = in3;
+ _input4 = in4;
+ _gputype = GPUINSTRUCTION_TYPE.Dnn;
+ _output = out;
+ _intermediateMemoryBudget = intermediateMemoryBudget;
+ }
+
public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode,
String istr, ArrayList<CPOperand> stride,
ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape,
@@ -298,6 +313,15 @@ public class DnnGPUInstruction extends GPUInstruction {
CPOperand out = new CPOperand(parts[4]);
return new DnnGPUInstruction(in, in2, in3, out, opcode, str, 0);
}
+ else if (opcode.equalsIgnoreCase("update_nesterov_x")) {
+ InstructionUtils.checkNumFields(parts, 5);
+ CPOperand in = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand in3 = new CPOperand(parts[3]);
+ CPOperand in4 = new CPOperand(parts[4]);
+ CPOperand out = new CPOperand(parts[5]);
+ return new DnnGPUInstruction(in, in2, in3, in4, out, opcode, str, 0);
+ }
else if (opcode.equalsIgnoreCase("lstm")) {
InstructionUtils.checkNumFields(parts, 8);
CPOperand in1 = new CPOperand(parts[1]);
@@ -552,6 +576,34 @@ public class DnnGPUInstruction extends GPUInstruction {
ec.releaseMatrixOutputForGPUInstruction(_output.getName());
}
+ private void processNesterovUpdateInstruction(ExecutionContext ec) {
+ GPUStatistics.incrementNoOfExecutedGPUInst();;
+ MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName());
+ MatrixObject v = getMatrixInputForGPUInstruction(ec, _input2.getName());
+ MatrixObject v_prev = getMatrixInputForGPUInstruction(ec, _input3.getName());
+ double mu = (int) ec.getScalarInput(_input4.getName(), _input4.getValueType(), _input4.isLiteral()).getDoubleValue();
+ int rows = LibMatrixCUDA.toInt(input.getNumRows());
+ int cols = LibMatrixCUDA.toInt(input.getNumColumns());
+ MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), rows, cols);
+
+ GPUContext gCtx = ec.getGPUContext(0);
+ String instName = getExtendedOpcode();
+ LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("update_nesterov_x",
+ ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(rows*cols)),
+ LibMatrixCUDA.getDensePointer(gCtx, input, instName),
+ LibMatrixCUDA.getDensePointer(gCtx, v, instName),
+ LibMatrixCUDA.getDensePointer(gCtx, v_prev, instName),
+ mu,
+ LibMatrixCUDA.getDensePointer(gCtx, out, instName),
+ rows*cols);
+
+ // release inputs/outputs
+ ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+ ec.releaseMatrixInputForGPUInstruction(_input2.getName());
+ ec.releaseMatrixInputForGPUInstruction(_input3.getName());
+ ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+ }
+
private static int toInt(long num) throws DMLRuntimeException {
if(num >= Integer.MAX_VALUE || num <= Integer.MIN_VALUE) {
throw new DMLRuntimeException("GPU : Exceeded supported size " + num);
@@ -697,6 +749,10 @@ public class DnnGPUInstruction extends GPUInstruction {
processChannelSumsInstruction(ec);
return;
}
+ else if (instOpcode.equalsIgnoreCase("update_nesterov_x")) {
+ processNesterovUpdateInstruction(ec);
+ return;
+ }
else if (instOpcode.equalsIgnoreCase("lstm")) {
processLstmInstruction(ec);
return;
http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
index 3d7ab2c..5d0e4bc 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
@@ -405,7 +405,8 @@ public class GPUMemoryManager {
allPointers.remove(toFree);
lazyCudaFreeMemoryManager.removeIfPresent(size, toFree);
allocator.free(toFree);
- // JCuda.cudaDeviceSynchronize(); // Force a device synchronize after free-ing the pointer for debugging
+ if(DMLScript.SYNCHRONIZE_GPU)
+ jcuda.runtime.JCuda.cudaDeviceSynchronize(); // Force a device synchronize after free-ing the pointer for debugging
}
else {
throw new RuntimeException("Attempting to free an unaccounted pointer:" + toFree);
@@ -447,7 +448,7 @@ public class GPUMemoryManager {
public void removeGPUObject(GPUObject gpuObj) {
if(LOG.isDebugEnabled())
LOG.debug("Removing the GPU object: " + gpuObj);
- matrixMemoryManager.gpuObjects.removeIf(a -> a.equals(gpuObj));
+ matrixMemoryManager.gpuObjects.remove(gpuObj);
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
index cae2e33..e1ae1ae 100644
--- a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
+++ b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
@@ -311,6 +311,16 @@ public abstract class GPUTests extends AutomatedTestBase {
Set<String> heavyHitterOpCodes = Statistics.getCPHeavyHitterOpCodes();
Assert.assertTrue(heavyHitterOpCodes.contains(heavyHitterOpCode));
}
+
+ /**
+ * asserts that the expected op was executed
+ *
+ * @param heavyHitterOpCode opcode of the heavy hitter for the unary op
+ */
+ protected void assertHeavyHitterNotPresent(String heavyHitterOpCode) {
+ Set<String> heavyHitterOpCodes = Statistics.getCPHeavyHitterOpCodes();
+ Assert.assertTrue(!heavyHitterOpCodes.contains(heavyHitterOpCode));
+ }
/**
* Runs a program on the CPU
http://git-wip-us.apache.org/repos/asf/systemml/blob/04bc667f/src/test/java/org/apache/sysml/test/gpu/SGDUpdate.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/SGDUpdate.java b/src/test/java/org/apache/sysml/test/gpu/SGDUpdate.java
new file mode 100644
index 0000000..c98a74d
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/gpu/SGDUpdate.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.gpu;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Test;
+
+/**
+ * Tests update rewrites for SGD
+ */
+public class SGDUpdate extends GPUTests {
+
+ private final static String TEST_NAME = "SGDUpdateTests";
+ private final int seed = 42;
+
+ @Override
+ public void setUp() {
+ super.setUp();
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_DIR, TEST_NAME);
+ getAndLoadTestConfiguration(TEST_NAME);
+ }
+
+ @Test
+ public void testNesterovRewrite() {
+ String scriptStr = "mu=0.99; output = x - mu*v_prev + (1+mu)*v;" ;
+ int inRows = 10;
+ int inCols = 30;
+ HashMap<String, Object> inputs = new HashMap<>();
+ inputs.put("x", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed));
+ inputs.put("v_prev", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed));
+ inputs.put("v", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed));
+ List<String> outputs = Arrays.asList("output");
+ List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, outputs);
+ List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs);
+ assertHeavyHitterPresent("gpu_update_nesterov_x");
+ assertEqualObjects(outCPU.get(0), outGPU.get(0));
+ }
+
+ @Test
+ public void testNoNesterovRewrite1() {
+ String scriptStr = "mu=0.99; output = x - mu*v_prev + (1+mu)*v;" ;
+ int inRows = 10;
+ int inCols = 30;
+ HashMap<String, Object> inputs = new HashMap<>();
+ inputs.put("x", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed));
+ inputs.put("v_prev", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed));
+ inputs.put("v", generateInputMatrix(spark, inRows, 1, 0, 10, 0.9, seed));
+ List<String> outputs = Arrays.asList("output");
+ List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, outputs);
+ List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs);
+ assertHeavyHitterNotPresent("gpu_update_nesterov_x");
+ assertEqualObjects(outCPU.get(0), outGPU.get(0));
+ }
+
+ @Test
+ public void testNoNesterovRewrite2() {
+ String scriptStr = "mu=0.99; output = x - mu*v_prev + mu*v;" ;
+ int inRows = 10;
+ int inCols = 30;
+ HashMap<String, Object> inputs = new HashMap<>();
+ inputs.put("x", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed));
+ inputs.put("v_prev", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed));
+ inputs.put("v", generateInputMatrix(spark, inRows, inCols, 0, 10, 0.9, seed));
+ List<String> outputs = Arrays.asList("output");
+ List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, outputs);
+ List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, outputs);
+ assertHeavyHitterNotPresent("gpu_update_nesterov_x");
+ assertEqualObjects(outCPU.get(0), outGPU.get(0));
+ }
+}