You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by na...@apache.org on 2017/07/13 22:01:41 UTC

systemml git commit: [SYSTEMML-1713] Added mem estimates for various GPU ops

Repository: systemml
Updated Branches:
  refs/heads/master 4e47b5e10 -> 32ba9cf9f


[SYSTEMML-1713] Added mem estimates for various GPU ops

Closes #553


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/32ba9cf9
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/32ba9cf9
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/32ba9cf9

Branch: refs/heads/master
Commit: 32ba9cf9fdff2aba7432c7a4e51317b6e5bf1a18
Parents: 4e47b5e
Author: Nakul Jindal <na...@gmail.com>
Authored: Thu Jul 13 15:01:11 2017 -0700
Committer: Nakul Jindal <na...@gmail.com>
Committed: Thu Jul 13 15:01:11 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/AggBinaryOp.java |  57 ++++-
 .../java/org/apache/sysml/hops/AggUnaryOp.java  |  44 +++-
 .../java/org/apache/sysml/hops/BinaryOp.java    |  32 ++-
 src/main/java/org/apache/sysml/hops/Hop.java    |   4 +-
 .../java/org/apache/sysml/hops/ReorgOp.java     |   4 +-
 .../java/org/apache/sysml/hops/TernaryOp.java   |  17 +-
 .../java/org/apache/sysml/hops/UnaryOp.java     |  16 +-
 .../instructions/gpu/context/CSRPointer.java    |   6 +-
 .../runtime/matrix/data/LibMatrixCUDA.java      | 214 ++++++++++++-------
 9 files changed, 279 insertions(+), 115 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/32ba9cf9/src/main/java/org/apache/sysml/hops/AggBinaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggBinaryOp.java b/src/main/java/org/apache/sysml/hops/AggBinaryOp.java
index eb83549..9077976 100644
--- a/src/main/java/org/apache/sysml/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggBinaryOp.java
@@ -21,19 +21,19 @@ package org.apache.sysml.hops;
 
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.Hop.MultiThreadedHop;
 import org.apache.sysml.hops.rewrite.HopRewriteUtils;
 import org.apache.sysml.lops.Aggregate;
 import org.apache.sysml.lops.Binary;
 import org.apache.sysml.lops.DataPartition;
 import org.apache.sysml.lops.Group;
-import org.apache.sysml.hops.Hop.MultiThreadedHop;
 import org.apache.sysml.lops.Lop;
 import org.apache.sysml.lops.LopProperties.ExecType;
 import org.apache.sysml.lops.LopsException;
 import org.apache.sysml.lops.MMCJ;
+import org.apache.sysml.lops.MMCJ.MMCJType;
 import org.apache.sysml.lops.MMRJ;
 import org.apache.sysml.lops.MMTSJ;
-import org.apache.sysml.lops.MMCJ.MMCJType;
 import org.apache.sysml.lops.MMTSJ.MMTSJType;
 import org.apache.sysml.lops.MMZip;
 import org.apache.sysml.lops.MapMult;
@@ -343,11 +343,48 @@ public class AggBinaryOp extends Hop implements MultiThreadedHop
 	protected double computeIntermediateMemEstimate( long dim1, long dim2, long nnz )
 	{
 		double ret = 0;
-		
+
+		if (DMLScript.USE_ACCELERATOR) {
+			// In GPU Mode, intermediate memory is only needed in case of one of the matrix blocks is sparse
+			// When sparse block is converted to dense and a dense MM takes place, we need (dim1 * dim2)
+			// When dense block is converted to sparse and a sparse MM takes place, we need (dim1 * dim2 * 2)
+
+			Hop in1 = _input.get(0);
+			Hop in2 = _input.get(1);
+			double in1Sparsity = OptimizerUtils.getSparsity(in1.getDim1(), in1.getDim2(), in1.getNnz());
+			double in2Sparsity = OptimizerUtils.getSparsity(in2.getDim1(), in2.getDim2(), in2.getNnz());
+
+			boolean in1Sparse = in1Sparsity < MatrixBlock.SPARSITY_TURN_POINT;
+			boolean in2Sparse = in2Sparsity < MatrixBlock.SPARSITY_TURN_POINT;
+
+			boolean in1UltraSparse = in1Sparsity < MatrixBlock.ULTRA_SPARSITY_TURN_POINT;
+			boolean in2UltraSparse = in2Sparsity < MatrixBlock.ULTRA_SPARSITY_TURN_POINT;
+
+			// For Matmult X * Y, if X is sparse, Y is dense, X is converted to dense
+			// If X is ultrasparse, Y is converted to sparse
+			if (in1Sparse ^ in2Sparse) { // one sparse, one dense
+				if (in1Sparse) {
+					if (in1UltraSparse) {
+						ret += 2 * OptimizerUtils.estimateSizeExactSparsity(in2.getDim1(), in2.getDim2(), in2.getNnz());
+					} else {
+						ret += OptimizerUtils.estimateSizeExactSparsity(in1.getDim1(), in1.getDim2(), in1.getNnz());
+					}
+				} else if (in2Sparse) {
+					if (in2UltraSparse) {
+						ret += 2 * OptimizerUtils.estimateSizeExactSparsity(in1.getDim1(), in1.getDim2(), in1.getNnz());
+					} else {
+						ret += OptimizerUtils.estimateSizeExactSparsity(in2.getDim1(), in2.getDim2(), in2.getNnz());
+					}
+				}
+
+			}
+
+		}
+
 		//account for potential final dense-sparse transformation (worst-case sparse representation)
 		if( dim2 >= 2 ) //vectors always dense
-			ret = OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, MatrixBlock.SPARSITY_TURN_POINT);
-		
+			ret += OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, MatrixBlock.SPARSITY_TURN_POINT);
+
 		return ret;
 	}
 	
