You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/12 05:27:56 UTC

systemml git commit: [SYSTEMML-1761] Performance wsloss w/o weights (sparsity exploitation)

Repository: systemml
Updated Branches:
  refs/heads/master b67f18641 -> 083cc77f8


[SYSTEMML-1761] Performance wsloss w/o weights (sparsity exploitation)

So far the wsloss operator w/o weights for sum((X-U%*%t(V))^2) was not
sparsity-exploiting due to the missing sparse driver (sparse matrix with
sparse-safe operation such as multiply or divide). However, this
expression can be rewritten into a sparsity-exploiting form with 
sum((X-U%*%t(V))^2) -> sum(X^2) - sum(2*X*(U%*%t(V)) +
sum((t(U)%*%U)*(t(V)%*%V)). 

This patch leverages leverages this rewrite for a much more efficient,
sparsity-exploiting and cache-conscious block-level implementation. The
performance improvements of the entire wsloss operation were as follows:

100K x 100K, sparsity=0.1, rank=100: 92.5s -> 8.5s
100K x 100K, sparsity=0.01, rank=100: 92.2s -> 1.3s
100K x 100K, sparsity=0.001, rank=100: 92.1s -> 0.4s


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/083cc77f
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/083cc77f
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/083cc77f

Branch: refs/heads/master
Commit: 083cc77f82b70748b70f8cc311fde040034bcc7c
Parents: b67f186
Author: Matthias Boehm <mb...@gmail.com>
Authored: Tue Jul 11 21:41:20 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Tue Jul 11 22:28:52 2017 -0700

----------------------------------------------------------------------
 .../runtime/matrix/data/LibMatrixMult.java      | 126 +++++++++++++------
 1 file changed, 88 insertions(+), 38 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/083cc77f/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
index 5996a51..da3b12b 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
@@ -390,7 +390,7 @@ public class LibMatrixMult
 		
 		//check no parallelization benefit (fallback to sequential)
 		//check too small workload in terms of flops (fallback to sequential too)
-		if( ret.rlen == 1 
+		if( ret.rlen == 1 || k <= 1
 			|| leftTranspose && 1L * m1.rlen * m1.clen * m1.clen < PAR_MINFLOP_THRESHOLD
 			|| !leftTranspose && 1L * m1.clen * m1.rlen * m1.rlen < PAR_MINFLOP_THRESHOLD) 
 		{ 
@@ -533,6 +533,10 @@ public class LibMatrixMult
 		else
 			matrixMultWSLossGeneric(mX, mU, mV, mW, ret, wt, 0, mX.rlen);
 		
+		//add correction for sparse wsloss w/o weight
+		if( mX.sparse && wt==WeightsType.NONE )
+			addMatrixMultWSLossNoWeightCorrection(mU, mV, ret, 1);
+		
 		//System.out.println("MMWSLoss " +wt.toString()+ " ("+mX.isInSparseFormat()+","+mX.getNumRows()+","+mX.getNumColumns()+","+mX.getNonZeros()+")x" +
 		//                  "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop());
 	}
@@ -572,6 +576,10 @@ public class LibMatrixMult
 			throw new DMLRuntimeException(e);
 		}
 
+		//add correction for sparse wsloss w/o weight
+		if( mX.sparse && wt==WeightsType.NONE )
+			addMatrixMultWSLossNoWeightCorrection(mU, mV, ret, k);
+		
 		//System.out.println("MMWSLoss "+wt.toString()+" k="+k+" ("+mX.isInSparseFormat()+","+mX.getNumRows()+","+mX.getNumColumns()+","+mX.getNonZeros()+")x" +
 		//                   "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop());
 	}
