You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2022/01/16 13:06:46 UTC

[systemds] 01/03: [SYSTEMDS-3243] Consistent allocation of MatrixBlock for MM

This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 80025ddf3e8409d238c06084445dafa55e7a8579
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Sun Jan 16 13:56:01 2022 +0100

    [SYSTEMDS-3243] Consistent allocation of MatrixBlock for MM
    
    This commit change the matrix multiplication to not allocate or analyze
    the output and inputs before calls to the libraries, to remove a
    unnecessary analysis step from MatrixBlock, and avoid sparse
    allocation into a dense allocation in some cases.
    
    The MM is now consolidated to only have one code path (both single and
    multithreaded) that check for output allocation making the API more
    robust and remove code duplication.
---
 .../sysds/runtime/matrix/data/LibMatrixMult.java   | 228 ++++++++++++---------
 .../sysds/runtime/matrix/data/LibMatrixNative.java | 135 ++++++------
 .../sysds/runtime/matrix/data/MatrixBlock.java     |  66 +++---
 3 files changed, 217 insertions(+), 212 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index a503085..8384dd2 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -88,14 +88,14 @@ public class LibMatrixMult
 	 * 
 	 * All variants use a IKJ access pattern, and internally use dense output. After the
 	 * actual computation, we recompute nnz and check for sparse/dense representation.
-	 *  
 	 * 
 	 * @param m1 first matrix
 	 * @param m2 second matrix
 	 * @param ret result matrix
+	 * @return ret Matrix Block
 	 */
-	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret) {
-		matrixMult(m1, m2, ret, 0, m1.rlen);
+	public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret) {
+		return matrixMult(m1, m2, ret, false, 1);
 	}
 	
 	/**
@@ -109,141 +109,165 @@ public class LibMatrixMult
 	 * @param m2 second matrix
 	 * @param ret result matrix
 	 * @param fixedRet if true, output representation is fixed and nnzs not recomputed
+	 * @return ret Matrix Block
 	 */
-	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean fixedRet) {
-		matrixMult(m1, m2, ret, 0, m1.rlen, fixedRet);
+	public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean fixedRet) {
+		return matrixMult(m1, m2, ret, fixedRet, 1);
 	}
 	
-	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru) {
-		matrixMult(m1, m2, ret, rl, ru, false);
+	/**
+	 * Performs a multi-threaded matrix multiplication and stores the result in the output matrix.
+	 * The parameter k (k&gt;=1) determines the max parallelism k' with k'=min(k, vcores, m1.rlen).
+	 * 
+	 * @param m1 first matrix
+	 * @param m2 second matrix
+	 * @param ret result matrix
+	 * @param k maximum parallelism
+	 * @return ret Matrix Block
+	 */
+	public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
+		return matrixMult(m1, m2, ret, false, k);
 	}
 	