@@ -544,8 +581,8 @@ public class AggBinaryOp extends Hop implements MultiThreadedHop
 		int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
 		
 		ExecType et = ExecType.CP;
-		if(DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR || getMemEstimate() < GPUContextPool
-				.initialGPUMemBudget())) {
+		if (DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR
+				|| getMemEstimate() < Math.min(GPUContextPool.initialGPUMemBudget(), OptimizerUtils.getLocalMemBudget()))) {
 			et = ExecType.GPU;
 		}
 		
@@ -623,9 +660,9 @@ public class AggBinaryOp extends Hop implements MultiThreadedHop
 		throws HopsException, LopsException
 	{	
 		Lop matmultCP = null;
-		
-		if(DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR || getMemEstimate() < GPUContextPool
-				.initialGPUMemBudget())) {
+
+		if (DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR
+				|| getMemEstimate() < Math.min(GPUContextPool.initialGPUMemBudget(), OptimizerUtils.getLocalMemBudget()))) {
 			Hop h1 = getInput().get(0);
 			Hop h2 = getInput().get(1);
 			Lop left; Lop right;

http://git-wip-us.apache.org/repos/asf/systemml/blob/32ba9cf9/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index e2f2a8e..e94aaf3 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -149,8 +149,8 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 				}				
 				else { //general case		
 					int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
-					if(DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR || getMemEstimate() < GPUContextPool
-							.initialGPUMemBudget())) {
+					if (DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR
+							|| getMemEstimate() < Math.min(GPUContextPool.initialGPUMemBudget(), OptimizerUtils.getLocalMemBudget()))) {
 						// Only implemented methods for GPU
 						if ((_op == AggOp.SUM    && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col))
 						 || (_op == AggOp.SUM_SQ && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col))
@@ -328,8 +328,15 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 
 	@Override
 	protected double computeOutputMemEstimate( long dim1, long dim2, long nnz )
-	{		
-		double sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz);
+	{
+		double sparsity = -1;
+		if (DMLScript.USE_ACCELERATOR) {
+			// The GPU version (for the time being) only does dense outputs
+			sparsity = 1.0;
+		} else {
+			sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz);
+		}
+
 		return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
 	}
 	
@@ -351,14 +358,14 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 				break;
 			case SUM:
 			case SUM_SQ:
-				//worst-case correction LASTROW / LASTCOLUMN 
+				//worst-case correction LASTROW / LASTCOLUMN
 				if( _direction == Direction.Col ) //(potentially sparse)
 					val = OptimizerUtils.estimateSizeExactSparsity(1, dim2, sparsity);
 				else if( _direction == Direction.Row ) //(always dense)
 					val = OptimizerUtils.estimateSizeExactSparsity(dim1, 1, 1.0);
 				break;
 			case MEAN:
-				//worst-case correction LASTTWOROWS / LASTTWOCOLUMNS 
+				//worst-case correction LASTTWOROWS / LASTTWOCOLUMNS
 				if( _direction == Direction.Col ) //(potentially sparse)
 					val = OptimizerUtils.estimateSizeExactSparsity(2, dim2, sparsity);
 				else if( _direction == Direction.Row ) //(always dense)
@@ -366,10 +373,31 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 				break;
 			case VAR:
 				//worst-case correction LASTFOURROWS / LASTFOURCOLUMNS
-				if( _direction == Direction.Col ) //(potentially sparse)
+				if (DMLScript.USE_ACCELERATOR) {
+					// The GPU implementation only operates on dense data
+					// It allocates 2 dense blocks to help with these ops:
+					// Assume Y = var(X) Or colVars(X), Or rowVars(X)
+					// 1. Y = mean/rowMeans/colMeans(X)               <-- Y is a scalar or row-vector or col-vector
+					// 2. temp1 = X - Y                               <-- temp1 is a matrix of size(X)
+					// 3. temp2 = temp1 ^ 2                           <-- temp2 is a matrix of size(X)
+					// 4. temp3 = sum/rowSums/colSums(temp2)          <-- temp3 is a scalar or a row-vector or col-vector
+					// 5. Y = temp3 / (size(X) or nrow(X) or ncol(X)) <-- Y is a scalar or a row-vector or col-vector
+
+					long in1dim1 = getInput().get(0).getDim1();
+					long in1dim2 = getInput().get(0).getDim2();
+
+					val = 2 * OptimizerUtils.estimateSize(in1dim1, in1dim2);    // For temp1 & temp2
+					if (_direction == Direction.Col){
+						val += OptimizerUtils.estimateSize(in1dim1, 1);   // For temp3
+					} else if (_direction == Direction.Row){
+						val += OptimizerUtils.estimateSize(1, in1dim2);  // For temp3
+					}
+
+				} else if( _direction == Direction.Col ) { //(potentially sparse)
 					val = OptimizerUtils.estimateSizeExactSparsity(4, dim2, sparsity);
-				else if( _direction == Direction.Row ) //(always dense)
+				} else if( _direction == Direction.Row ) { //(always dense)
 					val = OptimizerUtils.estimateSizeExactSparsity(dim1, 4, 1.0);
+				}
 				break;
 			case MAXINDEX:
 			case MININDEX:

