You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by se...@apache.org on 2021/09/03 10:21:45 UTC

[systemds] branch master updated: [SYSTEMDS-3085] FederatedCTable - Keep Output Federated Closes #1371.

This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 0fa4463  [SYSTEMDS-3085] FederatedCTable - Keep Output Federated Closes #1371.
0fa4463 is described below

commit 0fa4463b42dce5a65bdeb3f9d7a1c422db174cda
Author: ywcb00 <yw...@ywcb.org>
AuthorDate: Wed Aug 11 12:50:36 2021 +0200

    [SYSTEMDS-3085] FederatedCTable - Keep Output Federated
    Closes #1371.
---
 .../instructions/fed/CtableFEDInstruction.java     | 306 ++++++++++++++-------
 .../federated/primitives/FederatedCtableTest.java  |  11 +-
 .../federated/FederatedCtableFedOutput.dml         |  10 +-
 .../FederatedCtableFedOutputReference.dml          |  12 +-
 4 files changed, 234 insertions(+), 105 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
index a7d4cb6..2a1cbb1 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
@@ -22,7 +22,10 @@ package org.apache.sysds.runtime.instructions.fed;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.concurrent.Future;
+import java.util.Iterator;
+import java.util.SortedMap;
 import java.util.stream.IntStream;
+import java.util.TreeMap;
 
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.common.Types.DataType;
@@ -38,7 +41,6 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
-import org.apache.sysds.runtime.functionobjects.And;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -105,8 +107,10 @@ public class CtableFEDInstruction extends ComputationFEDInstruction {
 		}
 
 		// get new output dims
-		Long[] dims1 = getOutputDimension(mo1, input1, _outDim1, mo1.getFedMapping().getFederatedRanges());
-		Long[] dims2 = getOutputDimension(mo2, input2, _outDim2, mo1.getFedMapping().getFederatedRanges());
+		Long[] dims1 = getOutputDimension(mo1, reversed ? input2 : input1, reversed ? _outDim2 : _outDim1,
+			mo1.getFedMapping().getFederatedRanges());
+		Long[] dims2 = getOutputDimension(mo2, reversed ? input1 : input2, reversed ? _outDim1 : _outDim2,
+			mo1.getFedMapping().getFederatedRanges());
 
 		MatrixObject mo3 = input3 != null && input3.isMatrix() ? ec.getMatrixObject(input3) : null;
 
@@ -116,119 +120,157 @@ public class CtableFEDInstruction extends ComputationFEDInstruction {
 			mo1 = ec.getMatrixObject(input3);
 		}
 
-		long dim1 = Collections.max(Arrays.asList(dims1), Long::compare);
-		boolean fedOutput = dim1 % mo1.getFedMapping().getSize() == 0 && dims1.length == Arrays.stream(dims1).distinct().count();
+		// static non-partitioned output dimension (same for all federated partitions)
+		long staticDim = Collections.max(Arrays.asList(dims1), Long::compare);
+		boolean fedOutput = isFedOutput(mo1.getFedMapping(), mo2);
 
-		processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, fedOutput, dims1, dims2);
+		processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, fedOutput, staticDim, dims2);
 	}
 
