You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by na...@apache.org on 2017/05/01 04:45:51 UTC

incubator-systemml git commit: [SYSTEMML-1034] Initial implementation of "solve" for GPU

Repository: incubator-systemml
Updated Branches:
  refs/heads/master f2a927f87 -> e8fbc7539


[SYSTEMML-1034] Initial implementation of "solve" for GPU

Closes #476


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/e8fbc753
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/e8fbc753
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/e8fbc753

Branch: refs/heads/master
Commit: e8fbc753988dc94e97a8e8b723e22e89483a1fc6
Parents: f2a927f
Author: Nakul Jindal <na...@gmail.com>
Authored: Sun Apr 30 21:45:21 2017 -0700
Committer: Nakul Jindal <na...@gmail.com>
Committed: Sun Apr 30 21:45:21 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/BinaryOp.java    |   2 +-
 .../instructions/GPUInstructionParser.java      |  17 ++-
 .../gpu/BuiltinBinaryGPUInstruction.java        |  78 +++++++++++
 .../gpu/BuiltinUnaryGPUInstruction.java         |   2 +-
 .../instructions/gpu/GPUInstruction.java        |   2 +-
 .../gpu/MatrixMatrixBuiltinGPUInstruction.java  |  58 ++++++++
 .../instructions/gpu/context/CSRPointer.java    |  29 +++-
 .../instructions/gpu/context/GPUContext.java    |  35 ++++-
 .../instructions/gpu/context/GPUObject.java     |  72 +++++++---
 .../runtime/matrix/data/LibMatrixCUDA.java      | 133 ++++++++++++++++++-
 10 files changed, 391 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/e8fbc753/src/main/java/org/apache/sysml/hops/BinaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/BinaryOp.java b/src/main/java/org/apache/sysml/hops/BinaryOp.java
index 7ddc656..17a099f 100644
--- a/src/main/java/org/apache/sysml/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/BinaryOp.java
@@ -592,7 +592,7 @@ public class BinaryOp extends Hop
 			if ( et == ExecType.CP ) 
 			{
 				if(DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR || getMemEstimate() < OptimizerUtils.GPU_MEMORY_BUDGET) 
-						&& (op == OpOp2.MULT || op == OpOp2.PLUS || op == OpOp2.MINUS || op == OpOp2.DIV || op == OpOp2.POW)) {
+						&& (op == OpOp2.MULT || op == OpOp2.PLUS || op == OpOp2.MINUS || op == OpOp2.DIV || op == OpOp2.POW || op == OpOp2.SOLVE)) {
 					et = ExecType.GPU;
 				}
 				

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/e8fbc753/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index e5b3326..ef0412c 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -23,6 +23,7 @@ import java.util.HashMap;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.instructions.gpu.AggregateBinaryGPUInstruction;
 import org.apache.sysml.runtime.instructions.gpu.ArithmeticBinaryGPUInstruction;
+import org.apache.sysml.runtime.instructions.gpu.BuiltinBinaryGPUInstruction;
 import org.apache.sysml.runtime.instructions.gpu.BuiltinUnaryGPUInstruction;
 import org.apache.sysml.runtime.instructions.gpu.ConvolutionGPUInstruction;
 import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
@@ -68,12 +69,15 @@ public class GPUInstructionParser  extends InstructionParser
 		String2GPUInstructionType.put( "^2"   , GPUINSTRUCTION_TYPE.ArithmeticBinary); //special ^ case
 		String2GPUInstructionType.put( "*2"   , GPUINSTRUCTION_TYPE.ArithmeticBinary); //special * case
 		String2GPUInstructionType.put( "-nz"  , GPUINSTRUCTION_TYPE.ArithmeticBinary); //special - case
-		String2GPUInstructionType.put( "+*"  , GPUINSTRUCTION_TYPE.ArithmeticBinary); 
-		String2GPUInstructionType.put( "-*"  , GPUINSTRUCTION_TYPE.ArithmeticBinary); 
+		String2GPUInstructionType.put( "+*"  	, GPUINSTRUCTION_TYPE.ArithmeticBinary);
+		String2GPUInstructionType.put( "-*"  	, GPUINSTRUCTION_TYPE.ArithmeticBinary);
 		
 		// Builtin functions
