You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/11/16 01:45:27 UTC
[2/2] systemml git commit: [SYSTEMML-1958] Performance sparse conv2d
(im2col, alloc, load balance)
[SYSTEMML-1958] Performance sparse conv2d (im2col, alloc, load balance)
This patch makes three performance improvements to sparse conv2d
operations and also fixes a potential source of result incorrectness
with sparse im2col.
(1) Special case sparse im2col: We now use a greatly simplified sparse
im2col for the special case of stride=1, pad=0, Q=1, and S=W.
(2) Preallocation sparse im2col output: In order to account to skew in
the sparse input data, we now allocate the temporary block with 4x the
capacity derived from the average input sparsity.
(3) Load balance: To improve the load balance and thus utilization of
conv2d operations, we now create twice as many tasks as threads which is
a good trade-off between load balance and task overhead.
On an end-to-end cnn scoring application with sparse data, this patch
improved performance from 772s to 687s.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/8c4d5165
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/8c4d5165
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/8c4d5165
Branch: refs/heads/master
Commit: 8c4d516521e62a9ada190f76c8c4f20faaa3921b
Parents: 1336d32
Author: Matthias Boehm <mb...@gmail.com>
Authored: Wed Nov 15 13:54:12 2017 -0800
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Wed Nov 15 17:46:52 2017 -0800
----------------------------------------------------------------------
.../runtime/matrix/data/LibMatrixDNNHelper.java | 7 ++-
.../matrix/data/LibMatrixDNNIm2ColHelper.java | 45 +++++++++++---------
2 files changed, 28 insertions(+), 24 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/8c4d5165/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
index dfd0778..be8a833 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
@@ -118,11 +118,10 @@ public class LibMatrixDNNHelper {
public static ArrayList<Callable<Long>> getConv2dWorkers(ConvolutionParameters params) throws DMLRuntimeException {
ArrayList<Callable<Long>> ret = new ArrayList<>();
- // Try to create as many tasks as threads.
- // Creating more tasks will help in tail, but would have additional overhead of maintaining the intermediate
- // data structures such as im2col blocks.
+ // Try to create twice as many tasks as threads for improved load balance
+ // (due to constant-sized intermediates, GC works well, so the overhead per task is small)
int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
- int taskSize = (int)(Math.ceil((double)params.N / k));
+ int taskSize = (int)(Math.ceil((double)params.N / k / 2));
// TODO: Decide here based on params whether to use LoopedIm2ColConv2dAllChannels or LoopedIm2ColConv2dOneChannel
// For now, let's stick to the existing approach of converting [1, CHW] to [CRS, PQ] as it allows matrix multiplication large enough matrix.
http://git-wip-us.apache.org/repos/asf/systemml/blob/8c4d5165/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2ColHelper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2ColHelper.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2ColHelper.java
index a4b1877..1fa863e 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2ColHelper.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2ColHelper.java
@@ -55,13 +55,10 @@ public class LibMatrixDNNIm2ColHelper {
LOG.trace("Using SparseIm2colWorker operator to perform im2col.");
out.reset(out.rlen, out.clen, true);
out.allocateSparseRowsBlock();
- //preallocate sparse-rows
- double sparsity = Math.min(MatrixBlock.SPARSITY_TURN_POINT,
- (input.getNonZeros()*2.0) / (input.getNumRows()*input.getNumColumns()));
- int estnnz = (int)Math.ceil((trans ? params.C*params.R*params.S : params.P*params.Q)*sparsity);
+ //preallocate sparse-rows (larger than average sparsity to account for skew)
+ int estnnz = (int)Math.ceil(4*input.getSparsity()*out.clen);
for(int r = 0; r < out.rlen; r++)
- out.getSparseBlock().allocate(r, estnnz);
-
+ out.getSparseBlock().allocate(r, Math.min(estnnz, out.clen));
return new SparseSparseIm2colWorkerAllChan(input, out, params, trans);
}
}
@@ -248,17 +245,20 @@ public class LibMatrixDNNIm2ColHelper {
*/
private static class SparseSparseIm2colWorkerAllChan implements Im2colWorker {
private final MatrixBlock input, output;
- private final int S, R, P, Q, W, HW;
+ private final int S, R, P, Q, W, HW, RS;
private final int stride_h, stride_w, pad_h, pad_w;
private final boolean trans;
+ private final boolean simple;
public SparseSparseIm2colWorkerAllChan(MatrixBlock input, MatrixBlock im2ColOutBlock, ConvolutionParameters params, boolean trans) {
this.input = input;
this.output = im2ColOutBlock;
this.HW = params.H * params.W;
+ this.RS = params.R * params.S;
this.W = params.W; this.R = params.R; this.S = params.S; this.P = params.P; this.Q = params.Q;
this.stride_h = params.stride_h; this.stride_w = params.stride_w;
this.pad_h = params.pad_h; this.pad_w = params.pad_w;
this.trans = trans;
+ this.simple = params.isStride1Pad0() && W == S && Q == 1;
if(!input.isInSparseFormat())
throw new RuntimeException("Incorrect operator selection. Expected dense input for SparseIm2colWorkerAllChannels");
}
@@ -267,7 +267,7 @@ public class LibMatrixDNNIm2ColHelper {
public void execute(int n, int c) {
throw new RuntimeException("Not supported");
}
-
+
@Override
public void execute(int n) {
output.reset();
@@ -289,13 +289,15 @@ public class LibMatrixDNNIm2ColHelper {
int hInput = (chw - cInput*HW)/W;
int wInput = chw % W;
- appendInputValueToIm2colOutput(output, cInput, hInput, wInput, avals[j],
+ if( simple )
+ appendInputValueToIm2colOutputSimple(output, cInput, hInput, wInput,
+ avals[j], R, S, RS, P, trans);
+ else
+ appendInputValueToIm2colOutput(output, cInput, hInput, wInput, avals[j],
R, S, P, Q, stride_h, stride_w, pad_h, pad_w, trans);
}
- // Since the chw are appended in sorted order, no need to sort the output rows
- // unless in trans mode, then sorting is needed
- if( trans )
- output.sortSparseRows();
+
+ output.sortSparseRows();
}
}
@@ -349,10 +351,7 @@ public class LibMatrixDNNIm2ColHelper {
R, S, P, Q, stride_h, stride_w, pad_h, pad_w, trans);
}
}
- // Since the chw are appended in sorted order, no need to sort the output rows
- // unless in trans mode, then sorting is needed
- if( trans )
- output.sortSparseRows();
+ output.sortSparseRows();
}
}
@@ -375,8 +374,6 @@ public class LibMatrixDNNIm2ColHelper {
*/
private static void appendInputValueToIm2colOutput(MatrixBlock output, int cInput, int hInput, int wInput, double value,
int R, int S, int P, int Q, int stride_h, int stride_w, int pad_h, int pad_w, boolean trans) {
- if(value == 0)
- return;
int RS = R*S;
// For the given h,w index, insert avals[j] into respective r,s,p,q locations
@@ -407,8 +404,16 @@ public class LibMatrixDNNIm2ColHelper {
output.appendValue(pQ + q, outRowIndex + s, value);
else
output.appendValue(outRowIndex + s, pQ + q, value);
- // Since the chw are appended in sorted order, no need to sort the output rows
}
}
}
+
+ private static void appendInputValueToIm2colOutputSimple(MatrixBlock output, int c, int h, int w,
+ double value, int R, int S, int RS, int P, boolean trans) {
+ int rMin = Math.max(0, h - P + 1);
+ int rMax = Math.min(R-1, h);
+ int cix = c*RS+w+rMin*S;
+ for(int p=h-rMin; p >= h-rMax; p--, cix+=S)
+ output.appendValue(trans?p:cix, trans?cix:p, value);
+ }
}