+	/**
+	 * Broadcast, execute, and finalize the federated instruction according to
+	 * the specified inputs.
+	 *
+	 * @param ec execution context
+	 * @param mo1 input matrix object 1
+	 * @param mo2 input matrix object 2
+	 * @param mo3 input matrix object 3 or null
+	 * @param reversed boolean indicating if inputs mo1 and mo2 are reversed
+	 * @param reversedWeights boolean indicating if inputs mo1 and mo3 are reversed
+	 * @param fedOutput boolean indicating if output can be kept federated
+	 * @param staticDim static non-partitioned dimension of the output
+	 * @param dims2 dimensions of the partial outputs along the federated partitioning
+	 */
 	private void processRequest(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3,
-		boolean reversed, boolean reversedWeights, boolean fedOutput, Long[] dims1, Long[] dims2) {
+		boolean reversed, boolean reversedWeights, boolean fedOutput, long staticDim, Long[] dims2) {
+
+		FederationMap fedMap = mo1.getFedMapping();
+
+		FederatedRequest[] fr1 = fedMap.broadcastSliced(mo2, false);
+		FederatedRequest[] fr2 = null;
+		FederatedRequest fr3, fr4, fr5;
 		Future<FederatedResponse>[] ffr;
 
-		FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
-		FederatedRequest fr2, fr3;
 		if(mo3 != null && mo1.isFederated() && mo3.isFederated()
-		&& mo1.getFedMapping().isAligned(mo3.getFedMapping(), AlignType.FULL)) { // mo1 and mo3 federated and aligned
+			&& fedMap.isAligned(mo3.getFedMapping(), AlignType.FULL)) { // mo1 and mo3 federated and aligned
 			if(!reversed)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {mo1.getFedMapping().getID(), fr1[0].getID(), mo3.getFedMapping().getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fedMap.getID(), fr1[0].getID(), mo3.getFedMapping().getID()});
 			else
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {fr1[0].getID(), mo1.getFedMapping().getID(), mo3.getFedMapping().getID()});
-
-			fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-			ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fr1[0].getID(), fedMap.getID(), mo3.getFedMapping().getID()});
 		}
 		else if(mo3 == null) {
 			if(!reversed)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
-					new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
+					new long[] {fedMap.getID(), fr1[0].getID()});
 			else
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
-					new long[] {fr1[0].getID(), mo1.getFedMapping().getID()});
-
-			fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-			ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
-
-		} else {
-			FederatedRequest[] fr4 = mo1.getFedMapping().broadcastSliced(mo3, false);
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
+					new long[] {fr1[0].getID(), fedMap.getID()});
+		}
+		else {
+			fr2 = fedMap.broadcastSliced(mo3, false);
 			if(!reversed && !reversedWeights)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {mo1.getFedMapping().getID(), fr1[0].getID(), fr4[0].getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fedMap.getID(), fr1[0].getID(), fr2[0].getID()});
 			else if(reversed && !reversedWeights)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {fr1[0].getID(), mo1.getFedMapping().getID(), fr4[0].getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fr1[0].getID(), fedMap.getID(), fr2[0].getID()});
 			else
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {fr1[0].getID(), fr4[0].getID(), mo1.getFedMapping().getID()});
-
-			fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-			ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr4, fr2, fr3);
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fr1[0].getID(), fr2[0].getID(), fedMap.getID()});
 		}
 
-		if(fedOutput && isFedOutput(ffr, dims1)) {
+		if(fedOutput) {
+			if(fr2 != null) // broadcasted mo3
+				fedMap.execute(getTID(), true, fr1, fr2, fr3);
+			else
+				fedMap.execute(getTID(), true, fr1, fr3);
+
 			MatrixObject out = ec.getMatrixObject(output);
-			FederationMap newFedMap = modifyFedRanges(mo1.getFedMapping(), dims1, dims2);
-			setFedOutput(mo1, out, newFedMap, dims1, fr2.getID());
+			FederationMap newFedMap = modifyFedRanges(fedMap.copyWithNewID(fr3.getID()),
+				staticDim, dims2, reversed);
+			setFedOutput(mo1, out, newFedMap, staticDim, dims2, reversed);
 		} else {
+			fr4 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr3.getID());
+			fr5 = fedMap.cleanup(getTID(), fr3.getID());
+			if(fr2 != null) // broadcasted mo3
+				ffr = fedMap.execute(getTID(), true, fr1, fr2, fr3, fr4, fr5);
+			else
+				ffr = fedMap.execute(getTID(), true, fr1, fr3, fr4, fr5);
+
 			ec.setMatrixOutput(output.getName(), aggResult(ffr));
 		}
 	}
 
