You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2016/07/09 03:30:23 UTC
incubator-systemml git commit: [SYSTEMML-769] Improved performance of
LibMatrixDNN's conv2d and conv2d_backward_filter
Repository: incubator-systemml
Updated Branches:
refs/heads/master 20e05458b -> 2ebf885a6
[SYSTEMML-769] Improved performance of LibMatrixDNN's conv2d and
conv2d_backward_filter
- Fixed bug while iterating through sparse conv2d_backward_filter
- Also added vectorized conv2d
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/2ebf885a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/2ebf885a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/2ebf885a
Branch: refs/heads/master
Commit: 2ebf885a6919e1cb0598e2aab4d0ffb46b8e0ab5
Parents: 20e0545
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Fri Jul 8 20:26:16 2016 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Fri Jul 8 20:28:25 2016 -0700
----------------------------------------------------------------------
.../sysml/runtime/matrix/data/LibMatrixDNN.java | 558 ++++++++++++++-----
1 file changed, 410 insertions(+), 148 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2ebf885a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index d9faf7e..26e2b8b 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -20,6 +20,7 @@ package org.apache.sysml.runtime.matrix.data;
import java.lang.ref.SoftReference;
import java.util.ArrayList;
+import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
@@ -29,12 +30,16 @@ import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicLong;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.util.ConvolutionUtils;
public class LibMatrixDNN {
+
+ protected static final Log LOG = LogFactory.getLog(LibMatrixDNN.class.getName());
public static final boolean ALLOW_MULTI_THREADED_OPS = true;
// Using hashmap to avoid any performance impacts of multimap
@@ -62,13 +67,14 @@ public class LibMatrixDNN {
enum TaskType {
ReshapeCol, Rotate180, Im2Col, Col2Im, MaxPooling_Forward, MaxPooling_Backward, LoopBasedConv2d
}
- public static final int TASK_SIZE = 64; // to take care of extremely small tasks
public static class TemporaryConvolutionData {
public int [] minIndexArrR;
public int [] minIndexArrS;
public int [] maxIndexArrR;
public int [] maxIndexArrS;
+ int minCommonIndexS;
+ int maxCommonIndexS;
}
public static class ConvolutionParameters {
@@ -159,6 +165,9 @@ public class LibMatrixDNN {
dout.getNumRows() != params.N || dout.getNumColumns() != params.K*params.P*params.Q) {
throw new DMLRuntimeException("Incorrect input to conv2d_backward_filter");
}
+ if(params.stride_h <= 0 || params.stride_w <= 0) {
+ throw new DMLRuntimeException("Only positive strides supported");
+ }
int constrainedNumThreads = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
if(!ALLOW_MULTI_THREADED_OPS || constrainedNumThreads <= 1) {
@@ -198,7 +207,7 @@ public class LibMatrixDNN {
}
}
- public static void doConv2d_Backward_Filter(int k, int c, int r, int s, ConvolutionParameters params) {
+ private static void doConv2d_Backward_Filter(int k, int c, int r, int s, ConvolutionParameters params) throws DMLRuntimeException {
double [] inputArray = null;
if (!params.input1.isInSparseFormat())
inputArray = params.input1.getDenseBlock();
@@ -207,62 +216,125 @@ public class LibMatrixDNN {
doutArray = params.input2.getDenseBlock();
double [] outputArray = params.output.getDenseBlock();
- long outputVal = 0;
- if(doutArray != null) {
- for (int n = 0; n < params.N; n++) {
- for (int p = 0; p < params.P; p++) {
- for (int q = 0; q < params.Q; q++) {
- int h = p*params.stride_h + r - params.pad_h;
- int w = q*params.stride_w + s - params.pad_w;
- if(h >= 0 && h < params.H && w >= 0 && w < params.W) {
- double doutVal = doutArray[n*params.K*params.P*params.Q + k*params.P*params.Q + p*params.Q + q];
- if(doutVal != 0) {
- if(inputArray != null)
- outputVal += doutVal*inputArray[n*params.C*params.H*params.W + c*params.H*params.W + h*params.W+w];
- else
- outputVal += doutVal*params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w);
- }
- }
- }
- }
- }
+ double outputVal = 0;
+ if(inputArray == null && doutArray == null) {
+ outputVal = doConv2d_Backward_Filter_SparseSparse(k, c, r, s, params);
+ }
+ else if(inputArray != null && doutArray == null) {
+ outputVal = doConv2d_Backward_Filter_DenseSparse(k, c, r, s, params, inputArray);
+ }
+ else if(inputArray == null && doutArray != null) {
+ outputVal = doConv2d_Backward_Filter_SparseDense(k, c, r, s, params, doutArray);
}
else {
- MatrixBlock dout = params.input2;
- if( !dout.isEmptyBlock(false) ) {
- int start=0;
- int rlen = dout.getNumRows();
- int clen = dout.getNumColumns();
- for(int r1=0; r1<Math.min(dout.sparseBlock.numRows(), rlen); r1++, start+=clen)
- {
- if(dout.sparseBlock.isEmpty(r1))
- continue;
- int pos = dout.sparseBlock.pos(r1);
- int len = dout.sparseBlock.size(r1);
- int[] aix = dout.sparseBlock.indexes(r1);
- double[] avals = dout.sparseBlock.values(r1);
-
- for(int i=pos; i<pos+len; i++) {
- int index = start+aix[i];
- double doutVal = avals[i];
- int n = index / clen;
- int p = index / params.Q;
- int q = index % params.Q;
- int h = p*params.stride_h + r - params.pad_h;
- int w = q*params.stride_w + s - params.pad_w;
- if(h >= 0 && h < params.H && w >= 0 && w < params.W && doutVal != 0) {
- if(inputArray != null)
- outputVal += doutVal*inputArray[n*params.C*params.H*params.W + c*params.H*params.W + h*params.W+w];
- else
- outputVal += doutVal*params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w);
- }
- }
+ outputVal = doConv2d_Backward_Filter_DenseDense(k, c, r, s, params, inputArray, doutArray);
+ }
+
+ outputArray[k*params.C*params.R*params.S + c*params.R*params.S + r*params.S + s] = outputVal;
+ }
+
+ private static double doConv2d_Backward_Filter_SparseDense(int k, int c, int r, int s, ConvolutionParameters params, double [] doutArray) throws DMLRuntimeException {
+ double outputVal = 0;
+ // To ensure h >= 0 && h < params.H
+ int pMin = (int) Math.max(0, Math.ceil(((double)(params.pad_h-r))/params.stride_h));
+ int qMin = (int) Math.max(0, Math.ceil(((double)(params.pad_w-s))/params.stride_w));
+ // To ensure w >= 0 && w < params.W
+ int pMax = (int) Math.min(params.P, Math.ceil(((double)(params.H+params.pad_h-r))/params.stride_h));
+ int qMax = (int) Math.min(params.Q, Math.ceil(((double)(params.W+params.pad_w-s))/params.stride_w));
+
+ // TODO: Optimize this case
+ for (int n = 0; n < params.N; n++) {
+ int doutOffset = n*params.K*params.P*params.Q + k*params.P*params.Q;
+ for (int p = pMin; p < pMax; p++) {
+ for (int q = qMin; q < qMax; q++) {
+ int h = p*params.stride_h + r - params.pad_h;
+ int w = q*params.stride_w + s - params.pad_w;
+ outputVal += doutArray[doutOffset + p*params.Q + q]*params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w);
}
- }
+ }
}
+ return outputVal;
+ }
+
+ private static double doConv2d_Backward_Filter_DenseDense(int k, int c, int r, int s, ConvolutionParameters params, double [] inputArray, double [] doutArray) {
+ double outputVal = 0;
+ // To ensure h >= 0 && h < params.H
+ int pMin = (int) Math.max(0, Math.ceil(((double)(params.pad_h-r))/params.stride_h));
+ int qMin = (int) Math.max(0, Math.ceil(((double)(params.pad_w-s))/params.stride_w));
+ // To ensure w >= 0 && w < params.W
+ int pMax = (int) Math.min(params.P, Math.ceil(((double)(params.H+params.pad_h-r))/params.stride_h));
+ int qMax = (int) Math.min(params.Q, Math.ceil(((double)(params.W+params.pad_w-s))/params.stride_w));
- outputArray[k*params.C*params.R*params.S + c*params.R*params.S + r*params.S + s] = outputVal;
+ for (int n = 0; n < params.N; n++) {
+ int inputOffset = n*params.C*params.H*params.W + c*params.H*params.W + s - params.pad_w;
+ int doutOffset = n*params.K*params.P*params.Q + k*params.P*params.Q;
+ for (int p = pMin; p < pMax; p++) {
+ int h = p*params.stride_h + r - params.pad_h;
+ for (int q = qMin; q < qMax; q++) {
+ int w = q*params.stride_w;
+ outputVal += doutArray[doutOffset + p*params.Q + q]*inputArray[inputOffset + h*params.W+w];
+ }
+ }
+ }
+
+ return outputVal;
+ }
+
+ private static void computeTensorIndexes(int i, int j, int [] ret, int N, int C, int H, int W) throws DMLRuntimeException {
+ ret[0] = i;
+ ret[1] = j / (H*W);
+ ret[2] = (j - ret[1]*(H*W))/W;
+ ret[3] = j % W;
+ }
+
+ private static double doConv2d_Backward_Filter_DenseSparse(int k, int c, int r, int s, ConvolutionParameters params, double [] inputArray) throws DMLRuntimeException {
+ MatrixBlock dout = params.input2;
+ double outputVal = 0;
+ Iterator<IJV> iter = dout.sparseBlock.getIterator();
+ int [] tensorIndexes = new int[4];
+ while(iter.hasNext()) {
+ IJV ijv = iter.next();
+ computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexes, params.N, params.K, params.P, params.Q);
+ if(k == tensorIndexes[1]) {
+ int n = tensorIndexes[0];
+ int p = tensorIndexes[2];
+ int q = tensorIndexes[3];
+
+ double doutVal = ijv.getV();
+ int h = p*params.stride_h + r - params.pad_h;
+ int w = q*params.stride_w + s - params.pad_w;
+ if(h >= 0 && h < params.H && w >= 0 && w < params.W) {
+ outputVal += doutVal*inputArray[n*params.C*params.H*params.W + c*params.H*params.W + h*params.W+w];
+ }
+ }
+ }
+ return outputVal;
+ }
+
+ private static double doConv2d_Backward_Filter_SparseSparse(int k, int c, int r, int s, ConvolutionParameters params) throws DMLRuntimeException {
+ MatrixBlock dout = params.input2;
+ double outputVal = 0;
+ Iterator<IJV> iter = dout.sparseBlock.getIterator();
+ int [] tensorIndexes = new int[4];
+
+ while(iter.hasNext()) {
+ IJV ijv = iter.next();
+ computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexes, params.N, params.K, params.P, params.Q);
+ if(k == tensorIndexes[1]) {
+ int n = tensorIndexes[0];
+ int p = tensorIndexes[2];
+ int q = tensorIndexes[3];
+
+ double doutVal = ijv.getV();
+ int h = p*params.stride_h + r - params.pad_h;
+ int w = q*params.stride_w + s - params.pad_w;
+ if(h >= 0 && h < params.H && w >= 0 && w < params.W) {
+ outputVal += doutVal*params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w);
+ }
+ }
+ }
+ return outputVal;
}
private static class ConvBackwardFilterTask implements Callable<Object> {
@@ -294,25 +366,55 @@ public class LibMatrixDNN {
throw new DMLRuntimeException("Incorrect input to conv2d");
}
- params.tmpData = new TemporaryConvolutionData();
- params.tmpData.minIndexArrR = new int[params.R];
- params.tmpData.maxIndexArrR = new int[params.R];
- params.tmpData.minIndexArrS = new int[params.S];
- params.tmpData.maxIndexArrS = new int[params.S];
- for (int r = 0; r < params.R; r++) {
- params.tmpData.minIndexArrR[r] = getMinPQ(params.pad_h, r, params.stride_h);
- params.tmpData.maxIndexArrR[r] = getMaxPQ(params.pad_h, r, params.stride_h, params.P, params.H);
+ params.tmpData = new TemporaryConvolutionData();
+ if(input.isInSparseFormat()) {
+ params.tmpData.minIndexArrR = new int[params.H];
+ params.tmpData.minIndexArrS = new int[params.W];
+ for(int h = 0; h < params.H; h++) {
+ for (int r = 0; r < params.R; r++) {
+ // int h = p*params.stride_h + r - params.pad_h;
+ if((h + params.pad_h - r) % params.stride_h == 0) {
+ params.tmpData.minIndexArrR[h] = r;
+ break;
+ }
+ }
+ }
+ for(int w = 0; w < params.W; w++) {
+ for (int s = 0; s < params.S; s++) {
+ // int h = p*params.stride_h + r - params.pad_h;
+ if((w + params.pad_w - s) % params.stride_w == 0) {
+ params.tmpData.minIndexArrS[w] = s;
+ break;
+ }
+ }
+ }
}
- for (int s = 0; s < params.S; s++) {
- params.tmpData.minIndexArrS[s] = getMinPQ(params.pad_w, s, params.stride_w);
- params.tmpData.maxIndexArrS[s] = getMaxPQ(params.pad_w, s, params.stride_w, params.Q, params.W);
+ else {
+ params.tmpData.minIndexArrR = new int[params.R];
+ params.tmpData.maxIndexArrR = new int[params.R];
+ params.tmpData.minIndexArrS = new int[params.S];
+ params.tmpData.maxIndexArrS = new int[params.S];
+ for (int r = 0; r < params.R; r++) {
+ params.tmpData.minIndexArrR[r] = getMinPQ(params.pad_h, r, params.stride_h);
+ params.tmpData.maxIndexArrR[r] = getMaxPQ(params.pad_h, r, params.stride_h, params.P, params.H);
+ }
+ for (int s = 0; s < params.S; s++) {
+ params.tmpData.minIndexArrS[s] = getMinPQ(params.pad_w, s, params.stride_w);
+ params.tmpData.maxIndexArrS[s] = getMaxPQ(params.pad_w, s, params.stride_w, params.Q, params.W);
+ }
+ params.tmpData.minCommonIndexS = params.tmpData.minIndexArrS[0];
+ params.tmpData.maxCommonIndexS = params.tmpData.maxIndexArrS[0];
+ for (int s = 1; s < params.S; s++) {
+ params.tmpData.minCommonIndexS = Math.max(params.tmpData.minCommonIndexS, params.tmpData.minIndexArrS[s]);
+ params.tmpData.maxCommonIndexS = Math.min(params.tmpData.maxCommonIndexS, params.tmpData.maxIndexArrS[s]);
+ }
}
int constrainedNumThreads = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
if(!ALLOW_MULTI_THREADED_OPS || constrainedNumThreads <= 1) {
for (int n = 0; n < params.N; n++) {
for (int k = 0; k < params.K; k++) {
- doLoopBasedConv2d(n, k, params);
+ doLoopBasedConv2d(n, n+1, k, params);
}
}
}
@@ -345,102 +447,255 @@ public class LibMatrixDNN {
}
}
- /**
- * This is essentially memory-less operation and can be used when the memory pressure is extremely high.
- * @param n
- * @param k
- * @param params
- */
- private static void doLoopBasedConv2d(int n, int k, ConvolutionParameters params) {
- double [] inputArray = null;
- if (!params.input1.isInSparseFormat())
- inputArray = params.input1.getDenseBlock();
- double [] filterArray = null;
- if (!params.input2.isInSparseFormat())
- filterArray = params.input2.getDenseBlock();
+ private static void doLoopBasedConv2dDenseDense(int n1, int n2, int k, ConvolutionParameters params,
+ double [] inputArray, double [] filterArray) {
double [] outputArray = params.output.getDenseBlock();
-
- int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q;
-
int [] minIndexArrR = params.tmpData.minIndexArrR;
int [] maxIndexArrR = params.tmpData.maxIndexArrR;
int [] minIndexArrS = params.tmpData.minIndexArrS;
int [] maxIndexArrS = params.tmpData.maxIndexArrS;
- if(inputArray != null && filterArray != null) {
- for (int c = 0; c < params.C; c++) {
- for (int r = 0; r < params.R; r++) {
- int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S;
- for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) {
- for (int s = 0; s < params.S; s++) {
- double filterVal = filterArray[filterOffset + s];
- if(filterVal != 0) {
- int h = p*params.stride_h + r - params.pad_h;
- for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) {
- int w = q*params.stride_w + s - params.pad_w;
- outputArray[outputOffset + p*params.Q + q] += denseConvMultiply(inputArray, filterVal, params, n, c, h, w);
+ int minCommonIndexS = params.tmpData.minCommonIndexS;
+ int maxCommonIndexS = params.tmpData.maxCommonIndexS;
+
+
+ int minS = 0;
+ if(params.S >= 4) {
+ minS = params.S - params.S % 4;
+ for (int n = n1; n < n2; n++) {
+ for (int c = 0; c < params.C; c++) {
+ for (int r = 0; r < params.R; r++) {
+ final int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S;
+ for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) {
+ final int h = p*params.stride_h + r - params.pad_h;
+ final int inputOffSet = n*params.C*params.H*params.W + c*params.H*params.W + h*params.W - params.pad_w;
+ final int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q + p*params.Q;
+ // ------------------------------------------------------------------------
+ // Efficient striding with vectorization
+ for (int q = minCommonIndexS; q < maxCommonIndexS; q++) {
+ final int wOffset = inputOffSet + q*params.stride_w;
+ final int outOffsetWithQ = outputOffset + q;
+ for (int s = 0; s < minS; s += 4) {
+ final int inOffsetWithS = wOffset + s;
+ final int filterOffsetWithS = filterOffset + s;
+ outputArray[outOffsetWithQ] += inputArray[inOffsetWithS]*filterArray[filterOffsetWithS]
+ + inputArray[inOffsetWithS+1]*filterArray[filterOffsetWithS+1]
+ + inputArray[inOffsetWithS+2]*filterArray[filterOffsetWithS+2]
+ + inputArray[inOffsetWithS+3]*filterArray[filterOffsetWithS+3];
}
}
+ // ------------------------------------------------------------------------
}
}
}
}
}
- else if(inputArray != null && filterArray == null) {
+
+ for (int n = n1; n < n2; n++) {
for (int c = 0; c < params.C; c++) {
for (int r = 0; r < params.R; r++) {
+ final int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S;
for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) {
- for (int s = 0; s < params.S; s++) {
- double filterVal = params.input2.quickGetValue(k, c*params.R*params.S + r*params.S + s);
- if(filterVal != 0) {
- int h = p*params.stride_h + r - params.pad_h;
- for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) {
- int w = q*params.stride_w + s - params.pad_w;
- outputArray[outputOffset + p*params.Q + q] += denseConvMultiply(inputArray, filterVal, params, n, c, h, w);
- }
+ final int h = p*params.stride_h + r - params.pad_h;
+ final int inputOffSet = n*params.C*params.H*params.W + c*params.H*params.W + h*params.W - params.pad_w;
+ final int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q + p*params.Q;
+ // ------------------------------------------------------------------------
+ // Efficient striding
+ for (int q = minCommonIndexS; q < maxCommonIndexS; q++) {
+ final int wOffset = inputOffSet + q*params.stride_w;
+ for (int s = minS; s < params.S; s++) {
+ outputArray[outputOffset + q] += inputArray[wOffset + s]*filterArray[filterOffset + s];
}
}
+ // ------------------------------------------------------------------------
}
}
}
- }
- else if(inputArray == null && filterArray != null) {
+
+
for (int c = 0; c < params.C; c++) {
for (int r = 0; r < params.R; r++) {
- int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S;
+ final int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S;
for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) {
+ final int h = p*params.stride_h + r - params.pad_h;
+ final int inputOffSet = n*params.C*params.H*params.W + c*params.H*params.W + h*params.W - params.pad_w;
+ final int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q + p*params.Q;
+ // ------------------------------------------------------------------------
+ // Inefficient striding
for (int s = 0; s < params.S; s++) {
- double filterVal = filterArray[filterOffset + s];
- if(filterVal != 0) {
- int h = p*params.stride_h + r - params.pad_h;
- for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) {
- int w = q*params.stride_w + s - params.pad_w;
- outputArray[outputOffset + p*params.Q + q] += sparseConvMultiply(inputArray, filterVal, params, n, c, h, w);
- }
+ for (int q = minIndexArrS[s]; q < minCommonIndexS; q++) {
+ final int w = q*params.stride_w + s;
+ outputArray[outputOffset + q] += inputArray[inputOffSet + w]*filterArray[filterOffset + s];
+ }
+ for (int q = maxCommonIndexS; q < maxIndexArrS[s]; q++) {
+ final int w = q*params.stride_w + s;
+ outputArray[outputOffset + q] += inputArray[inputOffSet + w]*filterArray[filterOffset + s];
}
}
+ // ------------------------------------------------------------------------
}
}
}
}
- else if(inputArray == null && filterArray == null) {
- for (int c = 0; c < params.C; c++) {
- for (int r = 0; r < params.R; r++) {
- for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) {
- for (int s = 0; s < params.S; s++) {
- double filterVal = params.input2.quickGetValue(k, c*params.R*params.S + r*params.S + s);
- if(filterVal != 0) {
- int h = p*params.stride_h + r - params.pad_h;
- for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) {
- int w = q*params.stride_w + s - params.pad_w;
- outputArray[outputOffset + p*params.Q + q] += sparseConvMultiply(inputArray, filterVal, params, n, c, h, w);
- }
+ }
+
+ private static void doLoopBasedConv2dDenseSparse(int n, int k, ConvolutionParameters params, double [] inputArray) throws DMLRuntimeException {
+ double [] outputArray = params.output.getDenseBlock();
+ int [] minIndexArrR = params.tmpData.minIndexArrR;
+ int [] maxIndexArrR = params.tmpData.maxIndexArrR;
+ int [] minIndexArrS = params.tmpData.minIndexArrS;
+ int [] maxIndexArrS = params.tmpData.maxIndexArrS;
+ final int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q;
+
+ Iterator<IJV> iter = params.input2.sparseBlock.getIterator();
+ int [] tensorIndexes = new int[4];
+
+ while(iter.hasNext()) {
+ IJV ijv = iter.next();
+ computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexes, params.K, params.C, params.R, params.S);
+ if(k == tensorIndexes[0]) {
+ int c = tensorIndexes[1];
+ int r = tensorIndexes[2];
+ int s = tensorIndexes[3];
+ double filterVal = ijv.getV();
+ final int inputOffset = n*params.C*params.H*params.W + c*params.H*params.W + s - params.pad_w;
+ for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) {
+ final int hOffset = inputOffset + (p*params.stride_h + r - params.pad_h)*params.W;
+ final int pOffset = outputOffset + p*params.Q;
+ for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) {
+ final int w = q*params.stride_w;
+ outputArray[pOffset + q] += inputArray[hOffset + w]*filterVal;
+ }
+ }
+ }
+ }
+ }
+
+ private static void doLoopBasedConv2dSparseDense(int n, int k, ConvolutionParameters params, double [] filterArray) throws DMLRuntimeException {
+ double [] outputArray = params.output.getDenseBlock();
+ int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q;
+
+ Iterator<IJV> iter = params.input1.sparseBlock.getIterator();
+ int [] tensorIndexes = new int[4];
+
+ int [] minIndexArrR = params.tmpData.minIndexArrR;
+ int [] minIndexArrS = params.tmpData.minIndexArrS;
+ while(iter.hasNext()) {
+ IJV ijv = iter.next();
+ computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexes, params.N, params.C, params.H, params.W);
+ if(n == tensorIndexes[0]) {
+ int c = tensorIndexes[1];
+ int h = tensorIndexes[2];
+ int w = tensorIndexes[3];
+ double imgVal = ijv.getV();
+ for (int r = minIndexArrR[h]; r < params.R; r += params.stride_h) {
+ int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S;
+ for (int s = minIndexArrS[w]; s < params.S; s += params.stride_w) {
+ int p = (int)Math.ceil(((double)(h + params.pad_h - r)) / params.stride_h);
+ int q = (int)Math.ceil(((double)(w + params.pad_w - s)) / params.stride_w);
+ if(p >= 0 && p < params.P && q >= 0 && q < params.Q) {
+ double filterVal = filterArray[filterOffset + s];
+ outputArray[outputOffset + p*params.Q + q] += imgVal*filterVal;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ private static void doLoopBasedConv2dSparseSparse(int n, int k, ConvolutionParameters params) throws DMLRuntimeException {
+ double [] outputArray = params.output.getDenseBlock();
+ int [] minIndexArrR = params.tmpData.minIndexArrR;
+ int [] maxIndexArrR = params.tmpData.maxIndexArrR;
+ int [] minIndexArrS = params.tmpData.minIndexArrS;
+ int [] maxIndexArrS = params.tmpData.maxIndexArrS;
+ int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q;
+
+
+ int [] tensorIndexesImage = new int[4];
+ int [] tensorIndexesFilter = new int[4];
+
+ Iterator<IJV> iter = params.input1.sparseBlock.getIterator();
+
+ while(iter.hasNext()) {
+ IJV ijv = iter.next();
+ computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexesImage, params.N, params.C, params.H, params.W);
+ if(n == tensorIndexesImage[0]) {
+ int c = tensorIndexesImage[1];
+ int h = tensorIndexesImage[2];
+ int w = tensorIndexesImage[3];
+ double imgVal = ijv.getV();
+
+ Iterator<IJV> iter1 = params.input2.sparseBlock.getIterator();
+ while(iter1.hasNext()) {
+ IJV ijv1 = iter1.next();
+ computeTensorIndexes(ijv1.getI(), ijv1.getJ(), tensorIndexesFilter, params.K, params.C, params.R, params.S);
+ if(k == tensorIndexesFilter[0] && c == tensorIndexesFilter[1]) {
+ int r = tensorIndexesFilter[2];
+ int s = tensorIndexesFilter[3];
+ if((r-minIndexArrR[h])%params.stride_h == 0 && (s-minIndexArrS[w])%params.stride_w == 0) {
+ int p = (int)Math.ceil(((double)(h + params.pad_h - r)) / params.stride_h);
+ int q = (int)Math.ceil(((double)(w + params.pad_w - s)) / params.stride_w);
+ if(p >= 0 && p < params.P && q >= 0 && q < params.Q) {
+ double filterVal = ijv1.getV();
+ outputArray[outputOffset + p*params.Q + q] += imgVal*filterVal;
}
}
}
}
}
}
+
+ while(iter.hasNext()) {
+ IJV ijv = iter.next();
+ computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexesFilter, params.K, params.C, params.R, params.S);
+ if(k == tensorIndexesFilter[0]) {
+ int c = tensorIndexesFilter[1];
+ int r = tensorIndexesFilter[2];
+ int s = tensorIndexesFilter[3];
+ double filterVal = ijv.getV();
+ for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) {
+ int h = p*params.stride_h + r - params.pad_h;
+ for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) {
+ int w = q*params.stride_w + s - params.pad_w;
+ // TODO: Improve the performance of sparse sparse
+ outputArray[outputOffset + p*params.Q + q] += sparseConvMultiply(filterVal, params, n, c, h, w);
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * This is essentially memory-less operation and can be used when the memory pressure is extremely high.
+ * @param n
+ * @param k
+ * @param params
+ * @throws DMLRuntimeException
+ */
+ private static void doLoopBasedConv2d(int n1, int n2, int k, ConvolutionParameters params) throws DMLRuntimeException {
+ double [] inputArray = null;
+ if (!params.input1.isInSparseFormat())
+ inputArray = params.input1.getDenseBlock();
+ double [] filterArray = null;
+ if (!params.input2.isInSparseFormat())
+ filterArray = params.input2.getDenseBlock();
+
+ if(inputArray != null && filterArray != null) {
+ doLoopBasedConv2dDenseDense(n1, n2, k, params, inputArray, filterArray);
+ }
+ else if(inputArray != null && filterArray == null) {
+ for (int n = n1; n < n2; n++)
+ doLoopBasedConv2dDenseSparse(n, k, params, inputArray);
+ }
+ else if(inputArray == null && filterArray != null) {
+ for (int n = n1; n < n2; n++)
+ doLoopBasedConv2dSparseDense(n, k, params, filterArray);
+ }
+ else if(inputArray == null && filterArray == null) {
+ for (int n = n1; n < n2; n++)
+ doLoopBasedConv2dSparseSparse(n, k, params);
+ }
}
private static int getMinPQ(int pad, int filterSize, int stride) {
@@ -451,12 +706,7 @@ public class LibMatrixDNN {
return Math.min(outputSize, (int)Math.ceil(((double)(inputSize + pad - filterSize)) / stride));
}
- private static double denseConvMultiply(double [] inputArray, double filterVal, ConvolutionParameters params,
- int n, int c, int h, int w) {
- return inputArray[n*params.C*params.H*params.W + c*params.H*params.W + h*params.W+w]*filterVal;
- }
-
- private static double sparseConvMultiply(double [] inputArray, double filterVal, ConvolutionParameters params,
+ private static double sparseConvMultiply(double filterVal, ConvolutionParameters params,
int n, int c, int h, int w) {
return params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w)*filterVal;
}
@@ -635,27 +885,41 @@ public class LibMatrixDNN {
outputBlock.setNonZeros(input.getNonZeros()); // As number of non-zeros doesnot change for reshape_col
}
- private static void runParallelConvTask(int constrainedNumThreads, int Z, TaskType type, ConvolutionParameters params) throws DMLRuntimeException {
- // Total number of compute units available: constrainedNumThreads
- // Static task allocation. TODO: Do this in dynamic way
- int taskSize = TASK_SIZE;
- while(true) {
- if(params.N * Math.ceil(Z/taskSize) > constrainedNumThreads || taskSize == 1) {
- doRunParallelConvTask(constrainedNumThreads, Z, type, params, taskSize);
- return;
+ private static int [] getTaskSize(int constrainedNumThreads, int maxNumTaskSize1, int maxNumTaskSize2) {
+ int taskSize1 = 1; int taskSize2 = 1;
+ // Why this heuristics ? To reduce the impact of the thread-creation overhead in case of small tasks
+ int approxNumTasksToCreate = 3*constrainedNumThreads;
+ while((maxNumTaskSize1*maxNumTaskSize2)/(taskSize1*taskSize2) > approxNumTasksToCreate) {
+ // Possibility of creating too many tasks, increase taskSize2
+ taskSize2 *= 2;
+ if(taskSize2 >= maxNumTaskSize2) {
+ taskSize2 = maxNumTaskSize2;
+ break;
}
- taskSize = Math.max(taskSize/2, 1);
}
+ while((maxNumTaskSize1*maxNumTaskSize2)/(taskSize1*taskSize2) > approxNumTasksToCreate) {
+ // Possibility of creating too many tasks, increase taskSize1
+ taskSize1 *= 2;
+ if(taskSize1 >= maxNumTaskSize1) {
+ taskSize1 = maxNumTaskSize1;
+ break;
+ }
+ }
+ int [] ret = new int[2];
+ ret[0] = taskSize1;
+ ret[1] = taskSize2;
+ return ret;
}
- private static void doRunParallelConvTask(int constrainedNumThreads, int Z, TaskType type, ConvolutionParameters params, int taskSize) throws DMLRuntimeException {
- ArrayList<ConvTask> tasks = new ArrayList<ConvTask>();
-
- for (int n = 0; n < params.N; n++) {
- for (int z = 0; z < Z; z += taskSize) {
- tasks.add(new ConvTask(n, n+1, z, Math.min(Z, z+taskSize), type, params));
+ private static void runParallelConvTask(int constrainedNumThreads, int Z, TaskType type, ConvolutionParameters params) throws DMLRuntimeException {
+ ArrayList<ConvTask> tasks = new ArrayList<ConvTask>();
+ int [] taskSizes = getTaskSize(constrainedNumThreads, params.N, Z);
+ for (int n = 0; n < params.N; n += taskSizes[0]) {
+ for (int z = 0; z < Z; z += taskSizes[1]) {
+ tasks.add(new ConvTask(n, Math.min(params.N, n+taskSizes[0]), z, Math.min(Z, z+taskSizes[1]), type, params));
}
}
+ LOG.debug("Reduce number of tasks from " + (params.N*Z) + "(" + params.N + "," + Z + ") to " + tasks.size());
ExecutorService pool = Executors.newFixedThreadPool( Math.min(constrainedNumThreads, tasks.size()) );
List<Future<Object>> taskret;
@@ -727,10 +991,8 @@ public class LibMatrixDNN {
}
break;
case LoopBasedConv2d:
- for (int n = n1; n < n2; n++) {
- for (int z = z1; z < z2; z++) {
- LibMatrixDNN.doLoopBasedConv2d(n, z, params);
- }
+ for (int z = z1; z < z2; z++) {
+ LibMatrixDNN.doLoopBasedConv2d(n1, n2, z, params);
}
break;
default: