You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by na...@apache.org on 2017/08/08 20:29:53 UTC

systemml git commit: [MINOR] bug fixes in the GPU backend

Repository: systemml
Updated Branches:
  refs/heads/master 98a9d653d -> 815ca4f2a


[MINOR] bug fixes in the GPU backend

- Each thread is assigned a cuda library handle
- JCudaKernels is also made thread safe
- Removed setting GPUContext to null
- Bug fix in initial gpu budget estimate
- Cuda Kernels use blockId.x and threadId.x only

Closes #607


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/815ca4f2
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/815ca4f2
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/815ca4f2

Branch: refs/heads/master
Commit: 815ca4f2aedcbe491d10a873db99a9b5e6f29226
Parents: 98a9d65
Author: Nakul Jindal <na...@gmail.com>
Authored: Tue Aug 8 13:29:11 2017 -0700
Committer: Nakul Jindal <na...@gmail.com>
Committed: Tue Aug 8 13:29:11 2017 -0700

----------------------------------------------------------------------
 src/main/cpp/kernels/SystemML.cu                |  54 +--
 src/main/cpp/kernels/SystemML.ptx               | 333 +++++++++----------
 .../controlprogram/ParForProgramBlock.java      |   3 -
 .../controlprogram/parfor/LocalParWorker.java   |  12 +-
 .../cp/FunctionCallCPInstruction.java           |   7 -
 .../gpu/context/ExecutionConfig.java            |  26 +-
 .../instructions/gpu/context/GPUContext.java    |  94 +++---
 .../gpu/context/GPUContextPool.java             |   2 +-
 .../instructions/gpu/context/JCudaKernels.java  |   5 +-
 .../org/apache/sysml/test/gpu/GPUTests.java     |  18 +
 .../test/gpu/MatrixMultiplicationOpTest.java    |   1 +
 11 files changed, 303 insertions(+), 252 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/815ca4f2/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index 297269f..dcd64b2 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -35,12 +35,13 @@ nvcc -ptx -arch=sm_30 SystemML.cu
  */
 extern "C"
 __global__ void copy_u2l_dense(double* ret, int dim, int N) {
-	int ix = blockIdx.x * blockDim.x + threadIdx.x;
-	int iy = blockIdx.y * blockDim.y + threadIdx.y;
+	int tid = blockIdx.x * blockDim.x + threadIdx.x;
+	int ix = tid / dim;
+	int iy = tid % dim;
 	int id_dest = iy * dim + ix;
 	if(iy > ix && id_dest < N) {
 		// TODO: Potential to reduce the number of threads by half
-		int id_src = ix * dim + iy;
+		int id_src = tid;
 		ret[id_dest] = ret[id_src];
 	}
 }
