You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@systemds.apache.org by GitBox <gi...@apache.org> on 2021/12/18 18:11:33 UTC

[GitHub] [systemds] mboehm7 commented on a change in pull request #1480: [SYSTEMDS-3243] Consistent allocation of MatrixBlock for MM

mboehm7 commented on a change in pull request #1480:
URL: https://github.com/apache/systemds/pull/1480#discussion_r771847935



##########
File path: src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
##########
@@ -119,33 +119,55 @@ public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, i
 	}
 	
 	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, boolean fixedRet) {
-		//check inputs / outputs
-		if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) {
-			ret.examSparsity(); //turn empty dense into sparse

Review comment:
       Avoid all formatting changes when re-working a central piece of code so we reviewers can focus on the actual changes.

##########
File path: src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
##########
@@ -177,71 +196,80 @@ else if(m1.sparse)
 	 * @param k maximum parallelism
 	 */
 	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
-		//check inputs / outputs
-		if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) {
-			ret.examSparsity(); //turn empty dense into sparse
-			return;
-		}
-		
-		//check too small workload and fallback to sequential if needed
-		if( !satisfiesMultiThreadingConstraints(m1, m2, m1.rlen==1, true, 2, k) ) {
-			matrixMult(m1, m2, ret);
+		// check inputs / outputs
+		if(m1.isEmptyBlock(false) || m2.isEmptyBlock(false)) {
+			ret.examSparsity(); // turn empty dense into sparse
 			return;
 		}
-		
-		//Timing time = new Timing(true);
-		
-		//pre-processing: output allocation (in contrast to single-threaded,
-		//we need to allocate sparse as well in order to prevent synchronization)
-		boolean m1Perm = m1.isSparsePermutationMatrix();
-		boolean ultraSparse = isUltraSparseMatrixMult(m1, m2, m1Perm);
-		boolean sparse = !ultraSparse && !m1Perm && isSparseOutputMatrixMult(m1, m2);
-		boolean tm2 = checkPrepMatrixMultRightInput(m1,m2);
-		m2 = prepMatrixMultRightInput(m1, m2);
+
+		// pre-processing: output allocation (in contrast to single-threaded,
+		// we need to allocate sparse as well in order to prevent synchronization)
+		final boolean m1Perm = m1.isSparsePermutationMatrix();
+		final boolean ultraSparse = isUltraSparseMatrixMult(m1, m2, m1Perm);
+		final boolean sparse = !ultraSparse && !m1Perm && isSparseOutputMatrixMult(m1, m2);
 		ret.sparse = ultraSparse | sparse;
 		ret.allocateBlock();
-		
-		if (!ret.isThreadSafe()) {
-			matrixMult(m1, m2, ret);
+
+		matrixMult(m1, m2, ret, k, m1Perm, ultraSparse, sparse);
+
+	}

Review comment:
       Avoid empty lines before the end of a method.

##########
File path: src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
##########
@@ -90,7 +84,8 @@ public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, i
 		if( NativeHelper.isNativeLibraryLoaded() && isValidForNative ) 
 		{
 			ret.sparse = false;
-			ret.allocateDenseBlock();
+			if(!ret.isAllocated())
+				ret.allocateDenseBlock();

Review comment:
       These changes are NOT equivalent - the allocateDenseBlock also resets the values to zero; only probing for allocated representations might yield incorrect results if the native operation adds to it or does not fully overwrite it.

##########
File path: src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
##########
@@ -5096,34 +5081,52 @@ public final MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m
 		return aggregateBinaryOperations(m1, m2, null, op);
 	}
 
-	public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
+	public final MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
 		//check input types, dimensions, configuration
-		if( m1.clen != m2.rlen ) {
+		if( m1.clen != m2.rlen )
 			throw new RuntimeException("Dimensions do not match for matrix multiplication ("+m1.clen+"!="+m2.rlen+").");
-		}
-		if( !(op.binaryFn instanceof Multiply && op.aggOp.increOp.fn instanceof Plus) ) {
+		if( !(op.binaryFn instanceof Multiply && op.aggOp.increOp.fn instanceof Plus) )
 			throw new DMLRuntimeException("Unsupported binary aggregate operation: ("+op.binaryFn+", "+op.aggOp+").");
+		if(!(m1 == this || m2 == this))
+			throw new DMLRuntimeException("Invalid aggregateBinaryOperatio: one of either input should be this");
+		return matrixMult(m1, m2, ret, op.getNumThreads());
+	}
+
+	public MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k){
+
+		final int rl = m1.rlen;
+		final int cl = m2.clen;
+
+		if(m1.isEmptyBlock(false) || m2.isEmptyBlock(false)) {
+			if(ret == null)
+				return new MatrixBlock(rl, cl, true);
+			else {
+				ret.reset(rl, cl, true);
+				ret.setNonZeros(0);
+				ret.cleanupBlock(true, true);
+				return ret;
+			}
 		}
-		
-		//setup meta data (dimensions, sparsity)
-		int rl = m1.rlen;
-		int cl = m2.clen;
-		SparsityEstimate sp = estimateSparsityOnAggBinary(m1, m2, op);
-		
-		//create output matrix block
-		if( ret==null )
-			ret = new MatrixBlock(rl, cl, sp.sparse, sp.estimatedNonZeros);
+
+		final boolean m1Perm = m1.isSparsePermutationMatrix();
+		final boolean ultraSparse = LibMatrixMult.isUltraSparseMatrixMult(m1, m2, m1Perm);
+		final boolean sparse = !m1Perm && !ultraSparse && LibMatrixMult.isSparseOutputMatrixMult(m1, m2);
+		final boolean sparseRet = ultraSparse | sparse;
+
+		// create output matrix block
+		if(ret == null)
+			ret = new MatrixBlock(rl, cl, sparseRet);
 		else
-			ret.reset(rl, cl, sp.sparse, sp.estimatedNonZeros);
-		
-		//compute matrix multiplication (only supported binary aggregate operation)
-		if( NativeHelper.isNativeLibraryLoaded() )
-			LibMatrixNative.matrixMult(m1, m2, ret, op.getNumThreads());
-		else if( op.getNumThreads() > 1 )
-			LibMatrixMult.matrixMult(m1, m2, ret, op.getNumThreads());
+			ret.reset(rl, cl, sparseRet);
+		ret.allocateBlock();

Review comment:
       Do not allocate the block here, there are still cases in ultra-sparse matrix multiplications where the output representation is decided in a data-driven manner (e.g., permutation matrices).

##########
File path: src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
##########
@@ -5096,34 +5081,52 @@ public final MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m
 		return aggregateBinaryOperations(m1, m2, null, op);
 	}
 
