You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/03/27 23:28:54 UTC

[systemds] branch master updated: [SYSTEMDS-2856] Extended multi-threading element-wise binary operations

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 28a6fa5  [SYSTEMDS-2856] Extended multi-threading element-wise binary operations
28a6fa5 is described below

commit 28a6fa5443f29d878f182817ed1ac5f77a6c502f
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sat Mar 27 21:16:26 2021 +0100

    [SYSTEMDS-2856] Extended multi-threading element-wise binary operations
    
    This patch makes some minor performance improvements, by now also
    supporting multi-threading for matrix-vector (not just matrix-matrix)
    binary element-wise, sparse-safe operations. Furthermore, this also
    includes a small specialized code path for axpy (+* and -*).
---
 .../runtime/functionobjects/MinusMultiply.java     |   4 +
 .../runtime/functionobjects/PlusMultiply.java      |   4 +
 .../runtime/matrix/data/LibMatrixBincell.java      | 127 ++++++++++++---------
 3 files changed, 81 insertions(+), 54 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/MinusMultiply.java b/src/main/java/org/apache/sysds/runtime/functionobjects/MinusMultiply.java
index 9a34603..3d56d8c 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/MinusMultiply.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/MinusMultiply.java
@@ -39,6 +39,10 @@ public class MinusMultiply extends TernaryValueFunction implements ValueFunction
 	private MinusMultiply(double cnt) {
 		_cnt = cnt;
 	}
+	
+	public double getConstant() {
+		return _cnt;
+	}
 
 	public static MinusMultiply getFnObject() {
 		if ( singleObj == null )
diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/PlusMultiply.java b/src/main/java/org/apache/sysds/runtime/functionobjects/PlusMultiply.java
index 2ae8f0b..85033b3 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/PlusMultiply.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/PlusMultiply.java
@@ -45,6 +45,10 @@ public class PlusMultiply extends TernaryValueFunction implements ValueFunctionW
 			singleObj = new PlusMultiply();
 		return singleObj;
 	}
+
+	public double getConstant() {
+		return _cnt;
+	}
 	
 	@Override
 	public double execute(double in1, double in2, double in3) {
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
index 44a25f4..161328e 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
@@ -75,7 +75,11 @@ public class LibMatrixBincell
 		MATRIX_COL_VECTOR,
 		MATRIX_ROW_VECTOR,
 		OUTER_VECTOR_VECTOR,
-		INVALID,
+		INVALID;
+		public boolean isMatrixVector() {
+			return this == MATRIX_COL_VECTOR
+				|| this == MATRIX_ROW_VECTOR;
+		}
 	}
 	
 	private LibMatrixBincell() {
@@ -194,7 +198,8 @@ public class LibMatrixBincell
 		if( m1.isEmpty() || m2.isEmpty()
  			|| ret.getLength() < PAR_NUMCELL_THRESHOLD2
 			|| ((op.sparseSafe || isSparseSafeDivide(op, m2))
-				&& atype != BinaryAccessType.MATRIX_MATRIX))
+				&& !(atype == BinaryAccessType.MATRIX_MATRIX
+					|| atype.isMatrixVector() && isAllDense(m1, m2, ret))))
 		{
 			bincellOp(m1, m2, ret, op);
 			return;
@@ -296,6 +301,10 @@ public class LibMatrixBincell
 		return (op.fn instanceof Divide && rhs.getNonZeros()==(long)rhs.getNumRows()*rhs.getNumColumns());
 	}
 	
+	public static boolean isAllDense(MatrixBlock... mb) {
+		return Arrays.stream(mb).allMatch(m -> !m.sparse);
+	}
+	
 	//////////////////////////////////////////////////////
 	// private sparse-safe/sparse-unsafe implementations
 	///////////////////////////////////