@@ -2163,36 +2171,34 @@ public class LibMatrixMult
 		// Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
 		else if( wt==WeightsType.NONE )
 		{
-			// approach: iterate over all cells of X and 
-			for( int i=rl, uix=rl*cd; i<ru; i++, uix+=cd ) 
-			{
-				if( x.isEmpty(i) ) { //empty row
-					for( int j=0, vix=0; j<n; j++, vix+=cd) {
-						double uvij = dotProduct(u, v, uix, vix, cd);
-						wsloss += (-uvij)*(-uvij);
-					}
-				}
-				else { //non-empty row
-					int xpos = x.pos(i);
-					int xlen = x.size(i);
-					int[] xix = x.indexes(i);
-					double[] xval = x.values(i);
-					int last = -1;
-					for( int k=xpos; k<xpos+xlen; k++ ) {
-						//process last nnz til current nnz
-						for( int k2=last+1; k2<xix[k]; k2++ ){
-							double uvij = dotProduct(u, v, uix, k2*cd, cd);
-							wsloss += (-uvij)*(-uvij);							
+			//approach: use sparsity-exploiting pattern rewrite sum((X-(U%*%t(V)))^2) 
+			//-> sum(X^2)-sum(2*X*(U%*%t(V))))+sum((t(U)%*%U)*(t(V)%*%V)), where each
+			//parallel task computes sum(X^2)-sum(2*X*(U%*%t(V)))) and the last term
+			//sum((t(U)%*%U)*(t(V)%*%V)) is computed once via two tsmm operations.
+			
+			final int blocksizeIJ = (int) (8L*mX.rlen*mX.clen/mX.nonZeros); 
+			int[] curk = new int[blocksizeIJ];			
+			
+			for( int bi=rl; bi<ru; bi+=blocksizeIJ ) {
+				int bimin = Math.min(ru, bi+blocksizeIJ);
+				//prepare starting indexes for block row
+				Arrays.fill(curk, 0); 
+				//blocked execution over column blocks
+				for( int bj=0; bj<n; bj+=blocksizeIJ ) {
+					int bjmin = Math.min(n, bj+blocksizeIJ);
+					for( int i=bi, uix=bi*cd; i<bimin; i++, uix+=cd ) {
+						if( x.isEmpty(i) ) continue; 
+						int xpos = x.pos(i);
+						int xlen = x.size(i);
+						int[] xix = x.indexes(i);
+						double[] xval = x.values(i);
+						int k = xpos + curk[i-bi];
+						for( ; k<xpos+xlen && xix[k]<bjmin; k++ ) {
+							double xij = xval[k];
+							double uvij = dotProduct(u, v, uix, xix[k]*cd, cd);
+							wsloss += xij * xij - 2 * xij * uvij;
 						}
-						//process current nnz
-						double uvij = dotProduct(u, v, uix, xix[k]*cd, cd);
-						wsloss += (xval[k]-uvij)*(xval[k]-uvij);
-						last = xix[k];
-					}
-					//process last nnz til end of row
-					for( int k2=last+1; k2<n; k2++ ) { 
-						double uvij = dotProduct(u, v, uix, k2*cd, cd);
-						wsloss += (-uvij)*(-uvij);							
+						curk[i-bi] = k - xpos;
 					}
 				}
 			}
@@ -2291,18 +2297,52 @@ public class LibMatrixMult
 		// Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
 		else if( wt==WeightsType.NONE )
 		{
-			// approach: iterate over all cells of X and 
-			for( int i=rl; i<ru; i++ )
-				for( int j=0; j<n; j++)
-				{
-					double xij = mX.quickGetValue(i, j);
-					double uvij = dotProductGeneric(mU, mV, i, j, cd);
-					wsloss += (xij-uvij)*(xij-uvij);
-				}
+			//approach: use sparsity-exploiting pattern rewrite sum((X-(U%*%t(V)))^2) 
+			//-> sum(X^2)-sum(2*X*(U%*%t(V))))+sum((t(U)%*%U)*(t(V)%*%V)), where each
+			//parallel task computes sum(X^2)-sum(2*X*(U%*%t(V)))) and the last term
+			//sum((t(U)%*%U)*(t(V)%*%V)) is computed once via two tsmm operations.
+			
+			if( mW.sparse ) { //SPARSE
+				SparseBlock x = mX.sparseBlock;
+				for( int i=rl; i<ru; i++ ) {
+					if( x.isEmpty(i) ) continue;
+					int xpos = x.pos(i);
+					int xlen = x.size(i);
+					int[] xix = x.indexes(i);
+					double[] xval = x.values(i);
+					for( int k=xpos; k<xpos+xlen; k++ ) {
+						double xij = xval[k];
+						double uvij = dotProductGeneric(mU, mV, i, xix[k], cd);
+						wsloss += xij * xij - 2 * xij * uvij;
+					}
+				}	
+			}
+			else { //DENSE
+				double[] x = mX.denseBlock;
+				for( int i=rl, xix=rl*n; i<ru; i++, xix+=n )
+					for( int j=0; j<n; j++)
+						if( x[xix+j] != 0 ) {
+							double xij = x[xix+j];
+							double uvij = dotProductGeneric(mU, mV, i, j, cd);
+							wsloss += xij * xij - 2 * xij * uvij;
+						}
+			}
 		}
 
 		ret.quickSetValue(0, 0, wsloss);
 	}
+	
+	private static void addMatrixMultWSLossNoWeightCorrection(MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, int k) 
+		throws DMLRuntimeException 
+	{
+		MatrixBlock tmp1 = new MatrixBlock(mU.clen, mU.clen, false);
+		MatrixBlock tmp2 = new MatrixBlock(mU.clen, mU.clen, false);
+		matrixMultTransposeSelf(mU, tmp1, true, k);
+		matrixMultTransposeSelf(mV, tmp2, true, k);
+		ret.quickSetValue(0, 0, ret.quickGetValue(0, 0) + 
+			((tmp1.sparse || tmp2.sparse) ? dotProductGeneric(tmp1, tmp2) :
+			dotProduct(tmp1.denseBlock, tmp2.denseBlock, mU.clen*mU.clen)));
+	}
 
 	private static void matrixMultWSigmoidDense(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WSigmoidType wt, int rl, int ru) 
 		throws DMLRuntimeException 
@@ -3405,6 +3445,16 @@ public class LibMatrixMult
 		return val;
 	}
 	
+	private static double dotProductGeneric(MatrixBlock a, MatrixBlock b)
+	{
+		double val = 0;
+		for( int i=0; i<a.getNumRows(); i++ )
+			for( int j=0; j<a.getNumColumns(); j++ )
+				val += a.quickGetValue(i, j) * b.quickGetValue(i, j);
+		
+		return val;
+	}
+	
 	/**
 	 * Used for all version of TSMM where the result is known to be symmetric.
 	 * Hence, we compute only the upper triangular matrix and copy this partial