http://git-wip-us.apache.org/repos/asf/systemml/blob/32ba9cf9/src/main/java/org/apache/sysml/hops/BinaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/BinaryOp.java b/src/main/java/org/apache/sysml/hops/BinaryOp.java
index 9155203..2c88a9e 100644
--- a/src/main/java/org/apache/sysml/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/BinaryOp.java
@@ -583,9 +583,9 @@ public class BinaryOp extends Hop
 				ot = Unary.OperationTypes.MULTIPLY2;
 			else //general case
 				ot = HopsOpOp2LopsU.get(op);
-			
-			if(DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR || getMemEstimate() < GPUContextPool
-					.initialGPUMemBudget())
+
+			if (DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR
+					|| getMemEstimate() < Math.min(GPUContextPool.initialGPUMemBudget(), OptimizerUtils.getLocalMemBudget()))
 					&& (op == OpOp2.MULT || op == OpOp2.PLUS || op == OpOp2.MINUS || op == OpOp2.DIV || op == OpOp2.POW
 					|| op == OpOp2.MINUS_NZ || op == OpOp2.MINUS1_MULT || op == OpOp2.MODULUS || op == OpOp2.INTDIV
 					|| op == OpOp2.LESS || op == OpOp2.LESSEQUAL || op == OpOp2.EQUAL || op == OpOp2.NOTEQUAL
@@ -606,8 +606,8 @@ public class BinaryOp extends Hop
 			ExecType et = optFindExecType();
 			if ( et == ExecType.CP ) 
 			{
-				if(DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR || getMemEstimate() < GPUContextPool
-						.initialGPUMemBudget())
+				if(DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR
+						|| getMemEstimate() < Math.min(GPUContextPool.initialGPUMemBudget(), OptimizerUtils.getLocalMemBudget()))
 						&& (op == OpOp2.MULT || op == OpOp2.PLUS || op == OpOp2.MINUS || op == OpOp2.DIV || op == OpOp2.POW
 						|| op == OpOp2.SOLVE || op == OpOp2.MINUS1_MULT || op == OpOp2.MODULUS || op == OpOp2.INTDIV
 						|| op == OpOp2.LESS || op == OpOp2.LESSEQUAL || op == OpOp2.EQUAL || op == OpOp2.NOTEQUAL
@@ -829,10 +829,24 @@ public class BinaryOp extends Hop
 			ret = getInput().get(0).getMemEstimate() * 3; 
 		}
 		else if ( op == OpOp2.SOLVE ) {
-			// x=solve(A,b) relies on QR decomposition of A, which is done using Apache commons-math
-			// matrix of size same as the first input
-			double interOutput = OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), getInput().get(0).getDim2(), 1.0); 
-			return interOutput;
+			if (DMLScript.USE_ACCELERATOR) {
+				// Solve on the GPU takes an awful lot of intermediate space
+				// First the inputs are converted from row-major to column major
+				// Then a workspace and a temporary output (workSize, tauSize) are needed
+				long m = getInput().get(0).getDim1();
+				long n = getInput().get(0).getDim2();
+				long tauSize = OptimizerUtils.estimateSize(m, 1);
+				long workSize = OptimizerUtils.estimateSize(m, n);
+				long AtmpSize = OptimizerUtils.estimateSize(m, n);
+				long BtmpSize = OptimizerUtils.estimateSize(n, 1);
+				return (tauSize + workSize + AtmpSize + BtmpSize);
+			} else {
+				// x=solve(A,b) relies on QR decomposition of A, which is done using Apache commons-math
+				// matrix of size same as the first input
+				double interOutput = OptimizerUtils
+						.estimateSizeExactSparsity(getInput().get(0).getDim1(), getInput().get(0).getDim2(), 1.0);
+				return interOutput;
+			}
 
 		}
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/32ba9cf9/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index 80d33f1..4529d04 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -790,8 +790,8 @@ public abstract class Hop
 	}
 	
 	protected ExecType findGPUExecTypeByMemEstimate(ExecType et) {
-		if(DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR || getMemEstimate() < GPUContextPool
-				.initialGPUMemBudget())) {
+		if (DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR
+				|| getMemEstimate() < Math.min(GPUContextPool.initialGPUMemBudget(), OptimizerUtils.getLocalMemBudget()))) {
 			return ExecType.GPU;
 		}
 		return et;