-	public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
+	public final MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {

Review comment:
       remove the final qualifier

##########
File path: src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
##########
@@ -119,33 +119,55 @@ public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, i
 	}
 	
 	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, boolean fixedRet) {
-		//check inputs / outputs
-		if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) {
-			ret.examSparsity(); //turn empty dense into sparse
+		// check inputs / outputs
+		if(m1.isEmptyBlock(false) || m2.isEmptyBlock(false)) {
+			ret.examSparsity(); // turn empty dense into sparse
 			return;
 		}
-		
-		//Timing time = new Timing(true);
-		
-		//pre-processing: output allocation
+
+		// pre-processing: output allocation
 		boolean m1Perm = m1.isSparsePermutationMatrix();
-		boolean ultraSparse = (fixedRet && ret.sparse)
-			|| (!fixedRet && isUltraSparseMatrixMult(m1, m2, m1Perm));
-		boolean sparse = !m1Perm && !ultraSparse && !fixedRet 
-			&& isSparseOutputMatrixMult(m1, m2);
-		boolean tm2 = checkPrepMatrixMultRightInput(m1,m2);
-		m2 = prepMatrixMultRightInput(m1, m2);
+		boolean ultraSparse = (fixedRet && ret.sparse) || (!fixedRet && isUltraSparseMatrixMult(m1, m2, m1Perm));
+		boolean sparse = !m1Perm && !ultraSparse && !fixedRet && isSparseOutputMatrixMult(m1, m2);
 		ret.sparse = ultraSparse | sparse;
 		ret.allocateBlock();
-		
-		//prepare row-upper for special cases of vector-matrix
-		boolean pm2 = !ultraSparse &&
-			checkParMatrixMultRightInputRows(m1, m2, Integer.MAX_VALUE);
-		int ru2 = (pm2 && ru==m1.rlen) ? m2.rlen : ru; 
+
+		matrixMult(m1, m2, ret, rl, ru, fixedRet, m1Perm, ultraSparse, sparse);
+	}
+
+	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean m1Perm, boolean ultraSparse,
+		boolean sparse) {
+		LibMatrixMult.matrixMult(m1, m2, ret, 0, m1.rlen, false, m1Perm, ultraSparse, sparse);
+	}
+
+
+	/**
+	 * Matrix multiplication of m1 on the left and m2 on the right, ret is assumed to already be allocated in correct
+	 * format and size.
+	 * 
+	 * @param m1          left matrix
+	 * @param m2          right matrix
+	 * @param ret         result matrix
+	 * @param rl          row start to multiply from
+	 * @param ru          row end to multiply to (not included)
+	 * @param fixedRet    if the output is in fixed format.
+	 * @param m1Perm      if m1 is a sparse permutation matrix. (one value per row.)
+	 * @param ultraSparse if the multiplication is an ultra sparse multiplication
+	 * @param sparse      if the output is sparse
+	 */
+	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, boolean fixedRet,
+		boolean m1Perm, boolean ultraSparse, boolean sparse) {
+
+		boolean tm2 = checkPrepMatrixMultRightInput(m1, m2);
+		m2 = prepMatrixMultRightInput(m1, m2);
+

Review comment:
       These checks and modifications of the input are at the wrong place. They must not be in a method that specifies index ranges (rl, ru) as it would mean that the input is modified for every task.

##########
File path: src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
##########
@@ -177,71 +196,80 @@ else if(m1.sparse)
 	 * @param k maximum parallelism
 	 */
 	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
-		//check inputs / outputs
-		if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) {
-			ret.examSparsity(); //turn empty dense into sparse
-			return;
-		}
-		
-		//check too small workload and fallback to sequential if needed
-		if( !satisfiesMultiThreadingConstraints(m1, m2, m1.rlen==1, true, 2, k) ) {
-			matrixMult(m1, m2, ret);
+		// check inputs / outputs
+		if(m1.isEmptyBlock(false) || m2.isEmptyBlock(false)) {
+			ret.examSparsity(); // turn empty dense into sparse
 			return;
 		}
-		
-		//Timing time = new Timing(true);
-		
-		//pre-processing: output allocation (in contrast to single-threaded,
-		//we need to allocate sparse as well in order to prevent synchronization)
-		boolean m1Perm = m1.isSparsePermutationMatrix();
-		boolean ultraSparse = isUltraSparseMatrixMult(m1, m2, m1Perm);
-		boolean sparse = !ultraSparse && !m1Perm && isSparseOutputMatrixMult(m1, m2);
-		boolean tm2 = checkPrepMatrixMultRightInput(m1,m2);
-		m2 = prepMatrixMultRightInput(m1, m2);
+
+		// pre-processing: output allocation (in contrast to single-threaded,
+		// we need to allocate sparse as well in order to prevent synchronization)
+		final boolean m1Perm = m1.isSparsePermutationMatrix();
+		final boolean ultraSparse = isUltraSparseMatrixMult(m1, m2, m1Perm);
+		final boolean sparse = !ultraSparse && !m1Perm && isSparseOutputMatrixMult(m1, m2);
 		ret.sparse = ultraSparse | sparse;
 		ret.allocateBlock();
-		
-		if (!ret.isThreadSafe()) {
-			matrixMult(m1, m2, ret);
+
+		matrixMult(m1, m2, ret, k, m1Perm, ultraSparse, sparse);
+
+	}
+
+	/**
+	 * Parallel matrix multiplication of m1 on the left and m2 on the right, ret is assumed to already be allocated in
+	 * correct format and size.
+	 * 
+	 * @param m1          left matrix
+	 * @param m2          right matrix
+	 * @param ret         result matrix
+	 * @param k           the parallelization degree
+	 * @param m1Perm      if m1 is a sparse permutation matrix. (one value per row.)
+	 * @param ultraSparse if the multiplication is an ultra sparse multiplication
+	 * @param sparse      if the output is sparse
+	 */
+	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k, boolean m1Perm,
+		boolean ultraSparse, boolean sparse) {
+		// check too small workload or thread safety and fallback to sequential if needed
+		if(!satisfiesMultiThreadingConstraints(m1, m2, m1.rlen == 1, true, 2, k) || !ret.isThreadSafe()) {
+			matrixMult(m1, m2, ret, m1Perm, ultraSparse, sparse);
 			return;
 		}
-		
-		//prepare row-upper for special cases of vector-matrix / matrix-matrix
+
+		boolean tm2 = checkPrepMatrixMultRightInput(m1, m2);
+		m2 = prepMatrixMultRightInput(m1, m2);
+
+		// prepare row-upper for special cases of vector-matrix / matrix-matrix
 		boolean pm2r = !ultraSparse && !sparse && checkParMatrixMultRightInputRows(m1, m2, k);
 		boolean pm2c = !ultraSparse && checkParMatrixMultRightInputCols(m1, m2, k, pm2r);
-		int num = pm2r ? m2.rlen : pm2c ? m2.clen : m1.rlen; 
-		
-		//core multi-threaded matrix mult computation
-		//(currently: always parallelization over number of rows)
+		int num = pm2r ? m2.rlen : pm2c ? m2.clen : m1.rlen;
+
+		// core multi-threaded matrix mult computation
+		// (currently: always parallelization over number of rows)
 		try {
 			ExecutorService pool = CommonThreadPool.get(k);
 			ArrayList<MatrixMultTask> tasks = new ArrayList<>();
-			ArrayList<Integer> blklens = UtilFunctions.getBalancedBlockSizesDefault(num, k, (pm2r||pm2c));
-			for( int i=0, lb=0; i<blklens.size(); lb+=blklens.get(i), i++ )
-				tasks.add(new MatrixMultTask(m1, m2, ret, tm2, pm2r, pm2c, m1Perm, sparse, lb, lb+blklens.get(i)));
-			//execute tasks
+			ArrayList<Integer> blklens = UtilFunctions.getBalancedBlockSizesDefault(num, k, (pm2r || pm2c));
+			for(int i = 0, lb = 0; i < blklens.size(); lb += blklens.get(i), i++)
+				tasks.add(new MatrixMultTask(m1, m2, ret, tm2, pm2r, pm2c, m1Perm, sparse, lb, lb + blklens.get(i)));
+			// execute tasks
 			List<Future<Object>> taskret = pool.invokeAll(tasks);
 			pool.shutdown();
-			//aggregate partial results (nnz, ret for vector/matrix)
-			ret.nonZeros = 0; //reset after execute
-			for( Future<Object> task : taskret ) {
-				if( pm2r ) //guaranteed single block
-					vectAdd((double[])task.get(), ret.getDenseBlockValues(), 0, 0, ret.rlen*ret.clen);
+			// aggregate partial results (nnz, ret for vector/matrix)
+			ret.nonZeros = 0; // reset after execute
+			for(Future<Object> task : taskret) {
+				if(pm2r) // guaranteed single block
+					vectAdd((double[]) task.get(), ret.getDenseBlockValues(), 0, 0, ret.rlen * ret.clen);
 				else
-					ret.nonZeros += (Long)task.get();
+					ret.nonZeros += (Long) task.get();
 			}
-			if( pm2r )
+			if(pm2r)
 				ret.recomputeNonZeros();
 		}
 		catch(Exception ex) {
 			throw new DMLRuntimeException(ex);
 		}
-		
-		//post-processing (nnz maintained in parallel)
+
+		// post-processing (nnz maintained in parallel)
 		ret.examSparsity();
-		
-		//System.out.println("MM k="+k+" ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x" +

Review comment:
       keep this comment

##########
File path: src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
##########
@@ -5096,34 +5081,52 @@ public final MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m
 		return aggregateBinaryOperations(m1, m2, null, op);
 	}
 
-	public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
+	public final MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
 		//check input types, dimensions, configuration
-		if( m1.clen != m2.rlen ) {
+		if( m1.clen != m2.rlen )
 			throw new RuntimeException("Dimensions do not match for matrix multiplication ("+m1.clen+"!="+m2.rlen+").");
-		}
-		if( !(op.binaryFn instanceof Multiply && op.aggOp.increOp.fn instanceof Plus) ) {
+		if( !(op.binaryFn instanceof Multiply && op.aggOp.increOp.fn instanceof Plus) )
 			throw new DMLRuntimeException("Unsupported binary aggregate operation: ("+op.binaryFn+", "+op.aggOp+").");
+		if(!(m1 == this || m2 == this))
+			throw new DMLRuntimeException("Invalid aggregateBinaryOperatio: one of either input should be this");
+		return matrixMult(m1, m2, ret, op.getNumThreads());
+	}
+
+	public MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k){
+
+		final int rl = m1.rlen;
+		final int cl = m2.clen;
+
+		if(m1.isEmptyBlock(false) || m2.isEmptyBlock(false)) {
+			if(ret == null)
+				return new MatrixBlock(rl, cl, true);
+			else {
+				ret.reset(rl, cl, true);
+				ret.setNonZeros(0);
+				ret.cleanupBlock(true, true);
+				return ret;
+			}
 		}
-		
-		//setup meta data (dimensions, sparsity)
-		int rl = m1.rlen;
-		int cl = m2.clen;
-		SparsityEstimate sp = estimateSparsityOnAggBinary(m1, m2, op);
-		
-		//create output matrix block
-		if( ret==null )
-			ret = new MatrixBlock(rl, cl, sp.sparse, sp.estimatedNonZeros);
+
+		final boolean m1Perm = m1.isSparsePermutationMatrix();
+		final boolean ultraSparse = LibMatrixMult.isUltraSparseMatrixMult(m1, m2, m1Perm);
+		final boolean sparse = !m1Perm && !ultraSparse && LibMatrixMult.isSparseOutputMatrixMult(m1, m2);
+		final boolean sparseRet = ultraSparse | sparse;
+
+		// create output matrix block
+		if(ret == null)
+			ret = new MatrixBlock(rl, cl, sparseRet);
 		else
-			ret.reset(rl, cl, sp.sparse, sp.estimatedNonZeros);
-		
-		//compute matrix multiplication (only supported binary aggregate operation)
-		if( NativeHelper.isNativeLibraryLoaded() )
-			LibMatrixNative.matrixMult(m1, m2, ret, op.getNumThreads());
-		else if( op.getNumThreads() > 1 )
-			LibMatrixMult.matrixMult(m1, m2, ret, op.getNumThreads());
+			ret.reset(rl, cl, sparseRet);
+		ret.allocateBlock();
+
+		if(!sparseRet && NativeHelper.isNativeLibraryLoaded())

Review comment:
       Why all this splitting of core functionality just for this check? For sparse inputs it would still enter the native lib, and then come back. Couldn't you instead just check if neither lhs nor rhs are sparse or even make the checks in LibMatrixNative a method and call it here (which would cover the common case)?

##########
File path: src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
##########
@@ -141,7 +135,7 @@ else if(isValidForNative)
 			LOG.warn("Was valid for native MM but native lib was not loaded");
 		
 		if (k == 1)
-			LibMatrixMult.matrixMult(m1, m2, ret, !examSparsity);

Review comment:
       here is the relationship to fixedRet - why did you remove this?

##########
File path: src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
##########
@@ -66,18 +66,12 @@ public static boolean isMatMultMemoryBound(int m1Rlen, int m1Clen, int m2Clen) {
 	 * @param k number of threads
 	 */
 	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
-		matrixMult(m1, m2, ret, k, true);
-	}
-	
-	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k, boolean examSparsity) {

Review comment:
       well, shouldn't this be linked to fixedRet if people want to keep the allocated output, otherwise examSparsity might convert the representation.

##########
File path: src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
##########
@@ -119,33 +119,55 @@ public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, i
 	}
 	
 	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, boolean fixedRet) {
-		//check inputs / outputs
-		if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) {
-			ret.examSparsity(); //turn empty dense into sparse
+		// check inputs / outputs
+		if(m1.isEmptyBlock(false) || m2.isEmptyBlock(false)) {
+			ret.examSparsity(); // turn empty dense into sparse
 			return;
 		}
-		
-		//Timing time = new Timing(true);
-		
-		//pre-processing: output allocation
+
+		// pre-processing: output allocation
 		boolean m1Perm = m1.isSparsePermutationMatrix();
-		boolean ultraSparse = (fixedRet && ret.sparse)
-			|| (!fixedRet && isUltraSparseMatrixMult(m1, m2, m1Perm));
-		boolean sparse = !m1Perm && !ultraSparse && !fixedRet 
-			&& isSparseOutputMatrixMult(m1, m2);
-		boolean tm2 = checkPrepMatrixMultRightInput(m1,m2);
-		m2 = prepMatrixMultRightInput(m1, m2);
+		boolean ultraSparse = (fixedRet && ret.sparse) || (!fixedRet && isUltraSparseMatrixMult(m1, m2, m1Perm));
+		boolean sparse = !m1Perm && !ultraSparse && !fixedRet && isSparseOutputMatrixMult(m1, m2);
 		ret.sparse = ultraSparse | sparse;
 		ret.allocateBlock();
-		
-		//prepare row-upper for special cases of vector-matrix
-		boolean pm2 = !ultraSparse &&
-			checkParMatrixMultRightInputRows(m1, m2, Integer.MAX_VALUE);
-		int ru2 = (pm2 && ru==m1.rlen) ? m2.rlen : ru; 
+
+		matrixMult(m1, m2, ret, rl, ru, fixedRet, m1Perm, ultraSparse, sparse);
+	}
+
+	public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean m1Perm, boolean ultraSparse,
+		boolean sparse) {
+		LibMatrixMult.matrixMult(m1, m2, ret, 0, m1.rlen, false, m1Perm, ultraSparse, sparse);
+	}