-	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, boolean fixedRet) {
-		//check inputs / outputs
-		if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) {
-			ret.examSparsity(); //turn empty dense into sparse
-			return;
-		}
+	/**
+	 * Performs a matrix multiplication and stores the result in the output matrix.
+	 * 
+	 * All variants use a IKJ access pattern, and internally use dense output. After the
+	 * actual computation, we recompute nnz and check for sparse/dense representation.
+	 * 
+	 * This method allows one to disabling exam sparsity. This feature is useful if matrixMult is used as an intermediate
+	 * operation (for example: LibMatrixDNN). It makes sense for LibMatrixDNN because the output is internally
+	 * consumed by another dense instruction, which makes repeated conversion to sparse wasteful.
+	 * This should be used in rare cases and if you are unsure,
+	 * use the method 'matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret)' instead.
+	 * 
+	 * The parameter k (k&gt;=1) determines the max parallelism k' with k'=min(k, vcores, m1.rlen).
+	 * 
+	 * @param m1 first matrix
+	 * @param m2 second matrix
+	 * @param ret result matrix
+	 * @param fixedRet if true, output representation is fixed and nnzs not recomputed
+	 * @param k maximum parallelism
+	 * @return ret Matrix Block
+	 */
+	public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean fixedRet, int k) {
+		if(m1.isEmptyBlock(false) || m2.isEmptyBlock(false)) 
+			return emptyMatrixMult(m1, m2, ret);
 		
-		//Timing time = new Timing(true);
+		// Timing time = new Timing(true);
 		
-		//pre-processing: output allocation
+		// pre analysis
 		boolean m1Perm = m1.isSparsePermutationMatrix();
-		boolean ultraSparse = (fixedRet && ret.sparse)
-			|| (!fixedRet && isUltraSparseMatrixMult(m1, m2, m1Perm));
-		boolean sparse = !m1Perm && !ultraSparse && !fixedRet 
+		boolean ultraSparse = (fixedRet && ret.sparse) ||
+			(!fixedRet && isUltraSparseMatrixMult(m1, m2, m1Perm));
+		boolean sparse = !fixedRet && !ultraSparse && !m1Perm
 			&& isSparseOutputMatrixMult(m1, m2);
-		boolean tm2 = checkPrepMatrixMultRightInput(m1,m2);
-		m2 = prepMatrixMultRightInput(m1, m2);
-		ret.sparse = ultraSparse | sparse;
-		ret.allocateBlock();
 		
-		//prepare row-upper for special cases of vector-matrix
-		boolean pm2 = !ultraSparse &&
-			checkParMatrixMultRightInputRows(m1, m2, Integer.MAX_VALUE);
-		int ru2 = (pm2 && ru==m1.rlen) ? m2.rlen : ru; 
-		int cu = m2.clen;
+		// allocate output
+		if(ret == null)
+			ret = new MatrixBlock(m1.rlen, m2.clen, ultraSparse | sparse);
+		else 
+			ret.reset(m1.rlen, m2.clen, ultraSparse | sparse);
+		ret.allocateBlock();
 		
-		//core matrix mult computation
-		if( ultraSparse )
+		// Detect if we should transpose skinny right side.
+		boolean tm2 = !fixedRet && checkPrepMatrixMultRightInput(m1,m2);
+		m2 = prepMatrixMultRightInput(m1, m2, tm2);
+
+		// check for multi-threading
+		if (!ret.isThreadSafe() 
+				|| !satisfiesMultiThreadingConstraints(m1, m2, m1.rlen==1, true, 2, k)
+				|| fixedRet) // Fixed ret not supported in multithreaded execution yet
+			k = 1;
+
+		if(k <= 1)
+			singleThreadMatrixMult(m1, m2, ret, ultraSparse, sparse, tm2, m1Perm, fixedRet);
+		else
+			parallelMatrixMult(m1, m2, ret, k, ultraSparse, sparse, tm2, m1Perm);
+
+		//System.out.println("MM "+k+" ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x" +
+		//		"("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop());
+	
+		return ret;
+	}
+
+	private static void singleThreadMatrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret,  
+		boolean ultraSparse, boolean sparse, boolean tm2, boolean m1Perm, boolean fixedRet){
+		// prepare row-upper for special cases of vector-matrix
+		final boolean pm2 = !ultraSparse && checkParMatrixMultRightInputRows(m1, m2, Integer.MAX_VALUE);
+		final int ru2 = (pm2) ? m2.rlen : m1.rlen;
+
+		// core matrix mult computation
+		if(ultraSparse)
 			matrixMultUltraSparse(m1, m2, ret, m1Perm, 0, ru2);
 		else if(!m1.sparse && !m2.sparse)
-			matrixMultDenseDense(m1, m2, ret, tm2, pm2, 0, ru2, 0, cu);
+			matrixMultDenseDense(m1, m2, ret, tm2, pm2, 0, ru2, 0, m2.clen);
 		else if(m1.sparse && m2.sparse)
 			matrixMultSparseSparse(m1, m2, ret, pm2, sparse, 0, ru2);
 		else if(m1.sparse)
 			matrixMultSparseDense(m1, m2, ret, pm2, 0, ru2);
 		else
 			matrixMultDenseSparse(m1, m2, ret, pm2, 0, ru2);
-		
-		//post-processing: nnz/representation
-		if( !fixedRet ) {
-			if( !ret.sparse )
+
+		// post-processing: nnz/representation
+		if(!fixedRet) {
+			if(!ret.sparse)
 				ret.recomputeNonZeros();
 			ret.examSparsity();
 		}
-		
-		//System.out.println("MM ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x" +
-		//		"("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop());
 	}
-	
-	/**
-	 * Performs a multi-threaded matrix multiplication and stores the result in the output matrix.
-	 * The parameter k (k&gt;=1) determines the max parallelism k' with k'=min(k, vcores, m1.rlen).
-	 * 
-	 * @param m1 first matrix
-	 * @param m2 second matrix
-	 * @param ret result matrix
-	 * @param k maximum parallelism
-	 */
-	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
-		//check inputs / outputs
-		if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) {
-			ret.examSparsity(); //turn empty dense into sparse
-			return;
-		}
-		
-		//check too small workload and fallback to sequential if needed
-		if( !satisfiesMultiThreadingConstraints(m1, m2, m1.rlen==1, true, 2, k) ) {
-			matrixMult(m1, m2, ret);
-			return;
-		}
-		
-		//Timing time = new Timing(true);
-		
-		//pre-processing: output allocation (in contrast to single-threaded,
-		//we need to allocate sparse as well in order to prevent synchronization)
-		boolean m1Perm = m1.isSparsePermutationMatrix();
-		boolean ultraSparse = isUltraSparseMatrixMult(m1, m2, m1Perm);
-		boolean sparse = !ultraSparse && !m1Perm && isSparseOutputMatrixMult(m1, m2);
-		boolean tm2 = checkPrepMatrixMultRightInput(m1,m2);
-		m2 = prepMatrixMultRightInput(m1, m2);
-		ret.sparse = ultraSparse | sparse;
-		ret.allocateBlock();
-		
-		if (!ret.isThreadSafe()) {
-			matrixMult(m1, m2, ret);
-			return;
-		}
-		
-		//prepare row-upper for special cases of vector-matrix / matrix-matrix
+
+	private static void parallelMatrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k, 
+		boolean ultraSparse, boolean sparse, boolean tm2, boolean m1Perm){
+		// prepare row-upper for special cases of vector-matrix / matrix-matrix
 		boolean pm2r = !ultraSparse && !sparse && checkParMatrixMultRightInputRows(m1, m2, k);
 		boolean pm2c = !ultraSparse && checkParMatrixMultRightInputCols(m1, m2, k, pm2r);
-		int num = pm2r ? m2.rlen : pm2c ? m2.clen : m1.rlen; 
-		
-		//core multi-threaded matrix mult computation
-		//(currently: always parallelization over number of rows)
+		int num = pm2r ? m2.rlen : pm2c ? m2.clen : m1.rlen;
+
+		// core multi-threaded matrix mult computation
+		// (currently: always parallelization over number of rows)
 		try {
 			ExecutorService pool = CommonThreadPool.get(k);
 			ArrayList<MatrixMultTask> tasks = new ArrayList<>();
-			ArrayList<Integer> blklens = UtilFunctions.getBalancedBlockSizesDefault(num, k, (pm2r||pm2c));
-			for( int i=0, lb=0; i<blklens.size(); lb+=blklens.get(i), i++ )
-				tasks.add(new MatrixMultTask(m1, m2, ret, tm2, pm2r, pm2c, m1Perm, sparse, lb, lb+blklens.get(i)));
-			//execute tasks
+			ArrayList<Integer> blklens = UtilFunctions.getBalancedBlockSizesDefault(num, k, (pm2r || pm2c));
+			for(int i = 0, lb = 0; i < blklens.size(); lb += blklens.get(i), i++)
+				tasks.add(new MatrixMultTask(m1, m2, ret, tm2, pm2r, pm2c, m1Perm, sparse, lb, lb + blklens.get(i)));
+			// execute tasks
 			List<Future<Object>> taskret = pool.invokeAll(tasks);
 			pool.shutdown();
-			//aggregate partial results (nnz, ret for vector/matrix)
-			ret.nonZeros = 0; //reset after execute
-			for( Future<Object> task : taskret ) {
-				if( pm2r ) //guaranteed single block
-					vectAdd((double[])task.get(), ret.getDenseBlockValues(), 0, 0, ret.rlen*ret.clen);
+			// aggregate partial results (nnz, ret for vector/matrix)
+			ret.nonZeros = 0; // reset after execute
+			for(Future<Object> task : taskret) {
+				if(pm2r) // guaranteed single block
+					vectAdd((double[]) task.get(), ret.getDenseBlockValues(), 0, 0, ret.rlen * ret.clen);
 				else
-					ret.nonZeros += (Long)task.get();
+					ret.nonZeros += (Long) task.get();
 			}
-			if( pm2r )
+			if(pm2r)
 				ret.recomputeNonZeros();
 		}
 		catch(Exception ex) {
 			throw new DMLRuntimeException(ex);
 		}
-		
-		//post-processing (nnz maintained in parallel)
+
+		// post-processing (nnz maintained in parallel)
 		ret.examSparsity();
-		
-		//System.out.println("MM k="+k+" ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x" +
-		//		"("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop());
 	}
-	
+
+	public static MatrixBlock emptyMatrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret){
+		final int rl = m1.rlen;
+		final int cl = m2.clen;
+
+		if(ret == null)
+			return new MatrixBlock(rl, cl, true);
+		else {
+			ret.reset(rl, cl, true);
+			ret.setNonZeros(0);
+			ret.cleanupBlock(true, true);
+			return ret;
+		}
+	}
+
 	/**
 	 * Performs a matrix multiplication chain operation of type t(X)%*%(X%*%v) or t(X)%*%(w*(X%*%v)).
 	 * 
@@ -3959,16 +3983,16 @@ public class LibMatrixMult
 		boolean sparseOut = MatrixBlock.evalSparseFormatInMemory(m1.rlen, m2.clen, estNnz);
 		return m2.clen < 4*1024 && sparseOut;
 	}
-	
+
 	public static boolean isOuterProductTSMM(int rlen, int clen, boolean left) {
 		return left ? rlen == 1 & clen > 1 : rlen > 1 & clen == 1;
 	}
 
-	private static MatrixBlock prepMatrixMultRightInput( MatrixBlock m1, MatrixBlock m2 ) {
+	private static MatrixBlock prepMatrixMultRightInput( MatrixBlock m1, MatrixBlock m2, boolean tm2 ) {
 		MatrixBlock ret = m2;
 		
 		//transpose if dense-dense, skinny rhs matrix (not vector), and memory guarded by output 
-		if( checkPrepMatrixMultRightInput(m1, m2)  ) {
+		if( tm2 ) {
 			MatrixBlock tmpBlock = new MatrixBlock(m2.clen, m2.rlen, m2.sparse);
 			LibMatrixReorg.reorg(m2, tmpBlock, new ReorgOperator(SwapIndex.getSwapIndexFnObject()));
 			ret = tmpBlock;
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
index 5c92253..ac0b069 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
@@ -64,86 +64,83 @@ public class LibMatrixNative
 	 * @param m2 rhs matrix block
 	 * @param ret output matrix block
 	 * @param k number of threads
+	 * @return the ret matrixBlock if allocated otherwise a new matrixBlock.
 	 */
