You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by de...@apache.org on 2016/08/11 17:39:06 UTC

incubator-systemml git commit: [SYSTEMML-446][SYSTEMML-727][SYSTEMML-742] maxpool (+bwd), tsmm GPU inst and GPU eviction policy

Repository: incubator-systemml
Updated Branches:
  refs/heads/master 5f4b8bc3f -> 256deb4c4


[SYSTEMML-446][SYSTEMML-727][SYSTEMML-742] maxpool (+bwd), tsmm GPU inst and GPU eviction policy

Closes #206.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/256deb4c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/256deb4c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/256deb4c

Branch: refs/heads/master
Commit: 256deb4c45cb5fe8ab4eadd18dd6d509608f0752
Parents: 5f4b8bc
Author: taasawat <ta...@ece.ubc.ca>
Authored: Thu Aug 11 10:36:26 2016 -0700
Committer: Deron Eriksson <de...@us.ibm.com>
Committed: Thu Aug 11 10:36:26 2016 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/AggBinaryOp.java |  10 +-
 .../org/apache/sysml/hops/ConvolutionOp.java    |  30 ++-
 .../controlprogram/caching/CacheableData.java   |  20 +-
 .../context/ExecutionContext.java               |   4 +-
 .../instructions/GPUInstructionParser.java      |   8 +-
 .../gpu/AggregateBinaryGPUInstruction.java      |   5 +-
 .../gpu/ConvolutionGPUInstruction.java          | 135 +++++++++---
 .../instructions/gpu/GPUInstruction.java        |   2 +-
 .../instructions/gpu/MMTSJGPUInstruction.java   | 123 +++++++++++
 .../instructions/gpu/context/GPUObject.java     | 142 ++++++------
 .../instructions/gpu/context/JCudaObject.java   |  47 +++-
 .../runtime/matrix/data/LibMatrixCUDA.java      | 216 +++++++++++++++++++
 .../sysml/runtime/util/ConvolutionUtils.java    |  73 ++++---
 .../TransposeSelfMatrixMultiplication.java      | 175 +++++++++++++++
 .../matrix/TransposeSelfMatrixMultiplication.R  |  37 ++++
 .../TransposeSelfMatrixMultiplication.dml       |  28 +++
 16 files changed, 901 insertions(+), 154 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/256deb4c/src/main/java/org/apache/sysml/hops/AggBinaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggBinaryOp.java b/src/main/java/org/apache/sysml/hops/AggBinaryOp.java
index ea58ebd..fa08f60 100644
--- a/src/main/java/org/apache/sysml/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggBinaryOp.java
@@ -555,8 +555,16 @@ public class AggBinaryOp extends Hop implements MultiThreadedHop
 		throws HopsException, LopsException
 	{
 		int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
+		
+		ExecType et = ExecType.CP;
+//		if(DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR || getMemEstimate() < OptimizerUtils.GPU_MEMORY_BUDGET)) {
+		//TODO: Fix me. Currently forcing the instruction to GPU if gpu flag is set
+		if(DMLScript.USE_ACCELERATOR) {
+			et = ExecType.GPU;
+		}
+		
 		Lop matmultCP = new MMTSJ(getInput().get(mmtsj.isLeft()?1:0).constructLops(),
-				                 getDataType(), getValueType(), ExecType.CP, mmtsj, k);
+				                 getDataType(), getValueType(), et, mmtsj, k);
 	
 		matmultCP.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
 		setLineNumbers( matmultCP );

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/256deb4c/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
index 3da2cd2..fe277d1 100644
--- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
+++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops;
 
 import java.util.ArrayList;
 
+import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.hops.Hop.MultiThreadedHop;
 import org.apache.sysml.lops.ConvolutionTransform;
@@ -111,8 +112,6 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 			case RESHAPE_COL:
 			case ROTATE180:
 			case COL2IM:
-			case MAX_POOLING:
-			case MAX_POOLING_BACKWARD:
 			{	
 				et = ExecType.CP; // TODO: Since max_backwards and other Convolution Ops only implemented for CP
 				
@@ -127,7 +126,27 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 					throw new HopsException("Unimplemented ConvolutionOp for execution type: " + et.name());
 				}
 				// break;
-			}	
+			}
+			case MAX_POOLING:
+			case MAX_POOLING_BACKWARD:
+			{	
+				//TODO: Fix me. Currently forcing the instruction to GPU if gpu flag is set
+				if(DMLScript.USE_ACCELERATOR) {
+					et = ExecType.GPU;
+					setLops(constructConvolutionLops(et, inputs));
+					break;
+				}
+				else if(et == ExecType.CP) {
+					setLops(constructConvolutionLops(et, inputs));
+					break;
+				}			
+				else {
+					// TODO: Add support for SPARK/MR backends once we are happy with the performance of
+					// single node Lenet script. 
+					throw new HopsException("Unimplemented ConvolutionOp for execution type: " + et.name());
+				}
+				// break;
+			}
 			case DIRECT_CONV2D:
 			case DIRECT_CONV2D_BACKWARD_DATA:
 			case DIRECT_CONV2D_BACKWARD_FILTER:
@@ -385,6 +404,11 @@ public class ConvolutionOp extends Hop  implements MultiThreadedHop
 	protected ExecType optFindExecType() throws HopsException {
 		
 		checkAndSetForcedPlatform();
+		
+		//TODO: Remove this once memEstimate is fixed for these instructions 
+		if((op == ConvOp.MAX_POOLING || op == ConvOp.MAX_POOLING_BACKWARD) && DMLScript.USE_ACCELERATOR) {
+			return ExecType.GPU;
+		}
 	
 		ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
 		

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/256deb4c/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
index c7425c1..d043879 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
@@ -428,10 +428,13 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
 		//get object from cache
 		if( _data == null )
 			getCache();
-		
-		if( _gpuHandle != null )
+			
+		//call acquireHostRead if gpuHandle is set as well as is allocated  
+		if( _gpuHandle != null && _gpuHandle.isAllocated()) {
 			_gpuHandle.acquireHostRead();
-		
+			if( _data == null )
+				getCache();
+		}
 		//read data from HDFS/RDD if required
 		//(probe data for cache_nowrite / jvm_reuse)  
 		if( isEmpty(true) && _data==null ) 
@@ -446,7 +449,7 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
 					//check filename
 					if( _hdfsFileName == null )
 						throw new CacheException("Cannot read matrix for empty filename.");
-
+					
 					//read cacheable data from hdfs
 					_data = readBlobFromHDFS( _hdfsFileName );
 					
@@ -462,7 +465,7 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
 					//mark for initial local write (prevent repeated execution of rdd operations)
 					if( writeStatus.booleanValue() )
 						_requiresLocalWrite = CACHING_WRITE_CACHE_ON_READ;
-					else		
+					else
 						_requiresLocalWrite = true;
 				}
 				
@@ -571,18 +574,19 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
 		if (! isAvailableToModify ())
 			throw new CacheException ("CacheableData not available to modify.");
 		
-		//clear old data 
-		clearData(); 
+		//clear old data
+		clearData();
 		
 		//cache status maintenance
 		acquire (true, false); //no need to load evicted matrix
