You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/10/11 22:52:47 UTC
systemml git commit: [SYSTEMML-445] Refactoring to avoid potential
memory leaks
Repository: systemml
Updated Branches:
refs/heads/master 8f786aa22 -> 96ae6c7eb
[SYSTEMML-445] Refactoring to avoid potential memory leaks
- Removed tensor descriptors from GPUObject
- Created closeable LibMatrixCuDNNPoolingDescriptors class to manage the data structures required by maxpooling
- Enabled JCuda exceptions to catch CUDA errors eagerly
- Added debugging messages in eviction logic. The printing of these messages is guarded to avoid additional overhead
- Removed unused batch normalization methods from LibMatrixCuDNN
Closes #679.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/96ae6c7e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/96ae6c7e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/96ae6c7e
Branch: refs/heads/master
Commit: 96ae6c7eb34e792f9fe7c3a8b37c9130fb0ea7ae
Parents: 8f786aa
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Wed Oct 11 15:47:59 2017 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Wed Oct 11 15:49:14 2017 -0700
----------------------------------------------------------------------
.../instructions/gpu/context/GPUContext.java | 18 +-
.../instructions/gpu/context/GPUObject.java | 75 +---
.../runtime/matrix/data/LibMatrixCuDNN.java | 363 ++-----------------
.../LibMatrixCuDNNConvolutionAlgorithm.java | 56 ++-
.../data/LibMatrixCuDNNPoolingDescriptors.java | 164 +++++++++
5 files changed, 251 insertions(+), 425 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/96ae6c7e/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
index 118602b..55cb95f 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
@@ -401,8 +401,12 @@ public class GPUContext {
*/
public void cudaFreeHelper(String instructionName, final Pointer toFree, boolean eager) {
Pointer dummy = new Pointer();
- if (toFree == dummy) // trying to free a null pointer
+ if (toFree == dummy) { // trying to free a null pointer
+ if (LOG.isTraceEnabled()) {
+ LOG.trace("GPU : trying to free an empty pointer");
+ }
return;
+ }
long t0 = 0;
if (!cudaBlockSizeMap.containsKey(toFree))
throw new RuntimeException(
@@ -410,7 +414,7 @@ public class GPUContext {
long size = cudaBlockSizeMap.get(toFree);
if (eager) {
if (LOG.isTraceEnabled()) {
- LOG.trace("GPU : eagerly freeing cuda memory [ " + toFree + " ] for instruction " + instructionName
+ LOG.trace("GPU : eagerly freeing cuda memory [ " + toFree + " ] of size " + size + " for instruction " + instructionName
+ " on " + this);
}
if (DMLScript.STATISTICS)
@@ -426,7 +430,7 @@ public class GPUContext {
System.nanoTime() - t0);
} else {
if (LOG.isTraceEnabled()) {
- LOG.trace("GPU : lazily freeing cuda memory for instruction " + instructionName + " on " + this);
+ LOG.trace("GPU : lazily freeing cuda memory of size " + size + " for instruction " + instructionName + " on " + this);
}
Set<Pointer> freeList = freeCUDASpaceMap.get(size);
if (freeList == null) {
@@ -492,6 +496,10 @@ public class GPUContext {
LOG.trace("GPU : evict called from " + instructionName + " for size " + neededSize + " on " + this);
}
GPUStatistics.cudaEvictionCount.add(1);
+ if (LOG.isDebugEnabled()) {
+ printMemoryInfo("EVICTION_CUDA_FREE_SPACE");
+ }
+
// Release the set of free blocks maintained in a GPUObject.freeCUDASpaceMap
// to free up space
LRUCacheMap<Long, Set<Pointer>> lruCacheMap = freeCUDASpaceMap;
@@ -560,6 +568,9 @@ public class GPUContext {
});
while (neededSize > getAvailableMemory() && allocatedGPUObjects.size() > 0) {
+ if (LOG.isDebugEnabled()) {
+ printMemoryInfo("EVICTION_UNLOCKED");
+ }
GPUObject toBeRemoved = allocatedGPUObjects.get(allocatedGPUObjects.size() - 1);
if (toBeRemoved.isLocked()) {
throw new DMLRuntimeException(
@@ -569,7 +580,6 @@ public class GPUContext {
if (toBeRemoved.dirty) {
toBeRemoved.copyFromDeviceToHost();
}
-
toBeRemoved.clearData(true);
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/96ae6c7e/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
index 31bf151..feb34bc 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
@@ -19,11 +19,6 @@
package org.apache.sysml.runtime.instructions.gpu.context;
import static jcuda.jcublas.cublasOperation.CUBLAS_OP_T;
-import static jcuda.jcudnn.JCudnn.cudnnCreateTensorDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnDestroyTensorDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnSetTensor4dDescriptor;
-import static jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
-import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
import static jcuda.jcusparse.JCusparse.cusparseDdense2csr;
import static jcuda.jcusparse.JCusparse.cusparseDnnz;
import static jcuda.runtime.JCuda.cudaMemcpy;
@@ -32,7 +27,6 @@ import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice;
import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToHost;
import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyHostToDevice;
-import java.util.Arrays;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.LongAdder;
@@ -55,7 +49,6 @@ import org.apache.sysml.utils.GPUStatistics;
import jcuda.Pointer;
import jcuda.Sizeof;
import jcuda.jcublas.JCublas2;
-import jcuda.jcudnn.cudnnTensorDescriptor;
import jcuda.jcusparse.JCusparse;
import jcuda.jcusparse.cusparseDirection;
import jcuda.jcusparse.cusparseHandle;
@@ -84,17 +77,6 @@ public class GPUObject {
private CSRPointer jcudaSparseMatrixPtr = null;
/**
- * An optional tensor descriptor (and shape) that can be set by a tensor instruction such as convolution,
- * maxpooling and exploited by a subsequent non-tensor instruction such as relu
- */
- private cudnnTensorDescriptor tensorDescriptor = null;
-
- /**
- * the shape of this tensor, if in fact this is a tensor
- */
- private int[] tensorShape = null;
-
- /**
* whether the block attached to this {@link GPUContext} is dirty on the device and needs to be copied back to host
*/
protected boolean dirty = false;
@@ -132,13 +114,7 @@ public class GPUObject {
public Object clone() {
GPUObject me = this;
GPUObject that = new GPUObject(me.gpuContext, me.mat);
- if (me.tensorShape != null) {
- that.tensorShape = new int[me.tensorShape.length];
- System.arraycopy(me.tensorShape, 0, that.tensorShape, 0, me.tensorShape.length);
- that.allocateTensorDescriptor(me.tensorShape[0], me.tensorShape[1], me.tensorShape[2], me.tensorShape[3]);
- }
that.dirty = me.dirty;
- // TODO Nakul: Should the locks be cloned here ?
// The only place clone is getting called: LibMatrixCUDA's solve
that.readLocks.reset();
that.writeLock = false;
@@ -498,51 +474,7 @@ public class GPUObject {
public boolean isSparse() {
return isSparse;
}
-
- /**
- * Returns a previously allocated tensor shape or null
- *
- * @return int array of four elements or null
- */
- public int[] getTensorShape() {
- return tensorShape;
- }
-
- /**
- * Returns a previously allocated tensor descriptor or null
- *
- * @return cudnn tensor descriptor
- */
- public cudnnTensorDescriptor getTensorDescriptor() {
- return tensorDescriptor;
- }
-
- /**
- * Returns a previously allocated or allocates and returns a tensor descriptor
- *
- * @param N number of images
- * @param C number of channels
- * @param H height
- * @param W width
- * @return cudnn tensor descriptor
- */
- public cudnnTensorDescriptor allocateTensorDescriptor(int N, int C, int H, int W) {
- if(LOG.isTraceEnabled()) {
- LOG.trace("GPU : allocateTensorDescriptor with [N=" + N + ",C=" + C + ",H=" + H + ",W=" + W + "] on " + this);
- }
- if (tensorDescriptor == null) {
- tensorDescriptor = new cudnnTensorDescriptor();
- cudnnCreateTensorDescriptor(tensorDescriptor);
- cudnnSetTensor4dDescriptor(tensorDescriptor, CUDNN_TENSOR_NCHW, CUDNN_DATA_DOUBLE, N, C, H, W);
- tensorShape = new int[4];
- tensorShape[0] = N;
- tensorShape[1] = C;
- tensorShape[2] = H;
- tensorShape[3] = W;
- }
- return tensorDescriptor;
- }
-
+
private static long getDoubleSizeOf(long numElems) {
return numElems * ((long) jcuda.Sizeof.DOUBLE);
}
@@ -829,10 +761,6 @@ public class GPUObject {
}
jcudaDenseMatrixPtr = null;
jcudaSparseMatrixPtr = null;
- if (tensorDescriptor != null) {
- cudnnDestroyTensorDescriptor(tensorDescriptor);
- tensorDescriptor = null;
- }
resetReadWriteLock();
getGPUContext().removeRecordedUsage(this);
}
@@ -1094,7 +1022,6 @@ public class GPUObject {
@Override
public String toString() {
final StringBuilder sb = new StringBuilder("GPUObject{");
- sb.append(", tensorShape=").append(Arrays.toString(tensorShape));
sb.append(", dirty=").append(dirty);
sb.append(", readLocks=").append(readLocks.longValue());
sb.append(", writeLock=").append(writeLock);
http://git-wip-us.apache.org/repos/asf/systemml/blob/96ae6c7e/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
index 25dc604..bb74aa2 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
@@ -19,43 +19,27 @@
package org.apache.sysml.runtime.matrix.data;
import static jcuda.jcudnn.JCudnn.cudnnActivationForward;
-import static jcuda.jcudnn.JCudnn.cudnnBatchNormalizationBackward;
-import static jcuda.jcudnn.JCudnn.cudnnBatchNormalizationForwardInference;
-import static jcuda.jcudnn.JCudnn.cudnnBatchNormalizationForwardTraining;
import static jcuda.jcudnn.JCudnn.cudnnConvolutionBackwardData;
import static jcuda.jcudnn.JCudnn.cudnnConvolutionBackwardFilter;
import static jcuda.jcudnn.JCudnn.cudnnConvolutionForward;
import static jcuda.jcudnn.JCudnn.cudnnCreateActivationDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnCreateConvolutionDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnCreateFilterDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnCreatePoolingDescriptor;
import static jcuda.jcudnn.JCudnn.cudnnCreateTensorDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnDestroyTensorDescriptor;
import static jcuda.jcudnn.JCudnn.cudnnPoolingBackward;
import static jcuda.jcudnn.JCudnn.cudnnPoolingForward;
import static jcuda.jcudnn.JCudnn.cudnnSetActivationDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnSetConvolution2dDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnSetFilter4dDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnSetPooling2dDescriptor;
import static jcuda.jcudnn.JCudnn.cudnnSetTensor4dDescriptor;
import static jcuda.jcudnn.cudnnActivationMode.CUDNN_ACTIVATION_RELU;
-import static jcuda.jcudnn.cudnnConvolutionMode.CUDNN_CROSS_CORRELATION;
import static jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
import static jcuda.jcudnn.cudnnNanPropagation.CUDNN_PROPAGATE_NAN;
-import static jcuda.jcudnn.cudnnPoolingMode.CUDNN_POOLING_MAX;
import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
-import static jcuda.runtime.JCuda.cudaMemcpy;
import static jcuda.runtime.JCuda.cudaMemset;
-import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice;
import jcuda.CudaException;
import jcuda.Pointer;
import jcuda.Sizeof;
import jcuda.jcudnn.cudnnActivationDescriptor;
-import jcuda.jcudnn.cudnnBatchNormMode;
-import jcuda.jcudnn.cudnnConvolutionDescriptor;
import jcuda.jcudnn.cudnnConvolutionFwdPreference;
-import jcuda.jcudnn.cudnnFilterDescriptor;
import jcuda.jcudnn.cudnnHandle;
-import jcuda.jcudnn.cudnnPoolingDescriptor;
import jcuda.jcudnn.cudnnStatus;
import jcuda.jcudnn.cudnnTensorDescriptor;
@@ -115,6 +99,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
//cudaDeviceSynchronize;
biasAdd(gCtx, instName, output, bias, output);
}
+
/**
* Performs a 2D convolution
@@ -145,7 +130,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
long CHW = C*H*W; long KPQ = K*P*Q; long CRS = C*R*S;
long NCHW = N*CHW; long NKPQ = N*KPQ; long KCRS = K*CRS;
-
+
if(NCHW < maxNumDoublesOfCuDNNTensor && NKPQ < maxNumDoublesOfCuDNNTensor && KCRS < maxNumDoublesOfCuDNNTensor) {
// Filter and output are accounted as dense in the memory estimation for conv2d
double overhead = isInSparseFormat(gCtx, filter) ? OptimizerUtils.estimateSizeExactSparsity(K, CRS, 1.0) : 0;
@@ -489,14 +474,12 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
Pointer y = getDensePointerForCuDNN(gCtx, outputBlock, instName);
if(overhead <= intermediateMemoryBudget) {
Pointer x = getDensePointerForCuDNN(gCtx, image, instName);
- cudnnTensorDescriptor xDesc = allocateTensorDescriptor(gCtx, image, N, C, H, W);
- cudnnMaxpooling(gCtx, instName, x, xDesc, y, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
+ cudnnMaxpooling(gCtx, instName, x, y, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
}
else {
LibMatrixCuDNNInputRowFetcher imgFetcher = new LibMatrixCuDNNInputRowFetcher(gCtx, instName, image);
- cudnnTensorDescriptor xDesc = allocateTensorDescriptor(gCtx, image, N, C, H, W);
for(int n = 0; n < N; n++) {
- cudnnMaxpooling(gCtx, instName, imgFetcher.getNthRow(n), xDesc, y.withByteOffset(n*CPQ*Sizeof.DOUBLE), 1, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
+ cudnnMaxpooling(gCtx, instName, imgFetcher.getNthRow(n), y.withByteOffset(n*CPQ*Sizeof.DOUBLE), 1, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
}
imgFetcher.close();
}
@@ -506,7 +489,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
}
}
- private static void cudnnMaxpooling(GPUContext gCtx, String instName, Pointer x, cudnnTensorDescriptor xDesc,
+ private static void cudnnMaxpooling(GPUContext gCtx, String instName, Pointer x,
Pointer y, int N, int C, int H, int W, int K, int R,
int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
int Q) throws DMLRuntimeException {
@@ -514,33 +497,21 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
LOG.trace("GPU : performMaxpooling" + ", GPUContext=" + gCtx);
}
- cudnnPoolingDescriptor poolingDesc = null;
-
- try {
+ try(LibMatrixCuDNNPoolingDescriptors desc =
+ LibMatrixCuDNNPoolingDescriptors.cudnnMaxpoolingDescriptors(gCtx, instName, N, C, H, W, K, R, S,
+ pad_h, pad_w, stride_h, stride_w, P, Q)) {
long t1=0,t2=0;
if (GPUStatistics.DISPLAY_STATISTICS) t1 = System.nanoTime();
- // Allocate descriptors
- cudnnTensorDescriptor yDesc = allocateTensorDescriptor(N, C, P, Q);
- poolingDesc = allocatePoolingDescriptor(R, S, pad_h, pad_w, stride_h, stride_w);
if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
-
if (GPUStatistics.DISPLAY_STATISTICS) t2 = System.nanoTime();
- int status = cudnnPoolingForward(getCudnnHandle(gCtx), poolingDesc, one(), xDesc, x, zero(), yDesc, y);
+ int status = cudnnPoolingForward(getCudnnHandle(gCtx), desc.poolingDesc, one(), desc.xDesc, x, zero(), desc.yDesc, y);
if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_MAXPOOLING_FORWARD_LIB, System.nanoTime() - t2);
-
if(status != jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
throw new DMLRuntimeException("Could not executed cudnnPoolingForward: " + jcuda.jcudnn.cudnnStatus.stringFor(status));
}
} catch (CudaException e) {
throw new DMLRuntimeException("Error in conv2d in GPUContext " + gCtx.toString() + " from Thread " + Thread.currentThread().toString(), e);
}
- finally {
- long t3=0;
- if (GPUStatistics.DISPLAY_STATISTICS) t3 = System.nanoTime();
- if(poolingDesc != null)
- jcuda.jcudnn.JCudnn.cudnnDestroyPoolingDescriptor(poolingDesc);
- if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t3);
- }
}
/**
@@ -611,28 +582,22 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
LOG.trace("GPU : maxpoolingBackward" + ", GPUContext=" + gCtx);
}
Pointer y = null;
- cudnnPoolingDescriptor poolingDesc = null;
- try {
+ try(LibMatrixCuDNNPoolingDescriptors desc =
+ LibMatrixCuDNNPoolingDescriptors.cudnnMaxpoolingBackwardDescriptors(gCtx, instName, N, C, H, W, K, R, S,
+ pad_h, pad_w, stride_h, stride_w, P, Q)) {
long t1=0, t2=0, t3=0;
if (GPUStatistics.DISPLAY_STATISTICS) t1 = System.nanoTime();
- // Allocate descriptors
- cudnnTensorDescriptor xDesc = allocateTensorDescriptor(N, C, H, W);
- cudnnTensorDescriptor yDesc = allocateTensorDescriptor(N, C, P, Q);
- cudnnTensorDescriptor dxDesc = allocateTensorDescriptor(N, C, H, W);
- cudnnTensorDescriptor dyDesc = allocateTensorDescriptor(N, C, P, Q);
-
- poolingDesc = allocatePoolingDescriptor(R, S, pad_h, pad_w, stride_h, stride_w);
-
+
// Calling PoolForward first, y is one of the inputs for poolBackward
// TODO: Remove calling poolForward after necessary changes at language level for poolBackward
long numBytes = N*C*P*Q*Sizeof.DOUBLE;
y = gCtx.allocate(numBytes);
if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
-
+
if (GPUStatistics.DISPLAY_STATISTICS) t2 = System.nanoTime();
- int status = cudnnPoolingForward(getCudnnHandle(gCtx), poolingDesc, one(), xDesc, x, zero(), yDesc, y);
+ int status = cudnnPoolingForward(getCudnnHandle(gCtx), desc.poolingDesc, one(), desc.xDesc, x, zero(), desc.yDesc, y);
if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_MAXPOOLING_FORWARD_LIB, System.nanoTime() - t2);
if(status != jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
@@ -640,7 +605,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
}
if (GPUStatistics.DISPLAY_STATISTICS) t3 = System.nanoTime();
- status = cudnnPoolingBackward(getCudnnHandle(gCtx), poolingDesc, one(), yDesc, y, dyDesc, dy, xDesc, x, zero(), dxDesc, dx);
+ status = cudnnPoolingBackward(getCudnnHandle(gCtx), desc.poolingDesc, one(), desc.yDesc, y, desc.dyDesc, dy, desc.xDesc, x, zero(), desc.dxDesc, dx);
if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_MAXPOOLING_BACKWARD_LIB, System.nanoTime() - t3);
if(status != jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
@@ -652,297 +617,12 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
finally {
long t4=0;
if (GPUStatistics.DISPLAY_STATISTICS) t4 = System.nanoTime();
-
if(y != null)
gCtx.cudaFreeHelper(instName, y);
- if(poolingDesc != null)
- jcuda.jcudnn.JCudnn.cudnnDestroyPoolingDescriptor(poolingDesc);
-
if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t4);
}
}
- static cudnnConvolutionDescriptor allocateConvolutionDescriptor(int padding [], int strides []) {
- cudnnConvolutionDescriptor convDesc = new cudnnConvolutionDescriptor();
- cudnnCreateConvolutionDescriptor(convDesc);
- cudnnSetConvolution2dDescriptor(convDesc, padding[0], padding[1], strides[0], strides[1], 1, 1, CUDNN_CROSS_CORRELATION);
- return convDesc;
- }
-
- protected static cudnnFilterDescriptor allocateFilterDescriptor(int K, int C, int R, int S) {
- cudnnFilterDescriptor filterDesc = new cudnnFilterDescriptor();
- cudnnCreateFilterDescriptor(filterDesc);
- cudnnSetFilter4dDescriptor(filterDesc, CUDNN_DATA_DOUBLE, CUDNN_TENSOR_NCHW, K, C, R, S);
- return filterDesc;
- }
-
- /**
- * allocates pooling descriptor, used in poolingForward and poolingBackward
- * @param R pooling window height
- * @param S pooling window width
- * @param pad_h vertical padding
- * @param pad_w horizontal padding
- * @param stride_h pooling vertical stride
- * @param stride_w pooling horizontal stride
- * @return cudnn pooling descriptor
- */
- private static cudnnPoolingDescriptor allocatePoolingDescriptor(int R, int S, int pad_h, int pad_w, int stride_h, int stride_w) {
- cudnnPoolingDescriptor poolingDesc = new cudnnPoolingDescriptor();
- cudnnCreatePoolingDescriptor(poolingDesc);
- cudnnSetPooling2dDescriptor(poolingDesc, CUDNN_POOLING_MAX, CUDNN_PROPAGATE_NAN, R, S, pad_h, pad_w, stride_h, stride_w);
- return poolingDesc;
- }
-
- /**
- * Convenience method to get tensor descriptor
- * @param N number of images
- * @param C number of channels
- * @param H height
- * @param W width
- * @return cudnn tensor descriptor
- * @throws DMLRuntimeException if the input descriptor and matrix dimensions don't match
- */
- static cudnnTensorDescriptor allocateTensorDescriptor(int N, int C, int H, int W) throws DMLRuntimeException {
- cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
- cudnnCreateTensorDescriptor(tensorDescriptor);
- cudnnSetTensor4dDescriptor(tensorDescriptor, CUDNN_TENSOR_NCHW, CUDNN_DATA_DOUBLE, N, C, H, W);
- return tensorDescriptor;
- }
-
- /**
- * Convenience method to get tensor descriptor from underlying GPUObject
- * @param gCtx a valid {@link GPUContext}
- * @param mat matrix object
- * @param N number of images
- * @param C number of channels
- * @param H height
- * @param W width
- * @return cudnn tensor descriptor
- * @throws DMLRuntimeException if the input descriptor and matrix dimensions don't match
- */
- private static cudnnTensorDescriptor allocateTensorDescriptor(GPUContext gCtx, MatrixObject mat, int N, int C, int H, int W) throws DMLRuntimeException {
- if(mat.getNumRows() != N || mat.getNumColumns() != C*H*W) {
- throw new DMLRuntimeException("Mismatch descriptor-matrix dimensions:" + mat.getNumRows() + " != " + N
- + " || " + mat.getNumColumns() + " != " + (C*H*W));
- }
- return mat.getGPUObject(gCtx).allocateTensorDescriptor(N, C, H, W);
- }
-
- /**
- * Performs the forward BatchNormalization layer computation for inference
- * @param gCtx a valid {@link GPUContext}
- * @param instName name of the instruction
- * @param image input image
- * @param scale scale (as per CuDNN) and gamma as per original paper: shape [1, C, 1, 1]
- * @param bias bias (as per CuDNN) and beta as per original paper: shape [1, C, 1, 1]
- * @param runningMean running mean accumulated during training phase: shape [1, C, 1, 1]
- * @param runningVar running variance accumulated during training phase: shape [1, C, 1, 1]
- * @param ret normalized input
- * @param epsilon epsilon value used in the batch normalization formula
- * @throws DMLRuntimeException if error occurs
- */
- public static void batchNormalizationForwardInference(GPUContext gCtx, String instName, MatrixObject image,
- MatrixObject scale, MatrixObject bias, MatrixObject runningMean, MatrixObject runningVar,
- MatrixObject ret, double epsilon) throws DMLRuntimeException {
- if(LOG.isTraceEnabled()) {
- LOG.trace("GPU : batchNormalizationForwardInference" + ", GPUContext=" + gCtx);
- }
- int mode = cudnnBatchNormMode.CUDNN_BATCHNORM_SPATIAL;
-
- int N = toInt(image.getNumRows());
- int C = toInt(scale.getNumColumns());
- long CHW = image.getNumColumns();
- validateBatchNormalizationDimensions(scale, bias, runningMean, runningVar, C);
-
- // Allocate descriptors
- cudnnTensorDescriptor nCHWDescriptor = allocateNCHWDescriptors(gCtx, N, C, CHW,
- new MatrixObject[] {image}, new MatrixObject[] {ret});
- cudnnTensorDescriptor scaleTensorDesc = allocateTensorDescriptor(gCtx, scale, 1, C, 1, 1);
-
- // Get underlying dense pointer
- Pointer imagePtr = getDensePointerForCuDNN(gCtx, image, instName);
- Pointer retPtr = getDensePointerForCuDNN(gCtx, ret, instName);
- Pointer biasPtr = getDensePointerForCuDNN(gCtx, bias, instName);
- Pointer scalePtr = getDensePointerForCuDNN(gCtx, scale, instName);
- Pointer runningMeanPtr = getDensePointerForCuDNN(gCtx, runningMean, instName);
- Pointer runningVarPtr = getDensePointerForCuDNN(gCtx, runningVar, instName);
-
- checkStatus(cudnnBatchNormalizationForwardInference(getCudnnHandle(gCtx), mode, one(), zero(),
- nCHWDescriptor, imagePtr, nCHWDescriptor, retPtr,
- scaleTensorDesc, scalePtr, biasPtr,
- runningMeanPtr, runningVarPtr, epsilon));
- }
-
- /**
- * Performs the forward BatchNormalization layer computation for training
- * @param gCtx a valid {@link GPUContext}
- * @param instName name of the instruction
- * @param image input image
- * @param scale scale (as per CuDNN) and gamma as per original paper: shape [1, C, 1, 1]
- * @param bias bias (as per CuDNN) and beta as per original paper: shape [1, C, 1, 1]
- * @param runningMean running mean accumulated during training phase: shape [1, C, 1, 1]
- * @param runningVar running variance accumulated during training phase: shape [1, C, 1, 1]
- * @param ret (output) normalized input
- * @param retRunningMean (output) running mean accumulated during training phase: shape [1, C, 1, 1]
- * @param retRunningVar (output) running variance accumulated during training phase: shape [1, C, 1, 1]
- * @param epsilon epsilon value used in the batch normalization formula
- * @param exponentialAverageFactor factor used in the moving average computation
- * @throws DMLRuntimeException if error occurs
- */
- public static void batchNormalizationForwardTraining(GPUContext gCtx, String instName, MatrixObject image,
- MatrixObject scale, MatrixObject bias, MatrixObject runningMean, MatrixObject runningVar,
- MatrixObject ret, MatrixObject retRunningMean, MatrixObject retRunningVar, double epsilon, double exponentialAverageFactor) throws DMLRuntimeException {
- if(LOG.isTraceEnabled()) {
- LOG.trace("GPU : batchNormalizationForwardTraining" + ", GPUContext=" + gCtx);
- }
- int mode = cudnnBatchNormMode.CUDNN_BATCHNORM_SPATIAL;
-
- int N = toInt(image.getNumRows());
- int C = toInt(scale.getNumColumns());
- long CHW = image.getNumColumns();
- validateBatchNormalizationDimensions(scale, bias, runningMean, runningVar, C);
-
- // Allocate descriptors
- cudnnTensorDescriptor nCHWDescriptor = allocateNCHWDescriptors(gCtx, N, C, CHW,
- new MatrixObject[] {image}, new MatrixObject[] {ret});
- cudnnTensorDescriptor scaleTensorDesc = allocateTensorDescriptor(gCtx, scale, 1, C, 1, 1);
-
- // Get underlying dense pointer
- Pointer imagePtr = getDensePointerForCuDNN(gCtx, image, instName);
- Pointer retPtr = getDensePointerForCuDNN(gCtx, ret, instName);
- Pointer biasPtr = getDensePointerForCuDNN(gCtx, bias, instName);
- Pointer scalePtr = getDensePointerForCuDNN(gCtx, scale, instName);
- Pointer runningMeanPtr = getDensePointerForCuDNN(gCtx, runningMean, instName);
- Pointer runningVarPtr = getDensePointerForCuDNN(gCtx, runningVar, instName);
-
- // To allow for copy-on-write
- Pointer retRunningMeanPtr = getDensePointerForCuDNN(gCtx, retRunningMean, instName);
- Pointer retRunningVarPtr = getDensePointerForCuDNN(gCtx, retRunningVar, instName);
- cudaMemcpy(retRunningMeanPtr, runningMeanPtr, C * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
- cudaMemcpy(retRunningVarPtr, runningVarPtr, C * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice);
-
- // ignoring resultSaveMean and resultSaveVariance as it requires state management
- checkStatus(cudnnBatchNormalizationForwardTraining(getCudnnHandle(gCtx), mode, one(), zero(),
- nCHWDescriptor, imagePtr, nCHWDescriptor, retPtr,
- scaleTensorDesc, scalePtr, biasPtr, exponentialAverageFactor,
- retRunningMeanPtr, retRunningVarPtr, epsilon, new Pointer(), new Pointer()));
- }
-
- private static void validateBatchNormalizationDimensions(MatrixObject scale, MatrixObject bias, MatrixObject runningMean, MatrixObject runningVar, int C) throws DMLRuntimeException {
- if(scale.getNumRows() != 1 || scale.getNumColumns() != C) {
- throw new DMLRuntimeException("Incorrect dimensions for scale");
- }
- if(bias.getNumRows() != 1 || bias.getNumColumns() != C) {
- throw new DMLRuntimeException("Incorrect dimensions for bias");
- }
- if(runningMean.getNumRows() != 1 || runningMean.getNumColumns() != C) {
- throw new DMLRuntimeException("Incorrect dimensions for running mean");
- }
- if(runningVar.getNumRows() != 1 || runningVar.getNumColumns() != C) {
- throw new DMLRuntimeException("Incorrect dimensions for running variance");
- }
- }
-
- /**
- * Convenient utility for batch normalization that returns a NCHW descriptor
- * @param gCtx a valid {@link GPUContext}
- * @param N number of images
- * @param C number of channels
- * @param CHW channels*height*width
- * @param input input matrix objects
- * @param output output matrix objects
- * @return one of the NCHW descriptor
- * @throws DMLRuntimeException if error occurs
- */
- private static cudnnTensorDescriptor allocateNCHWDescriptors(GPUContext gCtx, int N, int C, long CHW, MatrixObject [] input, MatrixObject [] output) throws DMLRuntimeException {
- cudnnTensorDescriptor ret = null; // Return any one
- if(CHW > ((long)Integer.MAX_VALUE)*C) {
- throw new DMLRuntimeException("image size (height*width) should be less than " + Integer.MAX_VALUE);
- }
- cudnnTensorDescriptor knownNCHWdescriptor = null;
- int H = -1; int W = -1;
- for(int i = 0; i < input.length; i++) {
- knownNCHWdescriptor = input[i].getGPUObject(gCtx).getTensorDescriptor();
- if(knownNCHWdescriptor != null) {
- int [] shape = input[i].getGPUObject(gCtx).getTensorShape();
- if(shape[0] != N || shape[1] != C) {
- throw new DMLRuntimeException("Incorrect N and C:" + shape[0] + " != " + N + " || " + shape[1] + " != " + C);
- }
- H = shape[2];
- W = shape[3];
- break;
- }
- }
- if(knownNCHWdescriptor != null) {
- // We precisely know N, C, H, W
- for(int i = 0; i < input.length; i++) {
- ret = allocateTensorDescriptor(gCtx, input[i], N, C, H, W);
- }
- for(int i = 0; i < output.length; i++) {
- ret = allocateTensorDescriptor(gCtx, output[i], N, C, H, W);
- }
- }
- else {
- int HW = (int) (CHW / C);
- H = HW; W = 1; // If not known
- double potentialH = Math.sqrt(HW);
- if(potentialH == ((int) potentialH)) {
- H = (int) potentialH;
- W = H;
- }
- // We are not sure about H and W, hence don't allocate them.
- ret = new cudnnTensorDescriptor();
- cudnnCreateTensorDescriptor(ret);
- cudnnSetTensor4dDescriptor(ret, CUDNN_TENSOR_NCHW, CUDNN_DATA_DOUBLE, N, C, H, W);
- }
- return ret;
- }
-
- /**
- * This method computes the backpropagation errors for image, scale and bias of batch normalization layer
- * @param gCtx a valid {@link GPUContext}
- * @param instName name of the instruction
- * @param image input image
- * @param dout input errors of shape C, H, W
- * @param scale scale (as per CuDNN) and gamma as per original paper: shape [1, C, 1, 1]
- * @param ret (output) backpropagation errors for previous layer
- * @param retScale backpropagation error for scale
- * @param retBias backpropagation error for bias
- * @param epsilon epsilon value used in the batch normalization formula
- * @throws DMLRuntimeException if error occurs
- */
- public static void batchNormalizationBackward(GPUContext gCtx, String instName, MatrixObject image, MatrixObject dout,
- MatrixObject scale, MatrixObject ret, MatrixObject retScale, MatrixObject retBias,
- double epsilon) throws DMLRuntimeException {
- if(LOG.isTraceEnabled()) {
- LOG.trace("GPU : batchNormalizationBackward" + ", GPUContext=" + gCtx);
- }
- int mode = cudnnBatchNormMode.CUDNN_BATCHNORM_SPATIAL;
-
- int N = toInt(image.getNumRows());
- int C = toInt(scale.getNumColumns());
- long CHW = image.getNumColumns();
-
- // Allocate descriptors
- cudnnTensorDescriptor nCHWDescriptor = allocateNCHWDescriptors(gCtx, N, C, CHW,
- new MatrixObject[] {image, dout}, new MatrixObject[] {ret});
- cudnnTensorDescriptor scaleTensorDesc = allocateTensorDescriptor(gCtx, scale, 1, C, 1, 1);
-
- // Get underlying dense pointer
- Pointer imagePtr = getDensePointerForCuDNN(gCtx, image, instName);
- Pointer doutPtr = getDensePointerForCuDNN(gCtx, dout, instName);
- Pointer scalePtr = getDensePointerForCuDNN(gCtx, scale, instName);
- Pointer retPtr = getDensePointerForCuDNN(gCtx, ret, instName);
- Pointer retScalePtr = getDensePointerForCuDNN(gCtx, retScale, instName);
- Pointer retBiasPtr = getDensePointerForCuDNN(gCtx, retBias, instName);
-
- // ignoring resultSaveMean and resultSaveVariance as it requires state management
- checkStatus(cudnnBatchNormalizationBackward(getCudnnHandle(gCtx), mode, one(), zero(), one(), zero(),
- nCHWDescriptor, imagePtr, nCHWDescriptor, doutPtr, nCHWDescriptor, retPtr,
- scaleTensorDesc, scalePtr, retScalePtr, retBiasPtr, epsilon, new Pointer(), new Pointer()));
- }
-
-
private static void cudnnReLU(GPUContext gCtx, String instName, MatrixObject in, Pointer dstData, cudnnTensorDescriptor srcTensorDesc) throws DMLRuntimeException {
long t0=0;
try {
@@ -988,8 +668,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
MatrixObject output = ec.getMatrixObject(outputName);
getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, in.getNumRows(), in.getNumColumns()); // Allocated the dense output matrix
long t0=0;
- cudnnTensorDescriptor srcTensorDesc = in.getGPUObject(gCtx).getTensorDescriptor();
- if(N*CHW >= maxNumDoublesOfCuDNNTensor || srcTensorDesc == null) {
+ if(N*CHW >= maxNumDoublesOfCuDNNTensor) {
if(LOG.isTraceEnabled()) {
LOG.trace("GPU : relu custom kernel" + ", GPUContext=" + gCtx);
}
@@ -1003,7 +682,11 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_RELU_KERNEL, System.nanoTime() - t0);
}
else {
- cudnnReLU(gCtx, instName, in, getDensePointerForCuDNN(gCtx, output, instName), srcTensorDesc);
+ cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
+ cudnnCreateTensorDescriptor(tensorDescriptor);
+ cudnnSetTensor4dDescriptor(tensorDescriptor, CUDNN_TENSOR_NCHW, CUDNN_DATA_DOUBLE, toInt(N), 1, 1, toInt(CHW));
+ cudnnReLU(gCtx, instName, in, getDensePointerForCuDNN(gCtx, output, instName), tensorDescriptor);
+ cudnnDestroyTensorDescriptor(tensorDescriptor);
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/96ae6c7e/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
index 2243b58..871194e 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNConvolutionAlgorithm.java
@@ -31,9 +31,18 @@ import jcuda.jcudnn.cudnnConvolutionDescriptor;
import jcuda.jcudnn.cudnnConvolutionFwdPreference;
import jcuda.jcudnn.cudnnFilterDescriptor;
import jcuda.jcudnn.cudnnTensorDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnCreateConvolutionDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnCreateFilterDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnCreateTensorDescriptor;
import static jcuda.jcudnn.JCudnn.cudnnDestroyConvolutionDescriptor;
import static jcuda.jcudnn.JCudnn.cudnnDestroyFilterDescriptor;
import static jcuda.jcudnn.JCudnn.cudnnDestroyTensorDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnSetConvolution2dDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnSetFilter4dDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnSetTensor4dDescriptor;
+import static jcuda.jcudnn.cudnnConvolutionMode.CUDNN_CROSS_CORRELATION;
+import static jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
+import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
/**
* This class is a wrapper that contain necessary data structures to invoke
@@ -48,6 +57,9 @@ import static jcuda.jcudnn.JCudnn.cudnnDestroyTensorDescriptor;
*
*/
public class LibMatrixCuDNNConvolutionAlgorithm implements java.lang.AutoCloseable {
+ // Limit the workspace available to cudnn convolution operation to 1 GB
+ private static long MAX_WORKSPACE_LIMIT_BYTES = (long) 1e+9;
+
public int algo = -1;
public Pointer workSpace = new Pointer();
public long sizeInBytes = 0;
@@ -61,12 +73,12 @@ public class LibMatrixCuDNNConvolutionAlgorithm implements java.lang.AutoCloseab
int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q) throws DMLRuntimeException {
int padding[] = {pad_h, pad_w};
int strides[] = {stride_h, stride_w};
- convDesc = LibMatrixCuDNN.allocateConvolutionDescriptor(padding, strides);
+ convDesc = allocateConvolutionDescriptor(padding, strides);
this.gCtx = gCtx;
this.instName = instName;
- nchwTensorDesc = LibMatrixCuDNN.allocateTensorDescriptor(N, C, H, W);
- nkpqTensorDesc = LibMatrixCuDNN.allocateTensorDescriptor(N, K, P, Q);
- filterDesc = LibMatrixCuDNN.allocateFilterDescriptor(K, C, R, S);
+ nchwTensorDesc = allocateTensorDescriptor(N, C, H, W);
+ nkpqTensorDesc = allocateTensorDescriptor(N, K, P, Q);
+ filterDesc = allocateFilterDescriptor(K, C, R, S);
}
/**
@@ -125,7 +137,7 @@ public class LibMatrixCuDNNConvolutionAlgorithm implements java.lang.AutoCloseab
}
else {
int[] algos = {-1};
- long sizeInBytesArray[] = {workspaceLimit};
+ long sizeInBytesArray[] = {Math.min(workspaceLimit, MAX_WORKSPACE_LIMIT_BYTES)};
jcuda.jcudnn.JCudnn.cudnnGetConvolutionForwardAlgorithm(LibMatrixCuDNN.getCudnnHandle(gCtx),
ret.nchwTensorDesc, ret.filterDesc, ret.convDesc, ret.nkpqTensorDesc,
cudnnConvolutionFwdPreference.CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, sizeInBytesArray[0], algos);
@@ -177,7 +189,7 @@ public class LibMatrixCuDNNConvolutionAlgorithm implements java.lang.AutoCloseab
}
else {
int[] algos = {-1};
- long sizeInBytesArray[] = {workspaceLimit};
+ long sizeInBytesArray[] = {Math.min(workspaceLimit, MAX_WORKSPACE_LIMIT_BYTES)};
jcuda.jcudnn.JCudnn.cudnnGetConvolutionBackwardFilterAlgorithm(
LibMatrixCuDNN.getCudnnHandle(gCtx),
ret.nchwTensorDesc, ret.nkpqTensorDesc, ret.convDesc, ret.filterDesc,
@@ -230,7 +242,7 @@ public class LibMatrixCuDNNConvolutionAlgorithm implements java.lang.AutoCloseab
}
else {
int[] algos = {-1};
- long sizeInBytesArray[] = {workspaceLimit};
+ long sizeInBytesArray[] = {Math.min(workspaceLimit, MAX_WORKSPACE_LIMIT_BYTES)};
jcuda.jcudnn.JCudnn.cudnnGetConvolutionBackwardDataAlgorithm(
LibMatrixCuDNN.getCudnnHandle(gCtx),
ret.filterDesc, ret.nkpqTensorDesc, ret.convDesc, ret.nchwTensorDesc,
@@ -246,4 +258,34 @@ public class LibMatrixCuDNNConvolutionAlgorithm implements java.lang.AutoCloseab
GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
return ret;
}
+
+ /**
+ * Convenience method to get tensor descriptor
+ * @param N number of images
+ * @param C number of channels
+ * @param H height
+ * @param W width
+ * @return cudnn tensor descriptor
+ * @throws DMLRuntimeException if the input descriptor and matrix dimensions don't match
+ */
+ private static cudnnTensorDescriptor allocateTensorDescriptor(int N, int C, int H, int W) throws DMLRuntimeException {
+ cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
+ cudnnCreateTensorDescriptor(tensorDescriptor);
+ cudnnSetTensor4dDescriptor(tensorDescriptor, CUDNN_TENSOR_NCHW, CUDNN_DATA_DOUBLE, N, C, H, W);
+ return tensorDescriptor;
+ }
+
+ private static cudnnFilterDescriptor allocateFilterDescriptor(int K, int C, int R, int S) {
+ cudnnFilterDescriptor filterDesc = new cudnnFilterDescriptor();
+ cudnnCreateFilterDescriptor(filterDesc);
+ cudnnSetFilter4dDescriptor(filterDesc, CUDNN_DATA_DOUBLE, CUDNN_TENSOR_NCHW, K, C, R, S);
+ return filterDesc;
+ }
+
+ private static cudnnConvolutionDescriptor allocateConvolutionDescriptor(int padding [], int strides []) {
+ cudnnConvolutionDescriptor convDesc = new cudnnConvolutionDescriptor();
+ cudnnCreateConvolutionDescriptor(convDesc);
+ cudnnSetConvolution2dDescriptor(convDesc, padding[0], padding[1], strides[0], strides[1], 1, 1, CUDNN_CROSS_CORRELATION);
+ return convDesc;
+ }
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/96ae6c7e/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNPoolingDescriptors.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNPoolingDescriptors.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNPoolingDescriptors.java
new file mode 100644
index 0000000..f817bd5
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNPoolingDescriptors.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.matrix.data;
+
+import static jcuda.jcudnn.JCudnn.cudnnCreatePoolingDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnCreateTensorDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnDestroyTensorDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnSetPooling2dDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnSetTensor4dDescriptor;
+import static jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
+import static jcuda.jcudnn.cudnnNanPropagation.CUDNN_PROPAGATE_NAN;
+import static jcuda.jcudnn.cudnnPoolingMode.CUDNN_POOLING_MAX;
+import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
+
+import jcuda.jcudnn.cudnnPoolingDescriptor;
+import jcuda.jcudnn.cudnnTensorDescriptor;
+
+/**
+ * This class is a wrapper that contain necessary data structures to invoke
+ * a cudnn convolution* functions (such as cudnnConvolutionForward, etc)
+ *
+ * It implements autocloseable to simplify the LibMatrixCuDNN code and also avoids potential memory leaks.
+ */
+public class LibMatrixCuDNNPoolingDescriptors implements java.lang.AutoCloseable {
+
+ public cudnnTensorDescriptor xDesc;
+ public cudnnTensorDescriptor yDesc;
+ public cudnnTensorDescriptor dxDesc;
+ public cudnnTensorDescriptor dyDesc;
+ public cudnnPoolingDescriptor poolingDesc;
+
+ @Override
+ public void close() {
+ if(xDesc != null)
+ cudnnDestroyTensorDescriptor(xDesc);
+ if(yDesc != null)
+ cudnnDestroyTensorDescriptor(yDesc);
+ if(dxDesc != null)
+ cudnnDestroyTensorDescriptor(dxDesc);
+ if(dyDesc != null)
+ cudnnDestroyTensorDescriptor(dyDesc);
+ if(poolingDesc != null)
+ jcuda.jcudnn.JCudnn.cudnnDestroyPoolingDescriptor(poolingDesc);
+ }
+
+ /**
+ * Get descriptors for maxpooling backward operation
+ *
+ * @param gCtx gpu context
+ * @param instName instruction name
+ * @param N batch size
+ * @param C number of channels
+ * @param H height of image
+ * @param W width of image
+ * @param K number of filters
+ * @param R height of filter
+ * @param S width of filter
+ * @param pad_h vertical padding
+ * @param pad_w horizontal padding
+ * @param stride_h horizontal stride
+ * @param stride_w vertical stride
+ * @param P (H - R + 1 + 2*pad_h)/stride_h
+ * @param Q (W - S + 1 + 2*pad_w)/stride_w
+ * @return decriptor wrapper
+ * @throws DMLRuntimeException if error occurs
+ */
+ public static LibMatrixCuDNNPoolingDescriptors cudnnMaxpoolingBackwardDescriptors(GPUContext gCtx,
+ String instName, int N, int C, int H, int W, int K, int R,
+ int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
+ int Q) throws DMLRuntimeException {
+ LibMatrixCuDNNPoolingDescriptors ret = new LibMatrixCuDNNPoolingDescriptors();
+ ret.xDesc = allocateTensorDescriptor(N, C, H, W);
+ ret.yDesc = allocateTensorDescriptor(N, C, P, Q);
+ ret.dxDesc = allocateTensorDescriptor(N, C, H, W);
+ ret.dyDesc = allocateTensorDescriptor(N, C, P, Q);
+ ret.poolingDesc = allocatePoolingDescriptor(R, S, pad_h, pad_w, stride_h, stride_w);
+ return ret;
+ }
+
+ /**
+ * Get descriptors for maxpooling operation
+ *
+ * @param gCtx gpu context
+ * @param instName instruction name
+ * @param N batch size
+ * @param C number of channels
+ * @param H height of image
+ * @param W width of image
+ * @param K number of filters
+ * @param R height of filter
+ * @param S width of filter
+ * @param pad_h vertical padding
+ * @param pad_w horizontal padding
+ * @param stride_h horizontal stride
+ * @param stride_w vertical stride
+ * @param P (H - R + 1 + 2*pad_h)/stride_h
+ * @param Q (W - S + 1 + 2*pad_w)/stride_w
+ * @return decriptor wrapper
+ * @throws DMLRuntimeException if error occurs
+ */
+ public static LibMatrixCuDNNPoolingDescriptors cudnnMaxpoolingDescriptors(GPUContext gCtx,
+ String instName, int N, int C, int H, int W, int K, int R,
+ int S, int pad_h, int pad_w, int stride_h, int stride_w, int P,
+ int Q) throws DMLRuntimeException {
+ LibMatrixCuDNNPoolingDescriptors ret = new LibMatrixCuDNNPoolingDescriptors();
+ ret.xDesc = allocateTensorDescriptor(N, C, H, W);
+ ret.yDesc = allocateTensorDescriptor(N, C, P, Q);
+ ret.poolingDesc = allocatePoolingDescriptor(R, S, pad_h, pad_w, stride_h, stride_w);
+ return ret;
+ }
+
+ /**
+ * Convenience method to get tensor descriptor
+ * @param N number of images
+ * @param C number of channels
+ * @param H height
+ * @param W width
+ * @return cudnn tensor descriptor
+ * @throws DMLRuntimeException if the input descriptor and matrix dimensions don't match
+ */
+ private static cudnnTensorDescriptor allocateTensorDescriptor(int N, int C, int H, int W) throws DMLRuntimeException {
+ cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
+ cudnnCreateTensorDescriptor(tensorDescriptor);
+ cudnnSetTensor4dDescriptor(tensorDescriptor, CUDNN_TENSOR_NCHW, CUDNN_DATA_DOUBLE, N, C, H, W);
+ return tensorDescriptor;
+ }
+
+ /**
+ * allocates pooling descriptor, used in poolingForward and poolingBackward
+ * @param R pooling window height
+ * @param S pooling window width
+ * @param pad_h vertical padding
+ * @param pad_w horizontal padding
+ * @param stride_h pooling vertical stride
+ * @param stride_w pooling horizontal stride
+ * @return cudnn pooling descriptor
+ */
+ private static cudnnPoolingDescriptor allocatePoolingDescriptor(int R, int S, int pad_h, int pad_w, int stride_h, int stride_w) {
+ cudnnPoolingDescriptor poolingDesc = new cudnnPoolingDescriptor();
+ cudnnCreatePoolingDescriptor(poolingDesc);
+ cudnnSetPooling2dDescriptor(poolingDesc, CUDNN_POOLING_MAX, CUDNN_PROPAGATE_NAN, R, S, pad_h, pad_w, stride_h, stride_w);
+ return poolingDesc;
+ }
+}