@@ -323,7 +332,7 @@ public class LibMatrixBincell
 		{
 			//note: m2 vector and hence always dense
 			if( !m1.sparse && !m2.sparse && !ret.sparse ) //DENSE all
-				safeBinaryMVDense(m1, m2, ret, op);
+				return safeBinaryMVDense(m1, m2, ret, op, rl, ru);
 			else if( m1.sparse && !m2.sparse && !ret.sparse
 				&& atype == BinaryAccessType.MATRIX_ROW_VECTOR)
 				safeBinaryMVSparseDenseRow(m1, m2, ret, op);
@@ -374,18 +383,20 @@ public class LibMatrixBincell
 		return ret.getNonZeros();
 	}
 
-	private static void safeBinaryMVDense(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) {
+	private static long safeBinaryMVDense(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op, int rl, int ru) {
 		boolean isMultiply = (op.fn instanceof Multiply);
 		boolean skipEmpty = (isMultiply);
 		BinaryAccessType atype = getBinaryAccessType(m1, m2);
-		int rlen = m1.rlen;
 		int clen = m1.clen;
 		
 		//early abort on skip and empy
 		if( skipEmpty && (m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) )
-			return; // skip entire empty block
+			return 0; // skip entire empty block
+		
+		//guard for postponed allocation in single-threaded exec
+		if( !ret.isAllocated() )
+			ret.allocateDenseBlock();
 		
-		ret.allocateDenseBlock();
 		DenseBlock da = m1.getDenseBlock();
 		DenseBlock dc = ret.getDenseBlock();
 		long nnz = 0;
@@ -394,34 +405,31 @@ public class LibMatrixBincell
 		{
 			double[] b = m2.getDenseBlockValues(); // always single block
 			
-			for( int bi=0; bi<dc.numBlocks(); bi++ ) {
-				double[] a = da.valuesAt(bi);
-				double[] c = dc.valuesAt(bi);
-				int len = dc.blockSize(bi);
-				int off = bi * dc.blockSize();
-				for( int i=0, ix=0; i<len; i++, ix+=clen )
-				{
-					//replicate vector value
-					double v2 = (b==null) ? 0 : b[off+i];
-					if( skipEmpty && v2 == 0 ) //skip empty rows
-						continue;
-						
-					if( isMultiply && v2 == 1 ) { //ROW COPY
-						//a guaranteed to be non-null (see early abort)
-						System.arraycopy(a, ix, c, ix, clen);
-						nnz += m1.recomputeNonZeros(i, i, 0, clen-1);
-					}
-					else { //GENERAL CASE
-						if( a != null )
-							for( int j=0; j<clen; j++ ) {
-								c[ix+j] = op.fn.execute( a[ix+j], v2 );	
-								nnz += (c[ix+j] != 0) ? 1 : 0;
-							}
-						else {
-							double val = op.fn.execute( 0, v2 );
-							Arrays.fill(c, ix, ix+clen, val);
-							nnz += (val != 0) ? clen : 0;
+			for( int i=rl; i<ru; i++ ) {
+				double[] a = da.values(i);
+				double[] c = dc.values(i);
+				int ix = da.pos(i);
+				
+				//replicate vector value
+				double v2 = (b==null) ? 0 : b[i];
+				if( skipEmpty && v2 == 0 ) //skip empty rows
+					continue;
+					
+				if( isMultiply && v2 == 1 ) { //ROW COPY
+					//a guaranteed to be non-null (see early abort)
+					System.arraycopy(a, ix, c, ix, clen);
+					nnz += m1.recomputeNonZeros(i, i, 0, clen-1);
+				}
+				else { //GENERAL CASE
+					if( a != null )
+						for( int j=0; j<clen; j++ ) {
+							double val = op.fn.execute( a[ix+j], v2 );
+							nnz += ((c[ix+j] = val) != 0) ? 1 : 0;
 						}
+					else {
+						double val = op.fn.execute( 0, v2 );
+						Arrays.fill(c, ix, ix+clen, val);
+						nnz += (val != 0) ? clen : 0;
 					}
 				}
 			}
@@ -431,38 +439,37 @@ public class LibMatrixBincell
 			double[] b = m2.getDenseBlockValues(); // always single block
 			
 			if( da==null && b==null ) { //both empty
-				double v = op.fn.execute( 0, 0 );
-				dc.set(v);
-				nnz += (v != 0) ? (long)rlen*clen : 0;
+				double val = op.fn.execute( 0, 0 );
+				dc.set(rl, ru, 0, clen, val);
+				nnz += (val != 0) ? (long)(ru-rl)*clen : 0;
 			}
 			else if( da==null ) //left empty
 			{
 				//compute first row
-				double[] c = dc.valuesAt(0);
+				double[] c = dc.values(rl);
 				for( int j=0; j<clen; j++ ) {
-					c[j] = op.fn.execute( 0, b[j] );
-					nnz += (c[j] != 0) ? rlen : 0;
+					double val = op.fn.execute( 0, b[j] );
+					nnz += ((c[j]=val) != 0) ? (ru-rl) : 0;
 				}
 				//copy first to all other rows
-				for( int i=1; i<rlen; i++ )
+				for( int i=rl+1; i<ru; i++ )
 					dc.set(i, c);
 			}
 			else //default case (incl right empty) 
 			{
-				for( int bi=0; bi<dc.numBlocks(); bi++ ) {
-					double[] a = da.valuesAt(bi);
-					double[] c = dc.valuesAt(bi);
-					int len = dc.blockSize(bi);
-					for( int i=0, ix=0; i<len; i++, ix+=clen )
-						for( int j=0; j<clen; j++ ) {
-							c[ix+j] = op.fn.execute( a[ix+j], ((b!=null) ? b[j] : 0) );
-							nnz += (c[ix+j] != 0) ? 1 : 0;
-						}
+				for( int i=rl; i<ru; i++ ) {
+					double[] a = da.values(i);
+					double[] c = dc.values(i);
+					int ix = da.pos(i);
+					for( int j=0; j<clen; j++ ) {
+						double val = op.fn.execute( a[ix+j], ((b!=null) ? b[j] : 0) );
+						nnz += ((c[ix+j]=val) != 0) ? 1 : 0;
+					}
 				}
 			}
 		}
 		
-		ret.nonZeros = nnz;
+		return nnz;
 	}
 
 	private static void safeBinaryMVSparseDenseRow(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) {
@@ -924,6 +931,10 @@ public class LibMatrixBincell
 	private static long safeBinaryMMDenseDenseDense(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret,
 		BinaryOperator op, int rl, int ru)
 	{
+		boolean isPM = m1.clen >= 512 & (op.fn instanceof PlusMultiply | op.fn instanceof MinusMultiply);
+		double cntPM = !isPM ? Double.NaN : (op.fn instanceof PlusMultiply ?
+			((PlusMultiply)op.fn).getConstant() : -1d * ((MinusMultiply)op.fn).getConstant());
+		
 		//guard for postponed allocation in single-threaded exec
 		if( !ret.isAllocated() )
 			ret.allocateDenseBlock();
@@ -941,9 +952,17 @@ public class LibMatrixBincell
 			double[] b = db.values(i);
 			double[] c = dc.values(i);
 			int pos = da.pos(i);
-			for(int j=pos; j<pos+clen; j++) {
-				c[j] = fn.execute(a[j], b[j]);
-				lnnz += (c[j]!=0)? 1 : 0;
+			
+			if( isPM ) {
+				System.arraycopy(a, pos, c, pos, clen);
+				LibMatrixMult.vectMultiplyAdd(cntPM, b, c, pos, pos, clen);
+				lnnz += UtilFunctions.computeNnz(c, pos, clen);
+			}
+			else {
+				for(int j=pos; j<pos+clen; j++) {
+					c[j] = fn.execute(a[j], b[j]);
+					lnnz += (c[j]!=0)? 1 : 0;
+				}
 			}
 		}
 		return lnnz;