@@ -104,8 +105,9 @@ __forceinline__ __device__ double binaryOp(double x, double y, int op) {
 
 extern "C"
 __global__ void relu(double* A,  double* ret, int rlen, int clen) {
-	int ix = blockIdx.x * blockDim.x + threadIdx.x;
-	int iy = blockIdx.y * blockDim.y + threadIdx.y;
+	int tid = blockIdx.x * blockDim.x + threadIdx.x;
+	int ix = tid / clen;
+	int iy = tid % clen;
 	if(ix < rlen && iy < clen) {
 		int index = ix * clen + iy;
 		ret[index] = max(0.0, A[index]);
@@ -115,8 +117,9 @@ __global__ void relu(double* A,  double* ret, int rlen, int clen) {
 // This method computes the backpropagation errors for previous layer of relu operation
 extern "C"
 __global__ void relu_backward(double* X,  double* dout, double* ret, int rlen, int clen) {
-	int ix = blockIdx.x * blockDim.x + threadIdx.x;
-	int iy = blockIdx.y * blockDim.y + threadIdx.y;
+	int tid = blockIdx.x * blockDim.x + threadIdx.x;
+	int ix = tid / clen;
+	int iy = tid % clen;
 	if(ix < rlen && iy < clen) {
 		int index = ix * clen + iy;
 		ret[index] = X[index] > 0 ?  dout[index] : 0;
@@ -129,8 +132,9 @@ __global__ void relu_backward(double* X,  double* dout, double* ret, int rlen, i
 // This operation is often followed by conv2d and hence we have introduced bias_add(input, bias) built-in function
 extern "C"
 __global__ void bias_add(double* input,  double* bias, double* ret, int rlen, int clen, int PQ) {
-	int ix = blockIdx.x * blockDim.x + threadIdx.x;
-	int iy = blockIdx.y * blockDim.y + threadIdx.y;
+	int tid = blockIdx.x * blockDim.x + threadIdx.x;
+	int ix = tid / clen;
+	int iy = tid % clen;
 	if(ix < rlen && iy < clen) {
 		int index = ix * clen + iy;
 		int biasIndex = iy / PQ;
@@ -141,8 +145,9 @@ __global__ void bias_add(double* input,  double* bias, double* ret, int rlen, in
 // Performs the operation "ret <- A + alpha*B", where B is a vector
 extern "C"
 __global__ void daxpy_matrix_vector(double* A,  double* B, double alpha, double* ret, int rlenA, int clenA, int rlenB, int clenB) {
-	int ix = blockIdx.x * blockDim.x + threadIdx.x;
-	int iy = blockIdx.y * blockDim.y + threadIdx.y;
+	int tid = blockIdx.x * blockDim.x + threadIdx.x;
+	int ix = tid / clenA;
+	int iy = tid % clenA;
 	if(ix < rlenA && iy < clenA) {
 		int index = ix * clenA + iy;
 		if(rlenB == 1) {
@@ -157,8 +162,9 @@ __global__ void daxpy_matrix_vector(double* A,  double* B, double alpha, double*
 // Performs similar operation as bias_add except elementwise multiplication instead of add
 extern "C"
 __global__ void bias_multiply(double* input,  double* bias, double* ret, int rlen, int clen, int PQ) {
-	int ix = blockIdx.x * blockDim.x + threadIdx.x;
-	int iy = blockIdx.y * blockDim.y + threadIdx.y;
+	int tid = blockIdx.x * blockDim.x + threadIdx.x;
+	int ix = tid / clen;
+	int iy = tid % clen;
 	if(ix < rlen && iy < clen) {
 		int index = ix * clen + iy;
 		int biasIndex = iy / PQ;
@@ -169,8 +175,9 @@ __global__ void bias_multiply(double* input,  double* bias, double* ret, int rle
 // Compares the value and set
 extern "C"
 __global__ void compare_and_set(double* A,  double* ret, int rlen, int clen, double compareVal, double tol, double ifEqualsVal, double ifLessThanVal, double ifGreaterThanVal) {
-	int ix = blockIdx.x * blockDim.x + threadIdx.x;
-	int iy = blockIdx.y * blockDim.y + threadIdx.y;
+	int tid = blockIdx.x * blockDim.x + threadIdx.x;
+	int ix = tid / clen;
+	int iy = tid % clen;
 	int index = ix * clen + iy;
 	if(ix < rlen && iy < clen) {
 		if(abs(A[index]-compareVal) < tol)
@@ -199,8 +206,9 @@ __global__ void compare_and_set(double* A,  double* ret, int rlen, int clen, dou
 extern "C"
 __global__ void matrix_matrix_cellwise_op(double* A, double* B, double* C,
 	int maxRlen, int maxClen, int vectorAStatus, int vectorBStatus, int op) {
-	int ix = blockIdx.x * blockDim.x + threadIdx.x;
-	int iy = blockIdx.y * blockDim.y + threadIdx.y;
+	int tid = blockIdx.x * blockDim.x + threadIdx.x;
+	int ix = tid / maxClen;
+	int iy = tid % maxClen;
 
 	if(ix < maxRlen && iy < maxClen) {
 		int outIndex = ix * maxClen + iy;
@@ -273,8 +281,10 @@ __global__ void fill(double* A, double scalar, int lenA) {
  */
 extern "C"
 __global__ void cbind(double *A, double *B, double *C, int rowsA, int colsA, int rowsB, int colsB) {
-	int ix = blockIdx.x * blockDim.x + threadIdx.x;
-	int iy = blockIdx.y * blockDim.y + threadIdx.y;
+	int maxClen = max(colsA, colsB);
+	int tid = blockIdx.x * blockDim.x + threadIdx.x;
+	int ix = tid / maxClen;
+	int iy = tid % maxClen;
 
 	int colsC = colsA + colsB;
 	int rowsC = rowsA;
@@ -310,8 +320,10 @@ __global__ void cbind(double *A, double *B, double *C, int rowsA, int colsA, int
  */
 extern "C"
 __global__ void rbind(double *A, double *B, double *C, int rowsA, int colsA, int rowsB, int colsB) {
-	int ix = blockIdx.x * blockDim.x + threadIdx.x;
-	int iy = blockIdx.y * blockDim.y + threadIdx.y;
+	int maxClen = max(colsA, colsB);
+	int tid = blockIdx.x * blockDim.x + threadIdx.x;
+	int ix = tid / maxClen;
+	int iy = tid % maxClen;
 
 	int rowsC = rowsA + rowsB;
 	int colsC = colsA;

http://git-wip-us.apache.org/repos/asf/systemml/blob/815ca4f2/src/main/cpp/kernels/SystemML.ptx
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.ptx b/src/main/cpp/kernels/SystemML.ptx
index 6884d5b..7778317 100644
--- a/src/main/cpp/kernels/SystemML.ptx
+++ b/src/main/cpp/kernels/SystemML.ptx
@@ -34,36 +34,33 @@
 )
 {
 	.reg .pred 	%p<4>;
-	.reg .b32 	%r<13>;
+	.reg .b32 	%r<10>;
 	.reg .f64 	%fd<2>;
 	.reg .b64 	%rd<7>;
 
 
 	ld.param.u64 	%rd1, [copy_u2l_dense_param_0];
-	ld.param.u32 	%r4, [copy_u2l_dense_param_1];
-	ld.param.u32 	%r5, [copy_u2l_dense_param_2];
-	mov.u32 	%r6, %ntid.x;
-	mov.u32 	%r7, %ctaid.x;
-	mov.u32 	%r8, %tid.x;
-	mad.lo.s32 	%r1, %r6, %r7, %r8;
-	mov.u32 	%r9, %ntid.y;
-	mov.u32 	%r10, %ctaid.y;
-	mov.u32 	%r11, %tid.y;
-	mad.lo.s32 	%r2, %r9, %r10, %r11;
-	mad.lo.s32 	%r3, %r2, %r4, %r1;
-	setp.gt.s32	%p1, %r2, %r1;
-	setp.lt.s32	%p2, %r3, %r5;
+	ld.param.u32 	%r3, [copy_u2l_dense_param_1];
+	ld.param.u32 	%r4, [copy_u2l_dense_param_2];
+	mov.u32 	%r5, %ntid.x;
+	mov.u32 	%r6, %ctaid.x;
+	mov.u32 	%r7, %tid.x;
+	mad.lo.s32 	%r1, %r5, %r6, %r7;
+	div.s32 	%r8, %r1, %r3;
+	rem.s32 	%r9, %r1, %r3;
+	mad.lo.s32 	%r2, %r9, %r3, %r8;
+	setp.gt.s32	%p1, %r9, %r8;
+	setp.lt.s32	%p2, %r2, %r4;
 	and.pred  	%p3, %p1, %p2;
 	@!%p3 bra 	BB0_2;
 	bra.uni 	BB0_1;
 
 BB0_1:
 	cvta.to.global.u64 	%rd2, %rd1;
-	mad.lo.s32 	%r12, %r1, %r4, %r2;
-	mul.wide.s32 	%rd3, %r12, 8;
+	mul.wide.s32 	%rd3, %r1, 8;
 	add.s64 	%rd4, %rd2, %rd3;
 	ld.global.f64 	%fd1, [%rd4];
-	mul.wide.s32 	%rd5, %r3, 8;
+	mul.wide.s32 	%rd5, %r2, 8;
 	add.s64 	%rd6, %rd2, %rd5;
 	st.global.f64 	[%rd6], %fd1;
 
@@ -80,7 +77,7 @@ BB0_2:
 )
 {
 	.reg .pred 	%p<4>;
-	.reg .b32 	%r<12>;
+	.reg .b32 	%r<10>;
 	.reg .f64 	%fd<4>;
 	.reg .b64 	%rd<8>;
 
@@ -93,20 +90,18 @@ BB0_2:
 	mov.u32 	%r6, %ntid.x;
 	mov.u32 	%r7, %tid.x;
 	mad.lo.s32 	%r1, %r6, %r5, %r7;
-	mov.u32 	%r8, %ntid.y;
-	mov.u32 	%r9, %ctaid.y;
-	mov.u32 	%r10, %tid.y;
-	mad.lo.s32 	%r2, %r8, %r9, %r10;
-	setp.lt.s32	%p1, %r1, %r4;
-	setp.lt.s32	%p2, %r2, %r3;
+	div.s32 	%r2, %r1, %r3;
+	setp.lt.s32	%p1, %r2, %r4;
+	setp.gt.s32	%p2, %r3, -1;
 	and.pred  	%p3, %p1, %p2;
 	@!%p3 bra 	BB1_2;
 	bra.uni 	BB1_1;
 
 BB1_1:
+	rem.s32 	%r8, %r1, %r3;
 	cvta.to.global.u64 	%rd3, %rd1;
-	mad.lo.s32 	%r11, %r1, %r3, %r2;
-	mul.wide.s32 	%rd4, %r11, 8;
+	mad.lo.s32 	%r9, %r2, %r3, %r8;
+	mul.wide.s32 	%rd4, %r9, 8;
 	add.s64 	%rd5, %rd3, %rd4;
 	ld.global.f64 	%fd1, [%rd5];
 	mov.f64 	%fd2, 0d0000000000000000;
@@ -129,7 +124,7 @@ BB1_2:
 )
 {
 	.reg .pred 	%p<5>;
-	.reg .b32 	%r<12>;
+	.reg .b32 	%r<10>;
 	.reg .f64 	%fd<6>;
 	.reg .b64 	%rd<14>;
 
@@ -143,21 +138,19 @@ BB1_2:
 	mov.u32 	%r6, %ctaid.x;
 	mov.u32 	%r7, %tid.x;
 	mad.lo.s32 	%r1, %r5, %r6, %r7;
-	mov.u32 	%r8, %ntid.y;
-	mov.u32 	%r9, %ctaid.y;
-	mov.u32 	%r10, %tid.y;
-	mad.lo.s32 	%r2, %r8, %r9, %r10;
-	setp.lt.s32	%p1, %r1, %r4;
-	setp.lt.s32	%p2, %r2, %r3;
+	div.s32 	%r2, %r1, %r3;
+	setp.lt.s32	%p1, %r2, %r4;
+	setp.gt.s32	%p2, %r3, -1;
 	and.pred  	%p3, %p1, %p2;
 	@!%p3 bra 	BB2_4;
 	bra.uni 	BB2_1;
 
 BB2_1:
+	rem.s32 	%r8, %r1, %r3;
 	cvta.to.global.u64 	%rd5, %rd2;
-	mad.lo.s32 	%r11, %r1, %r3, %r2;
-	cvt.s64.s32	%rd1, %r11;
-	mul.wide.s32 	%rd6, %r11, 8;
+	mad.lo.s32 	%r9, %r2, %r3, %r8;
+	cvt.s64.s32	%rd1, %r9;
+	mul.wide.s32 	%rd6, %r9, 8;
 	add.s64 	%rd7, %rd5, %rd6;
 	ld.global.f64 	%fd4, [%rd7];
 	mov.f64 	%fd5, 0d0000000000000000;
@@ -190,7 +183,7 @@ BB2_4:
 )
 {
 	.reg .pred 	%p<4>;
-	.reg .b32 	%r<14>;
+	.reg .b32 	%r<12>;
 	.reg .f64 	%fd<4>;
 	.reg .b64 	%rd<12>;
 
@@ -205,24 +198,22 @@ BB2_4:
 	mov.u32 	%r7, %ntid.x;
 	mov.u32 	%r8, %tid.x;
 	mad.lo.s32 	%r1, %r7, %r6, %r8;
-	mov.u32 	%r9, %ntid.y;
-	mov.u32 	%r10, %ctaid.y;
-	mov.u32 	%r11, %tid.y;
-	mad.lo.s32 	%r2, %r9, %r10, %r11;
-	setp.lt.s32	%p1, %r1, %r5;
-	setp.lt.s32	%p2, %r2, %r3;
+	div.s32 	%r2, %r1, %r3;
+	setp.lt.s32	%p1, %r2, %r5;
+	setp.gt.s32	%p2, %r3, -1;
 	and.pred  	%p3, %p1, %p2;
 	@!%p3 bra 	BB3_2;
 	bra.uni 	BB3_1;
 
 BB3_1:
+	rem.s32 	%r9, %r1, %r3;
 	cvta.to.global.u64 	%rd4, %rd1;
-	mad.lo.s32 	%r12, %r1, %r3, %r2;
-	mul.wide.s32 	%rd5, %r12, 8;
+	mad.lo.s32 	%r10, %r2, %r3, %r9;
+	mul.wide.s32 	%rd5, %r10, 8;
 	add.s64 	%rd6, %rd4, %rd5;
-	div.s32 	%r13, %r2, %r4;
+	div.s32 	%r11, %r9, %r4;
 	cvta.to.global.u64 	%rd7, %rd2;
-	mul.wide.s32 	%rd8, %r13, 8;
+	mul.wide.s32 	%rd8, %r11, 8;
 	add.s64 	%rd9, %rd7, %rd8;
 	ld.global.f64 	%fd1, [%rd9];
 	ld.global.f64 	%fd2, [%rd6];
@@ -248,7 +239,7 @@ BB3_2:
 )
 {
 	.reg .pred 	%p<5>;
-	.reg .b32 	%r<13>;
+	.reg .b32 	%r<11>;
 	.reg .f64 	%fd<7>;
 	.reg .b64 	%rd<14>;
 
@@ -264,22 +255,20 @@ BB3_2:
 	mov.u32 	%r6, %ntid.x;
 	mov.u32 	%r7, %ctaid.x;
 	mov.u32 	%r8, %tid.x;
-	mad.lo.s32 	%r1, %r6, %r7, %r8;
-	mov.u32 	%r9, %ntid.y;
-	mov.u32 	%r10, %ctaid.y;
-	mov.u32 	%r11, %tid.y;
-	mad.lo.s32 	%r2, %r9, %r10, %r11;
+	mad.lo.s32 	%r9, %r6, %r7, %r8;
+	div.s32 	%r1, %r9, %r3;
+	rem.s32 	%r2, %r9, %r3;
 	setp.lt.s32	%p1, %r1, %r5;
-	setp.lt.s32	%p2, %r2, %r3;
+	setp.gt.s32	%p2, %r3, -1;
 	and.pred  	%p3, %p1, %p2;
 	@!%p3 bra 	BB4_4;
 	bra.uni 	BB4_1;
 
 BB4_1:
 	cvta.to.global.u64 	%rd6, %rd4;
-	mad.lo.s32 	%r12, %r1, %r3, %r2;
+	mad.lo.s32 	%r10, %r1, %r3, %r2;
 	cvta.to.global.u64 	%rd7, %rd3;
-	mul.wide.s32 	%rd8, %r12, 8;
+	mul.wide.s32 	%rd8, %r10, 8;
 	add.s64 	%rd9, %rd7, %rd8;
 	ld.global.f64 	%fd1, [%rd9];
 	add.s64 	%rd2, %rd6, %rd8;
@@ -317,7 +306,7 @@ BB4_4:
 )
 {
 	.reg .pred 	%p<4>;
-	.reg .b32 	%r<14>;
+	.reg .b32 	%r<12>;
 	.reg .f64 	%fd<4>;
 	.reg .b64 	%rd<12>;
 
@@ -332,24 +321,22 @@ BB4_4:
 	mov.u32 	%r7, %ntid.x;
 	mov.u32 	%r8, %tid.x;
 	mad.lo.s32 	%r1, %r7, %r6, %r8;
-	mov.u32 	%r9, %ntid.y;
-	mov.u32 	%r10, %ctaid.y;
-	mov.u32 	%r11, %tid.y;
-	mad.lo.s32 	%r2, %r9, %r10, %r11;
-	setp.lt.s32	%p1, %r1, %r5;
-	setp.lt.s32	%p2, %r2, %r3;
+	div.s32 	%r2, %r1, %r3;
+	setp.lt.s32	%p1, %r2, %r5;
+	setp.gt.s32	%p2, %r3, -1;
 	and.pred  	%p3, %p1, %p2;
 	@!%p3 bra 	BB5_2;
 	bra.uni 	BB5_1;
 
 BB5_1:
+	rem.s32 	%r9, %r1, %r3;
 	cvta.to.global.u64 	%rd4, %rd1;
-	mad.lo.s32 	%r12, %r1, %r3, %r2;
-	mul.wide.s32 	%rd5, %r12, 8;
+	mad.lo.s32 	%r10, %r2, %r3, %r9;
+	mul.wide.s32 	%rd5, %r10, 8;
 	add.s64 	%rd6, %rd4, %rd5;
-	div.s32 	%r13, %r2, %r4;
+	div.s32 	%r11, %r9, %r4;
 	cvta.to.global.u64 	%rd7, %rd2;
-	mul.wide.s32 	%rd8, %r13, 8;
+	mul.wide.s32 	%rd8, %r11, 8;
 	add.s64 	%rd9, %rd7, %rd8;
 	ld.global.f64 	%fd1, [%rd9];
 	ld.global.f64 	%fd2, [%rd6];
@@ -376,7 +363,7 @@ BB5_2:
 )
 {
 	.reg .pred 	%p<6>;
-	.reg .b32 	%r<12>;
+	.reg .b32 	%r<10>;
 	.reg .f64 	%fd<9>;
 	.reg .b64 	%rd<8>;
 
@@ -394,13 +381,11 @@ BB5_2:
 	mov.u32 	%r5, %ntid.x;
 	mov.u32 	%r6, %tid.x;
 	mad.lo.s32 	%r7, %r5, %r4, %r6;
-	mov.u32 	%r8, %ntid.y;
-	mov.u32 	%r9, %ctaid.y;
-	mov.u32 	%r10, %tid.y;
-	mad.lo.s32 	%r11, %r8, %r9, %r10;
-	mad.lo.s32 	%r1, %r7, %r3, %r11;
-	setp.lt.s32	%p1, %r7, %r2;
-	setp.lt.s32	%p2, %r11, %r3;
+	div.s32 	%r8, %r7, %r3;
+	rem.s32 	%r9, %r7, %r3;
+	mad.lo.s32 	%r1, %r8, %r3, %r9;
+	setp.lt.s32	%p1, %r8, %r2;
+	setp.gt.s32	%p2, %r3, -1;
 	and.pred  	%p3, %p1, %p2;
 	@!%p3 bra 	BB6_6;
 	bra.uni 	BB6_1;
@@ -451,7 +436,7 @@ BB6_6:
 )
 {
 	.reg .pred 	%p<73>;
-	.reg .b32 	%r<68>;
+	.reg .b32 	%r<66>;
 	.reg .f64 	%fd<56>;
 	.reg .b64 	%rd<19>;
 
@@ -467,13 +452,11 @@ BB6_6:
 	mov.u32 	%r15, %ntid.x;
 	mov.u32 	%r16, %ctaid.x;
 	mov.u32 	%r17, %tid.x;
-	mad.lo.s32 	%r1, %r15, %r16, %r17;
-	mov.u32 	%r18, %ntid.y;
-	mov.u32 	%r19, %ctaid.y;
-	mov.u32 	%r20, %tid.y;
-	mad.lo.s32 	%r2, %r18, %r19, %r20;
+	mad.lo.s32 	%r18, %r15, %r16, %r17;
+	div.s32 	%r1, %r18, %r10;
+	rem.s32 	%r2, %r18, %r10;
 	setp.lt.s32	%p2, %r1, %r14;
-	setp.lt.s32	%p3, %r2, %r10;
+	setp.gt.s32	%p3, %r10, -1;
 	and.pred  	%p4, %p2, %p3;
 	@!%p4 bra 	BB7_77;
 	bra.uni 	BB7_1;
@@ -481,34 +464,34 @@ BB6_6:
 BB7_1:
 	mad.lo.s32 	%r3, %r1, %r10, %r2;
 	setp.eq.s32	%p5, %r11, 1;
-	mov.u32 	%r66, %r1;
+	mov.u32 	%r64, %r1;
 	@%p5 bra 	BB7_5;
 
 	setp.ne.s32	%p6, %r11, 2;
-	mov.u32 	%r67, %r3;
+	mov.u32 	%r65, %r3;
 	@%p6 bra 	BB7_4;
 
-	mov.u32 	%r67, %r2;
+	mov.u32 	%r65, %r2;
 
 BB7_4:
-	mov.u32 	%r61, %r67;
-	mov.u32 	%r4, %r61;
-	mov.u32 	%r66, %r4;
+	mov.u32 	%r59, %r65;
+	mov.u32 	%r4, %r59;
+	mov.u32 	%r64, %r4;
 
 BB7_5:
-	mov.u32 	%r5, %r66;
+	mov.u32 	%r5, %r64;
 	setp.eq.s32	%p7, %r12, 1;
-	mov.u32 	%r64, %r1;
+	mov.u32 	%r62, %r1;
 	@%p7 bra 	BB7_9;
 
 	setp.ne.s32	%p8, %r12, 2;
-	mov.u32 	%r65, %r3;
+	mov.u32 	%r63, %r3;
 	@%p8 bra 	BB7_8;
 
-	mov.u32 	%r65, %r2;
+	mov.u32 	%r63, %r2;
 
 BB7_8:
-	mov.u32 	%r64, %r65;
+	mov.u32 	%r62, %r63;
 
 BB7_9:
 	cvta.to.global.u64 	%rd5, %rd3;
@@ -516,7 +499,7 @@ BB7_9:
 	mul.wide.s32 	%rd7, %r5, 8;
 	add.s64 	%rd8, %rd6, %rd7;
 	ld.global.f64 	%fd1, [%rd8];
-	mul.wide.s32 	%rd9, %r64, 8;
+	mul.wide.s32 	%rd9, %r62, 8;
 	add.s64 	%rd10, %rd5, %rd9;
 	ld.global.f64 	%fd2, [%rd10];
 	mov.f64 	%fd55, 0d7FEFFFFFFFFFFFFF;
@@ -570,10 +553,10 @@ BB7_58:
 	.reg .b32 %temp; 
 	mov.b64 	{%temp, %r9}, %fd2;
 	}
-	bfe.u32 	%r33, %r9, 20, 11;
-	add.s32 	%r34, %r33, -1012;
+	bfe.u32 	%r31, %r9, 20, 11;
+	add.s32 	%r32, %r31, -1012;
 	mov.b64 	 %rd15, %fd2;
-	shl.b64 	%rd1, %rd15, %r34;
+	shl.b64 	%rd1, %rd15, %r32;
 	setp.eq.s64	%p53, %rd1, -9223372036854775808;
 	abs.f64 	%fd19, %fd1;
 	// Callseq Start 0
@@ -603,14 +586,14 @@ BB7_58:
 BB7_59:
 	{
 	.reg .b32 %temp; 
-	mov.b64 	{%temp, %r35}, %fd54;
+	mov.b64 	{%temp, %r33}, %fd54;
 	}
-	xor.b32  	%r36, %r35, -2147483648;
+	xor.b32  	%r34, %r33, -2147483648;
 	{
 	.reg .b32 %temp; 
-	mov.b64 	{%r37, %temp}, %fd54;
+	mov.b64 	{%r35, %temp}, %fd54;
 	}
-	mov.b64 	%fd54, {%r37, %r36};
+	mov.b64 	%fd54, {%r35, %r34};
 
 BB7_60:
 	mov.f64 	%fd53, %fd54;
@@ -619,12 +602,12 @@ BB7_60:
 	bra.uni 	BB7_61;
 
 BB7_63:
-	selp.b32	%r38, %r8, 0, %p53;
-	or.b32  	%r39, %r38, 2146435072;
+	selp.b32	%r36, %r8, 0, %p53;
+	or.b32  	%r37, %r36, 2146435072;
 	setp.lt.s32	%p59, %r9, 0;
-	selp.b32	%r40, %r39, %r38, %p59;
-	mov.u32 	%r41, 0;
-	mov.b64 	%fd53, {%r41, %r40};
+	selp.b32	%r38, %r37, %r36, %p59;
+	mov.u32 	%r39, 0;
+	mov.b64 	%fd53, {%r39, %r38};
 	bra.uni 	BB7_64;
 
 BB7_35:
@@ -638,10 +621,10 @@ BB7_35:
 BB7_52:
 	cvt.rni.s64.f64	%rd11, %fd1;
 	cvt.rni.s64.f64	%rd12, %fd2;
-	cvt.u32.u64	%r27, %rd11;
-	cvt.u32.u64	%r28, %rd12;
-	or.b32  	%r29, %r28, %r27;
-	setp.eq.s32	%p45, %r29, 0;
+	cvt.u32.u64	%r25, %rd11;
+	cvt.u32.u64	%r26, %rd12;
+	or.b32  	%r27, %r26, %r25;
+	setp.eq.s32	%p45, %r27, 0;
 	selp.f64	%fd55, 0d0000000000000000, 0d3FF0000000000000, %p45;
 	bra.uni 	BB7_76;
 
@@ -701,17 +684,17 @@ BB7_46:
 
 	{
 	.reg .b32 %temp; 
-	mov.b64 	{%temp, %r24}, %fd55;
+	mov.b64 	{%temp, %r22}, %fd55;
 	}
-	and.b32  	%r25, %r24, 2147483647;
-	setp.ne.s32	%p42, %r25, 2146435072;
+	and.b32  	%r23, %r22, 2147483647;
+	setp.ne.s32	%p42, %r23, 2146435072;
 	@%p42 bra 	BB7_50;
 
 	{
 	.reg .b32 %temp; 
-	mov.b64 	{%r26, %temp}, %fd55;
+	mov.b64 	{%r24, %temp}, %fd55;
 	}
-	setp.eq.s32	%p43, %r26, 0;
+	setp.eq.s32	%p43, %r24, 0;
 	@%p43 bra 	BB7_76;
 
 BB7_50:
@@ -781,10 +764,10 @@ BB7_33:
 BB7_34:
 	cvt.rni.s64.f64	%rd13, %fd1;
 	cvt.rni.s64.f64	%rd14, %fd2;
-	cvt.u32.u64	%r30, %rd13;
-	cvt.u32.u64	%r31, %rd14;
-	and.b32  	%r32, %r31, %r30;
-	setp.eq.s32	%p46, %r32, 0;
+	cvt.u32.u64	%r28, %rd13;
+	cvt.u32.u64	%r29, %rd14;
+	and.b32  	%r30, %r29, %r28;
+	setp.eq.s32	%p46, %r30, 0;
 	selp.f64	%fd55, 0d0000000000000000, 0d3FF0000000000000, %p46;
 	bra.uni 	BB7_76;
 
@@ -820,17 +803,17 @@ BB7_41:
 
 	{
 	.reg .b32 %temp; 
-	mov.b64 	{%temp, %r21}, %fd55;
+	mov.b64 	{%temp, %r19}, %fd55;
 	}
-	and.b32  	%r22, %r21, 2147483647;
-	setp.ne.s32	%p36, %r22, 2146435072;
+	and.b32  	%r20, %r19, 2147483647;
+	setp.ne.s32	%p36, %r20, 2146435072;
 	@%p36 bra 	BB7_45;
 
 	{
 	.reg .b32 %temp; 
-	mov.b64 	{%r23, %temp}, %fd55;
+	mov.b64 	{%r21, %temp}, %fd55;
 	}
-	setp.eq.s32	%p37, %r23, 0;
+	setp.eq.s32	%p37, %r21, 0;
 	@%p37 bra 	BB7_76;
 
 BB7_45:
@@ -850,10 +833,10 @@ BB7_64:
 	add.f64 	%fd26, %fd1, %fd2;
 	{
 	.reg .b32 %temp; 
-	mov.b64 	{%temp, %r42}, %fd26;
+	mov.b64 	{%temp, %r40}, %fd26;
 	}
-	and.b32  	%r43, %r42, 2146435072;
-	setp.ne.s32	%p60, %r43, 2146435072;
+	and.b32  	%r41, %r40, 2146435072;
+	setp.ne.s32	%p60, %r41, 2146435072;
 	mov.f64 	%fd52, %fd25;
 	@%p60 bra 	BB7_73;
 
@@ -867,51 +850,51 @@ BB7_64:
 	mov.f64 	%fd52, %fd51;
 	@%p62 bra 	BB7_73;
 
-	and.b32  	%r44, %r9, 2147483647;
-	setp.ne.s32	%p63, %r44, 2146435072;
+	and.b32  	%r42, %r9, 2147483647;
+	setp.ne.s32	%p63, %r42, 2146435072;
 	@%p63 bra 	BB7_69;
 
 	{
 	.reg .b32 %temp; 
-	mov.b64 	{%r45, %temp}, %fd2;
+	mov.b64 	{%r43, %temp}, %fd2;
 	}
-	setp.eq.s32	%p64, %r45, 0;
+	setp.eq.s32	%p64, %r43, 0;
 	@%p64 bra 	BB7_72;
 
 BB7_69:
-	and.b32  	%r46, %r8, 2147483647;
-	setp.ne.s32	%p65, %r46, 2146435072;
+	and.b32  	%r44, %r8, 2147483647;
+	setp.ne.s32	%p65, %r44, 2146435072;
 	mov.f64 	%fd49, %fd25;
 	mov.f64 	%fd52, %fd49;
 	@%p65 bra 	BB7_73;
 
 	{
 	.reg .b32 %temp; 
-	mov.b64 	{%r47, %temp}, %fd1;
+	mov.b64 	{%r45, %temp}, %fd1;
 	}
-	setp.ne.s32	%p66, %r47, 0;
+	setp.ne.s32	%p66, %r45, 0;
 	mov.f64 	%fd52, %fd25;
 	@%p66 bra 	BB7_73;
 
-	shr.s32 	%r48, %r9, 31;
-	and.b32  	%r49, %r48, -2146435072;
-	add.s32 	%r50, %r49, 2146435072;
-	or.b32  	%r51, %r50, -2147483648;
-	selp.b32	%r52, %r51, %r50, %p1;
-	mov.u32 	%r53, 0;
-	mov.b64 	%fd52, {%r53, %r52};
+	shr.s32 	%r46, %r9, 31;
+	and.b32  	%r47, %r46, -2146435072;
+	add.s32 	%r48, %r47, 2146435072;
+	or.b32  	%r49, %r48, -2147483648;
+	selp.b32	%r50, %r49, %r48, %p1;
+	mov.u32 	%r51, 0;
+	mov.b64 	%fd52, {%r51, %r50};
 	bra.uni 	BB7_73;
 
 BB7_72:
 	setp.gt.f64	%p67, %fd19, 0d3FF0000000000000;
-	selp.b32	%r54, 2146435072, 0, %p67;
-	xor.b32  	%r55, %r54, 2146435072;
+	selp.b32	%r52, 2146435072, 0, %p67;
+	xor.b32  	%r53, %r52, 2146435072;
 	setp.lt.s32	%p68, %r9, 0;
-	selp.b32	%r56, %r55, %r54, %p68;
+	selp.b32	%r54, %r53, %r52, %p68;
 	setp.eq.f64	%p69, %fd1, 0dBFF0000000000000;
-	selp.b32	%r57, 1072693248, %r56, %p69;
-	mov.u32 	%r58, 0;
-	mov.b64 	%fd52, {%r58, %r57};
+	selp.b32	%r55, 1072693248, %r54, %p69;
+	mov.u32 	%r56, 0;
+	mov.b64 	%fd52, {%r56, %r55};
 
 BB7_73:
 	setp.eq.f64	%p70, %fd2, 0d0000000000000000;
@@ -1825,7 +1808,7 @@ BB9_2:
 )
 {
 	.reg .pred 	%p<7>;
-	.reg .b32 	%r<19>;
+	.reg .b32 	%r<18>;
 	.reg .f64 	%fd<3>;
 	.reg .b64 	%rd<15>;
 
@@ -1841,11 +1824,10 @@ BB9_2:
 	mov.u32 	%r8, %ntid.x;
 	mov.u32 	%r9, %ctaid.x;
 	mov.u32 	%r10, %tid.x;
-	mad.lo.s32 	%r1, %r8, %r9, %r10;
-	mov.u32 	%r11, %ntid.y;
-	mov.u32 	%r12, %ctaid.y;
-	mov.u32 	%r13, %tid.y;
-	mad.lo.s32 	%r2, %r11, %r12, %r13;
+	mad.lo.s32 	%r11, %r8, %r9, %r10;
+	max.s32 	%r12, %r4, %r6;
+	div.s32 	%r1, %r11, %r12;
+	rem.s32 	%r2, %r11, %r12;
 	add.s32 	%r3, %r6, %r4;
 	setp.lt.s32	%p1, %r1, %r7;
 	setp.lt.s32	%p2, %r2, %r4;
@@ -1855,12 +1837,12 @@ BB9_2:
 
 BB10_1:
 	cvta.to.global.u64 	%rd5, %rd2;
-	mad.lo.s32 	%r14, %r1, %r4, %r2;
-	mul.wide.s32 	%rd6, %r14, 8;
+	mad.lo.s32 	%r13, %r1, %r4, %r2;
+	mul.wide.s32 	%rd6, %r13, 8;
 	add.s64 	%rd7, %rd5, %rd6;
 	ld.global.f64 	%fd1, [%rd7];
-	mad.lo.s32 	%r15, %r1, %r3, %r2;
-	mul.wide.s32 	%rd8, %r15, 8;
+	mad.lo.s32 	%r14, %r1, %r3, %r2;
+	mul.wide.s32 	%rd8, %r14, 8;
 	add.s64 	%rd9, %rd1, %rd8;
 	st.global.f64 	[%rd9], %fd1;
 
@@ -1873,13 +1855,13 @@ BB10_2:
 
 BB10_3:
 	cvta.to.global.u64 	%rd10, %rd3;
-	mad.lo.s32 	%r16, %r1, %r6, %r2;
-	mul.wide.s32 	%rd11, %r16, 8;
+	mad.lo.s32 	%r15, %r1, %r6, %r2;
+	mul.wide.s32 	%rd11, %r15, 8;
 	add.s64 	%rd12, %rd10, %rd11;
 	ld.global.f64 	%fd2, [%rd12];
-	mad.lo.s32 	%r17, %r1, %r3, %r4;
-	add.s32 	%r18, %r17, %r2;
-	mul.wide.s32 	%rd13, %r18, 8;
+	add.s32 	%r16, %r2, %r4;
+	mad.lo.s32 	%r17, %r1, %r3, %r16;
+	mul.wide.s32 	%rd13, %r17, 8;
 	add.s64 	%rd14, %rd1, %rd13;
 	st.global.f64 	[%rd14], %fd2;
 
@@ -1899,7 +1881,7 @@ BB10_4:
 )
 {
 	.reg .pred 	%p<7>;
-	.reg .b32 	%r<17>;
+	.reg .b32 	%r<16>;
 	.reg .f64 	%fd<3>;
 	.reg .b64 	%rd<14>;
 
@@ -1915,11 +1897,10 @@ BB10_4:
 	mov.u32 	%r7, %ntid.x;
 	mov.u32 	%r8, %ctaid.x;
 	mov.u32 	%r9, %tid.x;
-	mad.lo.s32 	%r1, %r7, %r8, %r9;
-	mov.u32 	%r10, %ntid.y;
-	mov.u32 	%r11, %ctaid.y;
-	mov.u32 	%r12, %tid.y;
-	mad.lo.s32 	%r2, %r10, %r11, %r12;
+	mad.lo.s32 	%r10, %r7, %r8, %r9;
+	max.s32 	%r11, %r4, %r6;
+	div.s32 	%r1, %r10, %r11;
+	rem.s32 	%r2, %r10, %r11;
 	setp.lt.s32	%p1, %r1, %r3;
 	setp.lt.s32	%p2, %r2, %r4;
 	and.pred  	%p3, %p1, %p2;
@@ -1928,8 +1909,8 @@ BB10_4:
 
 BB11_1:
 	cvta.to.global.u64 	%rd5, %rd2;
-	mad.lo.s32 	%r13, %r1, %r4, %r2;
-	mul.wide.s32 	%rd6, %r13, 8;
+	mad.lo.s32 	%r12, %r1, %r4, %r2;
+	mul.wide.s32 	%rd6, %r12, 8;
 	add.s64 	%rd7, %rd5, %rd6;
 	ld.global.f64 	%fd1, [%rd7];
 	add.s64 	%rd8, %rd1, %rd6;
@@ -1944,13 +1925,13 @@ BB11_2:
 
 BB11_3:
 	cvta.to.global.u64 	%rd9, %rd3;
-	mad.lo.s32 	%r14, %r1, %r6, %r2;
-	mul.wide.s32 	%rd10, %r14, 8;
+	mad.lo.s32 	%r13, %r1, %r6, %r2;
+	mul.wide.s32 	%rd10, %r13, 8;
 	add.s64 	%rd11, %rd9, %rd10;
 	ld.global.f64 	%fd2, [%rd11];
-	add.s32 	%r15, %r1, %r3;
-	mad.lo.s32 	%r16, %r15, %r4, %r2;
-	mul.wide.s32 	%rd12, %r16, 8;
+	add.s32 	%r14, %r1, %r3;
+	mad.lo.s32 	%r15, %r14, %r4, %r2;
+	mul.wide.s32 	%rd12, %r15, 8;
 	add.s64 	%rd13, %rd1, %rd12;
 	st.global.f64 	[%rd13], %fd2;
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/815ca4f2/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
index a2d361c..169c3bb 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
@@ -828,9 +828,6 @@ public class ParForProgramBlock extends ForProgramBlock
 			// Frees up the GPUContexts used in the threaded Parfor and sets
 			// the main thread to use the GPUContext
 			if (DMLScript.USE_ACCELERATOR) {
-				for (int i = 0; i < _numThreads; i++) {
-					workers[i].getExecutionContext().setGPUContexts(null);
-				}
 				ec.getGPUContext(0).initializeThread();
 			}
 		}

http://git-wip-us.apache.org/repos/asf/systemml/blob/815ca4f2/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/LocalParWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/LocalParWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/LocalParWorker.java
index 636b1f8..f77c22e 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/LocalParWorker.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/LocalParWorker.java
@@ -25,6 +25,7 @@ import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.conf.CompilerConfig;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.runtime.controlprogram.parfor.stat.Stat;
 import org.apache.sysml.runtime.controlprogram.parfor.stat.StatisticMonitor;
@@ -82,8 +83,15 @@ public class LocalParWorker extends ParWorker implements Runnable
 		}
 
 		// Initialize this GPUContext to this thread
-		if (DMLScript.USE_ACCELERATOR)
-			_ec.getGPUContext(0).initializeThread();
+		if (DMLScript.USE_ACCELERATOR) {
+			try {
+				_ec.getGPUContext(0).initializeThread();
+			} catch(DMLRuntimeException e) {
+				LOG.error("Error executing task because of failure in GPU backend: ",e);
+				LOG.error("Stopping LocalParWorker.");
+				return;
+			}
+		}
 		
 		//setup compiler config for worker thread
 		ConfigurationManager.setLocalConfig(_cconf);

http://git-wip-us.apache.org/repos/asf/systemml/blob/815ca4f2/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java
index 3cd2633..77c48a7 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -169,7 +169,6 @@ public class FunctionCallCPInstruction extends CPInstruction
 		ExecutionContext fn_ec = ExecutionContextFactory.createContext(false, ec.getProgram());
 		if (DMLScript.USE_ACCELERATOR) {
 			fn_ec.setGPUContexts(ec.getGPUContexts());
-			ec.setGPUContexts(null);
 			fn_ec.getGPUContext(0).initializeThread();
 		}
 		fn_ec.setVariables(functionVariables);
@@ -205,12 +204,6 @@ public class FunctionCallCPInstruction extends CPInstruction
 		// Unpin the pinned variables
 		ec.unpinVariables(_boundInputParamNames, pinStatus);
 
-		if (DMLScript.USE_ACCELERATOR) {
-			ec.setGPUContexts(fn_ec.getGPUContexts());
-			fn_ec.setGPUContexts(null);
-			ec.getGPUContext(0).initializeThread();
-		}
-		
 		// add the updated binding for each return variable to the variables in original symbol table
 		for (int i=0; i< fpb.getOutputParams().size(); i++){
 		

http://git-wip-us.apache.org/repos/asf/systemml/blob/815ca4f2/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/ExecutionConfig.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/ExecutionConfig.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/ExecutionConfig.java
index ef000c2..5a0a772 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/ExecutionConfig.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/ExecutionConfig.java
@@ -89,16 +89,34 @@ public class ExecutionConfig {
 	 * @return execution configuration
 	 * @throws DMLRuntimeException if DMLRuntimeException occurs
 	 */
-	public static ExecutionConfig getConfigForSimpleMatrixOperations(int rlen, int clen) throws DMLRuntimeException {
+	public static ExecutionConfig getConfigForMatrixOperations(int rlen, int clen) throws DMLRuntimeException {
 		int deviceNumber = 0;
 		int maxBlockDim = getMaxBlockDim(deviceNumber);
 		int blockDimX = (int) Math.min(maxBlockDim, rlen);
 		int gridDimX = (int) Math.ceil((double) rlen / blockDimX);
 		int blockDimY = (int) Math.min(Math.floor(((double) maxBlockDim) / blockDimX), clen);
 		int gridDimY = (int) Math.ceil((double) clen / blockDimY);
+		if (gridDimY > 65535)
+			throw new DMLRuntimeException("Internal Error: gridDimY must be less than 65535 for all supported CUDA compute capabilites!");
 		return new ExecutionConfig(gridDimX, gridDimY, blockDimX, blockDimY);
 	}
 
+	/**
+	 * Use this for simple vector operations and use following in the kernel
+	 * <code>
+	 * int index = blockIdx.x * blockDim.x + threadIdx.x
+	 * </code>
+	 * <p>
+	 * @param rlen number of rows
+	 * @param clen number of columns
+	 * @return execution configuration
+	 * @throws DMLRuntimeException if DMLRuntimeException occurs
+	 */
+	public static ExecutionConfig getConfigForSimpleMatrixOperations(int rlen, int clen) throws DMLRuntimeException {
+		return getConfigForSimpleVectorOperations(rlen * clen);
+	}
+
+
 	public ExecutionConfig(int gridDimX, int blockDimX) {
 		this.gridDimX = gridDimX;
 		this.blockDimX = blockDimX;
@@ -134,4 +152,10 @@ public class ExecutionConfig {
 		return ret;
 	}
 
+	@Override
+	public String toString() {
+		return "ExecutionConfig{" + "gridDimX=" + gridDimX + ", gridDimY=" + gridDimY + ", gridDimZ=" + gridDimZ
+				+ ", blockDimX=" + blockDimX + ", blockDimY=" + blockDimY + ", blockDimZ=" + blockDimZ
+				+ ", sharedMemBytes=" + sharedMemBytes + '}';
+	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/815ca4f2/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
index b3c19ef..4c0562d 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
@@ -108,27 +108,27 @@ public class GPUContext {
 	/**
 	 * cudnnHandle for Deep Neural Network operations on the GPU
 	 */
-	private cudnnHandle cudnnHandle;
+	private final ThreadLocal<cudnnHandle> cudnnHandle = new ThreadLocal<>();
 	/**
 	 * cublasHandle for BLAS operations on the GPU
 	 */
-	private cublasHandle cublasHandle;
+	private final ThreadLocal<cublasHandle> cublasHandle = new ThreadLocal<>();
 	/**
 	 * cusparseHandle for certain sparse BLAS operations on the GPU
 	 */
-	private cusparseHandle cusparseHandle;
+	private final ThreadLocal<cusparseHandle> cusparseHandle = new ThreadLocal<>();
 	/**
 	 * cusolverDnHandle for invoking solve() function on dense matrices on the GPU
 	 */
-	private cusolverDnHandle cusolverDnHandle;
+	private final ThreadLocal<cusolverDnHandle> cusolverDnHandle = new ThreadLocal<>();
 	/**
 	 * cusolverSpHandle for invoking solve() function on sparse matrices on the GPU
 	 */
-	private cusolverSpHandle cusolverSpHandle;
+	private final ThreadLocal<cusolverSpHandle> cusolverSpHandle = new ThreadLocal<>();
 	/**
 	 * to launch custom CUDA kernel, specific to the active GPU for this GPUContext
 	 */
-	private JCudaKernels kernels;
+	private final ThreadLocal<JCudaKernels> kernels = new ThreadLocal<>();
 
 	protected GPUContext(int deviceNum) throws DMLRuntimeException {
 		this.deviceNum = deviceNum;
@@ -140,28 +140,51 @@ public class GPUContext {
 		long total[] = { 0 };
 		cudaMemGetInfo(free, total);
 
-		long start = System.nanoTime();
-		cudnnHandle = new cudnnHandle();
-		cudnnCreate(cudnnHandle);
-		cublasHandle = new cublasHandle();
-		cublasCreate(cublasHandle);
+		long start = -1;
+		if (DMLScript.STATISTICS)
+			start = System.nanoTime();
+		initializeCudaLibraryHandles();
+
+		if (DMLScript.STATISTICS)
+			GPUStatistics.cudaLibrariesInitTime = System.nanoTime() - start;
+		
+		LOG.info(" GPU memory - Total: " + (total[0] * (1e-6)) + " MB, Available: " + (free[0] * (1e-6)) + " MB on "
+				+ this);
+
+	}
+
+	private void initializeCudaLibraryHandles() throws DMLRuntimeException {
+		if (cudnnHandle.get() == null) {
+			cudnnHandle.set(new cudnnHandle());
+			cudnnCreate(cudnnHandle.get());
+		}
+
+		if (cublasHandle.get() == null) {
+			cublasHandle.set(new cublasHandle());
+			cublasCreate(cublasHandle.get());
+		}
 		// For cublas v2, cublasSetPointerMode tells Cublas whether to expect scalar arguments on device or on host
 		// This applies to arguments like "alpha" in Dgemm, and "y" in Ddot.
 		// cublasSetPointerMode(LibMatrixCUDA.cublasHandle, cublasPointerMode.CUBLAS_POINTER_MODE_DEVICE);
-		cusparseHandle = new cusparseHandle();
-		cusparseCreate(cusparseHandle);
 
-		cusolverDnHandle = new cusolverDnHandle();
-		cusolverDnCreate(cusolverDnHandle);
-		cusolverSpHandle = new cusolverSpHandle();
-		cusolverSpCreate(cusolverSpHandle);
+		if (cusparseHandle.get() == null) {
+			cusparseHandle.set(new cusparseHandle());
+			cusparseCreate(cusparseHandle.get());
+		}
 
-		kernels = new JCudaKernels(deviceNum);
+		if (cusolverDnHandle.get() == null) {
+			cusolverDnHandle.set(new cusolverDnHandle());
+			cusolverDnCreate(cusolverDnHandle.get());
+		}
 
-		GPUStatistics.cudaLibrariesInitTime = System.nanoTime() - start;
-		LOG.info(" GPU memory - Total: " + (total[0] * (1e-6)) + " MB, Available: " + (free[0] * (1e-6)) + " MB on "
-				+ this);
+		if (cusolverSpHandle.get() == null) {
+			cusolverSpHandle.set(new cusolverSpHandle());
+			cusolverSpCreate(cusolverSpHandle.get());
+		}
 
+		if (kernels.get() == null) {
+			kernels.set(new JCudaKernels());
+		}
 	}
 
 	public static int cudaGetDevice() {
@@ -181,8 +204,9 @@ public class GPUContext {
 	 * If in a multi-threaded env like parfor, this method must be called when in the
 	 * appropriate thread
 	 */
-	public void initializeThread() {
+	public void initializeThread() throws DMLRuntimeException {
 		cudaSetDevice(deviceNum);
+		initializeCudaLibraryHandles();
 	}
 
 	/**
@@ -595,27 +619,27 @@ public class GPUContext {
 	}
 
 	public cudnnHandle getCudnnHandle() {
-		return cudnnHandle;
+		return cudnnHandle.get();
 	}
 
 	public cublasHandle getCublasHandle() {
-		return cublasHandle;
+		return cublasHandle.get();
 	}
 
 	public cusparseHandle getCusparseHandle() {
-		return cusparseHandle;
+		return cusparseHandle.get();
 	}
 
 	public cusolverDnHandle getCusolverDnHandle() {
-		return cusolverDnHandle;
+		return cusolverDnHandle.get();
 	}
 
 	public cusolverSpHandle getCusolverSpHandle() {
-		return cusolverSpHandle;
+		return cusolverSpHandle.get();
 	}
 
 	public JCudaKernels getKernels() {
-		return kernels;
+		return kernels.get();
 	}
 
 	/**
@@ -626,15 +650,11 @@ public class GPUContext {
 	public void destroy() throws DMLRuntimeException {
 		LOG.trace("GPU : this context was destroyed, this = " + this.toString());
 		clearMemory();
-		cudnnDestroy(cudnnHandle);
-		cublasDestroy(cublasHandle);
-		cusparseDestroy(cusparseHandle);
-		cusolverDnDestroy(cusolverDnHandle);
-		cusolverSpDestroy(cusolverSpHandle);
-		cudnnHandle = null;
-		cublasHandle = null;
-		cusparseHandle = null;
-
+		cudnnDestroy(cudnnHandle.get());
+		cublasDestroy(cublasHandle.get());
+		cusparseDestroy(cusparseHandle.get());
+		cusolverDnDestroy(cusolverDnHandle.get());
+		cusolverSpDestroy(cusolverSpHandle.get());
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/systemml/blob/815ca4f2/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContextPool.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContextPool.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContextPool.java
index a9b1333..e030180 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContextPool.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContextPool.java
@@ -130,7 +130,7 @@ public class GPUContextPool {
 		// initially available memory is set to the GPU with the lowest memory
 		// This is because at runtime, we wouldn't know which GPU a certain
 		// operation gets scheduled on
-		long minAvailableMemory = Integer.MAX_VALUE;
+		long minAvailableMemory = Long.MAX_VALUE;
 		for (GPUContext gCtx : pool) {
 			gCtx.initializeThread();
 			minAvailableMemory = Math.min(minAvailableMemory, gCtx.getAvailableMemory());

http://git-wip-us.apache.org/repos/asf/systemml/blob/815ca4f2/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaKernels.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaKernels.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaKernels.java
index 246aecc..9cfab2b 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaKernels.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaKernels.java
@@ -49,17 +49,14 @@ public class JCudaKernels {
 	private final static String ptxFileName = "/kernels/SystemML.ptx";
 	private HashMap<String, CUfunction> kernels = new HashMap<String, CUfunction>();
 	private CUmodule module;
-	//	private final int deviceNum;
 
 	/**
 	 * Loads the kernels in the file ptxFileName. Though cubin files are also supported, we will stick with
 	 * ptx file as they are target-independent similar to Java's .class files.
 	 *
-	 * @param deviceNum the device number for which to initiate the driver API
 	 * @throws DMLRuntimeException if DMLRuntimeException occurs
 	 */
-	JCudaKernels(int deviceNum) throws DMLRuntimeException {
-		//		this.deviceNum = deviceNum;
+	JCudaKernels() throws DMLRuntimeException {
 		module = new CUmodule();
 		// Load the kernels specified in the ptxFileName file
 		checkResult(cuModuleLoadDataEx(module, initKernels(ptxFileName), 0, new int[0], Pointer.to(new int[0])));

http://git-wip-us.apache.org/repos/asf/systemml/blob/815ca4f2/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
index 195968a..d40b7a1 100644
--- a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
+++ b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java
@@ -98,6 +98,24 @@ public abstract class GPUTests extends AutomatedTestBase {
 	}
 
 	/**
+	 * Generates an input matrix which is a sequence of integers
+	 * @param spark valid instance of {@link SparkSession}
+	 * @param m number of rows
+	 * @param n number of columns
+	 * @return a matrix with a sequence of integers
+	 */
+	protected Matrix generateIntegerSequenceMatrix(SparkSession spark, int m, int n) {
+		MLContext genMLC = new MLContext(spark);
+		String scriptStr;
+		scriptStr = "temp = seq(1, " + (m*n) + ")" +
+				    "in1 = matrix(temp, rows=" + m + ", cols=" + n + ")";
+		Script generateScript = ScriptFactory.dmlFromString(scriptStr).out("in1");
+		Matrix in1 = genMLC.execute(generateScript).getMatrix("in1");
+		genMLC.close();
+		return in1;
+	}
+
+	/**
 	 * Generates a random input matrix with a given size and sparsity
 	 *
 	 * @param spark    valid instance of {@link SparkSession}

http://git-wip-us.apache.org/repos/asf/systemml/blob/815ca4f2/src/test/java/org/apache/sysml/test/gpu/MatrixMultiplicationOpTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/MatrixMultiplicationOpTest.java b/src/test/java/org/apache/sysml/test/gpu/MatrixMultiplicationOpTest.java
index f7c7851..81bc254 100644
--- a/src/test/java/org/apache/sysml/test/gpu/MatrixMultiplicationOpTest.java
+++ b/src/test/java/org/apache/sysml/test/gpu/MatrixMultiplicationOpTest.java
@@ -153,6 +153,7 @@ public class MatrixMultiplicationOpTest extends GPUTests {
 			for (int j = 0; j < sparsities.length; j++) {
 				int side = sizes[i];
 				double sparsity = sparsities[j];
+				System.out.println("Transpose Self matrix multiply, size = " + side + ", sparsity = " + sparsity);
 				Matrix X = generateInputMatrix(spark, side, side, sparsity, seed);
 				HashMap<String, Object> inputs = new HashMap<>();
 				inputs.put("X", X);