You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2019/03/04 21:32:48 UTC

[systemml] 02/02: [SYSTEMML-540] Improved the performance of lstm builtin function for sparse inputs

This is an automated email from the ASF dual-hosted git repository.

niketanpansare pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git

commit 792da5d0aa2abd6e650a3a17f243795d0f9a4b35
Author: Niketan Pansare <np...@us.ibm.com>
AuthorDate: Mon Mar 4 13:32:35 2019 -0800

    [SYSTEMML-540] Improved the performance of lstm builtin function for sparse inputs
    
    This commits allows matrix multiplication operator to exploit sparsity by separating lstm into three cases:
    1. If W is sparse, perform cbind(X_t, out_prev) %*% W
    2. If X_t is sparse, perform X_t %*% W1 + out_prev %*% W2
    3. If none of the case is applicable, perform cbind(X_t, out_prev) %*% W to maximize parallelism within matrix multiplication operator
---
 .../sysml/runtime/matrix/data/LibMatrixDNN.java    | 114 ++++++++++-----------
 1 file changed, 53 insertions(+), 61 deletions(-)

diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index e2742d8..365d7a2 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -284,26 +284,34 @@ public class LibMatrixDNN {
 	
 	private static MatrixBlock add(MatrixBlock matBlock1, MatrixBlock matBlock2, boolean inplace) {
 		BinaryOperator bop = new BinaryOperator(Plus.getPlusFnObject());
-//		if(inplace) {
-//			matBlock1.binaryOperationsInPlace(bop, matBlock2);
-//			return matBlock1;
-//		}
-//		else {
+		if(inplace && matBlock1.isInSparseFormat() == matBlock2.isInSparseFormat() &&
+			matBlock1.getNumRows() == matBlock2.getNumRows() && matBlock1.getNumColumns() == matBlock2.getNumColumns()) {
+			matBlock1.binaryOperationsInPlace(bop, matBlock2);
+			return matBlock1;
+		}
+		else {
 			return (MatrixBlock) matBlock1.binaryOperations(bop, matBlock2, new MatrixBlock());
-//		}
+		}
+	}
+	private static MatrixBlock plusMultiply(MatrixBlock matBlock1, MatrixBlock matBlock2, MatrixBlock matBlock3) {
+		return matBlock1.ternaryOperations(new TernaryOperator(PlusMultiply.getFnObject()), 
+				matBlock2, matBlock3, new MatrixBlock());
 	}
 	
+		
 	private static MatrixBlock multiply(MatrixBlock matBlock1, MatrixBlock matBlock2, boolean inplace) {
 		BinaryOperator bop = new BinaryOperator(Multiply.getMultiplyFnObject());
-//		if(inplace) {
-//			matBlock1.binaryOperationsInPlace(bop, matBlock2);
-//			return matBlock1;
-//		}
-//		else {
+		if(inplace && matBlock1.isInSparseFormat() == matBlock2.isInSparseFormat() &&
+			matBlock1.getNumRows() == matBlock2.getNumRows() && matBlock1.getNumColumns() == matBlock2.getNumColumns()) {
+			matBlock1.binaryOperationsInPlace(bop, matBlock2);
+			return matBlock1;
+		}
+		else {
 			return (MatrixBlock) matBlock1.binaryOperations(bop, matBlock2, new MatrixBlock());
-//		}
+		}
 	}
 	
+	
 	// sigmoid(0)*c_prev + sigmoid(0)*tanh(0);
 	
 	private static Builtin sigmoidOp = Builtin.getBuiltinFnObject(BuiltinCode.SIGMOID);
@@ -311,16 +319,10 @@ public class LibMatrixDNN {
 	private static MatrixBlock sigmoid(MatrixBlock in, int numThreads, boolean inPlace) {
 		return (MatrixBlock) in.unaryOperations(new UnaryOperator(sigmoidOp, numThreads, inPlace), new MatrixBlock());
 	}
-	
 	private static MatrixBlock tanh(MatrixBlock in, int numThreads, boolean inPlace) {
 		return (MatrixBlock) in.unaryOperations(new UnaryOperator(tanhOp, numThreads, inPlace), new MatrixBlock());
 	}
 	
-	private static MatrixBlock plusMultiply(MatrixBlock matBlock1, MatrixBlock matBlock2, MatrixBlock matBlock3) {
-		return matBlock1.ternaryOperations(new TernaryOperator(PlusMultiply.getFnObject()), 
-				matBlock2, matBlock3, new MatrixBlock());
-	}
-	
 	public static void lstm(MatrixBlock X, MatrixBlock W, MatrixBlock b, MatrixBlock out0, MatrixBlock c0, 
 			boolean return_seq, int N, int T, int D, int M,
 			MatrixBlock out, MatrixBlock c, // output 
@@ -329,61 +331,56 @@ public class LibMatrixDNN {
 		MatrixBlock out_prev = out0;
 		MatrixBlock c_prev = c0;
 		
-		MatrixBlock W1 = W.slice(0, D-1);
-		MatrixBlock W2 = W.slice(D, D+M-1);
+		MatrixBlock W1 = null;
+		MatrixBlock W2 = null;
 		MatrixBlock c_t = null;
 		MatrixBlock out_t = null;
 		
-		boolean profile = true;
-		long t1 = 0, t2 = 0, t3 = 0, t4 = 0, t5 = 0;
+		MatrixBlock input = null;
 		for(int t = 1; t <= T; t++) {
-			long s =  profile ? System.nanoTime() : 0;
-			MatrixBlock X_t = X.slice(0, N-1, (t-1)*D, t*D-1, new MatrixBlock());
-			if(profile) {
-				long e = System.nanoTime();
-				t1 += e - s;
-			}
-			
-			s =  profile ? System.nanoTime() : 0;
-			MatrixBlock ifog_raw = add(add(matmult(X_t, W1, numThreads), matmult(out_prev, W2, numThreads), true), b, true);
-			if(profile) {
-				long e = System.nanoTime();
-				t2 += e - s;
+			final MatrixBlock X_t = (T == 1) ? X : X.slice(0, N-1, (t-1)*D, t*D-1, new MatrixBlock());
+			MatrixBlock ifog_raw = null;
+			// Logic: Exploit sparse matrix multiplication whenever possible:
+			// 1. If W is sparse, perform cbind(X_t, out_prev) %*% W
+			// 2. Else if X_t is sparse, perform X_t %*% W1 + out_prev %*% W2
+			// 3. If none of the case is applicable, perform cbind(X_t, out_prev) %*% W
+			boolean isCase1 = W.isInSparseFormat();
+			boolean isCase2 = !isCase1 && X_t.isInSparseFormat();
+			if(isCase2) {
+				// Perform X_t %*% W1 + out_prev %*% W2
+				if(W1 == null) {
+					// Lazy slicing: applicable only when atleast one X_t is sparse.
+					W1 = W.slice(0, D-1);
+					W2 = W.slice(D, D+M-1);
+				}
+				ifog_raw = add(matmult(X_t, W1, numThreads), matmult(out_prev, W2, numThreads), true);
+				ifog_raw = add(ifog_raw, b, true);
+			} 
+			else {
+				// Case 1 and 3:
+				// Perform input %*% W, where input = cbind(X_t, out_prev)
+				if(input == null) {
+					input = new MatrixBlock(N, D+M, false);
+					input.allocateDenseBlock();
+				}
+				input = X_t.append(out_prev, input);
+				ifog_raw = add(matmult(input, W, numThreads), b, true);
 			}
 			
-			s =  profile ? System.nanoTime() : 0;
 			MatrixBlock ifo = ifog_raw.slice(0, N-1, 0, 3*M-1, new MatrixBlock());
 			ifo = sigmoid(ifo, numThreads, true);
 			MatrixBlock i = ifo.slice(0, N-1, 0, M-1, new MatrixBlock());
 			MatrixBlock f = ifo.slice(0, N-1, M, 2*M-1, new MatrixBlock());
 			MatrixBlock o = ifo.slice(0, N-1, 2*M, 3*M-1, new MatrixBlock());
-			
-			MatrixBlock g = ifog_raw.slice(0, N-1, 3*M, 4*M-1, new MatrixBlock());
-			g = tanh(g, numThreads, true);
-			if(profile) {
-				long e = System.nanoTime();
-				t3 += e - s;
-			}
-			
-			s =  profile ? System.nanoTime() : 0;
+			MatrixBlock g = tanh(ifog_raw.slice(0, N-1, 3*M, 4*M-1, new MatrixBlock()), numThreads, true);
+					
 			// c_t = f*c_prev + i*g
 			c_t = plusMultiply(multiply(f, c_prev, true), i, g);
 			// out_t = o*tanh(c)
 			out_t = multiply(o, tanh(c_t, numThreads, false), true);
-			if(profile) {
-				long e = System.nanoTime();
-				t4 += e - s;
-			}
-			
-			s =  profile ? System.nanoTime() : 0;
 			if(return_seq) {
 				out = out.leftIndexingOperations(out_t, 0, N-1, (t-1)*M, t*M-1, new MatrixBlock(), UpdateType.INPLACE);
 			}
-			if(profile) {
-				long e = System.nanoTime();
-				t5 += e - s;
-			}
-			
 			out_prev = out_t;
 			c_prev = c_t;
 			
@@ -398,11 +395,6 @@ public class LibMatrixDNN {
 			c.copy(c_t);
 		else
 			c.copy(c0);
-		System.out.println("Time taken in lstm forward call: [X_t indexing:" + String.format("%.3f", t1*1e-9) + 
-				", ifog_raw computation:" + String.format("%.3f", t2*1e-9) + 
-				", lstm_squash computation:" + String.format("%.3f", t3*1e-9) +  
-				", c_t/out_t computation:" + String.format("%.3f", t4*1e-9) + 
-				", out leftIndexing computation:" + String.format("%.3f", t5*1e-9));
 	}
 	
 	/**
@@ -1009,4 +1001,4 @@ public class LibMatrixDNN {
 			params.end_indexes_w[q] = Math.min(ix+params.S, params.W);
 		}
 	}
-}
+}
\ No newline at end of file