+		
 		setDirty(true);
 		_isAcquireFromEmpty = false;
 		
 		//set references to new data
 		if (newData == null)
 			throw new CacheException("acquireModify with empty cache block.");
-		_data = newData; 
+		_data = newData;
 		updateStatusPinned(true);
 		
 		if( DMLScript.STATISTICS ){

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/256deb4c/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
index 3e85a76..70a5b4f 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
@@ -309,7 +309,7 @@ public class ExecutionContext
 		throws DMLRuntimeException 
 	{
 		MatrixObject mo = getMatrixObject(varName);
-		mo.getGPUObject().release(false);
+		mo.getGPUObject().releaseInput();
 	}
 	
 	/**
@@ -383,7 +383,7 @@ public class ExecutionContext
 		if(mo.getGPUObject() == null || !mo.getGPUObject().isAllocated) {
 			throw new DMLRuntimeException("No output is allocated on GPU");
 		}
-		mo.getGPUObject().release(true);
+		mo.getGPUObject().releaseOutput();
 	}
 	
 	/**

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/256deb4c/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index 48e67e1..20527df 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -25,6 +25,7 @@ import org.apache.sysml.runtime.instructions.gpu.AggregateBinaryGPUInstruction;
 import org.apache.sysml.runtime.instructions.gpu.ConvolutionGPUInstruction;
 import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
 import org.apache.sysml.runtime.instructions.gpu.GPUInstruction.GPUINSTRUCTION_TYPE;
+import org.apache.sysml.runtime.instructions.gpu.MMTSJGPUInstruction;
 
 public class GPUInstructionParser  extends InstructionParser 
 {
@@ -34,8 +35,10 @@ public class GPUInstructionParser  extends InstructionParser
 		String2GPUInstructionType.put( "conv2d",                 GPUINSTRUCTION_TYPE.Convolution);
 		String2GPUInstructionType.put( "conv2d_backward_filter", GPUINSTRUCTION_TYPE.Convolution);
 		String2GPUInstructionType.put( "conv2d_backward_data",   GPUINSTRUCTION_TYPE.Convolution);
+		String2GPUInstructionType.put( "maxpooling",             GPUINSTRUCTION_TYPE.Convolution);
+		String2GPUInstructionType.put( "maxpooling_backward",    GPUINSTRUCTION_TYPE.Convolution);
 		String2GPUInstructionType.put( "ba+*",                   GPUINSTRUCTION_TYPE.AggregateBinary);
-		
+		String2GPUInstructionType.put( "tsmm",                   GPUINSTRUCTION_TYPE.MMTSJ);
 	}
 	
 	public static GPUInstruction parseSingleInstruction (String str ) 
@@ -68,6 +71,9 @@ public class GPUInstructionParser  extends InstructionParser
 			case Convolution:
 				return ConvolutionGPUInstruction.parseInstruction(str);
 				
+			case MMTSJ:
+				return MMTSJGPUInstruction.parseInstruction(str);
+				
 			default: 
 				throw new DMLRuntimeException("Invalid GPU Instruction Type: " + gputype );
 		}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/256deb4c/src/main/java/org/apache/sysml/runtime/instructions/gpu/AggregateBinaryGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/AggregateBinaryGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/AggregateBinaryGPUInstruction.java
index 3dc98ba..9c413d0 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/AggregateBinaryGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/AggregateBinaryGPUInstruction.java
@@ -68,16 +68,15 @@ public class AggregateBinaryGPUInstruction extends GPUInstruction
 		String opcode = parts[0];
 
 		if ( !opcode.equalsIgnoreCase("ba+*")) {
-			throw new DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown opcode " + opcode);
+ 			throw new DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown opcode " + opcode);
 		}
 		
 		InstructionUtils.checkNumFields( parts, 5 );
 		CPOperand in1 = new CPOperand(parts[1]);
 		CPOperand in2 = new CPOperand(parts[2]);
-		CPOperand out = new CPOperand(parts[3]);		
+		CPOperand out = new CPOperand(parts[3]);
 		boolean isLeftTransposed = Boolean.parseBoolean(parts[4]);
 		boolean isRightTransposed = Boolean.parseBoolean(parts[5]);
-		
 		AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
 		AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg, 1);
 		return new AggregateBinaryGPUInstruction(aggbin, in1, in2, out, opcode, str, isLeftTransposed, isRightTransposed);	

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/256deb4c/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
index ab6d7a4..e489f1c 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
@@ -23,9 +23,12 @@ import java.util.ArrayList;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysml.runtime.functionobjects.SwapIndex;
+import org.apache.sysml.runtime.instructions.Instruction;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.cp.ConvolutionCPInstruction;
 import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
 import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysml.runtime.util.ConvolutionUtils;
@@ -64,43 +67,88 @@ public class ConvolutionGPUInstruction extends GPUInstruction
 		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
 		String opcode = parts[0];
 		
-		if( !( opcode.equalsIgnoreCase("conv2d")
+		if( ( opcode.equalsIgnoreCase("conv2d")
 			 || opcode.equalsIgnoreCase("conv2d_backward_filter")
-			 || opcode.equalsIgnoreCase("conv2d_backward_data")) ) 
-		{
-			throw new DMLRuntimeException("Unknown opcode while parsing a ConvolutionGPUInstruction: " + str);	
+			 || opcode.equalsIgnoreCase("conv2d_backward_data")
+			 || opcode.equalsIgnoreCase("maxpooling_backward")) ) {
+			InstructionUtils.checkNumFields(parts, 15);
+			CPOperand in1 = new CPOperand(parts[1]);
+			CPOperand in2 = new CPOperand(parts[2]);
+			CPOperand out = new CPOperand(parts[15]);
+		
+			ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
+			ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
+			ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
+			ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
+			stride.add(new CPOperand(parts[3]));
+			stride.add(new CPOperand(parts[4]));
+			padding.add(new CPOperand(parts[5]));
+			padding.add(new CPOperand(parts[6]));
+			input_shape.add(new CPOperand(parts[7]));
+			input_shape.add(new CPOperand(parts[8]));
+			input_shape.add(new CPOperand(parts[9]));
+			input_shape.add(new CPOperand(parts[10]));
+			filter_shape.add(new CPOperand(parts[11]));
+			filter_shape.add(new CPOperand(parts[12]));
+			filter_shape.add(new CPOperand(parts[13]));
+			filter_shape.add(new CPOperand(parts[14]));
+
+			return new ConvolutionGPUInstruction(in1, in2, out, opcode, str, stride,
+					padding, input_shape, filter_shape);
 		}
+		else if (opcode.equalsIgnoreCase("maxpooling")) {
+			InstructionUtils.checkNumFields(parts, 14);
+			CPOperand in1 = new CPOperand(parts[1]);
+			CPOperand out = new CPOperand(parts[14]);
 		
-		InstructionUtils.checkNumFields(parts, 15);
-		CPOperand in1 = new CPOperand(parts[1]);
-		CPOperand in2 = new CPOperand(parts[2]);
-		CPOperand out = new CPOperand(parts[15]);
-	
-		ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
-		ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
-		ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
-		ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
-		stride.add(new CPOperand(parts[3]));
-		stride.add(new CPOperand(parts[4]));
-		padding.add(new CPOperand(parts[5]));
-		padding.add(new CPOperand(parts[6]));
-		input_shape.add(new CPOperand(parts[7]));
-		input_shape.add(new CPOperand(parts[8]));
-		input_shape.add(new CPOperand(parts[9]));
-		input_shape.add(new CPOperand(parts[10]));
-		filter_shape.add(new CPOperand(parts[11]));
-		filter_shape.add(new CPOperand(parts[12]));
-		filter_shape.add(new CPOperand(parts[13]));
-		filter_shape.add(new CPOperand(parts[14]));
+			ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
+			ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
+			ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
+			ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
+			stride.add(new CPOperand(parts[2]));
+			stride.add(new CPOperand(parts[3]));
+			padding.add(new CPOperand(parts[4]));
+			padding.add(new CPOperand(parts[5]));
+			input_shape.add(new CPOperand(parts[6]));
+			input_shape.add(new CPOperand(parts[7]));
+			input_shape.add(new CPOperand(parts[8]));
+			input_shape.add(new CPOperand(parts[9]));
+			filter_shape.add(new CPOperand(parts[10]));
+			filter_shape.add(new CPOperand(parts[11]));
+			filter_shape.add(new CPOperand(parts[12]));
+			filter_shape.add(new CPOperand(parts[13]));
 
-		return new ConvolutionGPUInstruction(in1, in2, out, opcode, str, stride,
-				padding, input_shape, filter_shape);
+			return new ConvolutionGPUInstruction(in1, null, out, opcode, str, stride,
+					padding, input_shape, filter_shape);
+		}
+		else {
+			throw new DMLRuntimeException("Unknown opcode while parsing a ConvolutionGPUInstruction: " + str);	
+		}
+	}
+	
+	private boolean isSparse(ExecutionContext ec, String var) throws DMLRuntimeException {
+		MatrixObject mo = ec.getMatrixObject(var);
+		return LibMatrixCUDA.isInSparseFormat(mo);
 	}
 	
 	@Override
 	public void processInstruction(ExecutionContext ec) 
 			throws DMLRuntimeException 
 	{
+		// TODO: Fix Me. Currently calling CP if data is sparse
+		if (instOpcode.equalsIgnoreCase("maxpooling")) {
+			if(	isSparse(ec, _input1.getName())) {
+				ConvolutionCPInstruction.parseInstruction(this.toString() + Instruction.OPERAND_DELIM + InfrastructureAnalyzer.getLocalParallelism()).processInstruction(ec);
+				return;
+			}
+		}
+		else {
+			if(	isSparse(ec, _input1.getName()) || isSparse(ec, _input2.getName())) {
+				ConvolutionCPInstruction.parseInstruction(this.toString() + Instruction.OPERAND_DELIM + InfrastructureAnalyzer.getLocalParallelism()).processInstruction(ec);
+				return;
+			}
+		}
+		
 		Statistics.incrementNoOfExecutedGPUInst();
 					
 		int pad_h = getScalarInput(ec, _padding, 0);
@@ -136,7 +184,6 @@ public class ConvolutionGPUInstruction extends GPUInstruction
 			MatrixObject out = ec.getMatrixOutputForGPUInstruction(_output.getName(), false);
 			LibMatrixCUDA.conv2d(image, filter, out, N, C, H, W,
 					K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
-			
 		}
 		else if (instOpcode.equalsIgnoreCase("conv2d_backward_filter")) {
 			MatrixObject image = ec.getMatrixInputForGPUInstruction(_input1.getName());
@@ -172,13 +219,43 @@ public class ConvolutionGPUInstruction extends GPUInstruction
 			LibMatrixCUDA.conv2d_backward_data(filter, dout, out, N, C, H, W,
 					K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
 		}
+		else if (instOpcode.equalsIgnoreCase("maxpooling")) {
+			MatrixObject image = ec.getMatrixInputForGPUInstruction(_input1.getName());
+			if(LibMatrixCUDA.isInSparseFormat(image))
+				throw new DMLRuntimeException("Sparse maxpooling not implemented");
+			if(image.getNumRows() != N || image.getNumColumns() != C*H*W) 
+				throw new DMLRuntimeException("Incorrect dimensions for image in maxpooling: " + 
+						image.getNumRows() + " != " +  N + " || " + image.getNumColumns() + " != " + C*H*W);
+			
+			ec.setMetaData(_output.getName(), N, C * P * Q);
+			MatrixObject out = ec.getMatrixOutputForGPUInstruction(_output.getName(), false);
+			LibMatrixCUDA.maxpooling(image, out, N, C, H, W,
+					K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
+		}
+		else if (instOpcode.equalsIgnoreCase("maxpooling_backward")) {
+			MatrixObject image = ec.getMatrixInputForGPUInstruction(_input1.getName());
+			MatrixObject dout = ec.getMatrixInputForGPUInstruction(_input2.getName());
+			if(LibMatrixCUDA.isInSparseFormat(image) || LibMatrixCUDA.isInSparseFormat(dout))
+				throw new DMLRuntimeException("Sparse maxpooling_backward_data not implemented");
+			if(dout.getNumRows() != N || dout.getNumColumns() != C*P*Q) 
+				throw new DMLRuntimeException("Incorrect dimensions for dout in maxpooling_backward");
+			if(image.getNumRows() != N || image.getNumColumns() != C*H*W) 
+				throw new DMLRuntimeException("Incorrect dimensions for image in maxpooling_backward: " + 
+						image.getNumRows() + " != " +  N + " || " + image.getNumColumns() + " != " + K*P*Q);
+			
+			ec.setMetaData(_output.getName(), N, C * H * W);
+			MatrixObject out = ec.getMatrixOutputForGPUInstruction(_output.getName(), false);
+			LibMatrixCUDA.maxpooling_backward(image, dout, out, N, C, H, W,
+					K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
+		}
 		else {
 			throw new DMLRuntimeException("Unsupported GPU context for " + instOpcode);
 		}
 		
 		// release inputs/outputs
 		ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-		ec.releaseMatrixInputForGPUInstruction(_input2.getName());
+		if (!instOpcode.equalsIgnoreCase("maxpooling"))
+			ec.releaseMatrixInputForGPUInstruction(_input2.getName());
 		ec.releaseMatrixOutputForGPUInstruction(_output.getName());
 	}
 	

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/256deb4c/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
index ce9646b..d842ac8 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
@@ -28,7 +28,7 @@ import org.apache.sysml.runtime.matrix.operators.Operator;
 
 public abstract class GPUInstruction extends Instruction 
 {
-	public enum GPUINSTRUCTION_TYPE { AggregateBinary, Convolution }; 
+	public enum GPUINSTRUCTION_TYPE { AggregateBinary, Convolution, MMTSJ }; 
 	
 	protected GPUINSTRUCTION_TYPE _gputype;
 	protected Operator _optr;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/256deb4c/src/main/java/org/apache/sysml/runtime/instructions/gpu/MMTSJGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MMTSJGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MMTSJGPUInstruction.java
new file mode 100644
index 0000000..4709085
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MMTSJGPUInstruction.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.gpu;
+
+/*
+ * Parses and processes the MMTSJ GPU Instruction
+ * @function	GPUInstruction(...)
+ * @function	parseInstruction(...)
+ * @function	processInstruction(...)
+ * @function	getMMTSJType(...)
+ */
+import org.apache.sysml.lops.MMTSJ.MMTSJType;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
+import org.apache.sysml.runtime.matrix.operators.Operator;
+import org.apache.sysml.utils.Statistics;
+
+public class MMTSJGPUInstruction extends GPUInstruction
+{
+
+        private MMTSJType _type = null;
+        
+        CPOperand _input;
+        CPOperand _output;
+
+        /**
+         * @param op	operator
+         * @param in1	input
+         * @param type	left/right, left-> A' %*% A, right-> A %*% A'
+         * @param out	output
+         * @param opcode
+         * @param istr
+         */
+        public MMTSJGPUInstruction(Operator op, CPOperand in1, MMTSJType type, CPOperand out,  String opcode, String istr)
+        {
+                super(op, opcode, istr);
+                _gputype = GPUINSTRUCTION_TYPE.MMTSJ;
+                _type = type;
+                _input = in1;
+                _output = out;
+        }
+
+        /**
+         * parse MMTSJ GPU instruction
+         * @param str
+         * @return
+         * @throws DMLRuntimeException
+         */
+        public static MMTSJGPUInstruction parseInstruction ( String str )
+        	throws DMLRuntimeException
+        {
+                String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+                InstructionUtils.checkNumFields ( parts, 3 );
+
+                String opcode = parts[0];
+                CPOperand in1 = new CPOperand(parts[1]);
+                CPOperand out = new CPOperand(parts[2]);
+                MMTSJType titype = MMTSJType.valueOf(parts[3]);
+
+                if(!opcode.equalsIgnoreCase("tsmm"))
+                        throw new DMLRuntimeException("Unknown opcode while parsing an MMTSJGPUInstruction: " + str);
+                else
+                        return new MMTSJGPUInstruction(new Operator(true), in1, titype, out, opcode, str);
+        }
+
+        /**
+         * process MMTSJ GPU instruction 
+         * @param ec	execution context
+         * @throws DMLRuntimeException
+         */
+        @Override
+        public void processInstruction(ExecutionContext ec)
+                throws DMLRuntimeException
+        {
+                Statistics.incrementNoOfExecutedGPUInst();
+
+                //get input
+                MatrixObject mat = ec.getMatrixInputForGPUInstruction(_input.getName());
+               
+                boolean isLeftTransposed = ( _type == MMTSJType.LEFT);
+
+                int rlen = (int) (isLeftTransposed? mat.getNumColumns() : mat.getNumRows());
+                int clen = rlen;
+
+                //execute operations 
+                ec.setMetaData(_output.getName(), rlen, clen);
+                MatrixObject out = ec.getMatrixOutputForGPUInstruction(_output.getName(), false);
+                LibMatrixCUDA.matmultTSMM(mat, out, isLeftTransposed);
+                
+                ec.releaseMatrixInputForGPUInstruction(_input.getName());
+                ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+        }
+
+        /**
+         * returns left/right depending on the type of MMTSJ instruction
+         * @return _type
+         */
+        public MMTSJType getMMTSJType()
+        {
+                return _type;
+        }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/256deb4c/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
index 45b8c5b..33f7099 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
@@ -21,6 +21,7 @@ package org.apache.sysml.runtime.instructions.gpu.context;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
 
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.CacheException;
@@ -30,8 +31,14 @@ import org.apache.sysml.utils.Statistics;
 //FIXME merge JCudaObject into GPUObject to avoid unnecessary complexity
 public abstract class GPUObject 
 {
+	public enum EvictionPolicy {
+        LRU, LFU, MIN_EVICT
+    }
+	public static final EvictionPolicy evictionPolicy = EvictionPolicy.LRU;
 	protected boolean isDeviceCopyModified = false;
 	protected AtomicInteger numLocks = new AtomicInteger(0);
+	AtomicLong timestamp = new AtomicLong(0);
+	
 	protected boolean isInSparseFormat = false;
 	public boolean isAllocated = false;
 	protected MatrixObject mat = null;
@@ -52,8 +59,8 @@ public abstract class GPUObject
 	public abstract void acquireDenseDeviceModify(int numElemsToAllocate) throws DMLRuntimeException;
 	public abstract void acquireHostRead() throws CacheException;
 	public abstract void acquireHostModify() throws CacheException;
-	public abstract void release(boolean isGPUCopyModified) throws CacheException;
-	
+	public abstract void releaseInput() throws CacheException;
+	public abstract void releaseOutput() throws CacheException;
 	
 	// package-level visibility as these methods are guarded by underlying GPUContext
 	abstract void allocateMemoryOnDevice(int numElemToAllocate) throws DMLRuntimeException;
@@ -64,72 +71,73 @@ public abstract class GPUObject
 	
 	
 	/**
-	 * It finds matrix toBeRemoved such that toBeRemoved.GPUSize >= size
-	 * // TODO: it is the smallest matrix size that satisfy the above condition. For now just evicting the largest pointer.
-	 * Then returns toBeRemoved. 
-	 * 
+	 * It finds matrix toBeRemoved such that toBeRemoved.GPUSize is the smallest one whose size is greater than the eviction size
+	 * // TODO: update it with hybrid policy
+	 * @return toBeRemoved
 	 */
-	protected void evict(long GPUSize) throws DMLRuntimeException {
-		if(GPUContext.allocatedPointers.size() == 0) {
-			throw new DMLRuntimeException("There is not enough memory on device for this matrix!");
-		}
-		
-		Statistics.cudaEvictionCount.addAndGet(1);
-		
-		synchronized(evictionLock) {
-			Collections.sort(GPUContext.allocatedPointers, new Comparator<GPUObject>() {
-	
-				@Override
-				public int compare(GPUObject p1, GPUObject p2) {
-					int p1Val = p1.numLocks.get();
-					int p2Val = p2.numLocks.get();
-					
-					if(p1Val < 0 || p2Val < 0) {
-						throw new RuntimeException("Number of locks cannot be negative");
-					}
-					else if(p1Val == 0 && p2Val == 0) {
-						// Both p1 and p2 are unlocked, return largest object
-						// TODO: Modify this !!
-						long p1Size = 0; long p2Size = 0;
-						try {
-							p1Size = p1.getSizeOnDevice();
-							p2Size = p2.getSizeOnDevice();
-						} catch (DMLRuntimeException e) {
-							throw new RuntimeException(e);
-						}
-						if(p1Size == p2Size) {
-							return 0;
-						}
-						else if(p1Size < p2Size) {
-							return 1;
-						}
-						else {
-							return -1;
-						}
-					}
-					else if(p1Val > p2Val) {
-						// There are more locks on p1
-						return 1;
-					}
-					else {
-						// There are more locks on p2
-						return -1;
-					}
-				}
-			});
-			
-			
-			while(GPUSize > getAvailableMemory() && GPUContext.allocatedPointers.size() > 0) {
-				GPUObject toBeRemoved = GPUContext.allocatedPointers.get(GPUContext.allocatedPointers.size() - 1);
-				if(toBeRemoved.numLocks.get() != 0) {
-					throw new DMLRuntimeException("There is not enough memory on device for this matrix!");
-				}
-				if(toBeRemoved.isDeviceCopyModified) {
-					toBeRemoved.copyFromDeviceToHost();
-				}
-				toBeRemoved.clearData();
-			}
-		}
+	protected void evict(final long GPUSize) throws DMLRuntimeException {
+        if(GPUContext.allocatedPointers.size() == 0) {
+                throw new DMLRuntimeException("There is not enough memory on device for this matrix!");
+        }
+        
+        Statistics.cudaEvictionCount.addAndGet(1);
+
+        synchronized(evictionLock) {
+        	Collections.sort(GPUContext.allocatedPointers, new Comparator<GPUObject>() {
+
+        		@Override
+                public int compare(GPUObject p1, GPUObject p2) {
+                	long p1Val = p1.numLocks.get();
+                 	long p2Val = p2.numLocks.get();
+
+                	if(p1Val>0 && p2Val>0) {
+                		// Both are locked, so don't sort
+                        return 0;
+                	}
+                	else if(p1Val>0 || p2Val>0) {
+                		// Put the unlocked one to RHS
+                		return Long.compare(p2Val, p1Val);
+                    }
+                	else {
+                		// Both are unlocked
+
+                		if(evictionPolicy == EvictionPolicy.MIN_EVICT) {
+                			long p1Size = 0; long p2Size = 0;
+                          	try {
+                          		p1Size = p1.getSizeOnDevice() - GPUSize;
+                            	p2Size = p2.getSizeOnDevice() - GPUSize;
+                         	} catch (DMLRuntimeException e) {
+                         		throw new RuntimeException(e);
+                        	}
+
+                          	if(p1Size>=0 && p2Size>=0 ) {
+                          		return Long.compare(p2Size, p1Size);
+                          	}
+                          	else {
+                          		return Long.compare(p1Size, p2Size);
+                          	}
+                     	}
+                		else if(evictionPolicy == EvictionPolicy.LRU || evictionPolicy == EvictionPolicy.LFU) {
+                			return Long.compare(p2.timestamp.get(), p1.timestamp.get());
+                    	}
+                     	else {
+                     		throw new RuntimeException("Unsupported eviction policy:" + evictionPolicy.name());
+                    	}
+                	}
+              	}
+        	});
+
+        	while(GPUSize > getAvailableMemory() && GPUContext.allocatedPointers.size() > 0) {
+        		GPUObject toBeRemoved = GPUContext.allocatedPointers.get(GPUContext.allocatedPointers.size() - 1);
+               	if(toBeRemoved.numLocks.get() > 0) {
+               		throw new DMLRuntimeException("There is not enough memory on device for this matrix!");
+              	}
+               	if(toBeRemoved.isDeviceCopyModified) {
+               		toBeRemoved.copyFromDeviceToHost();
+            	}
+             	toBeRemoved.clearData();
+        	}
+        }
 	}
 	
 	public void clearData() throws CacheException {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/256deb4c/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaObject.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaObject.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaObject.java
index 811f2dd..7f0b26b 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaObject.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaObject.java
@@ -91,6 +91,9 @@ public class JCudaObject extends GPUObject {
 				throw new CacheException(e);
 			}
 		}
+		else {
+			throw new CacheException("Cannot perform acquireHostRead as the GPU data is not allocated:" + mat.getVarName());
+		}
 	}
 	
 	@Override
@@ -108,11 +111,47 @@ public class JCudaObject extends GPUObject {
 		}
 	}
 	
-	public void release(boolean isGPUCopyModified) throws CacheException {
+	/**
+	 * updates the locks depending on the eviction policy selected
+	 * @throws CacheException if there is no locked GPU Object
+	 */
+	private void updateReleaseLocks() throws CacheException {
 		if(numLocks.addAndGet(-1) < 0) {
-			throw new CacheException("Redundant release of GPU object");
+            throw new CacheException("Redundant release of GPU object");
+		}
+		if(evictionPolicy == EvictionPolicy.LRU) {
+            timestamp.set(System.nanoTime());
+		}
+		else if(evictionPolicy == EvictionPolicy.LFU) {
+            timestamp.addAndGet(1);
+		}
+		else if(evictionPolicy == EvictionPolicy.MIN_EVICT) {
+            // Do Nothing
+		}
+		else {
+            throw new CacheException("The eviction policy is not supported:" + evictionPolicy.name());
 		}
-		isDeviceCopyModified = isGPUCopyModified;
+	}
+	
+	/**
+	 * releases input allocated on GPU
+	 * @throws CacheException if data is not allocated
+	 */
+	public void releaseInput() throws CacheException {
+		updateReleaseLocks();
+		if(!isAllocated)
+			throw new CacheException("Attempting to release an input before allocating it");
+	}
+	
+	/**
+	 * releases output allocated on GPU
+	 * @throws CacheException if data is not allocated
+	 */
+	public void releaseOutput() throws CacheException {
+		updateReleaseLocks();
+		isDeviceCopyModified = true;
+		if(!isAllocated)
+			throw new CacheException("Attempting to release an output before allocating it");
 	}
 
 	@Override
@@ -214,7 +253,7 @@ public class JCudaObject extends GPUObject {
 				double [] data = tmp.getDenseBlock();
 				
 				cudaMemcpy(Pointer.to(data), jcudaPointer, data.length * Sizeof.DOUBLE, cudaMemcpyDeviceToHost);
-				
+
 				tmp.recomputeNonZeros();
 				mat.acquireModify(tmp);
 				mat.release();

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/256deb4c/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index 4a94f6a..52272a0 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -22,29 +22,39 @@ package org.apache.sysml.runtime.matrix.data;
 import static jcuda.jcudnn.JCudnn.cudnnConvolutionBackwardData;
 import static jcuda.jcudnn.JCudnn.cudnnConvolutionBackwardFilter;
 import static jcuda.jcudnn.JCudnn.cudnnConvolutionForward;
+import static jcuda.jcudnn.JCudnn.cudnnPoolingForward;
+import static jcuda.jcudnn.JCudnn.cudnnPoolingBackward;
 import static jcuda.jcudnn.JCudnn.cudnnCreateConvolutionDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnCreateFilterDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnCreateTensorDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnCreatePoolingDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnDestroyConvolutionDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnDestroyFilterDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnDestroyTensorDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnDestroyPoolingDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnGetConvolutionBackwardDataWorkspaceSize;
 import static jcuda.jcudnn.JCudnn.cudnnGetConvolutionBackwardFilterWorkspaceSize;
 import static jcuda.jcudnn.JCudnn.cudnnGetConvolutionForwardWorkspaceSize;
 import static jcuda.jcudnn.JCudnn.cudnnSetConvolution2dDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnSetFilter4dDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnSetTensor4dDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnSetPooling2dDescriptor;
 import static jcuda.jcudnn.cudnnConvolutionMode.CUDNN_CROSS_CORRELATION;
 import static jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
 import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
+import static jcuda.jcudnn.cudnnPoolingMode.CUDNN_POOLING_MAX;
 import jcuda.jcudnn.cudnnConvolutionFwdPreference;
+import static jcuda.runtime.JCuda.cudaMalloc;
 import static jcuda.runtime.JCuda.cudaFree;
 import jcuda.Pointer;
+import jcuda.Sizeof;
 import jcuda.jcublas.JCublas;
+import jcuda.jcublas.JCublas2;
 import jcuda.jcublas.cublasHandle;
 import jcuda.jcudnn.cudnnConvolutionDescriptor;
 import jcuda.jcudnn.cudnnFilterDescriptor;
 import jcuda.jcudnn.cudnnHandle;
+import jcuda.jcudnn.cudnnPoolingDescriptor;
 import jcuda.jcudnn.cudnnTensorDescriptor;
 
 import org.apache.sysml.runtime.DMLRuntimeException;
@@ -175,6 +185,23 @@ public class LibMatrixCUDA {
 		return filterDesc;
 	}
 	
+	/**
+	 * allocates pooling descriptor, used in poolingForward and poolingBackward
+	 * @param R			pooling window height
+	 * @param S			pooling window width
+	 * @param pad_h		vertical padding
+	 * @param pad_w		horizontal padding
+	 * @param stride_h	pooling vertical stride
+	 * @param stride_w	pooling horizontal stride
+	 * @return
+	 */
+	private static cudnnPoolingDescriptor allocatePoolingDescriptor(int R, int S, int pad_h, int pad_w, int stride_h, int stride_w) {
+		cudnnPoolingDescriptor poolingDesc = new cudnnPoolingDescriptor();
+		cudnnCreatePoolingDescriptor(poolingDesc);
+		cudnnSetPooling2dDescriptor(poolingDesc, CUDNN_POOLING_MAX, R, S, pad_h, pad_w, stride_h, stride_w);
+		return poolingDesc;
+	}
+	
 	public static void conv2d_backward_filter(MatrixObject image, MatrixObject dout,
 			MatrixObject outputBlock, int N, int C, int H, int W, int K, int R,
 			int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
@@ -240,6 +267,48 @@ public class LibMatrixCUDA {
 		
 	}
 
+	/**
+	 * Performs tsmm, A %*% A' or A' %*% A, on GPU by exploiting cublasDsyrk(...)
+	 * @param left	input matrix, as in a tsmm expression like A %*% A' or A' %*% A, we just need to check whether the left one is transposed or not, I named it 'left'
+	 * @param output
+	 * @param isLeftTransposed
+	 * @throws DMLRuntimeException
+	 */
+	public static void matmultTSMM(MatrixObject left, MatrixObject output,
+            boolean isLeftTransposed) throws DMLRuntimeException {
+	    if(isInSparseFormat(left)) {
+	            throw new DMLRuntimeException("Sparse GPU TSMM is not implemented");
+	    }
+	
+	    // Since CuBLAS expects inputs in column-major format,
+	    // reverse the order of matrix-multiplication and take care of dimension mismatch.      
+	    char transa = isLeftTransposed ? 'N' : 'T';
+	    // Note: the dimensions are swapped
+	    int m = (int) (isLeftTransposed ? left.getNumColumns() : left.getNumRows());
+	    int k = (int) (isLeftTransposed ? left.getNumRows() : left.getNumColumns());
+	
+	    if(m == -1)
+	            throw new DMLRuntimeException("Incorrect dimensions");
+	
+	    double alpha = 1.0d;
+	    double beta = 0.0d;
+	
+	    int lda = (int) (isLeftTransposed ? m : k);
+	    int ldc = m;
+	
+	    if(!left.getGPUObject().isAllocated)
+	            throw new DMLRuntimeException("Input is not allocated:" + left.getGPUObject().isAllocated);
+	    if(!output.getGPUObject().isAllocated)
+	            throw new DMLRuntimeException("Output is not allocated:" + output.getGPUObject().isAllocated);
+	
+	    Pointer A = ((JCudaObject)left.getGPUObject()).jcudaPointer;
+	    Pointer C = ((JCudaObject)output.getGPUObject()).jcudaPointer;
+	    
+	    //TODO: Fix it if there is a cuBLAS API to do flipping
+	    JCublas.cublasDsyrk('U',transa, m, k, alpha, A, lda, beta, C, ldc);
+	    JCublas.cublasDsyrk('L',transa, m, k, alpha, A, lda, beta, C, ldc);
+	}
+	
 	public static void matmult(MatrixObject left1, MatrixObject right1, MatrixObject output, 
 			boolean isLeftTransposed1, boolean isRightTransposed1) throws DMLRuntimeException {
 		if(isInSparseFormat(left1) || isInSparseFormat(right1)) {
@@ -349,6 +418,153 @@ public class LibMatrixCUDA {
 		}
 	}
 	
+	/**
+	 * performs maxpooling on GPU by exploiting cudnnPoolingForward(...)
+	 * @param image
+	 * @param outputBlock
+	 * @param N				batch size
+	 * @param C				number of channels
+	 * @param H				height of image
+	 * @param W				width of image
+	 * @param K				number of filters
+	 * @param R				height of filter
+	 * @param S				width of filter
+	 * @param pad_h			vertical padding
+	 * @param pad_w			horizontal padding
+	 * @param stride_h		horizontal stride
+	 * @param stride_w		vertical stride
+	 * @param P				(H - R + 1 + 2*pad_h)/stride_h
+	 * @param Q				(W - S + 1 + 2*pad_w)/stride_w
+	 * @throws DMLRuntimeException
+	 */
+	public static void maxpooling(MatrixObject image,
+			MatrixObject outputBlock, int N, int C, int H, int W, int K, int R,
+			int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
+			int Q) throws DMLRuntimeException {
+		Pointer alpha = null;
+		Pointer beta = null;
+		cudnnTensorDescriptor xDesc = null;
+		cudnnTensorDescriptor yDesc = null;
+		cudnnPoolingDescriptor poolingDesc = null;
+
+		try {
+			// Allocate descriptors
+			yDesc = allocateTensorDescriptor(N, C, P, Q);
+			xDesc = allocateTensorDescriptor(N, C, H, W);
+			poolingDesc = allocatePoolingDescriptor(R, S, pad_h, pad_w, stride_h, stride_w);
+			
+			// Allocate data
+			Pointer x = ((JCudaObject)image.getGPUObject()).jcudaPointer; 
+			Pointer y = ((JCudaObject)outputBlock.getGPUObject()).jcudaPointer; 
+			
+			alpha = pointerTo(1.0);
+			beta = pointerTo(0.0f);
+			
+			int status = cudnnPoolingForward(cudnnHandle, poolingDesc, alpha, xDesc, x, beta, yDesc, y);
+			
+			if(status != jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
+				throw new DMLRuntimeException("Could not executed cudnnPoolingForward: " + jcuda.jcudnn.cudnnStatus.stringFor(status));
+			}
+		}
+		finally {
+			if(alpha != null)
+				cudaFree(alpha);
+			if(beta != null)
+				cudaFree(beta);
+			if(yDesc != null)
+				cudnnDestroyTensorDescriptor(yDesc);
+			if(xDesc != null)
+				cudnnDestroyTensorDescriptor(xDesc);
+			if(poolingDesc != null)
+				cudnnDestroyPoolingDescriptor(poolingDesc);
+		}
+	}
+	
+	/**
+	 * performs maxpoolingBackward on GPU by exploiting cudnnPoolingBackward(...)
+	 * @param image
+	 * @param dout			delta matrix, output of previous layer
+	 * @param outputBlock
+	 * @param N				batch size
+	 * @param C				number of channels
+	 * @param H				height of image
+	 * @param W				width of image
+	 * @param K				number of filters
+	 * @param R				height of filter
+	 * @param S				width of filter
+	 * @param pad_h			vertical padding
+	 * @param pad_w			horizontal padding
+	 * @param stride_h		horizontal stride
+	 * @param stride_w		vertical stride
+	 * @param P				(H - R + 1 + 2*pad_h)/stride_h
+	 * @param Q				(W - S + 1 + 2*pad_w)/stride_w
+	 * @throws DMLRuntimeException
+	 */
+	public static void maxpooling_backward(MatrixObject image, MatrixObject dout,
+			MatrixObject outputBlock, int N, int C, int H, int W, int K, int R,
+			int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
+			int Q) throws DMLRuntimeException {
+		Pointer alpha = null;
+		Pointer beta = null;
+		cudnnTensorDescriptor xDesc = null;
+		cudnnTensorDescriptor yDesc = null;
+		cudnnTensorDescriptor dyDesc = null;
+		cudnnTensorDescriptor dxDesc = null;
+		cudnnPoolingDescriptor poolingDesc = null;
+
+		try {
+			// Allocate descriptors
+			xDesc = allocateTensorDescriptor(N, C, H, W);
+			yDesc = allocateTensorDescriptor(N, C, P, Q);
+			dxDesc = allocateTensorDescriptor(N, C, H, W);
+			dyDesc = allocateTensorDescriptor(N, C, P, Q);
+			
+			poolingDesc = allocatePoolingDescriptor(R, S, pad_h, pad_w, stride_h, stride_w);
+			
+			// Calling PoolForward first, y is one of the inputs for poolBackward
+			// TODO: Remove calling poolForward after necessary changes at language level for poolBackward
+			Pointer y = new Pointer();
+			long numBytes = N*C*P*Q*Sizeof.DOUBLE;
+			cudaMalloc(y, numBytes);
+			
+			// Allocate data
+			Pointer x = ((JCudaObject)image.getGPUObject()).jcudaPointer; 
+			Pointer dx = ((JCudaObject)outputBlock.getGPUObject()).jcudaPointer;
+			Pointer dy = ((JCudaObject)dout.getGPUObject()).jcudaPointer;
+			
+			alpha = pointerTo(1.0);
+			beta = pointerTo(0.0f);
+			
+			int status = cudnnPoolingForward(cudnnHandle, poolingDesc, alpha, xDesc, x, beta, yDesc, y);
+			if(status != jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
+				throw new DMLRuntimeException("Could not executed cudnnPoolingForward before cudnnPoolingBackward: " + jcuda.jcudnn.cudnnStatus.stringFor(status));
+			}
+			
+			status = cudnnPoolingBackward(cudnnHandle, poolingDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx);
+			
+			if(status != jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
+				throw new DMLRuntimeException("Could not executed cudnnPoolingBackward: " + jcuda.jcudnn.cudnnStatus.stringFor(status));
+			}
+			
+			cudaFree(y);
+		}
+		finally {
+			if(alpha != null)
+				cudaFree(alpha);
+			if(beta != null)
+				cudaFree(beta);
+			if(yDesc != null)
+				cudnnDestroyTensorDescriptor(yDesc);
+			if(xDesc != null)
+				cudnnDestroyTensorDescriptor(xDesc);
+			if(dyDesc != null)
+				cudnnDestroyTensorDescriptor(dyDesc);
+			if(dxDesc != null)
+				cudnnDestroyTensorDescriptor(dxDesc);
+			if(poolingDesc != null)
+				cudnnDestroyPoolingDescriptor(poolingDesc);	
+		}	
+	}
 	public static boolean isInSparseFormat(MatrixObject mo) {
 		if(mo.getGPUObject() != null && mo.getGPUObject().isAllocated())
 			return mo.getGPUObject().isInSparseFormat();

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/256deb4c/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java b/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java
index f2482e1..db244ff 100644
--- a/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java
+++ b/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java
@@ -190,42 +190,45 @@ public class ConvolutionUtils {
 	}
 	
 	public static Lop constructConvolutionBackwardDataLops(Hop currentHop, ExecType et) throws HopsException, LopsException {
-		if(DMLScript.USE_ACCELERATOR)
-			et = ExecType.GPU; // TODO: Add memory estimate checks
-		else
-			return null;
-		
-		if(currentHop != null && isConvolutionOp(currentHop, ConvOp.COL2IM)) {
-			Hop temp = currentHop.getInput().get(0);
-			if(temp != null && isTranspose(temp)) {
-				Hop matMult = temp.getInput().get(0);
-				if(matMult != null && isMatMult(matMult)) {
-					Hop rotate180 = matMult.getInput().get(0);
-					Hop filter = matMult.getInput().get(1);
-					if(isConvolutionOp(rotate180, ConvOp.ROTATE180)) {
-						ArrayList<Hop> inputs = new ArrayList<Hop>();
-						inputs.add(filter);
-						inputs.add(rotate180.getInput().get(0));
-						for(int i = 1; i < rotate180.getInput().size(); i++) {
-							inputs.add(rotate180.getInput().get(i));
-						}
-						
-						// N, C * H * W
-						long N = currentHop.computeSizeInformation(inputs.get(6));
-						long C = currentHop.computeSizeInformation(inputs.get(7));
-						long H = currentHop.computeSizeInformation(inputs.get(8));
-						long W = currentHop.computeSizeInformation(inputs.get(9));
-						long rlen = N;
-						long clen = ConvolutionOp.getExtractedVal(C, H, W);
-						return ConvolutionOp.constructFusedConvolutionLops(et, inputs, ConvOp.DIRECT_CONV2D_BACKWARD_DATA, (ConvolutionOp) rotate180, rlen, clen);
-						
-						
-					}
-				}
-			}
-		}
+		return null; // Until we add CP conv2d_backward_data
 		
-		return null;
+		//TODO: uncomment the following after CP conv2d_backward_data is added
+//		if(DMLScript.USE_ACCELERATOR)
+//			et = ExecType.GPU; // TODO: Add memory estimate checks
+//		else
+//			return null;
+//		
+//		if(currentHop != null && isConvolutionOp(currentHop, ConvOp.COL2IM)) {
+//			Hop temp = currentHop.getInput().get(0);
+//			if(temp != null && isTranspose(temp)) {
+//				Hop matMult = temp.getInput().get(0);
+//				if(matMult != null && isMatMult(matMult)) {
+//					Hop rotate180 = matMult.getInput().get(0);
+//					Hop filter = matMult.getInput().get(1);
+//					if(isConvolutionOp(rotate180, ConvOp.ROTATE180)) {
+//						ArrayList<Hop> inputs = new ArrayList<Hop>();
+//						inputs.add(filter);
+//						inputs.add(rotate180.getInput().get(0));
+//						for(int i = 1; i < rotate180.getInput().size(); i++) {
+//							inputs.add(rotate180.getInput().get(i));
+//						}
+//						
+//						// N, C * H * W
+//						long N = currentHop.computeSizeInformation(inputs.get(6));
+//						long C = currentHop.computeSizeInformation(inputs.get(7));
+//						long H = currentHop.computeSizeInformation(inputs.get(8));
+//						long W = currentHop.computeSizeInformation(inputs.get(9));
+//						long rlen = N;
+//						long clen = ConvolutionOp.getExtractedVal(C, H, W);
+//						return ConvolutionOp.constructFusedConvolutionLops(et, inputs, ConvOp.DIRECT_CONV2D_BACKWARD_DATA, (ConvolutionOp) rotate180, rlen, clen);
+//						
+//						
+//					}
+//				}
+//			}
+//		}
+//		
+//		return null;
 	}
 	
 	

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/256deb4c/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/TransposeSelfMatrixMultiplication.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/TransposeSelfMatrixMultiplication.java b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/TransposeSelfMatrixMultiplication.java
new file mode 100644
index 0000000..b69248d
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/TransposeSelfMatrixMultiplication.java
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.binary.matrix;
+
+import java.util.HashMap;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+
+/**
+ * This test investigates the specific Hop-Lop rewrite t(X)%*%v -> t(t(v)%*%X).
+ * 
+ */
+public class TransposeSelfMatrixMultiplication extends AutomatedTestBase
+{
+
+        private final static String TEST_NAME1 = "TransposeSelfMatrixMultiplication";
+        private final static String TEST_DIR = "functions/binary/matrix/";
+        private final static String TEST_CLASS_DIR = TEST_DIR + TransposeSelfMatrixMultiplication.class.getSimpleName() + "/";
+        private final static double eps = 1e-10;
+
+        //multiblock
+        private final static int rowsA1 = 3;
+        private final static int colsA1 = 3;
+
+        //singleblock
+        private final static int rowsA2 = 2407;
+        private final static int colsA2 = 73;
+
+
+        private final static double sparsity1 = 0.7;
+        private final static double sparsity2 = 0.1;
+
+
+        @Override
+        public void setUp()
+        {
+                addTestConfiguration( TEST_NAME1,
+                        new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "C" }) );
+
+                if (TEST_CACHE_ENABLED) {
+                        setOutAndExpectedDeletionDisabled(true);
+                }
+        }
+
+        @BeforeClass
+        public static void init()
+        {
+                TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+        }
+
+        @AfterClass
+        public static void cleanUp()
+        {
+                if (TEST_CACHE_ENABLED) {
+                        TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+                }
+        }
+
+        @Test
+        public void testTransposeMMDenseDenseCP1()
+        {
+        		/**
+        		 * test case to test the pattern X %*% t(X) and t(X) %*% X 
+        		 * @param1	isSparse
+        		 * @param2	ExecType
+        		 * @param3	isVector
+        		 * @param4	isLeftTransposed	for A %*% A', it's false; for A' %*% A, it's true
+        		 */
+                runTransposeSelfMatrixMultiplication(false, ExecType.CP, false, true);
+        }
+
+        @Test
+        public void testTransposeMMDenseDenseCP2()
+        {
+                runTransposeSelfMatrixMultiplication(false, ExecType.CP, false, false);
+        }
+        
+        /**
+         * 
+         * @param sparseM1
+         * @param sparseM2
+         * @param instType
+         */
+        private void runTransposeSelfMatrixMultiplication( boolean sparseM1, ExecType instType, boolean vectorM2, boolean isLeftTransposed)
+        {
+                //rtplatform for MR
+                RUNTIME_PLATFORM platformOld = rtplatform;
+                switch( instType ){
+                        case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+                        case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+                        default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
+                }
+
+                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+                if( rtplatform == RUNTIME_PLATFORM.SPARK )
+                	DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+                int rowsA = vectorM2 ? rowsA2 : rowsA1;
+                int colsA = vectorM2 ? colsA2 : colsA1;
+
+                String TEST_NAME = TEST_NAME1;
+
+                try
+                {
+                        TestConfiguration config = getTestConfiguration(TEST_NAME);
+
+                        double sparsityM1 = sparseM1?sparsity2:sparsity1;
+
+                        String TEST_CACHE_DIR = "";
+                        if (TEST_CACHE_ENABLED)
+                        {
+                                TEST_CACHE_DIR = sparsityM1 + "_" + vectorM2 + "_" + isLeftTransposed + "/";
+                        }
+
+                        loadTestConfiguration(config, TEST_CACHE_DIR);
+
+                        /* This is for running the junit test the new way, i.e., construct the arguments directly */
+                        String HOME = SCRIPT_DIR + TEST_DIR;
+                        fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                        programArgs = new String[]{"-explain","-args",
+                                input("A"), Integer.toString(rowsA), Integer.toString(colsA),
+                                ("" + isLeftTransposed).toUpperCase(),
+                                output("C")};
+
+                        fullRScriptName = HOME + TEST_NAME + ".R";
+                        rCmd = "Rscript" + " " + fullRScriptName + " " +
+                        inputDir() + " " + isLeftTransposed + " " + expectedDir();
+
+                        //generate actual dataset
+                        double[][] A = getRandomMatrix(rowsA, colsA, 0, 1, sparsityM1, 7);
+                        writeInputMatrix("A", A, true);
+
+                        boolean exceptionExpected = false;
+                        runTest(true, exceptionExpected, null, -1);
+                        runRScript(true);
+
+                        //compare matrices 
+                        HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("C");
+                        HashMap<CellIndex, Double> rfile  = readRMatrixFromFS("C");
+                        TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+                }
+                finally
+                {
+                        rtplatform = platformOld;
+                        DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+                }
+        }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/256deb4c/src/test/scripts/functions/binary/matrix/TransposeSelfMatrixMultiplication.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/binary/matrix/TransposeSelfMatrixMultiplication.R b/src/test/scripts/functions/binary/matrix/TransposeSelfMatrixMultiplication.R
new file mode 100644
index 0000000..1695d40
--- /dev/null
+++ b/src/test/scripts/functions/binary/matrix/TransposeSelfMatrixMultiplication.R
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A <- as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+
+if(args[2] == "true") {
+        C = t(A) %*% A;
+}
+if(args[2] == "false") {
+        C = A %*% t(A);
+}
+
+
+writeMM(as(C, "CsparseMatrix"), paste(args[3], "C", sep=""));
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/256deb4c/src/test/scripts/functions/binary/matrix/TransposeSelfMatrixMultiplication.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/binary/matrix/TransposeSelfMatrixMultiplication.dml b/src/test/scripts/functions/binary/matrix/TransposeSelfMatrixMultiplication.dml
new file mode 100644
index 0000000..2825876
--- /dev/null
+++ b/src/test/scripts/functions/binary/matrix/TransposeSelfMatrixMultiplication.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1, rows=$2, cols=$3, format="text");
+isLeftTransposed = $4;
+if(isLeftTransposed)
+        C = t(A) %*% A;
+else
+        C = A %*% t(A);
+write(C, $5, format="text");
\ No newline at end of file