-		String2GPUInstructionType.put( "sel+"  , GPUINSTRUCTION_TYPE.BuiltinUnary);
-		String2GPUInstructionType.put( "exp"  , GPUINSTRUCTION_TYPE.BuiltinUnary);
+		String2GPUInstructionType.put( "sel+"  	, GPUINSTRUCTION_TYPE.BuiltinUnary);
+		String2GPUInstructionType.put( "exp"  	, GPUINSTRUCTION_TYPE.BuiltinUnary);
+
+		String2GPUInstructionType.put( "solve"  , GPUINSTRUCTION_TYPE.BuiltinBinary);
+
 
 		// Aggregate Unary
 		String2GPUInstructionType.put( "ua+"	 	 , GPUINSTRUCTION_TYPE.AggregateUnary);	// Sum
@@ -95,7 +99,7 @@ public class GPUInstructionParser  extends InstructionParser
 		String2GPUInstructionType.put( "uasqk+"	 , GPUINSTRUCTION_TYPE.AggregateUnary);	// Sum of Squares
 		String2GPUInstructionType.put( "uarsqk+" , GPUINSTRUCTION_TYPE.AggregateUnary);	// Row Sum of Squares
 		String2GPUInstructionType.put( "uacsqk+" , GPUINSTRUCTION_TYPE.AggregateUnary);	// Col Sum of Squares
-		String2GPUInstructionType.put( "uavar" 	 , GPUINSTRUCTION_TYPE.AggregateUnary);		// Variance
+		String2GPUInstructionType.put( "uavar" 	 , GPUINSTRUCTION_TYPE.AggregateUnary);	// Variance
 		String2GPUInstructionType.put( "uarvar"  , GPUINSTRUCTION_TYPE.AggregateUnary);	// Row Variance
 		String2GPUInstructionType.put( "uacvar"  , GPUINSTRUCTION_TYPE.AggregateUnary);	// Col Variance
 	}
@@ -132,6 +136,9 @@ public class GPUInstructionParser  extends InstructionParser
 			
 			case BuiltinUnary:
 				return BuiltinUnaryGPUInstruction.parseInstruction(str);
+
+			case BuiltinBinary:
+				return BuiltinBinaryGPUInstruction.parseInstruction(str);
 			
 			case Convolution:
 				return ConvolutionGPUInstruction.parseInstruction(str);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/e8fbc753/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java