http://git-wip-us.apache.org/repos/asf/systemml/blob/32ba9cf9/src/main/java/org/apache/sysml/hops/ReorgOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ReorgOp.java b/src/main/java/org/apache/sysml/hops/ReorgOp.java
index 20cd68d..3e27eb3 100644
--- a/src/main/java/org/apache/sysml/hops/ReorgOp.java
+++ b/src/main/java/org/apache/sysml/hops/ReorgOp.java
@@ -151,8 +151,8 @@ public class ReorgOp extends Hop implements MultiThreadedHop
 					setLops(lin); //if input of size 1x1, avoid unnecessary transpose
 				else { //general case
 					int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
-					if(DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR || getMemEstimate() < GPUContextPool
-							.initialGPUMemBudget())) {
+					if (DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR
+							|| getMemEstimate() < Math.min(GPUContextPool.initialGPUMemBudget(), OptimizerUtils.getLocalMemBudget()))) {
 						et = ExecType.GPU;
 					}
 					Transform transform1 = new Transform( lin, 

http://git-wip-us.apache.org/repos/asf/systemml/blob/32ba9cf9/src/main/java/org/apache/sysml/hops/TernaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/TernaryOp.java b/src/main/java/org/apache/sysml/hops/TernaryOp.java
index 5a12dea..e1bef3e 100644
--- a/src/main/java/org/apache/sysml/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/TernaryOp.java
@@ -650,11 +650,12 @@ public class TernaryOp extends Hop
 			throw new HopsException("Unexpected operation: " + _op + ", expecting " + OpOp3.PLUS_MULT + " or" +  OpOp3.MINUS_MULT);
 		
 		ExecType et = null;
-		if(DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR || getMemEstimate() < GPUContextPool
-				.initialGPUMemBudget()) )
+		if (DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR
+				|| getMemEstimate() < Math.min(GPUContextPool.initialGPUMemBudget(), OptimizerUtils.getLocalMemBudget()))) {
 			et = ExecType.GPU;
-		else
+		} else {
 			et = optFindExecType();
+		}
 		PlusMult plusmult = null;
 		
 		if( et == ExecType.CP || et == ExecType.SPARK || et == ExecType.GPU ) {
@@ -727,9 +728,15 @@ public class TernaryOp extends Hop
 				// Output is a vector of length = #of quantiles to be computed, and it is likely to be dense.
 				return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0);
 			case PLUS_MULT:
-			case MINUS_MULT:
-				sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); 
+			case MINUS_MULT: {
+				if (DMLScript.USE_ACCELERATOR) {
+					// For the GPU, the input is converted to dense
+					sparsity = 1.0;
+				} else {
+					sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz);
+				}
 				return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
+			}
 			default:
 				throw new RuntimeException("Memory for operation (" + _op + ") can not be estimated.");
 		}

http://git-wip-us.apache.org/repos/asf/systemml/blob/32ba9cf9/src/main/java/org/apache/sysml/hops/UnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/UnaryOp.java b/src/main/java/org/apache/sysml/hops/UnaryOp.java
index d90fcdf..61ebedf 100644
--- a/src/main/java/org/apache/sysml/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/UnaryOp.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops;
 
 import java.util.ArrayList;
 
+import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.hops.Hop.MultiThreadedHop;
 import org.apache.sysml.lops.Aggregate;
 import org.apache.sysml.lops.Aggregate.OperationTypes;
@@ -536,7 +537,7 @@ public class UnaryOp extends Hop implements MultiThreadedHop
 	{
 		//overwrites default hops behavior
 		super.computeMemEstimate(memo);
-		
+
 		if( _op == Hop.OpOp1.NROW || _op == Hop.OpOp1.NCOL ) //specific case for meta data ops
 		{
 			_memEstimate = OptimizerUtils.INT_SIZE;
@@ -547,8 +548,13 @@ public class UnaryOp extends Hop implements MultiThreadedHop
 
 	@Override
 	protected double computeOutputMemEstimate( long dim1, long dim2, long nnz )
-	{		
-		double sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz);
+	{
+		double sparsity = -1;
+		if (DMLScript.USE_ACCELERATOR) {
+			sparsity = 1.0; // Output is always dense (for now) on the GPU
+		} else {
+			sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz);
+		}
 		return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
 	}
 	
@@ -562,6 +568,10 @@ public class UnaryOp extends Hop implements MultiThreadedHop
 			// getMemEstimate works for both cases of known dims and worst-case stats
 			ret = getInput().get(0).getMemEstimate() * 3; 
 		}
+
+		if (DMLScript.USE_ACCELERATOR) {
+			OptimizerUtils.estimateSize(dim1, dim2); // Intermediate memory required to convert sparse to dense
+		}
 		
 		return ret;
 	}