-	boolean isFedOutput(Future<FederatedResponse>[] ffr,  Long[] dims1) {
-		boolean fedOutput = true;
-
-		long fedSize = Collections.max(Arrays.asList(dims1), Long::compare) / ffr.length;
-		try {
-			MatrixBlock curr;
-			MatrixBlock prev =(MatrixBlock) ffr[0].get().getData()[0];
-			for(int i = 1; i < ffr.length && fedOutput; i++) {
-				curr = (MatrixBlock) ffr[i].get().getData()[0];
-				MatrixBlock sliced = curr.slice((int) (curr.getNumRows() - fedSize), curr.getNumRows() - 1);
-
-				if(curr.getNumColumns() != prev.getNumColumns())
-					return false;
-
-				// no intersection
-				if(curr.getNumRows() == (i+1) * prev.getNumRows() && curr.getNonZeros() <= prev.getLength()
-					&& (curr.getNumRows() - sliced.getNumRows()) == i * prev.getNumRows()
-					&& curr.getNonZeros() - sliced.getNonZeros() == 0)
-					continue;
-
-				// check intersect with AND and compare number of nnz
-				MatrixBlock prevExtend = new MatrixBlock(curr.getNumRows(), curr.getNumColumns(), true, 0);
-				prevExtend.copy(0, prev.getNumRows()-1, 0, prev.getNumColumns()-1, prev, true);
-
-				MatrixBlock  intersect = curr.binaryOperationsInPlace(new BinaryOperator(And.getAndFnObject()), prevExtend);
-				if(intersect.getNonZeros() != 0)
-					fedOutput = false;
-				prev = sliced;
-			}
-		}
-		catch(Exception e) {
-			e.printStackTrace();
-		}
-		return fedOutput;
-	}
+	/**
+	 * Evaluate if the output can be kept federated on the different federated
+	 * sites or if the output needs to be aggregated on the coordinator, based
+	 * on the output ranges of mo2.
+	 * The output can be kept federated if the slices of mo2, sliced corresponding
+	 * to the federated ranges of mo1, have strict separable and ascending value
+	 * ranges. From this property it follows that the partial outputs can also
+	 * be separated, and hence the overall output can be created by a simple
+	 * binding through a federated mapping.
+	 *
+	 * @param fedMap the federation map of the federated matrix input mo1
+	 * @param mo2 input matrix object mo2
+	 * @return boolean indicating if the output can be kept on the federated sites
+	 */
+	private boolean isFedOutput(FederationMap fedMap, MatrixObject mo2) {
+		MatrixBlock mb = mo2.acquireReadAndRelease();
+		FederatedRange[] fedRanges = fedMap.getFederatedRanges(); // federated ranges of mo1
+		SortedMap<Double, Double> fedDims = new TreeMap<Double, Double>(); // <beginDim, endDim>
+
+		// collect min and max of the corresponding slices of mo2
+		IntStream.range(0, fedRanges.length).forEach(i -> {
+			MatrixBlock sliced = mb.slice(
+				fedRanges[i].getBeginDimsInt()[0], fedRanges[i].getEndDimsInt()[0] - 1,
+				fedRanges[i].getBeginDimsInt()[1], fedRanges[i].getEndDimsInt()[1] - 1);
+			fedDims.put(sliced.min(), sliced.max());
+		});
 
+		boolean retVal = (fedDims.size() == fedRanges.length); // no duplicate begin dimension entries
 
-	private static void setFedOutput(MatrixObject mo1, MatrixObject out, FederationMap fedMap, Long[] dims1, long outId) {
-		long fedSize = Collections.max(Arrays.asList(dims1), Long::compare) / dims1.length;
+		Iterator<SortedMap.Entry<Double, Double>> iter = fedDims.entrySet().iterator();
+		SortedMap.Entry<Double, Double> entry = iter.next(); // first entry does not have to be checked
+		double prevEndDim = entry.getValue();
+		while(iter.hasNext() && retVal) {
+			entry = iter.next();
+			// previous end dimension must be less than current begin dimension (no overlaps of ranges)
+			retVal &= (prevEndDim < entry.getKey());
+			prevEndDim = entry.getValue();
+		}
 
-		long d1 = Collections.max(Arrays.asList(dims1), Long::compare);
-		long d2 = Collections.max(Arrays.asList(dims1), Long::compare);
+		return retVal;
+	}
+
+	/**
+	 * Set the output and its data characteristics on the federated sites.
+	 *
+	 * @param mo1 input matrix object mo1
+	 * @param out input matrix object of the output
+	 * @param fedMap the federation map of the federated matrix input mo1
+	 * @param staticDim static non-partitioned dimension of the output
+	 * @param dims2 dimensions of the partial outputs along the federated partitioning
+	 * @param reversed boolean indicating if inputs mo1 and mo2 are reversed
+	 */
+	private static void setFedOutput(MatrixObject mo1, MatrixObject out, FederationMap fedMap,
+		long staticDim, Long[] dims2, boolean reversed) {
+		// get the final output dimensions
+		final long d1 = (reversed ? Collections.max(Arrays.asList(dims2)) : staticDim);
+		final long d2 = (reversed ? staticDim : Collections.max(Arrays.asList(dims2)));
 
 		// set output
 		out.getDataCharacteristics().set(d1, d2, (int) mo1.getBlocksize(), mo1.getNnz());
-		out.setFedMapping(fedMap.copyWithNewID(outId));
+		out.setFedMapping(fedMap);
 
 		long varID = FederationUtils.getNextFedDataID();
-		out.getFedMapping().mapParallel(varID, (range, data) -> {
+		fedMap.mapParallel(varID, (range, data) -> {
 			try {
 				FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(
 					FederatedRequest.RequestType.EXEC_UDF, -1,
-					new SliceOutput(data.getVarID(), fedSize))).get();
+					new SliceOutput(data.getVarID(), staticDim, dims2, reversed))).get();
 				if(!response.isSuccessful())
 					response.throwExceptionFromResponse();
 			}
@@ -239,6 +281,9 @@ public class CtableFEDInstruction extends ComputationFEDInstruction {
 		});
 	}
 
+	/**
+	 * Aggregate the partial outputs locally.
+	 */
 	private static MatrixBlock aggResult(Future<FederatedResponse>[] ffr) {
 		MatrixBlock resultBlock = new MatrixBlock(1, 1, true, 0);
 		int dim1 = 0, dim2 = 0;
@@ -266,27 +311,44 @@ public class CtableFEDInstruction extends ComputationFEDInstruction {
 		return resultBlock;
 	}
 
-	private static FederationMap modifyFedRanges(FederationMap fedMap, Long[] dims1, Long[] dims2) {
-		IntStream.range(0, fedMap.getFederatedRanges().length).forEach(i -> {
-			fedMap.getFederatedRanges()[i]
-				.setBeginDim(0, i == 0 ? 0 : fedMap.getFederatedRanges()[i - 1].getEndDims()[0]);
-			fedMap.getFederatedRanges()[i].setEndDim(0, dims1[i]);
-			fedMap.getFederatedRanges()[i]
-				.setBeginDim(1, i == 0 ? 0 : fedMap.getFederatedRanges()[i - 1].getBeginDims()[1]);
-			fedMap.getFederatedRanges()[i].setEndDim(1, dims2[i]);
+	/**
+	 * Set the ranges of the federation map according to the static dimension and
+	 * the individual dimensions of the partial output matrices.
+	 *
+	 * @param fedMap the federation map of the federated matrix input mo1
+	 * @param staticDim static non-partitioned dimension of the output
+	 * @param dims2 dimensions of the partial outputs along the federated partitioning
+	 * @param reversed boolean indicating if inputs mo1 and mo2 are reversed
+	 * @return FederationMap the modified federation map
+	 */
+	private static FederationMap modifyFedRanges(FederationMap fedMap, long staticDim,
+		Long[] dims2, boolean reversed) {
+		// set the federated ranges to the individual partition sizes
+		IntStream.range(0, fedMap.getFederatedRanges().length).forEach(counter -> {
+			FederatedRange fedRange = fedMap.getFederatedRanges()[counter];
+			fedRange.setBeginDim(reversed ? 1 : 0, 0);
+			fedRange.setEndDim(reversed ? 1 : 0, staticDim);
+			fedRange.setBeginDim(reversed ? 0 : 1, counter == 0 ? 0 : dims2[counter-1]);
+			fedRange.setEndDim(reversed ? 0 : 1, dims2[counter]);
 		});
 		return fedMap;
 	}
 
-	private Long[] getOutputDimension(MatrixObject in, CPOperand inOp, CPOperand outOp, FederatedRange[] federatedRanges) {
+	/**
+	 * Compute the output dimensions of the partial outputs according to the
+	 * federated ranges.
+	 */
+	private Long[] getOutputDimension(MatrixObject in, CPOperand inOp, CPOperand outOp,
+		FederatedRange[] federatedRanges) {
 		Long[] fedDims = new Long[federatedRanges.length];
 
 		if(!in.isFederated()) {
 			//slice
 			MatrixBlock mb = in.acquireReadAndRelease();
 			IntStream.range(0, federatedRanges.length).forEach(i -> {
-				MatrixBlock sliced = mb
-					.slice(federatedRanges[i].getBeginDimsInt()[0], federatedRanges[i].getEndDimsInt()[0] - 1);
+				MatrixBlock sliced = mb.slice(
+					federatedRanges[i].getBeginDimsInt()[0], federatedRanges[i].getEndDimsInt()[0] - 1,
+					federatedRanges[i].getBeginDimsInt()[1], federatedRanges[i].getEndDimsInt()[1] - 1);
 				fedDims[i] = (long) sliced.max();
 			});
 			return fedDims;
@@ -326,29 +388,79 @@ public class CtableFEDInstruction extends ComputationFEDInstruction {
 		return String.join(Lop.OPERAND_DELIMITOR, maxInstParts);
 	}
 
+	/**
+	 * Static class which extends FederatedUDF to modify the partial outputs on
+	 * the federated sites such that they can be bound without any local
+	 * aggregation.
+	 */
 	private static class SliceOutput extends FederatedUDF {
 
 		private static final long serialVersionUID = -2808597461054603816L;
-		private final long _fedSize;
+		private final int _staticDim;
+		private final Long[] _fedDims;
+		private final boolean _reversed;
 
-		protected SliceOutput(long input, long fedSize) {
+		protected SliceOutput(long input, long staticDim, Long[] fedDims, boolean reversed) {
 			super(new long[] {input});
-			_fedSize = fedSize;
+			_staticDim = (int)staticDim;
+			_fedDims = fedDims;
+			_reversed = reversed;
 		}
 
+		/**
+		 * Find the dimensions of the partial output matrix and expand it to the
+		 * global static dimension along the non-partitioned axis and crop it
+		 * along the paritioned axis.
+		 *
+		 * @param ec the execution context
+		 * @param data
+		 * @return FederatedResponse with status SUCCESS and an empty object
+		 */
 		public FederatedResponse execute(ExecutionContext ec, Data... data) {
 			MatrixObject mo = (MatrixObject) data[0];
 			MatrixBlock mb = mo.acquireReadAndRelease();
 
-			MatrixBlock sliced = mb.slice((int) (mb.getNumRows()-_fedSize), mb.getNumRows()-1);
+			int beginDim = 0;
+			int endDim = (_reversed ? mb.getNumRows() : mb.getNumColumns());
+			int localStaticDim = (_reversed ? mb.getNumColumns() : mb.getNumRows());
+			for(int counter = 0; counter < _fedDims.length; counter++) {
+				if(_fedDims[counter] == endDim) {
+					beginDim = (counter == 0 ? 0 : _fedDims[counter - 1].intValue());
+					break;
+				}
+			}
+
+			mb = expandMatrix(mb, localStaticDim);
+
+			// crop the output
+			MatrixBlock sliced = _reversed ? mb.slice(beginDim, endDim - 1, 0, _staticDim - 1)
+				: mb.slice(0, _staticDim - 1, beginDim, endDim - 1);
 			mo.acquireModify(sliced);
 			mo.release();
 
 			return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[] {});
 		}
+
+		/**
+		 * Expand the matrix with zeros up to the specified static dimension.
+		 *
+		 * @param mb the matrix block of the partial output
+		 * @param localStaticDim the static dimension of the output matrix block
+		 * @return MatrixBlock the output matrix block expanded to the global static dimension
+		 */
+		private MatrixBlock expandMatrix(MatrixBlock mb, int localStaticDim) {
+			int diff = _staticDim - localStaticDim;
+			if(diff > 0) {
+				MatrixBlock tmpMb = (_reversed ? new MatrixBlock(mb.getNumRows(), diff, (double) 0)
+					: new MatrixBlock(diff, mb.getNumColumns(), (double) 0));
+				mb = mb.append(tmpMb, null, _reversed);
+			}
+			return mb;
+		}
+
 		@Override
 		public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
 			return null;
 		}
 	}
-}
\ No newline at end of file
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
index a5793b5..9aeb776 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
@@ -85,6 +85,9 @@ public class FederatedCtableTest extends AutomatedTestBase {
 	@Test
 	public void federatedCtableMatrixInputSinglenode() { runCtable(Types.ExecMode.SINGLE_NODE, false, true); }
 
+	@Test
+	public void federatedCtableMatrixInputFedOutputSingleNode() { runCtable(Types.ExecMode.SINGLE_NODE, true, true); }
+
 
 	public void runCtable(Types.ExecMode execMode, boolean fedOutput, boolean matrixInput) {
 		String TEST_NAME = fedOutput ? TEST_NAME2 : TEST_NAME1;
@@ -108,7 +111,7 @@ public class FederatedCtableTest extends AutomatedTestBase {
 		loadTestConfiguration(config);
 
 		if(fedOutput)
-			runFedCtable(HOME, TEST_NAME, port1, port2, port3, port4);
+			runFedCtable(HOME, TEST_NAME, matrixInput, port1, port2, port3, port4);
 		else
 			runNonFedCtable(HOME, TEST_NAME, matrixInput, port1, port2, port3, port4);
 		checkResults();
@@ -155,7 +158,7 @@ public class FederatedCtableTest extends AutomatedTestBase {
 		runTest(true, false, null, -1);
 	}
 
-	private void runFedCtable(String HOME, String TEST_NAME, int port1, int port2, int port3, int port4) {
+	private void runFedCtable(String HOME, String TEST_NAME, boolean matrixInput, int port1, int port2, int port3, int port4) {
 		int r = rows / 4;
 		int c = cols;
 
@@ -174,7 +177,8 @@ public class FederatedCtableTest extends AutomatedTestBase {
 		fullDMLScriptName = HOME + TEST_NAME2 + "Reference.dml";
 		programArgs = new String[]{"-stats", "100", "-args",
 			input("X1"), input("X2"), input("X3"), input("X4"), Boolean.toString(reversedInputs).toUpperCase(),
-			Boolean.toString(weighted).toUpperCase(), expected("F")};
+			Boolean.toString(weighted).toUpperCase(), Boolean.toString(matrixInput).toUpperCase(),
+			expected("F")};
 		runTest(true, false, null, -1);
 
 		// Run actual dml script with federated matrix
@@ -185,6 +189,7 @@ public class FederatedCtableTest extends AutomatedTestBase {
 			"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
 			"in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
 			"rows=" + rows, "cols=" + cols, "revIn=" + Boolean.toString(reversedInputs).toUpperCase(),
+			"matrixInput=" + Boolean.toString(matrixInput).toUpperCase(),
 			"weighted=" + Boolean.toString(weighted).toUpperCase(), "out=" + output("F")
 		};
 		runTest(true, false, null, -1);
diff --git a/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml b/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
index 9c21ed5..a2eda9d 100644
--- a/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
+++ b/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
@@ -28,8 +28,14 @@ n = ncol(X);
 
 # prepare offset vectors and one-hot encoded X
 maxs = colMaxs(X);
-rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1);
-cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
+if($matrixInput) {
+  rix = matrix(seq(1,m)%*%matrix(1,1,n), m, n);
+  cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m, n);
+}
+else {
+  rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1);
+  cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
+}
 
 W = rix + cix;
 
diff --git a/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml b/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
index e0721df..4fc6852 100644
--- a/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
+++ b/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
@@ -27,8 +27,14 @@ n = ncol(X);
 # prepare offset vectors and one-hot encoded X
 maxs = colMaxs(X);
 
-rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1)
-cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
+if($7) { # matrix input
+  rix = matrix(seq(1,m)%*%matrix(1,1,n), m, n);
+  cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m, n);
+}
+else {
+  rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1);
+  cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
+}
 
 W = rix + cix;
 
@@ -43,4 +49,4 @@ else
   else
     X2 = table(rix, cix);
 
-write(X2, $7);
+write(X2, $8);