You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/09/18 20:17:33 UTC

[systemds] branch master updated: [SYSTEMDS-3101] Fix federated spoof instruction (federated output)

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 82536c1  [SYSTEMDS-3101] Fix federated spoof instruction (federated output)
82536c1 is described below

commit 82536c1841b546db4f519086d2d7a6cba011603c
Author: ywcb00 <yw...@ywcb.org>
AuthorDate: Sat Sep 18 22:16:43 2021 +0200

    [SYSTEMDS-3101] Fix federated spoof instruction (federated output)
    
    Closes #1380.
    
    Other cleanups:
    Closes #1336.
    Closes #1365.
---
 .../instructions/fed/SpoofFEDInstruction.java      | 493 ++++++++++++---------
 .../codegen/FederatedCodegenMultipleFedMOTest.java |   6 +-
 .../codegen/FederatedOuterProductTmplTest.java     |   8 +-
 .../codegen/FederatedRowwiseTmplTest.java          |   2 +-
 .../pipelines/BuiltinTopkEvaluateTest.java         |   1 -
 5 files changed, 289 insertions(+), 221 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
index ecf310c..331ecfc 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
@@ -50,6 +50,7 @@ import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.concurrent.Future;
+import java.util.stream.IntStream;
 
 public class SpoofFEDInstruction extends FEDInstruction
 {
@@ -82,37 +83,38 @@ public class SpoofFEDInstruction extends FEDInstruction
 
 	@Override
 	public void processInstruction(ExecutionContext ec) {
+		FederationMap fedMap = null;
+		for(CPOperand cpo : _inputs) { // searching for the first federated matrix to obtain the federation map
+			Data tmpData = ec.getVariable(cpo);
+			if(tmpData instanceof MatrixObject && ((MatrixObject)tmpData).isFederatedExcept(FType.BROADCAST)) {
+				fedMap = ((MatrixObject)tmpData).getFedMapping();
+				break;
+			}
+		}
+
 		Class<?> scla = _op.getClass().getSuperclass();
 		SpoofFEDType spoofType = null;
 		if(scla == SpoofCellwise.class)
-			spoofType = new SpoofFEDCellwise(_op, _output);
+			spoofType = new SpoofFEDCellwise(_op, _output, fedMap.getType());
 		else if(scla == SpoofRowwise.class)
-			spoofType = new SpoofFEDRowwise(_op, _output);
+			spoofType = new SpoofFEDRowwise(_op, _output, fedMap.getType());
 		else if(scla == SpoofMultiAggregate.class)
-			spoofType = new SpoofFEDMultiAgg(_op, _output);
+			spoofType = new SpoofFEDMultiAgg(_op, _output, fedMap.getType());
 		else if(scla == SpoofOuterProduct.class)
-			spoofType = new SpoofFEDOuterProduct(_op, _output);
+			spoofType = new SpoofFEDOuterProduct(_op, _output, fedMap.getType(), _inputs);
 		else
 			throw new DMLRuntimeException("Federated code generation only supported" +
 				" for cellwise, rowwise, multiaggregate, and outerproduct templates.");
 
+		processRequest(ec, fedMap, spoofType);
+	}
 
-		FederationMap fedMap = null;
-		long id = 0;
-		for(CPOperand cpo : _inputs) { // searching for the first federated matrix to obtain the federation map
-			Data tmpData = ec.getVariable(cpo);
-			if(tmpData instanceof MatrixObject && ((MatrixObject)tmpData).isFederatedExcept(FType.BROADCAST)) {
-				fedMap = ((MatrixObject)tmpData).getFedMapping();
-				id = ((MatrixObject)tmpData).getUniqueID();
-				break;
-			}
-		}
-
+	private void processRequest(ExecutionContext ec, FederationMap fedMap, SpoofFEDType spoofType) {
 		ArrayList<FederatedRequest> frBroadcast = new ArrayList<>();
 		ArrayList<FederatedRequest[]> frBroadcastSliced = new ArrayList<>();
 		long[] frIds = new long[_inputs.length];
 		int index = 0;
-		
+
 		for(CPOperand cpo : _inputs) {
 			Data tmpData = ec.getVariable(cpo);
 			if(tmpData instanceof MatrixObject) {
@@ -121,7 +123,7 @@ public class SpoofFEDInstruction extends FEDInstruction
 					frIds[index++] = mo.getFedMapping().getID();
 				}
 				else if(spoofType.needsBroadcastSliced(fedMap, mo.getNumRows(), mo.getNumColumns(), index)) {
-					FederatedRequest[] tmpFr = spoofType.broadcastSliced(mo, fedMap, id);
+					FederatedRequest[] tmpFr = spoofType.broadcastSliced(mo, fedMap);
 					frIds[index++] = tmpFr[0].getID();
 					frBroadcastSliced.add(tmpFr);
 				}
@@ -144,48 +146,71 @@ public class SpoofFEDInstruction extends FEDInstruction
 
 		FederatedRequest frCompute = FederationUtils.callInstruction(instString, _output, _inputs, frIds);
 
-		// get partial results from federated workers
-		FederatedRequest frGet = new FederatedRequest(RequestType.GET_VAR, frCompute.getID());
+		FederatedRequest frGet = null;
+		FederatedRequest frCleanup = null;
+		if(!spoofType.isFedOutput()) {
+			// get partial results from federated workers
+			frGet = new FederatedRequest(RequestType.GET_VAR, frCompute.getID());
+			// cleanup the federated request of callInstruction
+			frCleanup = fedMap.cleanup(getTID(), frCompute.getID());
+		}
 
-		ArrayList<FederatedRequest> frCleanup = new ArrayList<>();
-		frCleanup.add(fedMap.cleanup(getTID(), frCompute.getID()));
-		for(FederatedRequest[] fr : frBroadcastSliced)
-			frCleanup.add(fedMap.cleanup(getTID(), fr[0].getID()));
+		FederatedRequest[] frAll;
+		if(frGet == null) // no get request if output is kept federated
+			frAll = ArrayUtils.addAll(
+				frBroadcast.toArray(new FederatedRequest[0]), frCompute);
+		else
+			frAll = ArrayUtils.addAll(
+				frBroadcast.toArray(new FederatedRequest[0]), frCompute, frGet, frCleanup);
 
-		FederatedRequest[] frAll = ArrayUtils.addAll(ArrayUtils.addAll(
-			frBroadcast.toArray(new FederatedRequest[0]), frCompute, frGet),
-			frCleanup.toArray(new FederatedRequest[0]));
 		Future<FederatedResponse>[] response = fedMap.executeMultipleSlices(
 			getTID(), true, frBroadcastSliced.toArray(new FederatedRequest[0][]), frAll);
 
 		// setting the output with respect to the different aggregation types
 		// of the different spoof templates
-		spoofType.setOutput(ec, response, fedMap);
+		spoofType.setOutput(ec, response, fedMap, frCompute.getID());
 	}
 
 
+	// abstract class to differentiate between the different spoof templates
 	private static abstract class SpoofFEDType {
 		CPOperand _output;
+		FType _fedType;
 
-		protected SpoofFEDType(CPOperand out) {
+		protected SpoofFEDType(CPOperand out, FType fedType) {
 			_output = out;
+			_fedType = fedType;
 		}
-		
-		protected FederatedRequest[] broadcastSliced(MatrixObject mo, FederationMap fedMap, long id) {
+
+		/**
+		 * performs the sliced broadcast of the given matrix object
+		 *
+		 * @param mo the matrix object to broadcast sliced
+		 * @param fedMap the federated mapping
+		 * @return FederatedRequest[] the resulting federated request array of the broadcast
+		 */
+		protected FederatedRequest[] broadcastSliced(MatrixObject mo, FederationMap fedMap) {
 			return fedMap.broadcastSliced(mo, false);
 		}
 
+		/**
+		 * determine if a specific matrix object needs to be broadcast sliced
+		 *
+		 * @param fedMap the federated mapping
+		 * @param rowNum the number of rows of the matrix object
+		 * @param colNum the number of columns of the matrix object
+		 * @param inputIndex the index of the matrix inside the instruction inputs
+		 * @return boolean indicates if the matrix needs to be broadcast sliced
+		 */
 		protected boolean needsBroadcastSliced(FederationMap fedMap, long rowNum, long colNum, int inputIndex) {
-			FType fedType = fedMap.getType();
-
 			//TODO fix check by num rows/cols
 			boolean retVal = (rowNum == fedMap.getMaxIndexInRange(0) && colNum == fedMap.getMaxIndexInRange(1));
-			if(fedType == FType.ROW)
-				retVal |= (rowNum == fedMap.getMaxIndexInRange(0) 
-					&& (colNum == 1 || colNum == fedMap.getSize() || fedMap.getMaxIndexInRange(1) == 1));
-			else if(fedType == FType.COL)
+			if(_fedType == FType.ROW)
+				retVal |= (rowNum == fedMap.getMaxIndexInRange(0)
+					&& (colNum == 1 || fedMap.getMaxIndexInRange(1) == 1));
+			else if(_fedType == FType.COL)
 				retVal |= (colNum == fedMap.getMaxIndexInRange(1)
-					&& (rowNum == 1 || rowNum == fedMap.getSize() || fedMap.getMaxIndexInRange(0) == 1));
+					&& (rowNum == 1 || fedMap.getMaxIndexInRange(0) == 1));
 			else {
 				throw new DMLRuntimeException("Only row partitioned or column" +
 					" partitioned federated input supported yet.");
@@ -193,236 +218,281 @@ public class SpoofFEDInstruction extends FEDInstruction
 			return retVal;
 		}
 
-		protected abstract void setOutput(ExecutionContext ec,
-			Future<FederatedResponse>[] response, FederationMap fedMap);
+		/**
+		 * set the output by either calling setFedOutput to keep the output federated
+		 * or calling aggResult to aggregate the partial results locally
+		 */
+		protected void setOutput(ExecutionContext ec, Future<FederatedResponse>[] response,
+			FederationMap fedMap, long frComputeID) {
+			if(isFedOutput())
+				setFedOutput(ec, fedMap, frComputeID);
+			else
+				aggResult(ec, response, fedMap);
+		}
+
+		// determine if the output can be kept on the federated sites
+		protected abstract boolean isFedOutput();
+		// set the output by deriving new a federated mapping
+		protected abstract void setFedOutput(ExecutionContext ec, FederationMap fedMap, long frComputeID);
+		// aggregate the partial results locally
+		protected abstract void aggResult(ExecutionContext ec, Future<FederatedResponse>[] response,
+			FederationMap fedMap);
 	}
 
+	// CELLWISE TEMPLATE
 	private static class SpoofFEDCellwise extends SpoofFEDType {
 		private final SpoofCellwise _op;
+		private final CellType _cellType;
 
-		SpoofFEDCellwise(SpoofOperator op, CPOperand out) {
-			super(out);
+		SpoofFEDCellwise(SpoofOperator op, CPOperand out, FType fedType) {
+			super(out, fedType);
 			_op = (SpoofCellwise)op;
+			_cellType = _op.getCellType();
 		}
 
-		protected void setOutput(ExecutionContext ec, Future<FederatedResponse>[] response, FederationMap fedMap) {
-			FType fedType = fedMap.getType();
-			AggOp aggOp = ((SpoofCellwise)_op).getAggOp();
-			CellType cellType = ((SpoofCellwise)_op).getCellType();
-			if(cellType == CellType.FULL_AGG) { // full aggregation
-				AggregateUnaryOperator aop = null;
-				if(aggOp == AggOp.SUM || aggOp == AggOp.SUM_SQ)
-					aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
-				else if(aggOp == AggOp.MIN)
-					aop = InstructionUtils.parseBasicAggregateUnaryOperator("uamin");
-				else if(aggOp == AggOp.MAX)
-					aop = InstructionUtils.parseBasicAggregateUnaryOperator("uamax");
-				else
-					throw new DMLRuntimeException("Aggregation operation not supported yet.");
-				ec.setVariable(_output.getName(), FederationUtils.aggScalar(aop, response));
-			}
-			else if(cellType == CellType.ROW_AGG) { // row aggregation
-				if(fedType == FType.ROW) {
-					// bind partial results from federated responses
-					ec.setMatrixOutput(_output.getName(), FederationUtils.bind(response, false));
-				}
-				else if(fedType == FType.COL) {
-					AggregateUnaryOperator aop = null;
-					if(aggOp == AggOp.SUM || aggOp == AggOp.SUM_SQ)
-						aop = InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
-					else if(aggOp == AggOp.MIN)
-						aop = InstructionUtils.parseBasicAggregateUnaryOperator("uarmin");
-					else if(aggOp == AggOp.MAX)
-						aop = InstructionUtils.parseBasicAggregateUnaryOperator("uarmax");
-					else
-						throw new DMLRuntimeException("Aggregation operation not supported yet.");
-					ec.setMatrixOutput(_output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
-				}
-				else {
-					throw new DMLRuntimeException("Aggregation type for federated spoof instructions not supported yet.");
-				}
-			}
-			else if(cellType == CellType.COL_AGG) { // col aggregation
-				if(fedType == FType.ROW) {
-					AggregateUnaryOperator aop = null;
-					if(aggOp == AggOp.SUM || aggOp == AggOp.SUM_SQ)
-						aop = InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
-					else if(aggOp == AggOp.MIN)
-						aop = InstructionUtils.parseBasicAggregateUnaryOperator("uacmin");
-					else if(aggOp == AggOp.MAX)
-						aop = InstructionUtils.parseBasicAggregateUnaryOperator("uacmax");
-					else
-						throw new DMLRuntimeException("Aggregation operation not supported yet.");
-					ec.setMatrixOutput(_output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
-				}
-				else if(fedType == FType.COL) {
-					// cbind partial results from federated responses
-					ec.setMatrixOutput(_output.getName(), FederationUtils.bind(response, true));
-				}
-				else {
-					throw new DMLRuntimeException("Aggregation type for federated spoof instructions not supported yet.");
-				}
+		protected boolean isFedOutput() {
+			boolean retVal = false;
+			retVal |= (_cellType == CellType.ROW_AGG && _fedType == FType.ROW);
+			retVal |= (_cellType == CellType.COL_AGG && _fedType == FType.COL);
+			retVal |= (_cellType == CellType.NO_AGG);
+			return retVal;
+		}
+
+		protected void setFedOutput(ExecutionContext ec, FederationMap fedMap, long frComputeID) {
+			// derive output federated mapping
+			MatrixObject out = ec.getMatrixObject(_output);
+			FederationMap newFedMap = modifyFedRanges(fedMap.copyWithNewID(frComputeID));
+			out.setFedMapping(newFedMap);
+		}
+
+		private FederationMap modifyFedRanges(FederationMap fedMap) {
+			if(_cellType == CellType.ROW_AGG || _cellType == CellType.COL_AGG) {
+				int dim = (_cellType == CellType.COL_AGG ? 0 : 1);
+				// crop federation map to a vector
+				IntStream.range(0, fedMap.getFederatedRanges().length).forEach(i -> {
+					fedMap.getFederatedRanges()[i].setBeginDim(dim, 0);
+					fedMap.getFederatedRanges()[i].setEndDim(dim, 1);
+				});
 			}
-			else if(cellType == CellType.NO_AGG) { // no aggregation
-				if(fedType == FType.ROW) //rbind
-					ec.setMatrixOutput(_output.getName(), FederationUtils.bind(response, false));
-				else if(fedType == FType.COL) //cbind
-					ec.setMatrixOutput(_output.getName(), FederationUtils.bind(response, true));
-				else
-					throw new DMLRuntimeException("Only row partitioned or column" +
-						" partitioned federated matrices supported yet.");
+			return fedMap;
+		}
+
+		protected void aggResult(ExecutionContext ec, Future<FederatedResponse>[] response,
+			FederationMap fedMap) {
+			AggOp aggOp = _op.getAggOp();
+
+			// build up the instruction for aggregation
+			// (uak+/uamin/uamax/uark+/uarmin/uarmax/uack+/uacmin/uacmax)
+			String aggInst = "ua";
+			switch(_cellType) {
+				case FULL_AGG: break;
+				case ROW_AGG: aggInst += "r"; break;
+				case COL_AGG: aggInst += "c"; break;
+				case NO_AGG:
+				default:
+					throw new DMLRuntimeException("Aggregation type not supported yet.");
 			}
-			else {
-				throw new DMLRuntimeException("Aggregation type not supported yet.");
+
+			switch(aggOp) {
+				case SUM:
+				case SUM_SQ: aggInst += "k+"; break;
+				case MIN:    aggInst += "min"; break;
+				case MAX:    aggInst += "max"; break;
+				default:
+					throw new DMLRuntimeException("Aggregation operation not supported yet.");
 			}
+
+			AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator(aggInst);
+			if(_cellType == CellType.FULL_AGG)
+				ec.setVariable(_output.getName(), FederationUtils.aggScalar(aop, response));
+			else
+				ec.setMatrixOutput(_output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
 		}
 	}
 
+	// ROWWISE TEMPLATE
 	private static class SpoofFEDRowwise extends SpoofFEDType {
 		private final SpoofRowwise _op;
+		private final RowType _rowType;
 
-		SpoofFEDRowwise(SpoofOperator op, CPOperand out) {
-			super(out);
+		SpoofFEDRowwise(SpoofOperator op, CPOperand out, FType fedType) {
+			super(out, fedType);
 			_op = (SpoofRowwise)op;
+			_rowType = _op.getRowType();
+		}
+
+		protected boolean isFedOutput() {
+			boolean retVal = false;
+			retVal |= (_rowType == RowType.NO_AGG);
+			retVal |= (_rowType == RowType.NO_AGG_B1);
+			retVal |= (_rowType == RowType.NO_AGG_CONST);
+			retVal &= (_fedType == FType.ROW);
+			return retVal;
 		}
 
-		protected void setOutput(ExecutionContext ec, Future<FederatedResponse>[] response, FederationMap fedMap) {
-			RowType rowType = ((SpoofRowwise)_op).getRowType();
-			if(rowType == RowType.FULL_AGG) { // full aggregation
-				// aggregate partial results from federated responses as sum
-				AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+		protected void setFedOutput(ExecutionContext ec, FederationMap fedMap, long frComputeID) {
+			// derive output federated mapping
+			MatrixObject out = ec.getMatrixObject(_output);
+			FederationMap newFedMap = modifyFedRanges(fedMap.copyWithNewID(frComputeID), out.getNumColumns());
+			out.setFedMapping(newFedMap);
+		}
+
+		private static FederationMap modifyFedRanges(FederationMap fedMap, long cols) {
+			IntStream.range(0, fedMap.getFederatedRanges().length).forEach(i -> {
+				fedMap.getFederatedRanges()[i].setBeginDim(1, 0);
+				fedMap.getFederatedRanges()[i].setEndDim(1, cols);
+			});
+			return fedMap;
+		}
+
+		protected void aggResult(ExecutionContext ec, Future<FederatedResponse>[] response,
+			FederationMap fedMap) {
+			if(_fedType != FType.ROW)
+				throw new DMLRuntimeException("Only row partitioned federated matrices supported yet.");
+
+			// build up the instruction for aggregation (uak+/uark+/uack+)
+			String aggInst = "ua";
+			if(_rowType == RowType.FULL_AGG) // full aggregation
+				aggInst += "k+";
+			else if(_rowType == RowType.ROW_AGG) // row aggregation
+				aggInst += "rk+";
+			else if(_rowType.isColumnAgg()) // col aggregation
+				aggInst += "ck+";
+			else
+				throw new DMLRuntimeException("AggregationType not supported yet.");
+
+			// aggregate partial results from federated responses as sum/rowSum/colSum
+			AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator(aggInst);
+			if(_rowType == RowType.FULL_AGG)
 				ec.setVariable(_output.getName(), FederationUtils.aggScalar(aop, response));
-			}
-			else if(rowType == RowType.ROW_AGG) { // row aggregation
-				// aggregate partial results from federated responses as rowSum
-				AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
-				ec.setMatrixOutput(_output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
-			}
-			else if(rowType == RowType.COL_AGG
-				|| rowType == RowType.COL_AGG_T
-				|| rowType == RowType.COL_AGG_B1
-				|| rowType == RowType.COL_AGG_B1_T
-				|| rowType == RowType.COL_AGG_B1R
-				|| rowType == RowType.COL_AGG_CONST) { // col aggregation
-				// aggregate partial results from federated responses as colSum
-				AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
+			else
 				ec.setMatrixOutput(_output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
-			}
-			else if(rowType == RowType.NO_AGG
-				|| rowType == RowType.NO_AGG_B1
-				|| rowType == RowType.NO_AGG_CONST) { // no aggregation
-				if(fedMap.getType() == FType.ROW) {
-					// bind partial results from federated responses
-					ec.setMatrixOutput(_output.getName(), FederationUtils.bind(response, false));
-				}
-				else {
-					throw new DMLRuntimeException("Only row partitioned federated matrices supported yet.");
-				}
-			}
-			else {
-				throw new DMLRuntimeException("AggregationType not supported yet.");
-			}
 		}
 	}
 
+	// MULTIAGGREGATE TEMPLATE
 	private static class SpoofFEDMultiAgg extends SpoofFEDType {
 		private final SpoofMultiAggregate _op;
 
-		SpoofFEDMultiAgg(SpoofOperator op, CPOperand out) {
-			super(out);
+		SpoofFEDMultiAgg(SpoofOperator op, CPOperand out, FType fedType) {
+			super(out, fedType);
 			_op = (SpoofMultiAggregate)op;
 		}
 
-		protected void setOutput(ExecutionContext ec, Future<FederatedResponse>[] response, FederationMap fedMap) {
-			MatrixBlock[] partRes = FederationUtils.getResults(response);
-			SpoofCellwise.AggOp[] aggOps = ((SpoofMultiAggregate)_op).getAggOps();
-			for(int counter = 1; counter < partRes.length; counter++) {
-				SpoofMultiAggregate.aggregatePartialResults(aggOps, partRes[0], partRes[counter]);
-			}
-			ec.setMatrixOutput(_output.getName(), partRes[0]);
+		protected boolean isFedOutput() {
+			return false;
 		}
-	}
 
+		protected void setFedOutput(ExecutionContext ec, FederationMap fedMap, long frComputeID) {
+			throw new DMLRuntimeException("SpoofFEDMultiAgg cannot create a federated output.");
+		}
 
+		protected void aggResult(ExecutionContext ec, Future<FederatedResponse>[] response,
+			FederationMap fedMap) {
+				MatrixBlock[] partRes = FederationUtils.getResults(response);
+				SpoofCellwise.AggOp[] aggOps = _op.getAggOps();
+				for(int counter = 1; counter < partRes.length; counter++) {
+					SpoofMultiAggregate.aggregatePartialResults(aggOps, partRes[0], partRes[counter]);
+				}
+				ec.setMatrixOutput(_output.getName(), partRes[0]);
+			}
+	}
+
+	// OUTER PRODUCT TEMPLATE
 	private static class SpoofFEDOuterProduct extends SpoofFEDType {
 		private final SpoofOuterProduct _op;
+		private final OutProdType _outProdType;
+		private CPOperand[] _inputs;
 
-		SpoofFEDOuterProduct(SpoofOperator op, CPOperand out) {
-			super(out);
+		SpoofFEDOuterProduct(SpoofOperator op, CPOperand out, FType fedType, CPOperand[] inputs) {
+			super(out, fedType);
 			_op = (SpoofOuterProduct)op;
+			_outProdType = _op.getOuterProdType();
+			_inputs = inputs;
+		}
+
+		protected FederatedRequest[] broadcastSliced(MatrixObject mo, FederationMap fedMap) {
+			return fedMap.broadcastSliced(mo, (_fedType == FType.COL));
 		}
 
 		protected boolean needsBroadcastSliced(FederationMap fedMap, long rowNum, long colNum, int inputIndex) {
 			boolean retVal = false;
-			FType fedType = fedMap.getType();
 			
 			retVal |= (rowNum == fedMap.getMaxIndexInRange(0) && colNum == fedMap.getMaxIndexInRange(1));
 			
-			if(fedType == FType.ROW)
+			if(_fedType == FType.ROW)
 				retVal |= (rowNum == fedMap.getMaxIndexInRange(0)) && (inputIndex != 2); // input at index 2 is V
-			else if(fedType == FType.COL)
+			else if(_fedType == FType.COL)
 				retVal |= (rowNum == fedMap.getMaxIndexInRange(1)) && (inputIndex != 1); // input at index 1 is U
 			else
 				throw new DMLRuntimeException("Only row partitioned or column" +
 					" partitioned federated input supported yet.");
-			
+
 			return retVal;
 		}
 
-		protected void setOutput(ExecutionContext ec, Future<FederatedResponse>[] response, FederationMap fedMap) {
-			FType fedType = fedMap.getType();
-			OutProdType outProdType = ((SpoofOuterProduct)_op).getOuterProdType();
-			if(outProdType == OutProdType.LEFT_OUTER_PRODUCT) {
-				if(fedType == FType.ROW) {
-					// aggregate partial results from federated responses as elementwise sum
-					AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
-					ec.setMatrixOutput(_output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
-				}
-				else if(fedType == FType.COL) {
-					// bind partial results from federated responses
-					ec.setMatrixOutput(_output.getName(), FederationUtils.bind(response, false));
-				}
-				else {
-					throw new DMLRuntimeException("Only row partitioned or column" +
-						" partitioned federated matrices supported yet.");
-				}
+		protected boolean isFedOutput() {
+			boolean retVal = false;
+			retVal |= (_outProdType == OutProdType.LEFT_OUTER_PRODUCT && _fedType == FType.COL);
+			retVal |= (_outProdType == OutProdType.RIGHT_OUTER_PRODUCT && _fedType == FType.ROW);
+			retVal |= (_outProdType == OutProdType.CELLWISE_OUTER_PRODUCT);
+			return retVal;
+		}
+
+		protected void setFedOutput(ExecutionContext ec, FederationMap fedMap, long frComputeID) {
+			FederationMap newFedMap = fedMap.copyWithNewID(frComputeID);
+			long[] outDims = new long[2];
+
+			// find the resulting output dimensions
+			MatrixObject X = ec.getMatrixObject(_inputs[0]);
+			switch(_outProdType) {
+				case LEFT_OUTER_PRODUCT: // LEFT: nrows of transposed X, ncols of U
+					newFedMap = newFedMap.transpose();
+					outDims[0] = X.getNumColumns();
+					outDims[1] = ec.getMatrixObject(_inputs[1]).getNumColumns();
+					break;
+				case RIGHT_OUTER_PRODUCT: // RIGHT: nrows of X, ncols of V
+					outDims[0] = X.getNumRows();
+					outDims[1] = ec.getMatrixObject(_inputs[2]).getNumColumns();
+					break;
+				case CELLWISE_OUTER_PRODUCT: // BASIC: preserve dimensions of X
+					outDims[0] = X.getNumRows();
+					outDims[1] = X.getNumColumns();
+					break;
+				default:
+					throw new DMLRuntimeException("Outer Product Type " + _outProdType + " not supported yet.");
 			}
-			else if(outProdType == OutProdType.RIGHT_OUTER_PRODUCT) {
-				if(fedType == FType.ROW) {
-					// bind partial results from federated responses
-					ec.setMatrixOutput(_output.getName(), FederationUtils.bind(response, false));
-				}
-				else if(fedType == FType.COL) {
+
+			// derive output federated mapping
+			MatrixObject out = ec.getMatrixObject(_output);
+			int dim = (newFedMap.getType() == FType.ROW ? 1 : 0);
+			newFedMap = modifyFedRanges(newFedMap, dim, outDims[dim]);
+			out.setFedMapping(newFedMap);
+		}
+
+		private static FederationMap modifyFedRanges(FederationMap fedMap, int dim, long value) {
+			IntStream.range(0, fedMap.getFederatedRanges().length).forEach(i -> {
+				fedMap.getFederatedRanges()[i].setBeginDim(dim, 0);
+				fedMap.getFederatedRanges()[i].setEndDim(dim, value);
+			});
+			return fedMap;
+		}
+
+		protected void aggResult(ExecutionContext ec, Future<FederatedResponse>[] response,
+			FederationMap fedMap) {
+			AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+			switch(_outProdType) {
+				case LEFT_OUTER_PRODUCT:
+				case RIGHT_OUTER_PRODUCT:
 					// aggregate partial results from federated responses as elementwise sum
-					AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
 					ec.setMatrixOutput(_output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
-				}
-				else {
-					throw new DMLRuntimeException("Only row partitioned or column" +
-						" partitioned federated matrices supported yet.");
-				}
-			}
-			else if(outProdType == OutProdType.CELLWISE_OUTER_PRODUCT) {
-				if(fedType == FType.ROW) {
-					// rbind partial results from federated responses
-					ec.setMatrixOutput(_output.getName(), FederationUtils.bind(response, false));
-				}
-				else if(fedType == FType.COL) {
-					// cbind partial results from federated responses
-					ec.setMatrixOutput(_output.getName(), FederationUtils.bind(response, true));
-				}
-				else {
-					throw new DMLRuntimeException("Only row partitioned or column" +
-						" partitioned federated matrices supported yet.");
-				}
-			}
-			else if(outProdType == OutProdType.AGG_OUTER_PRODUCT) {
-				// aggregate partial results from federated responses as sum
-				AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
-				ec.setVariable(_output.getName(), FederationUtils.aggScalar(aop, response));
-			}
-			else {
-				throw new DMLRuntimeException("Outer Product Type " + outProdType + " not supported yet.");
+					break;
+				case AGG_OUTER_PRODUCT:
+					// aggregate partial results from federated responses as sum
+					ec.setVariable(_output.getName(), FederationUtils.aggScalar(aop, response));
+					break;
+				default:
+					throw new DMLRuntimeException("Outer Product Type " + _outProdType + " not supported yet.");
 			}
 		}
 	}
@@ -458,5 +528,4 @@ public class SpoofFEDInstruction extends FEDInstruction
 		}
 		return retVal;
 	}
-
 }
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
index 65f1728..61722db 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
@@ -104,7 +104,7 @@ public class FederatedCodegenMultipleFedMOTest extends AutomatedTestBase
 			// row partitioned
 			// {201, 6, 4, 6, 4, true},
 			{202, 6, 4, 6, 4, true},
-			// {203, 20, 1, 20, 1, true},
+			// FIXME: [SYSTEMDS-3110] {203, 20, 1, 20, 1, true},
 			// col partitioned
 			{201, 6, 4, 6, 4, false},
 			{202, 6, 4, 6, 4, false},
@@ -123,9 +123,9 @@ public class FederatedCodegenMultipleFedMOTest extends AutomatedTestBase
 			{308, 1000, 2000, 10, 2000, false},
 			// {310, 1000, 2000, 10, 2000, false},
 			// row and col partitioned
-			// {311, 1000, 2000, 1000, 10, true}, // not working yet - ArrayIndexOutOfBoundsException in dotProduct
+			// {311, 1000, 2000, 1000, 10, true}, // FIXME: ArrayIndexOutOfBoundsException in dotProduct
 			{312, 1000, 2000, 10, 2000, false},
-			// {313, 4000, 2000, 4000, 10, true}, // not working yet - ArrayIndexOutOfBoundsException in dotProduct
+			// {313, 4000, 2000, 4000, 10, true}, // FIXME: ArrayIndexOutOfBoundsException in dotProduct
 			{314, 4000, 2000, 10, 2000, false},
 
 			// combined tests
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
index edc9ab7..cef5fd5 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
@@ -86,14 +86,14 @@ public class FederatedOuterProductTmplTest extends AutomatedTestBase
 			{9, 1000, 2000, true},
 
 			// column partitioned
-			//FIXME {1, 2000, 2000, false},
+			{1, 2000, 2000, false},
 			// {2, 4000, 2000, false},
 			// {3, 1000, 1000, false},
-			//FIXME {4, 4000, 2000, false},
-			//FIXME {5, 4000, 2000, false},
+			{4, 4000, 2000, false},
+			{5, 4000, 2000, false},
 			// {6, 4000, 2000, false},
 			//FIXME {7, 2000, 2000, false},
-			//FIXME {8, 1000, 2000, false},
+			{8, 1000, 2000, false},
 			// {9, 1000, 2000, false},
 		});
 	}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
index b4bff76..89475d8 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
@@ -117,7 +117,7 @@ public class FederatedRowwiseTmplTest extends AutomatedTestBase
 	}
 
 	@Test
-	public void federatedCodegenCellwiseHybrid() {
+	public void federatedCodegenRowwiseHybrid() {
 		testFederatedCodegenRowwise(ExecMode.HYBRID);
 	}
 
diff --git a/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java b/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java
index 71160b7..f2e873c 100644
--- a/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java
@@ -25,7 +25,6 @@ import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.junit.Assert;
 import org.junit.Ignore;
-import org.junit.Test;
 
 public class BuiltinTopkEvaluateTest extends AutomatedTestBase {
 	//	private final static String TEST_NAME1 = "prioritized";