http://git-wip-us.apache.org/repos/asf/systemml/blob/32ba9cf9/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
index a4bff9a..7244938 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
@@ -54,7 +54,7 @@ public class CSRPointer {
 
 	private static final Log LOG = LogFactory.getLog(CSRPointer.class.getName());
 
-	private static final double ULTRA_SPARSITY_TURN_POINT = 0.0004;
+	private static final double ULTRA_SPARSITY_TURN_POINT = 0.00004;
 	public static cusparseMatDescr matrixDescriptor;
 	/**
 	 * {@link GPUContext} instance to track the GPU to do work on
@@ -242,7 +242,7 @@ public class CSRPointer {
 	 * Estimates the number of non-zero elements from the result of a sparse matrix multiplication C = A * B
 	 * and returns the {@link CSRPointer} to C with the appropriate GPU memory.
 	 *
-	 * @param gCtx   ?
+	 * @param gCtx   a valid {@link GPUContext}
 	 * @param handle a valid {@link cusparseHandle}
 	 * @param A      Sparse Matrix A on GPU
 	 * @param transA 'T' if A is to be transposed, 'N' otherwise
@@ -268,7 +268,7 @@ public class CSRPointer {
 	/**
 	 * Factory method to allocate an empty CSR Sparse matrix on the GPU
 	 *
-	 * @param gCtx ?
+	 * @param gCtx a valid {@link GPUContext}
 	 * @param nnz2 number of non-zeroes
 	 * @param rows number of rows
 	 * @return a {@link CSRPointer} instance that encapsulates the CSR matrix on GPU

http://git-wip-us.apache.org/repos/asf/systemml/blob/32ba9cf9/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index 17f6b22..b8b4f8b 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -355,17 +355,41 @@ public class LibMatrixCUDA {
 			throw new DMLRuntimeException("Error status returned by CuDNN:" + jcuda.jcudnn.cudnnStatus.stringFor(status));
 	}
 
-	public static void conv2dBiasAdd(GPUContext gCtx, String instName, MatrixObject image, MatrixObject bias, MatrixObject filter, MatrixObject outputBlock, int N, int C, int H, int W,
+	/**
+	 * Does a 2D convolution followed by a bias_add
+	 *
+	 * @param gCtx     a valid {@link GPUContext}
+	 * @param instName the invoking instruction's name for record {@link Statistics}.
+	 * @param image    input image matrix object
+	 * @param bias     bias matrix object
+	 * @param filter   filter matrix object
+	 * @param output   output matrix object
+	 * @param N        number of input images
+	 * @param C        number of channels
+	 * @param H        height of each image
+	 * @param W        width of each image
+	 * @param K        number of output "channels"
+	 * @param R        height of filter
+	 * @param S        width of filter
+	 * @param pad_h    padding height
+	 * @param pad_w    padding width
+	 * @param stride_h stride height
+	 * @param stride_w string width
+	 * @param P        output height
+	 * @param Q        output width
+	 * @throws DMLRuntimeException if error
+	 */
+	public static void conv2dBiasAdd(GPUContext gCtx, String instName, MatrixObject image, MatrixObject bias, MatrixObject filter, MatrixObject output, int N, int C, int H, int W,
 			int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q)
 					throws DMLRuntimeException {
 		/*
-		int rows = (int) outputBlock.getNumRows();
-		int cols = (int) outputBlock.getNumColumns();
+		int rows = (int) output.getNumRows();
+		int cols = (int) output.getNumColumns();
 		long size  = rows * cols * Sizeof.DOUBLE;
 
 		Pointer imagePointer = getDensePointer(image, instName);
 		Pointer biasPointer = getDensePointer(bias, instName);
-		Pointer outputPointer = getDensePointer(outputBlock, instName);
+		Pointer outputPointer = getDensePointer(output, instName);
 		Pointer filterPointer = getDensePointer(filter, instName);
 
 		Pointer tmp = allocate(size);
@@ -377,15 +401,15 @@ public class LibMatrixCUDA {
 		if(k1 != bias.getNumColumns() || bias.getNumColumns() != 1 || cols % k1 != 0) {
 			throw new DMLRuntimeException("Incorrect inputs for bias_add: input[" + rows + " X " + cols + "] and bias[" + K + " X " + bias.getNumColumns() + "]");
 		}
-		// biasAdd(instName, outputBlock, bias, outputBlock);
+		// biasAdd(instName, output, bias, output);
 		biasAdd(instName, tmp, biasPointer, outputPointer, rows, cols, (int)k1);
 
 		cudaFreeHelper(tmp);
 		*/
 		LOG.trace("GPU : conv2dBiasAdd" + ", GPUContext=" + gCtx);
-		conv2d(gCtx, instName, image, filter, outputBlock, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
+		conv2d(gCtx, instName, image, filter, output, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
 		//cudaDeviceSynchronize;
-		biasAdd(gCtx, instName, outputBlock, bias, outputBlock);
+		biasAdd(gCtx, instName, output, bias, output);
 	}
 
 	public static void conv2d(GPUContext gCtx, String instName, MatrixObject image, MatrixObject filter, MatrixObject outputBlock, int N, int C, int H, int W,
@@ -398,6 +422,31 @@ public class LibMatrixCUDA {
 		conv2d(gCtx, instName, imagePointer, filterPointer, dstPointer, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
 	}
 
+	/**
+	 * Performs 2D convolution
+	 * Takes up an insignificant amount of intermediate space when CONVOLUTION_PREFERENCE is set to CUDNN_CONVOLUTION_FWD_NO_WORKSPACE
+	 * Intermediate space is required by the filter descriptor and convolution descriptor which are metadata structures and don't scale with the size of the input
+	 *
+	 * @param gCtx     a valid {@link GPUContext}
+	 * @param instName the invoking instruction's name for record {@link Statistics}.
+	 * @param image    the input matrix (or image) allocated on the GPU
+	 * @param filter   the filter allocated on the GPU
+	 * @param output   the output matrix allocated on the GPU
+	 * @param N        number of input images
+	 * @param C        number of channels
+	 * @param H        height of each image
+	 * @param W        width of each image
+	 * @param K        number of output "channels"
+	 * @param R        height of filter
+	 * @param S        width of filter
+	 * @param pad_h    padding height
+	 * @param pad_w    padding width
+	 * @param stride_h stride height
+	 * @param stride_w string width
+	 * @param P        output height
+	 * @param Q        output width
+	 * @throws DMLRuntimeException if error
+	 */
 	public static void conv2d(GPUContext gCtx, String instName, Pointer image, Pointer filter, Pointer output, int N,
 														 int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q)
 					throws DMLRuntimeException {
@@ -1225,12 +1274,15 @@ public class LibMatrixCUDA {
 
 	/**
 	 * Performs tsmm, A %*% A' or A' %*% A, on GPU by exploiting cublasDsyrk(...)
+	 * <p>
+	 * Memory Usage - If dense, input space - rows * cols, no intermediate memory, output - Max(rows*rows, cols*cols)
+	 * If sparse, calls matmult
 	 *
-	 * @param ec execution context
-	 * @param gCtx   a valid {@link GPUContext}
-	 * @param instName the invoking instruction's name for record {@link Statistics}.
-	 * @param left input matrix, as in a tsmm expression like A %*% A' or A' %*% A, we just need to check whether the left one is transposed or not, I named it 'left'
-	 * @param outputName output matrix name
+	 * @param ec               execution context
+	 * @param gCtx             a valid {@link GPUContext}
+	 * @param instName         the invoking instruction's name for record {@link Statistics}.
+	 * @param left             input matrix, as in a tsmm expression like A %*% A' or A' %*% A, we just need to check whether the left one is transposed or not, I named it 'left'
+	 * @param outputName       output matrix name
 	 * @param isLeftTransposed if true, left transposed
 	 * @throws DMLRuntimeException if DMLRuntimeException occurs
 	 */
@@ -1285,9 +1337,10 @@ public class LibMatrixCUDA {
 	 * Used for all version of TSMM where the result is known to be symmetric.
 	 * Hence, we compute only the upper triangular matrix and copy this partial
 	 * result down to lower triangular matrix once.
-	 * @param gCtx   a valid {@link GPUContext}
+	 *
+	 * @param gCtx     a valid {@link GPUContext}
 	 * @param instName instruction name
-	 * @param ret upper triangular matrix
+	 * @param ret      upper triangular matrix
 	 * @throws DMLRuntimeException if DMLRuntimeException occurs
 	 */
 	private static void copyUpperToLowerTriangle(GPUContext gCtx, String instName, MatrixObject ret) throws DMLRuntimeException {
@@ -1319,16 +1372,22 @@ public class LibMatrixCUDA {
 	 * Examines sparsity and shapes and routes call to appropriate method
 	 * from cuBLAS or cuSparse
 	 * C = op(A) x op(B)
-	 * @param ec                    Current {@link ExecutionContext} instance
-	 * @param gCtx                  a valid {@link GPUContext}
-	 * @param instName              name of the invoking instruction to record{@link Statistics}.
-	 * @param left                  Matrix A
-	 * @param right                 Matrix B
-	 * @param outputName            Name of the output matrix C (in code generated after LOP layer)
-	 * @param isLeftTransposed      op for A, transposed or not
-	 * @param isRightTransposed     op for B, tranposed or not
-	 * @return	output of matrix multiply
+	 * <p>
+	 * Memory Requirements -
+	 * Both dense - inputs, output, no intermediate
+	 * Both sparse - inputs, output, no intermediate
+	 * One sparse, one dense - inputs, output, intermediates - (input_dim1 * input_dim2) OR (input_dim1 * input_dim2 + input in sparse format)
+	 *
+	 * @param ec                Current {@link ExecutionContext} instance
+	 * @param gCtx              a valid {@link GPUContext}
+	 * @param instName          name of the invoking instruction to record{@link Statistics}.
+	 * @param left              Matrix A
+	 * @param right             Matrix B
+	 * @param outputName        Name of the output matrix C (in code generated after LOP layer)
+	 * @param isLeftTransposed  op for A, transposed or not
+	 * @param isRightTransposed op for B, tranposed or not
 	 * @throws DMLRuntimeException if DMLRuntimeException occurs
+	 * @return output of matrix multiply
 	 */
 	public static MatrixObject matmult(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject left, MatrixObject right, String outputName,
 																		 boolean isLeftTransposed, boolean isRightTransposed) throws DMLRuntimeException {
@@ -1364,6 +1423,7 @@ public class LibMatrixCUDA {
 	/**
 	 * One of the matrices is sparse, the other dense
 	 * C = op(A) x op(B)
+	 *
 	 * @param gCtx              a valid {@link GPUContext}
 	 * @param instName          the invoking instruction's name for record {@link Statistics}.
 	 * @param output            allocated output object for C on host to which GPU output will be attached
@@ -1400,16 +1460,17 @@ public class LibMatrixCUDA {
 	 * C = op(A) * op(B) where A is dense and B is sparse
 	 * If B is ultrasparse, A is converted to a sparse matrix and {@code sparseSparseMatmult(MatrixObject, int, int, int, int, int, CSRPointer, CSRPointer)} is invoked
 	 * otherwise B is converted to a dense matrix and {@code denseDenseMatmult(Pointer, int, int, int, int, boolean, boolean, Pointer, Pointer)} is invoked.
-	 * @param gCtx   a valid {@link GPUContext}
-	 * @param instName the invoking instruction's name for record {@link Statistics}.
-	 * @param left {@link MatrixObject} of A
-	 * @param right {@link MatrixObject} of B
-	 * @param output {@link MatrixObject} of the output matrix C
-	 * @param isLeftTransposed whether matrix A needs to be transposed
+	 *
+	 * @param gCtx              a valid {@link GPUContext}
+	 * @param instName          the invoking instruction's name for record {@link Statistics}.
+	 * @param left              {@link MatrixObject} of A
+	 * @param right             {@link MatrixObject} of B
+	 * @param output            {@link MatrixObject} of the output matrix C
+	 * @param isLeftTransposed  whether matrix A needs to be transposed
 	 * @param isRightTransposed whether matrix B needs to be transposed
-	 * @param m ?
-	 * @param n ?
-	 * @param k ?
+	 * @param m                 ?
+	 * @param n                 ?
+	 * @param k                 ?
 	 * @throws DMLRuntimeException if DMLRuntimeException occurs
 	 */
 	private static void denseSparseMatmult(GPUContext gCtx, String instName, MatrixObject left, MatrixObject right, MatrixObject output,
@@ -1473,16 +1534,17 @@ public class LibMatrixCUDA {
 	 * * C = op(A) * op(B) where A is sparse and B is dense
 	 * If A is ultrasparse, B is converted to a sparse matrix and {@code sparseSparseMatmult(MatrixObject, int, int, int, int, int, CSRPointer, CSRPointer)} is invoked
 	 * otherwise A is converted to a dense matrix and {@code denseDenseMatmult(Pointer, int, int, int, int, boolean, boolean, Pointer, Pointer)} is invoked.
-	 * @param gCtx   a valid {@link GPUContext}
-	 * @param instName the invoking instruction's name for record {@link Statistics}.
-	 * @param output the output matrix object
-	 * @param left matrix A
-	 * @param right matrix B
-	 * @param isLeftTransposed if A needs to be transposed
+	 *
+	 * @param gCtx              a valid {@link GPUContext}
+	 * @param instName          the invoking instruction's name for record {@link Statistics}.
+	 * @param output            the output matrix object
+	 * @param left              matrix A
+	 * @param right             matrix B
+	 * @param isLeftTransposed  if A needs to be transposed
 	 * @param isRightTransposed if B needs to be transposed
-	 * @param m ?
-	 * @param n ?
-	 * @param k ?
+	 * @param m                 ?
+	 * @param n                 ?
+	 * @param k                 ?
 	 * @throws DMLRuntimeException if DMLRuntimeException occurs
 	 */
 	private static void sparseDenseMatmult(GPUContext gCtx, String instName, MatrixObject output, MatrixObject left, MatrixObject right,
@@ -1553,14 +1615,15 @@ public class LibMatrixCUDA {
 	/**
 	 * C = op(A) x B
 	 * A is a sparse matrix, B is a dense vector
-	 * @param gCtx   a valid {@link GPUContext}
-	 * @param instName the invoking instruction's name for record {@link Statistics}.
-	 * @param output	allocated output on the host, to which the GPU output C will be attached
-	 * @param A			sparse matrix A on the GPU
-	 * @param B_dense	dense matrix/vector B on the GPU
-	 * @param isATranposed	op for A, tranposed or not
-	 * @param m			number of rows in A (not op(A))
-	 * @param k			number of cols in A or number of rows in B (not op(A) or op(B))
+	 *
+	 * @param gCtx         a valid {@link GPUContext}
+	 * @param instName     the invoking instruction's name for record {@link Statistics}.
+	 * @param output       allocated output on the host, to which the GPU output C will be attached
+	 * @param A            sparse matrix A on the GPU
+	 * @param B_dense      dense matrix/vector B on the GPU
+	 * @param isATranposed op for A, tranposed or not
+	 * @param m            number of rows in A (not op(A))
+	 * @param k            number of cols in A or number of rows in B (not op(A) or op(B))
 	 * @throws DMLRuntimeException if DMLRuntimeException occurs
 	 */
 	private static void sparseMatrixDenseVectorMult(GPUContext gCtx, String instName, MatrixObject output, CSRPointer A, Pointer B_dense, boolean isATranposed,
@@ -1585,13 +1648,14 @@ public class LibMatrixCUDA {
 	/**
 	 * Sparse C = Sparse op(A) * Sparse op(B)
 	 * Reroutes call to sparse matrix-vector mult if needed
-	 * @param gCtx   a valid {@link GPUContext}
-	 * @param instName the invoking instruction's name for record {@link Statistics}.
-	 * @param output ?
-	 * @param instName name of the invoking instruction to record{@link Statistics}.
-	 * @param left ?
-	 * @param right ?
-	 * @param isLeftTransposed ?
+	 *
+	 * @param gCtx              a valid {@link GPUContext}
+	 * @param instName          the invoking instruction's name for record {@link Statistics}.
+	 * @param output            ?
+	 * @param instName          name of the invoking instruction to record{@link Statistics}.
+	 * @param left              ?
+	 * @param right             ?
+	 * @param isLeftTransposed  ?
 	 * @param isRightTransposed ?
 	 * @throws DMLRuntimeException if DMLRuntimeException occurs
 	 */
@@ -1622,15 +1686,16 @@ public class LibMatrixCUDA {
 	/**
 	 * Does a sparse matrix-vector multiply.
 	 * C = op(A) x B, A is a sparse matrix, B is a sparse vector with numCols = 1.
-	 * @param gCtx   a valid {@link GPUContext}
-	 * @param instName      the invoking instruction's name for record {@link Statistics}.
-	 * @param output        allocated output object C to which the GPU output matrix will be attached
-	 * @param isATranposed  if A is to be transposed or not (the op in op(A))
-	 * @param m             number of rows in A (not op(A))
-	 * @param n             number of cols in A (not op(A))
-	 * @param k             number of rows in B, (cols in B is assumed to be 1)
-	 * @param A             left sparse matrix on GPU
-	 * @param B             right sparse vector on GPU
+	 *
+	 * @param gCtx         a valid {@link GPUContext}
+	 * @param instName     the invoking instruction's name for record {@link Statistics}.
+	 * @param output       allocated output object C to which the GPU output matrix will be attached
+	 * @param isATranposed if A is to be transposed or not (the op in op(A))
+	 * @param m            number of rows in A (not op(A))
+	 * @param n            number of cols in A (not op(A))
+	 * @param k            number of rows in B, (cols in B is assumed to be 1)
+	 * @param A            left sparse matrix on GPU
+	 * @param B            right sparse vector on GPU
 	 * @throws DMLRuntimeException if DMLRuntimeException occurs
 	 */
 	private static void sparseMatrixVectorMult(GPUContext gCtx, String instName, MatrixObject output, boolean isATranposed, int m, int n, int k,
@@ -1645,6 +1710,7 @@ public class LibMatrixCUDA {
 	/**
 	 * Does a sparse-sparse Matrix multiply
 	 * C = op(A) x op(B), A, B are sparse matrices
+	 *
 	 * @param gCtx              a valid {@link GPUContext}
 	 * @param instName          the invoking instruction's name for record {@link Statistics}.
 	 * @param A                 left sparse matrix on GPU
@@ -1683,6 +1749,7 @@ public class LibMatrixCUDA {
 	/**
 	 * Dense dense matrix multiply
 	 * C = op(A) * op(B), A and B are dense matrices
+	 *
 	 * @param gCtx              a valid {@link GPUContext}
 	 * @param instName          name of the invoking instruction to record{@link Statistics}.
 	 * @param output            output object C on host with GPU data allocated
@@ -1715,6 +1782,7 @@ public class LibMatrixCUDA {
 	 * We do t(B) %*% t(A) to get t(C);
 	 * If we were to calculate t(t(C), we would get the resultant matrix C, but this would be in column-major format.
 	 * What we really want is t(C). This we already have as the result of t(B) %*% t(A).
+	 *
 	 * @param gCtx               a valid {@link GPUContext}
 	 * @param instName           name of the invoking instruction to record{@link Statistics}.
 	 * @param output             output allocated on GPU in column major format
@@ -1809,16 +1877,16 @@ public class LibMatrixCUDA {
 	//****************  UNARY AGGREGATE Functions ************************/
 	//********************************************************************/
 
-
 	/**
 	 * Entry point to perform Unary aggregate operations on the GPU.
 	 * The execution context object is used to allocate memory for the GPU.
-	 * @param ec			Instance of {@link ExecutionContext}, from which the output variable will be allocated
-	 * @param gCtx    a valid {@link GPUContext}
+	 *
+	 * @param ec       Instance of {@link ExecutionContext}, from which the output variable will be allocated
+	 * @param gCtx     a valid {@link GPUContext}
 	 * @param instName name of the invoking instruction to record{@link Statistics}.
-	 * @param in1			input matrix
-	 * @param output	output matrix/scalar name
-	 * @param op			Instance of {@link AggregateUnaryOperator} which encapsulates the direction of reduction/aggregation and the reduction operation.
+	 * @param in1      input matrix
+	 * @param output   output matrix/scalar name
+	 * @param op       Instance of {@link AggregateUnaryOperator} which encapsulates the direction of reduction/aggregation and the reduction operation.
 	 * @throws DMLRuntimeException if {@link DMLRuntimeException} occurs
 	 */
 	public static void unaryAggregate(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in1, String output, AggregateUnaryOperator op)
@@ -1852,7 +1920,7 @@ public class LibMatrixCUDA {
 		IndexFunction indexFn = op.indexFn;
 		AggregateOperator aggOp = op.aggOp;
 
-		// Convert Reduction direction to a number to pass to CUDA kernel
+		// Convert Reduction direction to a number
 		int reductionDirection = -1;
 		if (indexFn instanceof ReduceAll){
 			reductionDirection = REDUCTION_ALL;
@@ -1867,7 +1935,7 @@ public class LibMatrixCUDA {
 		}
 		assert reductionDirection !=-1 : "Internal Error - Incorrect type of reduction direction set for aggregate unary GPU instruction";
 
-		// Convert function type to a number to pass to the CUDA Kernel
+		// Convert function type to a number
 		int opIndex = -1;
 		if (aggOp.increOp.fn instanceof KahanPlus) {
 			opIndex = OP_PLUS;