Review comment:
       Please, try to keep the attack surface as small as possible - this method is called from multiple places and people might mistakenly use it to bypass the sparsity checks above.

##########
File path: src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
##########
@@ -155,18 +177,15 @@ else if(m1.sparse)
 			matrixMultSparseDense(m1, m2, ret, pm2, 0, ru2);
 		else
 			matrixMultDenseSparse(m1, m2, ret, pm2, 0, ru2);
-		
-		//post-processing: nnz/representation
-		if( !fixedRet ) {
-			if( !ret.sparse )
+
+		// post-processing: nnz/representation
+		if(!fixedRet) {
+			if(!ret.sparse)
 				ret.recomputeNonZeros();
 			ret.examSparsity();
 		}
-		
-		//System.out.println("MM ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x" +
-		//		"("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+") in "+time.stop());

Review comment:
       Leave the core timings in there as they are frequently commented in during fine-grained debugging. 

##########
File path: src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
##########
@@ -5096,34 +5081,52 @@ public final MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m
 		return aggregateBinaryOperations(m1, m2, null, op);
 	}
 
-	public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
+	public final MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
 		//check input types, dimensions, configuration
-		if( m1.clen != m2.rlen ) {
+		if( m1.clen != m2.rlen )
 			throw new RuntimeException("Dimensions do not match for matrix multiplication ("+m1.clen+"!="+m2.rlen+").");
-		}
-		if( !(op.binaryFn instanceof Multiply && op.aggOp.increOp.fn instanceof Plus) ) {
+		if( !(op.binaryFn instanceof Multiply && op.aggOp.increOp.fn instanceof Plus) )
 			throw new DMLRuntimeException("Unsupported binary aggregate operation: ("+op.binaryFn+", "+op.aggOp+").");
+		if(!(m1 == this || m2 == this))
+			throw new DMLRuntimeException("Invalid aggregateBinaryOperatio: one of either input should be this");
+		return matrixMult(m1, m2, ret, op.getNumThreads());
+	}
+
+	public MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k){
+
+		final int rl = m1.rlen;
+		final int cl = m2.clen;
+
+		if(m1.isEmptyBlock(false) || m2.isEmptyBlock(false)) {

Review comment:
       I generally dislike that parts of the main kernel (empty input handling, sparsity/special case handling) are now split and duplicated across two places.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: dev-unsubscribe@systemds.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org