You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2019/03/04 21:32:48 UTC
[systemml] 02/02: [SYSTEMML-540] Improved the performance of lstm
builtin function for sparse inputs
This is an automated email from the ASF dual-hosted git repository.
niketanpansare pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git
commit 792da5d0aa2abd6e650a3a17f243795d0f9a4b35
Author: Niketan Pansare <np...@us.ibm.com>
AuthorDate: Mon Mar 4 13:32:35 2019 -0800
[SYSTEMML-540] Improved the performance of lstm builtin function for sparse inputs
This commits allows matrix multiplication operator to exploit sparsity by separating lstm into three cases:
1. If W is sparse, perform cbind(X_t, out_prev) %*% W
2. If X_t is sparse, perform X_t %*% W1 + out_prev %*% W2
3. If none of the case is applicable, perform cbind(X_t, out_prev) %*% W to maximize parallelism within matrix multiplication operator
---
.../sysml/runtime/matrix/data/LibMatrixDNN.java | 114 ++++++++++-----------
1 file changed, 53 insertions(+), 61 deletions(-)
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index e2742d8..365d7a2 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -284,26 +284,34 @@ public class LibMatrixDNN {
private static MatrixBlock add(MatrixBlock matBlock1, MatrixBlock matBlock2, boolean inplace) {
BinaryOperator bop = new BinaryOperator(Plus.getPlusFnObject());
-// if(inplace) {
-// matBlock1.binaryOperationsInPlace(bop, matBlock2);
-// return matBlock1;
-// }
-// else {
+ if(inplace && matBlock1.isInSparseFormat() == matBlock2.isInSparseFormat() &&
+ matBlock1.getNumRows() == matBlock2.getNumRows() && matBlock1.getNumColumns() == matBlock2.getNumColumns()) {
+ matBlock1.binaryOperationsInPlace(bop, matBlock2);
+ return matBlock1;
+ }
+ else {
return (MatrixBlock) matBlock1.binaryOperations(bop, matBlock2, new MatrixBlock());
-// }
+ }
+ }
+ private static MatrixBlock plusMultiply(MatrixBlock matBlock1, MatrixBlock matBlock2, MatrixBlock matBlock3) {
+ return matBlock1.ternaryOperations(new TernaryOperator(PlusMultiply.getFnObject()),
+ matBlock2, matBlock3, new MatrixBlock());
}
+
private static MatrixBlock multiply(MatrixBlock matBlock1, MatrixBlock matBlock2, boolean inplace) {
BinaryOperator bop = new BinaryOperator(Multiply.getMultiplyFnObject());
-// if(inplace) {
-// matBlock1.binaryOperationsInPlace(bop, matBlock2);
-// return matBlock1;
-// }
-// else {
+ if(inplace && matBlock1.isInSparseFormat() == matBlock2.isInSparseFormat() &&
+ matBlock1.getNumRows() == matBlock2.getNumRows() && matBlock1.getNumColumns() == matBlock2.getNumColumns()) {
+ matBlock1.binaryOperationsInPlace(bop, matBlock2);
+ return matBlock1;
+ }
+ else {
return (MatrixBlock) matBlock1.binaryOperations(bop, matBlock2, new MatrixBlock());
-// }
+ }
}
+
// sigmoid(0)*c_prev + sigmoid(0)*tanh(0);
private static Builtin sigmoidOp = Builtin.getBuiltinFnObject(BuiltinCode.SIGMOID);
@@ -311,16 +319,10 @@ public class LibMatrixDNN {
private static MatrixBlock sigmoid(MatrixBlock in, int numThreads, boolean inPlace) {
return (MatrixBlock) in.unaryOperations(new UnaryOperator(sigmoidOp, numThreads, inPlace), new MatrixBlock());
}
-
private static MatrixBlock tanh(MatrixBlock in, int numThreads, boolean inPlace) {
return (MatrixBlock) in.unaryOperations(new UnaryOperator(tanhOp, numThreads, inPlace), new MatrixBlock());
}
- private static MatrixBlock plusMultiply(MatrixBlock matBlock1, MatrixBlock matBlock2, MatrixBlock matBlock3) {
- return matBlock1.ternaryOperations(new TernaryOperator(PlusMultiply.getFnObject()),
- matBlock2, matBlock3, new MatrixBlock());
- }
-
public static void lstm(MatrixBlock X, MatrixBlock W, MatrixBlock b, MatrixBlock out0, MatrixBlock c0,
boolean return_seq, int N, int T, int D, int M,
MatrixBlock out, MatrixBlock c, // output
@@ -329,61 +331,56 @@ public class LibMatrixDNN {
MatrixBlock out_prev = out0;
MatrixBlock c_prev = c0;
- MatrixBlock W1 = W.slice(0, D-1);
- MatrixBlock W2 = W.slice(D, D+M-1);
+ MatrixBlock W1 = null;
+ MatrixBlock W2 = null;
MatrixBlock c_t = null;
MatrixBlock out_t = null;
- boolean profile = true;
- long t1 = 0, t2 = 0, t3 = 0, t4 = 0, t5 = 0;
+ MatrixBlock input = null;
for(int t = 1; t <= T; t++) {
- long s = profile ? System.nanoTime() : 0;
- MatrixBlock X_t = X.slice(0, N-1, (t-1)*D, t*D-1, new MatrixBlock());
- if(profile) {
- long e = System.nanoTime();
- t1 += e - s;
- }
-
- s = profile ? System.nanoTime() : 0;
- MatrixBlock ifog_raw = add(add(matmult(X_t, W1, numThreads), matmult(out_prev, W2, numThreads), true), b, true);
- if(profile) {
- long e = System.nanoTime();
- t2 += e - s;
+ final MatrixBlock X_t = (T == 1) ? X : X.slice(0, N-1, (t-1)*D, t*D-1, new MatrixBlock());
+ MatrixBlock ifog_raw = null;
+ // Logic: Exploit sparse matrix multiplication whenever possible:
+ // 1. If W is sparse, perform cbind(X_t, out_prev) %*% W
+ // 2. Else if X_t is sparse, perform X_t %*% W1 + out_prev %*% W2
+ // 3. If none of the case is applicable, perform cbind(X_t, out_prev) %*% W
+ boolean isCase1 = W.isInSparseFormat();
+ boolean isCase2 = !isCase1 && X_t.isInSparseFormat();
+ if(isCase2) {
+ // Perform X_t %*% W1 + out_prev %*% W2
+ if(W1 == null) {
+ // Lazy slicing: applicable only when atleast one X_t is sparse.
+ W1 = W.slice(0, D-1);
+ W2 = W.slice(D, D+M-1);
+ }
+ ifog_raw = add(matmult(X_t, W1, numThreads), matmult(out_prev, W2, numThreads), true);
+ ifog_raw = add(ifog_raw, b, true);
+ }
+ else {
+ // Case 1 and 3:
+ // Perform input %*% W, where input = cbind(X_t, out_prev)
+ if(input == null) {
+ input = new MatrixBlock(N, D+M, false);
+ input.allocateDenseBlock();
+ }
+ input = X_t.append(out_prev, input);
+ ifog_raw = add(matmult(input, W, numThreads), b, true);
}
- s = profile ? System.nanoTime() : 0;
MatrixBlock ifo = ifog_raw.slice(0, N-1, 0, 3*M-1, new MatrixBlock());
ifo = sigmoid(ifo, numThreads, true);
MatrixBlock i = ifo.slice(0, N-1, 0, M-1, new MatrixBlock());
MatrixBlock f = ifo.slice(0, N-1, M, 2*M-1, new MatrixBlock());
MatrixBlock o = ifo.slice(0, N-1, 2*M, 3*M-1, new MatrixBlock());
-
- MatrixBlock g = ifog_raw.slice(0, N-1, 3*M, 4*M-1, new MatrixBlock());
- g = tanh(g, numThreads, true);
- if(profile) {
- long e = System.nanoTime();
- t3 += e - s;
- }
-
- s = profile ? System.nanoTime() : 0;
+ MatrixBlock g = tanh(ifog_raw.slice(0, N-1, 3*M, 4*M-1, new MatrixBlock()), numThreads, true);
+
// c_t = f*c_prev + i*g
c_t = plusMultiply(multiply(f, c_prev, true), i, g);
// out_t = o*tanh(c)
out_t = multiply(o, tanh(c_t, numThreads, false), true);
- if(profile) {
- long e = System.nanoTime();
- t4 += e - s;
- }
-
- s = profile ? System.nanoTime() : 0;
if(return_seq) {
out = out.leftIndexingOperations(out_t, 0, N-1, (t-1)*M, t*M-1, new MatrixBlock(), UpdateType.INPLACE);
}
- if(profile) {
- long e = System.nanoTime();
- t5 += e - s;
- }
-
out_prev = out_t;
c_prev = c_t;
@@ -398,11 +395,6 @@ public class LibMatrixDNN {
c.copy(c_t);
else
c.copy(c0);
- System.out.println("Time taken in lstm forward call: [X_t indexing:" + String.format("%.3f", t1*1e-9) +
- ", ifog_raw computation:" + String.format("%.3f", t2*1e-9) +
- ", lstm_squash computation:" + String.format("%.3f", t3*1e-9) +
- ", c_t/out_t computation:" + String.format("%.3f", t4*1e-9) +
- ", out leftIndexing computation:" + String.format("%.3f", t5*1e-9));
}
/**
@@ -1009,4 +1001,4 @@ public class LibMatrixDNN {
params.end_indexes_w[q] = Math.min(ix+params.S, params.W);
}
}
-}
+}
\ No newline at end of file