new file mode 100644
index 0000000..372f883
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.gpu;
+
+import org.apache.sysml.parser.Expression;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.functionobjects.Builtin;
+import org.apache.sysml.runtime.functionobjects.ValueFunction;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysml.runtime.matrix.operators.Operator;
+
+public abstract class BuiltinBinaryGPUInstruction extends GPUInstruction {
+
+  private int _arity;
+  CPOperand output;
+  CPOperand input1, input2;
+
+
+  public BuiltinBinaryGPUInstruction(Operator op, CPOperand input1, CPOperand input2, CPOperand output, String opcode, String istr, int _arity) {
+    super(op, opcode, istr);
+    this._arity = _arity;
+    this.output = output;
+    this.input1 = input1;
+    this.input2 = input2;
+  }
+
+  public static BuiltinBinaryGPUInstruction parseInstruction(String str) throws DMLRuntimeException {
+    CPOperand in1 = new CPOperand("", Expression.ValueType.UNKNOWN, Expression.DataType.UNKNOWN);
+    CPOperand in2 = new CPOperand("", Expression.ValueType.UNKNOWN, Expression.DataType.UNKNOWN);
+    CPOperand out = new CPOperand("", Expression.ValueType.UNKNOWN, Expression.DataType.UNKNOWN);
+
+    String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+    InstructionUtils.checkNumFields ( parts, 3 );
+
+    String opcode = parts[0];
+    in1.split(parts[1]);
+    in2.split(parts[2]);
+    out.split(parts[3]);
+
+    // check for valid data type of output
+    if((in1.getDataType() == Expression.DataType.MATRIX || in2.getDataType() == Expression.DataType.MATRIX) && out.getDataType() != Expression.DataType.MATRIX)
+      throw new DMLRuntimeException("Element-wise matrix operations between variables " + in1.getName() +
+              " and " + in2.getName() + " must produce a matrix, which " + out.getName() + " is not");
+
+    // Determine appropriate Function Object based on opcode
+    ValueFunction func = Builtin.getBuiltinFnObject(opcode);
+
+    // Only for "solve"
+    if ( in1.getDataType() == Expression.DataType.SCALAR && in2.getDataType() == Expression.DataType.SCALAR )
+      throw new DMLRuntimeException("GPU : Unsupported GPU builtin operations on 2 scalars");
+    else if ( in1.getDataType() == Expression.DataType.MATRIX && in2.getDataType() == Expression.DataType.MATRIX )
+      return new MatrixMatrixBuiltinGPUInstruction(new BinaryOperator(func), in1, in2, out, opcode, str, 2);
+    else
+      throw new DMLRuntimeException("GPU : Unsupported GPU builtin operations on a matrix and a scalar");
+
+
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/e8fbc753/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinUnaryGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinUnaryGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinUnaryGPUInstruction.java
index 181af4e..7529b05 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinUnaryGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinUnaryGPUInstruction.java
@@ -43,7 +43,7 @@ public abstract class BuiltinUnaryGPUInstruction  extends GPUInstruction {
 		_gputype = GPUINSTRUCTION_TYPE.BuiltinUnary;
 		this._arity = _arity;
 		_input = in;
-        _output = out;
+    _output = out;
 	}
 
 	public int getArity() {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/e8fbc753/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
index 0b69b5e..9eef072 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
@@ -32,7 +32,7 @@ import org.apache.sysml.utils.Statistics;
 
 public abstract class GPUInstruction extends Instruction
 {
-	public enum GPUINSTRUCTION_TYPE { AggregateUnary, AggregateBinary, Convolution, MMTSJ, Reorg, ArithmeticBinary, BuiltinUnary, Builtin };
+	public enum GPUINSTRUCTION_TYPE { AggregateUnary, AggregateBinary, Convolution, MMTSJ, Reorg, ArithmeticBinary, BuiltinUnary, BuiltinBinary, Builtin };
 
 	// Memory/conversions
 	public final static String MISC_TIMER_HOST_TO_DEVICE = 				"H2D";	// time spent in bringing data to gpu (from host)

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/e8fbc753/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixBuiltinGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixBuiltinGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixBuiltinGPUInstruction.java
new file mode 100644
index 0000000..f492b6e
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixBuiltinGPUInstruction.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.gpu;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
+import org.apache.sysml.runtime.matrix.operators.Operator;
+import org.apache.sysml.utils.GPUStatistics;
+
+
+public class MatrixMatrixBuiltinGPUInstruction extends BuiltinBinaryGPUInstruction {
+
+  public MatrixMatrixBuiltinGPUInstruction(Operator op, CPOperand input1, CPOperand input2, CPOperand output, String opcode, String istr, int _arity) {
+    super(op, input1, input2, output, opcode, istr, _arity);
+    _gputype = GPUINSTRUCTION_TYPE.BuiltinUnary;
+
+  }
+
+  @Override
+  public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
+    GPUStatistics.incrementNoOfExecutedGPUInst();
+
+    String opcode = getOpcode();
+    MatrixObject mat1 = getMatrixInputForGPUInstruction(ec, input1.getName());
+    MatrixObject mat2 = getMatrixInputForGPUInstruction(ec, input2.getName());
+
+    if(opcode.equals("solve")) {
+      LibMatrixCUDA.solve(ec, ec.getGPUContext(), getExtendedOpcode(), mat1, mat2, output.getName());
+
+    } else {
+      throw new DMLRuntimeException("Unsupported GPU operator:" + opcode);
+    }
+    ec.releaseMatrixInputForGPUInstruction(input1.getName());
+    ec.releaseMatrixInputForGPUInstruction(input2.getName());
+    ec.releaseMatrixOutputForGPUInstruction(output.getName());
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/e8fbc753/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
index ef549a1..05257e5 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
@@ -29,6 +29,7 @@ import static jcuda.jcusparse.JCusparse.cusparseXcsrgemmNnz;
 import static jcuda.jcusparse.cusparseIndexBase.CUSPARSE_INDEX_BASE_ZERO;
 import static jcuda.jcusparse.cusparseMatrixType.CUSPARSE_MATRIX_TYPE_GENERAL;
 import static jcuda.runtime.JCuda.cudaMemcpy;
+import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice;
 import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToHost;
 import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyHostToDevice;
 
@@ -39,6 +40,7 @@ import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.utils.GPUStatistics;
 
 import jcuda.Pointer;
+import jcuda.Sizeof;
 import jcuda.jcublas.cublasHandle;
 import jcuda.jcusparse.cusparseHandle;
 import jcuda.jcusparse.cusparseMatDescr;
@@ -52,11 +54,11 @@ public class CSRPointer {
 
   private static final Log LOG = LogFactory.getLog(CSRPointer.class.getName());
 
+  private static final double ULTRA_SPARSITY_TURN_POINT = 0.0004;
+
   /** {@link GPUContext} instance to track the GPU to do work on */
   private final GPUContext gpuContext;
 
-  private static final double ULTRA_SPARSITY_TURN_POINT = 0.0004;
-
   public static cusparseMatDescr matrixDescriptor;
 
   /** Number of non zeroes */
@@ -74,6 +76,27 @@ public class CSRPointer {
   /** descriptor of matrix, only CUSPARSE_MATRIX_TYPE_GENERAL supported */
   public cusparseMatDescr descr;
 
+
+  public CSRPointer clone(int rows) throws DMLRuntimeException {
+    CSRPointer me = this;
+    CSRPointer that = new CSRPointer(me.getGPUContext());
+
+    that.allocateMatDescrPointer();
+    long totalSize = estimateSize(me.nnz, rows);
+    that.gpuContext.ensureFreeSpace(totalSize);
+
+    that.nnz = me.nnz;
+    that.val = allocate(that.nnz * Sizeof.DOUBLE);
+    that.rowPtr = allocate(rows * Sizeof.DOUBLE);
+    that.colInd = allocate(that.nnz * Sizeof.DOUBLE);
+
+    cudaMemcpy(that.val, me.val, that.nnz * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
+    cudaMemcpy(that.rowPtr, me.rowPtr, rows * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
+    cudaMemcpy(that.colInd, me.colInd, that.nnz * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
+
+    return that;
+  }
+
   /**
    * Default constructor to help with Factory method {@link #allocateEmpty(GPUContext, long, long)}
    * @param gCtx   a valid {@link GPUContext}
@@ -114,7 +137,7 @@ public class CSRPointer {
     return numElems * ((long)jcuda.Sizeof.INT);
   }
 
-  private GPUContext getGPUContext() throws DMLRuntimeException {
+  private GPUContext getGPUContext() {
     return gpuContext;
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/e8fbc753/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
index d6f3a71..d71f725 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
@@ -22,8 +22,13 @@ import static jcuda.jcublas.JCublas2.cublasCreate;
 import static jcuda.jcublas.JCublas2.cublasDestroy;
 import static jcuda.jcudnn.JCudnn.cudnnCreate;
 import static jcuda.jcudnn.JCudnn.cudnnDestroy;
+import static jcuda.jcusolver.JCusolverDn.cusolverDnDestroy;
+import static jcuda.jcusolver.JCusolverSp.cusolverSpDestroy;
 import static jcuda.jcusparse.JCusparse.cusparseCreate;
 import static jcuda.jcusparse.JCusparse.cusparseDestroy;
+import static jcuda.jcusolver.JCusolverDn.cusolverDnCreate;
+import static jcuda.jcusolver.JCusolverSp.cusolverSpCreate;
+
 import static jcuda.runtime.JCuda.cudaDeviceScheduleBlockingSync;
 import static jcuda.runtime.JCuda.cudaFree;
 import static jcuda.runtime.JCuda.cudaGetDeviceCount;
@@ -54,6 +59,8 @@ import org.apache.sysml.utils.LRUCacheMap;
 import jcuda.Pointer;
 import jcuda.jcublas.cublasHandle;
 import jcuda.jcudnn.cudnnHandle;
+import jcuda.jcusolver.cusolverDnHandle;
+import jcuda.jcusolver.cusolverSpHandle;
 import jcuda.jcusparse.cusparseHandle;
 import jcuda.runtime.JCuda;
 import jcuda.runtime.cudaDeviceProp;
@@ -90,15 +97,21 @@ public class GPUContext {
    * so that an extraneous host to dev transfer can be avoided */
   private ArrayList<GPUObject> allocatedGPUObjects = new ArrayList<>();
 
-  /** cudnnHandle specific to the active GPU for this GPUContext */
+  /** cudnnHandle for Deep Neural Network operations on the GPU */
   private cudnnHandle cudnnHandle;
 
-  /** cublasHandle specific to the active GPU for this GPUContext */
+  /** cublasHandle for BLAS operations on the GPU */
   private cublasHandle cublasHandle;
 
-  /** cusparseHandle specific to the active GPU for this GPUContext */
+  /** cusparseHandle for certain sparse BLAS operations on the GPU */
   private cusparseHandle cusparseHandle;
 
+  /** cusolverDnHandle for invoking solve() function on dense matrices on the GPU */
+  private cusolverDnHandle cusolverDnHandle;
+
+  /** cusolverSpHandle for invoking solve() function on sparse matrices on the GPU */
+  private cusolverSpHandle cusolverSpHandle;
+
   /** to launch custom CUDA kernel, specific to the active GPU for this GPUContext */
   private JCudaKernels kernels;
 
@@ -133,6 +146,12 @@ public class GPUContext {
     // cublasSetPointerMode(LibMatrixCUDA.cublasHandle, cublasPointerMode.CUBLAS_POINTER_MODE_DEVICE);
     cusparseHandle = new cusparseHandle();
     cusparseCreate(cusparseHandle);
+
+    cusolverDnHandle = new cusolverDnHandle();
+    cusolverDnCreate(cusolverDnHandle);
+    cusolverSpHandle = new cusolverSpHandle();
+    cusolverSpCreate(cusolverSpHandle);
+
     kernels = new JCudaKernels(deviceNum);
 
     GPUStatistics.cudaLibrariesInitTime = System.nanoTime() - start;
@@ -553,6 +572,14 @@ public class GPUContext {
     return cusparseHandle;
   }
 
+  public cusolverDnHandle getCusolverDnHandle() {
+    return cusolverDnHandle;
+  }
+
+  public cusolverSpHandle getCusolverSpHandle() {
+    return cusolverSpHandle;
+  }
+
   public JCudaKernels getKernels() {
     return kernels;
   }
@@ -569,6 +596,8 @@ public class GPUContext {
     cudnnDestroy(cudnnHandle);
     cublasDestroy(cublasHandle);
     cusparseDestroy(cusparseHandle);
+    cusolverDnDestroy(cusolverDnHandle);
+    cusolverSpDestroy(cusolverSpHandle);
     cudnnHandle = null;
     cublasHandle = null;
     cusparseHandle = null;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/e8fbc753/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
index dd5ba41..1d2285d 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
@@ -26,7 +26,9 @@ import static jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
 import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
 import static jcuda.jcusparse.JCusparse.cusparseDdense2csr;
 import static jcuda.jcusparse.JCusparse.cusparseDnnz;
+import static jcuda.runtime.JCuda.cudaMalloc;
 import static jcuda.runtime.JCuda.cudaMemcpy;
+import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice;
 import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToHost;
 import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyHostToDevice;
 
@@ -50,6 +52,7 @@ import org.apache.sysml.runtime.matrix.data.SparseBlockMCSR;
 import org.apache.sysml.utils.GPUStatistics;
 
 import jcuda.Pointer;
+import jcuda.Sizeof;
 import jcuda.jcublas.JCublas2;
 import jcuda.jcudnn.cudnnTensorDescriptor;
 import jcuda.jcusparse.JCusparse;
@@ -100,6 +103,43 @@ public class GPUObject {
 //		return getGPUContext().allocate(instName, size);
 //	}
 
+	@Override
+	public Object clone() {
+		GPUObject me = this;
+		GPUObject that = new GPUObject(me.gpuContext, me.mat);
+		if (me.tensorShape != null) {
+            that.tensorShape = new int[me.tensorShape.length];
+            System.arraycopy(me.tensorShape, 0, that.tensorShape, 0, me.tensorShape.length);
+            that.allocateTensorDescriptor(me.tensorShape[0], me.tensorShape[1], me.tensorShape[2], me.tensorShape[3]);
+        }
+		that.dirty = me.dirty;
+		that.readLocks = new AtomicInteger(me.readLocks.get());
+		that.timestamp = new AtomicLong(me.timestamp.get());
+		that.isSparse = me.isSparse;
+
+		try {
+		if (me.jcudaDenseMatrixPtr != null) {
+			long rows = me.mat.getNumRows();
+			long cols = me.mat.getNumColumns();
+			long size = rows * cols * Sizeof.DOUBLE;
+			me.gpuContext.ensureFreeSpace((int)size);
+			that.jcudaDenseMatrixPtr = allocate(size);
+			cudaMemcpy(that.jcudaDenseMatrixPtr, me.jcudaDenseMatrixPtr, size, cudaMemcpyDeviceToDevice);
+		}
+
+		if (me.jcudaSparseMatrixPtr != null){
+			long rows = mat.getNumRows();
+			that.jcudaSparseMatrixPtr = me.jcudaSparseMatrixPtr.clone((int)rows);
+		}
+
+
+		} catch (DMLRuntimeException e){
+			throw new RuntimeException(e);
+		}
+
+		return that;
+	}
+
 	private Pointer allocate(long size) throws DMLRuntimeException {
 		return getGPUContext().allocate(size);
 	}
@@ -116,7 +156,7 @@ public class GPUObject {
 		getGPUContext().cudaFreeHelper(instName, toFree, eager);
 	}
 
-	private GPUContext getGPUContext() throws DMLRuntimeException {
+	private GPUContext getGPUContext() {
 		return gpuContext;
 	}
 
@@ -275,7 +315,7 @@ public class GPUObject {
 		if(getJcudaDenseMatrixPtr() == null || !isAllocated())
 			throw new DMLRuntimeException("Expected allocated dense matrix before denseToSparse() call");
 
-		convertDensePtrFromRowMajorToColumnMajor();
+		denseRowMajorToColumnMajor();
 		setSparseMatrixCudaPointer(columnMajorDenseToRowMajorSparse(getGPUContext(), cusparseHandle, getJcudaDenseMatrixPtr(), rows, cols));
 		// TODO: What if mat.getNnz() is -1 ?
 		if (DMLScript.STATISTICS) GPUStatistics.cudaDenseToSparseTime.addAndGet(System.nanoTime() - t0);
@@ -283,10 +323,10 @@ public class GPUObject {
 	}
 
 	/**
-	 * Convenience method. Converts Row Major Dense Matrix --> Column Major Dense Matrix
+	 * Convenience method. Converts Row Major Dense Matrix to Column Major Dense Matrix
 	 * @throws DMLRuntimeException if DMLRuntimeException occurs
 	 */
-	private void convertDensePtrFromRowMajorToColumnMajor() throws DMLRuntimeException {
+	public void denseRowMajorToColumnMajor() throws DMLRuntimeException {
 		LOG.trace("GPU : dense Ptr row-major -> col-major on " + this + ", GPUContext=" + getGPUContext());
 		int m = toIntExact(mat.getNumRows());
 		int n = toIntExact(mat.getNumColumns());
@@ -301,7 +341,11 @@ public class GPUObject {
 		setDenseMatrixCudaPointer(tmp);
 	}
 
-	private void convertDensePtrFromColMajorToRowMajor() throws DMLRuntimeException {
+	/**
+	 * Convenience method. Converts Column Major Dense Matrix to Row Major Dense Matrix
+	 * @throws DMLRuntimeException
+	 */
+	public void denseColumnMajorToRowMajor() throws DMLRuntimeException {
 		LOG.trace("GPU : dense Ptr row-major -> col-major on " + this + ", GPUContext=" + getGPUContext());
 
 		int n = toIntExact(mat.getNumRows());
@@ -340,7 +384,7 @@ public class GPUObject {
 			throw new DMLRuntimeException("Expected allocated sparse matrix before sparseToDense() call");
 
 		sparseToColumnMajorDense();
-		convertDensePtrFromColMajorToRowMajor();
+		denseColumnMajorToRowMajor();
 		if (DMLScript.STATISTICS) end = System.nanoTime();
 		if (instructionName != null && GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instructionName, GPUInstruction.MISC_TIMER_SPARSE_TO_DENSE, end - start);
 		if (DMLScript.STATISTICS) GPUStatistics.cudaSparseToDenseTime.addAndGet(end - start);
@@ -431,17 +475,12 @@ public class GPUObject {
 	}
 
 	public boolean isInputAllocated() {
-		try {
-			boolean eitherAllocated = (getJcudaDenseMatrixPtr() != null || getJcudaSparseMatrixPtr() != null);
-			boolean isAllocatedOnThisGPUContext = getGPUContext().isBlockRecorded(this);
-			if (eitherAllocated && !isAllocatedOnThisGPUContext) {
-				LOG.warn("GPU : A block was allocated but was not on this GPUContext, GPUContext=" + getGPUContext());
-			}
-			return eitherAllocated && isAllocatedOnThisGPUContext;
-		} catch (DMLRuntimeException e){
-			LOG.info("GPU : System is in an inconsistent state");
-			throw new RuntimeException(e);
+		boolean eitherAllocated = (getJcudaDenseMatrixPtr() != null || getJcudaSparseMatrixPtr() != null);
+		boolean isAllocatedOnThisGPUContext = getGPUContext().isBlockRecorded(this);
+		if (eitherAllocated && !isAllocatedOnThisGPUContext) {
+			LOG.warn("GPU : A block was allocated but was not on this GPUContext, GPUContext=" + getGPUContext());
 		}
+		return eitherAllocated && isAllocatedOnThisGPUContext;
 	}
 
 	/**
@@ -863,5 +902,4 @@ public class GPUObject {
 		return sb.toString();
 	}
 
-
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/e8fbc753/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index 56360f8..23304b5 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -95,10 +95,10 @@ import org.apache.sysml.runtime.functionobjects.ReduceRow;
 import org.apache.sysml.runtime.functionobjects.ValueFunction;
 import org.apache.sysml.runtime.instructions.cp.DoubleObject;
 import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
+import org.apache.sysml.runtime.instructions.gpu.context.CSRPointer;
 import org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUObject;
-import org.apache.sysml.runtime.instructions.gpu.context.CSRPointer;
 import org.apache.sysml.runtime.instructions.gpu.context.JCudaKernels;
 import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator;
@@ -114,9 +114,11 @@ import jcuda.CudaException;
 import jcuda.Pointer;
 import jcuda.Sizeof;
 import jcuda.jcublas.JCublas2;
+import jcuda.jcublas.cublasDiagType;
 import jcuda.jcublas.cublasFillMode;
 import jcuda.jcublas.cublasHandle;
 import jcuda.jcublas.cublasOperation;
+import jcuda.jcublas.cublasSideMode;
 import jcuda.jcudnn.cudnnActivationDescriptor;
 import jcuda.jcudnn.cudnnBatchNormMode;
 import jcuda.jcudnn.cudnnConvolutionDescriptor;
@@ -126,6 +128,7 @@ import jcuda.jcudnn.cudnnHandle;
 import jcuda.jcudnn.cudnnPoolingDescriptor;
 import jcuda.jcudnn.cudnnStatus;
 import jcuda.jcudnn.cudnnTensorDescriptor;
+import jcuda.jcusolver.JCusolverDn;
 import jcuda.jcusparse.JCusparse;
 import jcuda.jcusparse.cusparseHandle;
 
@@ -306,15 +309,31 @@ public class LibMatrixCUDA {
 	/**
 	 * Convenience method to get jcudaDenseMatrixPtr. This method explicitly converts sparse to dense format, so use it judiciously.
 	 * @param gCtx a valid {@link GPUContext}
-	 * @param image input matrix object
+	 * @param input input matrix object
+	 * @param instName  the invoking instruction's name for record {@link Statistics}.
 	 * @return jcuda pointer
 	 * @throws DMLRuntimeException if error occurs while sparse to dense conversion
 	 */
-	private static Pointer getDensePointer(GPUContext gCtx, MatrixObject image, String instName) throws DMLRuntimeException {
-		if(isInSparseFormat(gCtx, image)) {
-			image.getGPUObject(gCtx).sparseToDense(instName);
+	private static Pointer getDensePointer(GPUContext gCtx, MatrixObject input, String instName) throws DMLRuntimeException {
+		if(isInSparseFormat(gCtx, input)) {
+			input.getGPUObject(gCtx).sparseToDense(instName);
+		}
+		return input.getGPUObject(gCtx).getJcudaDenseMatrixPtr();
+	}
+
+	/**
+	 * Convenience method to get the sparse matrix pointer from a {@link MatrixObject}. Converts dense to sparse if necessary.
+	 * @param gCtx a valid {@link GPUContext}
+	 * @param input input matrix
+	 * @param instName the invoking instruction's name for record {@link Statistics}.
+	 * @return a sparse matrix pointer
+	 * @throws DMLRuntimeException if error occurs
+	 */
+	private static CSRPointer getSparsePointer(GPUContext gCtx, MatrixObject input, String instName) throws DMLRuntimeException {
+		if(!isInSparseFormat(gCtx, input)) {
+			input.getGPUObject(gCtx).denseToSparse();
 		}
-		return image.getGPUObject(gCtx).getJcudaDenseMatrixPtr();
+		return input.getGPUObject(gCtx).getJcudaSparseMatrixPtr();
 	}
 	
 	/**
@@ -2927,6 +2946,108 @@ public class LibMatrixCUDA {
 		}
 	}
 
+
+    /**
+     * Implements the "solve" function for systemml Ax = B (A is of size m*n, B is of size m*1, x is of size n*1)
+     *
+     * @param ec         a valid {@link ExecutionContext}
+     * @param gCtx       a valid {@link GPUContext}
+     * @param instName   the invoking instruction's name for record {@link Statistics}.
+     * @param in1        input matrix A
+     * @param in2        input matrix B
+     * @param outputName name of the output matrix
+     * @throws DMLRuntimeException if an error occurs
+     */
+    public static void solve(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in1, MatrixObject in2, String outputName) throws DMLRuntimeException {
+        if (ec.getGPUContext() != gCtx)
+            throw new DMLRuntimeException("GPU : Invalid internal state, the GPUContext set with the ExecutionContext is not the same used to run this LibMatrixCUDA function");
+
+        // x = solve(A, b)
+
+        // Both Sparse
+        if (!isInSparseFormat(gCtx, in1) && !isInSparseFormat(gCtx, in2)) {    // Both dense
+            GPUObject Aobj = in1.getGPUObject(gCtx);
+            GPUObject bobj = in2.getGPUObject(gCtx);
+            int m = (int) in1.getNumRows();
+            int n = (int) in1.getNumColumns();
+            if ((int) in2.getNumRows() != m)
+                throw new DMLRuntimeException("GPU : Incorrect input for solve(), rows in A should be the same as rows in B");
+            if ((int) in2.getNumColumns() != 1)
+                throw new DMLRuntimeException("GPU : Incorrect input for solve(), columns in B should be 1");
+
+
+            // Copy over matrices and
+            // convert dense matrices to row major
+            // Operation in cuSolver and cuBlas are for column major dense matrices
+            // and are destructive to the original input
+            GPUObject ATobj = (GPUObject) Aobj.clone();
+            ATobj.denseRowMajorToColumnMajor();
+            Pointer A = ATobj.getJcudaDenseMatrixPtr();
+
+            GPUObject bTobj = (GPUObject) bobj.clone();
+            bTobj.denseRowMajorToColumnMajor();
+            Pointer b = bTobj.getJcudaDenseMatrixPtr();
+
+            // The following set of operations is done following the example in the cusolver documentation
+            // http://docs.nvidia.com/cuda/cusolver/#ormqr-example1
+
+            // step 3: query working space of geqrf and ormqr
+            int[] lwork = {0};
+            JCusolverDn.cusolverDnDgeqrf_bufferSize(gCtx.getCusolverDnHandle(), m, n, A, m, lwork);
+
+            // step 4: compute QR factorization
+            Pointer work = gCtx.allocate(lwork[0] * Sizeof.DOUBLE);
+            Pointer tau = gCtx.allocate(Math.max(m, m) * Sizeof.DOUBLE);
+            Pointer devInfo = gCtx.allocate(Sizeof.INT);
+            JCusolverDn.cusolverDnDgeqrf(gCtx.getCusolverDnHandle(), m, n, A, m, tau, work, lwork[0], devInfo);
+
+            int[] qrError = {-1};
+            cudaMemcpy(Pointer.to(qrError), devInfo, Sizeof.INT, cudaMemcpyDeviceToHost);
+            if (qrError[0] != 0) {
+                throw new DMLRuntimeException("GPU : Error in call to geqrf (QR factorization) as part of solve, argument " + qrError[0] + " was wrong");
+            }
+
+            // step 5: compute Q^T*B
+            JCusolverDn.cusolverDnDormqr(gCtx.getCusolverDnHandle(), cublasSideMode.CUBLAS_SIDE_LEFT, cublasOperation.CUBLAS_OP_T, m, 1, n, A, m, tau, b, m, work, lwork[0], devInfo);
+            cudaMemcpy(Pointer.to(qrError), devInfo, Sizeof.INT, cudaMemcpyDeviceToHost);
+            if (qrError[0] != 0) {
+                throw new DMLRuntimeException("GPU : Error in call to ormqr (to compuete Q^T*B after QR factorization) as part of solve, argument " + qrError[0] + " was wrong");
+            }
+
+            // step 6: compute x = R \ Q^T*B
+            JCublas2.cublasDtrsm(gCtx.getCublasHandle(),
+                cublasSideMode.CUBLAS_SIDE_LEFT, cublasFillMode.CUBLAS_FILL_MODE_UPPER, cublasOperation.CUBLAS_OP_N, cublasDiagType.CUBLAS_DIAG_NON_UNIT,
+                n, 1, pointerTo(1.0), A, m, b, m);
+
+            bTobj.denseColumnMajorToRowMajor();
+
+            // TODO  : Find a way to assign bTobj directly to the output and set the correct flags so as to not crash
+            // There is an avoidable copy happening here
+            MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, instName, outputName);
+            cudaMemcpy(out.getGPUObject(gCtx).getJcudaDenseMatrixPtr(), bTobj.getJcudaDenseMatrixPtr(), n * 1 * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
+
+            gCtx.cudaFreeHelper(work);
+            gCtx.cudaFreeHelper(tau);
+            gCtx.cudaFreeHelper(tau);
+            ATobj.clearData();
+            bTobj.clearData();
+
+            //debugPrintMatrix(b, n, 1);
+
+
+        } else if (isInSparseFormat(gCtx, in1) && isInSparseFormat(gCtx, in2)) { // Both sparse
+            throw new DMLRuntimeException("GPU : solve on sparse inputs not supported");
+        } else if (!isInSparseFormat(gCtx, in1) && isInSparseFormat(gCtx, in2)) { // A is dense, b is sparse
+            // Pointer A = getDensePointer(gCtx, in1, instName);
+            // Pointer B = getDensePointer(gCtx, in2, instName);
+            throw new DMLRuntimeException("GPU : solve on sparse inputs not supported");
+        } else if (isInSparseFormat(gCtx, in1) && !isInSparseFormat(gCtx, in2)) { // A is sparse, b is dense
+            throw new DMLRuntimeException("GPU : solve on sparse inputs not supported");
+        }
+
+
+    }
+
 	//********************************************************************/
 	//*****************  END OF Builtin Functions ************************/
 	//********************************************************************/