-	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
-		matrixMult(m1, m2, ret, k, true);
-	}
-	
-	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k, boolean examSparsity) {
-		// Sanity check:
-		k = k <= 0 ? NativeHelper.getMaxNumThreads() : k;
-		
-		// check inputs / outputs
-		if (m1.isEmptyBlock(false) || m2.isEmptyBlock(false)){
-			ret.setNonZeros(0);
-			if(examSparsity)
-				ret.examSparsity(); // turn empty dense into sparse
-			return;
-		}
+	public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
 		
-		boolean isValidForNative = !isMatMultMemoryBound(m1.rlen, m1.clen, m2.clen) 
-			&& !m1.isInSparseFormat() && !m2.isInSparseFormat()
-			&& (m1.getDenseBlock().isContiguous() || !isSinglePrecision())
-			&& m2.getDenseBlock().isContiguous() //contiguous but not allocated
-			&& 8L * ret.getLength() < Integer.MAX_VALUE;
-
-		if( NativeHelper.isNativeLibraryLoaded() && isValidForNative ) 
-		{
-			ret.sparse = false;
-			ret.allocateDenseBlock();
-			long start = DMLScript.STATISTICS ? System.nanoTime() : 0;
-			long nnz = 0;
-			if( isSinglePrecision() ) {
-				FloatBuffer fin1 = toFloatBuffer(m1.getDenseBlockValues(), inBuff, true);
-				FloatBuffer fin2 = toFloatBuffer(m2.getDenseBlockValues(), filterBuff, true);
-				FloatBuffer fout = toFloatBuffer(ret.getDenseBlockValues(), outBuff, false);
-				nnz = NativeHelper.smmdd(fin1, fin2, fout, 
-					m1.getNumRows(), m1.getNumColumns(), m2.getNumColumns(), k);
-				fromFloatBuffer(outBuff.get(), ret.getDenseBlockValues());
-			}
-			else {
-				DenseBlock a = m1.getDenseBlock();
-				if( a.isContiguous() ) {
-					nnz = NativeHelper.dmmdd(m1.getDenseBlockValues(), m2.getDenseBlockValues(),
-						ret.getDenseBlockValues(), m1.rlen, m1.clen, m2.clen, k);
+		if(NativeHelper.isNativeLibraryLoaded()){
+			// Sanity check:
+			k = k <= 0 ? NativeHelper.getMaxNumThreads() : k;
+			
+			// check inputs / outputs
+			if (m1.isEmptyBlock(false) || m2.isEmptyBlock(false))
+				return LibMatrixMult.emptyMatrixMult(m1,m2, ret);
+			
+			boolean isValidForNative = !isMatMultMemoryBound(m1.rlen, m1.clen, m2.clen) 
+				&& !m1.isInSparseFormat() && !m2.isInSparseFormat()
+				&& (m1.getDenseBlock().isContiguous() || !isSinglePrecision())
+				&& m2.getDenseBlock().isContiguous() //contiguous but not allocated
+				&& 8L * ret.getLength() < Integer.MAX_VALUE;
+	
+			if( isValidForNative ) 
+			{
+				// allocate output
+				if(ret == null)
+					ret = new MatrixBlock(m1.rlen, m2.clen, false);
+				else 
+					ret.reset(m1.rlen, m2.clen, false);
+				ret.allocateBlock();
+				
+				long start = DMLScript.STATISTICS ? System.nanoTime() : 0;
+				long nnz = 0;
+				if( isSinglePrecision() ) {
+					FloatBuffer fin1 = toFloatBuffer(m1.getDenseBlockValues(), inBuff, true);
+					FloatBuffer fin2 = toFloatBuffer(m2.getDenseBlockValues(), filterBuff, true);
+					FloatBuffer fout = toFloatBuffer(ret.getDenseBlockValues(), outBuff, false);
+					nnz = NativeHelper.smmdd(fin1, fin2, fout, 
+						m1.getNumRows(), m1.getNumColumns(), m2.getNumColumns(), k);
+					fromFloatBuffer(outBuff.get(), ret.getDenseBlockValues());
 				}
 				else {
-					//sequential processing of individual blocks to 
-					//avoid segementation faults with concurrent multi-threaded BLAS calls
-					for(int bix = 0; bix < a.numBlocks(); bix++) {
-						double[] tmp = new double[a.blockSize(bix)*m2.clen];
-						nnz += NativeHelper.dmmdd(a.valuesAt(bix), m2.getDenseBlockValues(),
-							tmp, a.blockSize(bix), m1.clen, m2.clen, k);
-						int rl = bix * a.blockSize();
-						ret.getDenseBlock().set(rl, rl+a.blockSize(bix), 0, m2.clen,
-							DenseBlockFactory.createDenseBlock(tmp, new int[]{a.blockSize(bix),m2.clen}));
+					DenseBlock a = m1.getDenseBlock();
+					if( a.isContiguous() ) {
+						nnz = NativeHelper.dmmdd(m1.getDenseBlockValues(), m2.getDenseBlockValues(),
+							ret.getDenseBlockValues(), m1.rlen, m1.clen, m2.clen, k);
+					}
+					else {
+						//sequential processing of individual blocks to 
+						//avoid segementation faults with concurrent multi-threaded BLAS calls
+						for(int bix = 0; bix < a.numBlocks(); bix++) {
+							double[] tmp = new double[a.blockSize(bix)*m2.clen];
+							nnz += NativeHelper.dmmdd(a.valuesAt(bix), m2.getDenseBlockValues(),
+								tmp, a.blockSize(bix), m1.clen, m2.clen, k);
+							int rl = bix * a.blockSize();
+							ret.getDenseBlock().set(rl, rl+a.blockSize(bix), 0, m2.clen,
+								DenseBlockFactory.createDenseBlock(tmp, new int[]{a.blockSize(bix),m2.clen}));
+						}
 					}
 				}
-			}
-			
-			if(nnz > -1) {
-				if(DMLScript.STATISTICS) {
-					Statistics.nativeLibMatrixMultTime += System.nanoTime() - start;
-					Statistics.numNativeLibMatrixMultCalls.increment();
-				}
-				ret.setNonZeros(nnz);
-				if(examSparsity)
+				
+				if(nnz > -1) {
+					if(DMLScript.STATISTICS) {
+						Statistics.nativeLibMatrixMultTime += System.nanoTime() - start;
+						Statistics.numNativeLibMatrixMultCalls.increment();
+					}
+					ret.setNonZeros(nnz);
 					ret.examSparsity();
-				return;
+					return ret;
+				}
+				//else record failure and fallback to java
+				Statistics.incrementNativeFailuresCounter();
+				LOG.warn("matrixMult: Native mat mult failed. Falling back to java version ("
+					+ "loaded=" + NativeHelper.isNativeLibraryLoaded()
+					+ ", sparse=" + (m1.isInSparseFormat() | m2.isInSparseFormat()) + ")");
 			}
-			//else record failure and fallback to java
-			Statistics.incrementNativeFailuresCounter();
-			LOG.warn("matrixMult: Native mat mult failed. Falling back to java version ("
-				+ "loaded=" + NativeHelper.isNativeLibraryLoaded()
-				+ ", sparse=" + (m1.isInSparseFormat() | m2.isInSparseFormat()) + ")");
 		}
-		else if(isValidForNative)
+		else
 			LOG.warn("Was valid for native MM but native lib was not loaded");
 		
-		if (k == 1)
-			LibMatrixMult.matrixMult(m1, m2, ret, !examSparsity);
-		else
-			LibMatrixMult.matrixMult(m1, m2, ret, k);
+		return LibMatrixMult.matrixMult(m1, m2, ret, k);
 	}
 	
 	public static void tsmm(MatrixBlock m1, MatrixBlock ret, boolean leftTrans, int k) {
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index f208174..a0fcef6 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -2616,22 +2616,6 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 		return size;
 	}
 
-	public static SparsityEstimate estimateSparsityOnAggBinary(MatrixBlock m1, MatrixBlock m2, AggregateBinaryOperator op)
-	{
-		//Since MatrixMultLib always uses a dense output (except for ultra-sparse mm)
-		//with subsequent check for sparsity, we should always return a dense estimate.
-		//Once, we support more aggregate binary operations, we need to change this.
-		
-		//WARNING: KEEP CONSISTENT WITH LIBMATRIXMULT
-		//Note that it is crucial to report the right output representation because
-		//in case of block reuse (e.g., mmcj) the output 'reset' refers to either
-		//dense or sparse representation and hence would produce incorrect results
-		//if we report the wrong representation (i.e., missing reset on ultrasparse mm). 
-		
-		boolean ultrasparse = (m1.isUltraSparse() || m2.isUltraSparse());
-		return new SparsityEstimate(ultrasparse, m1.getNumRows()*m2.getNumRows());
-	}
-
 	private static SparsityEstimate estimateSparsityOnBinary(MatrixBlock m1, MatrixBlock m2, BinaryOperator op)
 	{
 		SparsityEstimate est = new SparsityEstimate();
@@ -4988,34 +4972,34 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 	}
 
 	public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
+		checkAggregateBinaryOperations(m1, m2, op);
+		final int k = op.getNumThreads();
+		if(NativeHelper.isNativeLibraryLoaded())
+			return LibMatrixNative.matrixMult(m1, m2, ret, k);
+		else 
+			return LibMatrixMult.matrixMult(m1, m2, ret, k);
+	}
+
+	protected void checkAggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, AggregateBinaryOperator op) {
 		//check input types, dimensions, configuration
-		if( m1.clen != m2.rlen ) {
+		if( m1.clen != m2.rlen )
 			throw new RuntimeException("Dimensions do not match for matrix multiplication ("+m1.clen+"!="+m2.rlen+").");
-		}
-		if( !(op.binaryFn instanceof Multiply && op.aggOp.increOp.fn instanceof Plus) ) {
-			throw new DMLRuntimeException("Unsupported binary aggregate operation: ("+op.binaryFn+", "+op.aggOp+").");
-		}
-		
-		//setup meta data (dimensions, sparsity)
-		int rl = m1.rlen;
-		int cl = m2.clen;
-		SparsityEstimate sp = estimateSparsityOnAggBinary(m1, m2, op);
-		
-		//create output matrix block
-		if( ret==null )
-			ret = new MatrixBlock(rl, cl, sp.sparse, sp.estimatedNonZeros);
-		else
-			ret.reset(rl, cl, sp.sparse, sp.estimatedNonZeros);
-		
-		//compute matrix multiplication (only supported binary aggregate operation)
-		if( NativeHelper.isNativeLibraryLoaded() )
-			LibMatrixNative.matrixMult(m1, m2, ret, op.getNumThreads());
-		else if( op.getNumThreads() > 1 )
-			LibMatrixMult.matrixMult(m1, m2, ret, op.getNumThreads());
-		else
-			LibMatrixMult.matrixMult(m1, m2, ret);
+		checkAggregateBinaryOperationsCommon(m1, m2, op);
+	}
+
+	protected void checkAggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, AggregateBinaryOperator op, boolean transposeLeft,
+			boolean transposeRight) {
+		//check input types, dimensions, configuration
+		if((transposeLeft ? m1.rlen : m1.clen) != ( transposeRight ? m2.clen : m2.rlen) )
+			throw new RuntimeException("Dimensions do not match for matrix multiplication ("+m1.clen+"!="+m2.rlen+").");
+		checkAggregateBinaryOperationsCommon(m1, m2, op);
+	}
 		
-		return ret;
+	private void checkAggregateBinaryOperationsCommon(MatrixBlock m1, MatrixBlock m2, AggregateBinaryOperator op){
+		if( !(op.binaryFn instanceof Multiply && op.aggOp.increOp.fn instanceof Plus) )
+			throw new DMLRuntimeException("Unsupported binary aggregate operation: ("+op.binaryFn+", "+op.aggOp+").");
+		if(!(m1 == this || m2 == this))
+			throw new DMLRuntimeException("Invalid aggregateBinaryOperatio: one of either input should be this");
 	}
 
 	public MatrixBlock aggregateTernaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock m3, MatrixBlock ret,