You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by se...@apache.org on 2021/09/03 10:21:45 UTC
[systemds] branch master updated: [SYSTEMDS-3085] FederatedCTable -
Keep Output Federated Closes #1371.
This is an automated email from the ASF dual-hosted git repository.
sebwrede pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 0fa4463 [SYSTEMDS-3085] FederatedCTable - Keep Output Federated Closes #1371.
0fa4463 is described below
commit 0fa4463b42dce5a65bdeb3f9d7a1c422db174cda
Author: ywcb00 <yw...@ywcb.org>
AuthorDate: Wed Aug 11 12:50:36 2021 +0200
[SYSTEMDS-3085] FederatedCTable - Keep Output Federated
Closes #1371.
---
.../instructions/fed/CtableFEDInstruction.java | 306 ++++++++++++++-------
.../federated/primitives/FederatedCtableTest.java | 11 +-
.../federated/FederatedCtableFedOutput.dml | 10 +-
.../FederatedCtableFedOutputReference.dml | 12 +-
4 files changed, 234 insertions(+), 105 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
index a7d4cb6..2a1cbb1 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
@@ -22,7 +22,10 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.Arrays;
import java.util.Collections;
import java.util.concurrent.Future;
+import java.util.Iterator;
+import java.util.SortedMap;
import java.util.stream.IntStream;
+import java.util.TreeMap;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types.DataType;
@@ -38,7 +41,6 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
-import org.apache.sysds.runtime.functionobjects.And;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -105,8 +107,10 @@ public class CtableFEDInstruction extends ComputationFEDInstruction {
}
// get new output dims
- Long[] dims1 = getOutputDimension(mo1, input1, _outDim1, mo1.getFedMapping().getFederatedRanges());
- Long[] dims2 = getOutputDimension(mo2, input2, _outDim2, mo1.getFedMapping().getFederatedRanges());
+ Long[] dims1 = getOutputDimension(mo1, reversed ? input2 : input1, reversed ? _outDim2 : _outDim1,
+ mo1.getFedMapping().getFederatedRanges());
+ Long[] dims2 = getOutputDimension(mo2, reversed ? input1 : input2, reversed ? _outDim1 : _outDim2,
+ mo1.getFedMapping().getFederatedRanges());
MatrixObject mo3 = input3 != null && input3.isMatrix() ? ec.getMatrixObject(input3) : null;
@@ -116,119 +120,157 @@ public class CtableFEDInstruction extends ComputationFEDInstruction {
mo1 = ec.getMatrixObject(input3);
}
- long dim1 = Collections.max(Arrays.asList(dims1), Long::compare);
- boolean fedOutput = dim1 % mo1.getFedMapping().getSize() == 0 && dims1.length == Arrays.stream(dims1).distinct().count();
+ // static non-partitioned output dimension (same for all federated partitions)
+ long staticDim = Collections.max(Arrays.asList(dims1), Long::compare);
+ boolean fedOutput = isFedOutput(mo1.getFedMapping(), mo2);
- processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, fedOutput, dims1, dims2);
+ processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, fedOutput, staticDim, dims2);
}
+ /**
+ * Broadcast, execute, and finalize the federated instruction according to
+ * the specified inputs.
+ *
+ * @param ec execution context
+ * @param mo1 input matrix object 1
+ * @param mo2 input matrix object 2
+ * @param mo3 input matrix object 3 or null
+ * @param reversed boolean indicating if inputs mo1 and mo2 are reversed
+ * @param reversedWeights boolean indicating if inputs mo1 and mo3 are reversed
+ * @param fedOutput boolean indicating if output can be kept federated
+ * @param staticDim static non-partitioned dimension of the output
+ * @param dims2 dimensions of the partial outputs along the federated partitioning
+ */
private void processRequest(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3,
- boolean reversed, boolean reversedWeights, boolean fedOutput, Long[] dims1, Long[] dims2) {
+ boolean reversed, boolean reversedWeights, boolean fedOutput, long staticDim, Long[] dims2) {
+
+ FederationMap fedMap = mo1.getFedMapping();
+
+ FederatedRequest[] fr1 = fedMap.broadcastSliced(mo2, false);
+ FederatedRequest[] fr2 = null;
+ FederatedRequest fr3, fr4, fr5;
Future<FederatedResponse>[] ffr;
- FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
- FederatedRequest fr2, fr3;
if(mo3 != null && mo1.isFederated() && mo3.isFederated()
- && mo1.getFedMapping().isAligned(mo3.getFedMapping(), AlignType.FULL)) { // mo1 and mo3 federated and aligned
+ && fedMap.isAligned(mo3.getFedMapping(), AlignType.FULL)) { // mo1 and mo3 federated and aligned
if(!reversed)
- fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
- new long[] {mo1.getFedMapping().getID(), fr1[0].getID(), mo3.getFedMapping().getID()});
+ fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+ new long[] {fedMap.getID(), fr1[0].getID(), mo3.getFedMapping().getID()});
else
- fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
- new long[] {fr1[0].getID(), mo1.getFedMapping().getID(), mo3.getFedMapping().getID()});
-
- fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
- ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+ fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+ new long[] {fr1[0].getID(), fedMap.getID(), mo3.getFedMapping().getID()});
}
else if(mo3 == null) {
if(!reversed)
- fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
- new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+ fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
+ new long[] {fedMap.getID(), fr1[0].getID()});
else
- fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
- new long[] {fr1[0].getID(), mo1.getFedMapping().getID()});
-
- fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
- ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
-
- } else {
- FederatedRequest[] fr4 = mo1.getFedMapping().broadcastSliced(mo3, false);
+ fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
+ new long[] {fr1[0].getID(), fedMap.getID()});
+ }
+ else {
+ fr2 = fedMap.broadcastSliced(mo3, false);
if(!reversed && !reversedWeights)
- fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
- new long[] {mo1.getFedMapping().getID(), fr1[0].getID(), fr4[0].getID()});
+ fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+ new long[] {fedMap.getID(), fr1[0].getID(), fr2[0].getID()});
else if(reversed && !reversedWeights)
- fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
- new long[] {fr1[0].getID(), mo1.getFedMapping().getID(), fr4[0].getID()});
+ fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+ new long[] {fr1[0].getID(), fedMap.getID(), fr2[0].getID()});
else
- fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
- new long[] {fr1[0].getID(), fr4[0].getID(), mo1.getFedMapping().getID()});
-
- fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
- ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr4, fr2, fr3);
+ fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+ new long[] {fr1[0].getID(), fr2[0].getID(), fedMap.getID()});
}
- if(fedOutput && isFedOutput(ffr, dims1)) {
+ if(fedOutput) {
+ if(fr2 != null) // broadcasted mo3
+ fedMap.execute(getTID(), true, fr1, fr2, fr3);
+ else
+ fedMap.execute(getTID(), true, fr1, fr3);
+
MatrixObject out = ec.getMatrixObject(output);
- FederationMap newFedMap = modifyFedRanges(mo1.getFedMapping(), dims1, dims2);
- setFedOutput(mo1, out, newFedMap, dims1, fr2.getID());
+ FederationMap newFedMap = modifyFedRanges(fedMap.copyWithNewID(fr3.getID()),
+ staticDim, dims2, reversed);
+ setFedOutput(mo1, out, newFedMap, staticDim, dims2, reversed);
} else {
+ fr4 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr3.getID());
+ fr5 = fedMap.cleanup(getTID(), fr3.getID());
+ if(fr2 != null) // broadcasted mo3
+ ffr = fedMap.execute(getTID(), true, fr1, fr2, fr3, fr4, fr5);
+ else
+ ffr = fedMap.execute(getTID(), true, fr1, fr3, fr4, fr5);
+
ec.setMatrixOutput(output.getName(), aggResult(ffr));
}
}
- boolean isFedOutput(Future<FederatedResponse>[] ffr, Long[] dims1) {
- boolean fedOutput = true;
-
- long fedSize = Collections.max(Arrays.asList(dims1), Long::compare) / ffr.length;
- try {
- MatrixBlock curr;
- MatrixBlock prev =(MatrixBlock) ffr[0].get().getData()[0];
- for(int i = 1; i < ffr.length && fedOutput; i++) {
- curr = (MatrixBlock) ffr[i].get().getData()[0];
- MatrixBlock sliced = curr.slice((int) (curr.getNumRows() - fedSize), curr.getNumRows() - 1);
-
- if(curr.getNumColumns() != prev.getNumColumns())
- return false;
-
- // no intersection
- if(curr.getNumRows() == (i+1) * prev.getNumRows() && curr.getNonZeros() <= prev.getLength()
- && (curr.getNumRows() - sliced.getNumRows()) == i * prev.getNumRows()
- && curr.getNonZeros() - sliced.getNonZeros() == 0)
- continue;
-
- // check intersect with AND and compare number of nnz
- MatrixBlock prevExtend = new MatrixBlock(curr.getNumRows(), curr.getNumColumns(), true, 0);
- prevExtend.copy(0, prev.getNumRows()-1, 0, prev.getNumColumns()-1, prev, true);
-
- MatrixBlock intersect = curr.binaryOperationsInPlace(new BinaryOperator(And.getAndFnObject()), prevExtend);
- if(intersect.getNonZeros() != 0)
- fedOutput = false;
- prev = sliced;
- }
- }
- catch(Exception e) {
- e.printStackTrace();
- }
- return fedOutput;
- }
+ /**
+ * Evaluate if the output can be kept federated on the different federated
+ * sites or if the output needs to be aggregated on the coordinator, based
+ * on the output ranges of mo2.
+ * The output can be kept federated if the slices of mo2, sliced corresponding
+ * to the federated ranges of mo1, have strict separable and ascending value
+ * ranges. From this property it follows that the partial outputs can also
+ * be separated, and hence the overall output can be created by a simple
+ * binding through a federated mapping.
+ *
+ * @param fedMap the federation map of the federated matrix input mo1
+ * @param mo2 input matrix object mo2
+ * @return boolean indicating if the output can be kept on the federated sites
+ */
+ private boolean isFedOutput(FederationMap fedMap, MatrixObject mo2) {
+ MatrixBlock mb = mo2.acquireReadAndRelease();
+ FederatedRange[] fedRanges = fedMap.getFederatedRanges(); // federated ranges of mo1
+ SortedMap<Double, Double> fedDims = new TreeMap<Double, Double>(); // <beginDim, endDim>
+
+ // collect min and max of the corresponding slices of mo2
+ IntStream.range(0, fedRanges.length).forEach(i -> {
+ MatrixBlock sliced = mb.slice(
+ fedRanges[i].getBeginDimsInt()[0], fedRanges[i].getEndDimsInt()[0] - 1,
+ fedRanges[i].getBeginDimsInt()[1], fedRanges[i].getEndDimsInt()[1] - 1);
+ fedDims.put(sliced.min(), sliced.max());
+ });
+ boolean retVal = (fedDims.size() == fedRanges.length); // no duplicate begin dimension entries
- private static void setFedOutput(MatrixObject mo1, MatrixObject out, FederationMap fedMap, Long[] dims1, long outId) {
- long fedSize = Collections.max(Arrays.asList(dims1), Long::compare) / dims1.length;
+ Iterator<SortedMap.Entry<Double, Double>> iter = fedDims.entrySet().iterator();
+ SortedMap.Entry<Double, Double> entry = iter.next(); // first entry does not have to be checked
+ double prevEndDim = entry.getValue();
+ while(iter.hasNext() && retVal) {
+ entry = iter.next();
+ // previous end dimension must be less than current begin dimension (no overlaps of ranges)
+ retVal &= (prevEndDim < entry.getKey());
+ prevEndDim = entry.getValue();
+ }
- long d1 = Collections.max(Arrays.asList(dims1), Long::compare);
- long d2 = Collections.max(Arrays.asList(dims1), Long::compare);
+ return retVal;
+ }
+
+ /**
+ * Set the output and its data characteristics on the federated sites.
+ *
+ * @param mo1 input matrix object mo1
+ * @param out input matrix object of the output
+ * @param fedMap the federation map of the federated matrix input mo1
+ * @param staticDim static non-partitioned dimension of the output
+ * @param dims2 dimensions of the partial outputs along the federated partitioning
+ * @param reversed boolean indicating if inputs mo1 and mo2 are reversed
+ */
+ private static void setFedOutput(MatrixObject mo1, MatrixObject out, FederationMap fedMap,
+ long staticDim, Long[] dims2, boolean reversed) {
+ // get the final output dimensions
+ final long d1 = (reversed ? Collections.max(Arrays.asList(dims2)) : staticDim);
+ final long d2 = (reversed ? staticDim : Collections.max(Arrays.asList(dims2)));
// set output
out.getDataCharacteristics().set(d1, d2, (int) mo1.getBlocksize(), mo1.getNnz());
- out.setFedMapping(fedMap.copyWithNewID(outId));
+ out.setFedMapping(fedMap);
long varID = FederationUtils.getNextFedDataID();
- out.getFedMapping().mapParallel(varID, (range, data) -> {
+ fedMap.mapParallel(varID, (range, data) -> {
try {
FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(
FederatedRequest.RequestType.EXEC_UDF, -1,
- new SliceOutput(data.getVarID(), fedSize))).get();
+ new SliceOutput(data.getVarID(), staticDim, dims2, reversed))).get();
if(!response.isSuccessful())
response.throwExceptionFromResponse();
}
@@ -239,6 +281,9 @@ public class CtableFEDInstruction extends ComputationFEDInstruction {
});
}
+ /**
+ * Aggregate the partial outputs locally.
+ */
private static MatrixBlock aggResult(Future<FederatedResponse>[] ffr) {
MatrixBlock resultBlock = new MatrixBlock(1, 1, true, 0);
int dim1 = 0, dim2 = 0;
@@ -266,27 +311,44 @@ public class CtableFEDInstruction extends ComputationFEDInstruction {
return resultBlock;
}
- private static FederationMap modifyFedRanges(FederationMap fedMap, Long[] dims1, Long[] dims2) {
- IntStream.range(0, fedMap.getFederatedRanges().length).forEach(i -> {
- fedMap.getFederatedRanges()[i]
- .setBeginDim(0, i == 0 ? 0 : fedMap.getFederatedRanges()[i - 1].getEndDims()[0]);
- fedMap.getFederatedRanges()[i].setEndDim(0, dims1[i]);
- fedMap.getFederatedRanges()[i]
- .setBeginDim(1, i == 0 ? 0 : fedMap.getFederatedRanges()[i - 1].getBeginDims()[1]);
- fedMap.getFederatedRanges()[i].setEndDim(1, dims2[i]);
+ /**
+ * Set the ranges of the federation map according to the static dimension and
+ * the individual dimensions of the partial output matrices.
+ *
+ * @param fedMap the federation map of the federated matrix input mo1
+ * @param staticDim static non-partitioned dimension of the output
+ * @param dims2 dimensions of the partial outputs along the federated partitioning
+ * @param reversed boolean indicating if inputs mo1 and mo2 are reversed
+ * @return FederationMap the modified federation map
+ */
+ private static FederationMap modifyFedRanges(FederationMap fedMap, long staticDim,
+ Long[] dims2, boolean reversed) {
+ // set the federated ranges to the individual partition sizes
+ IntStream.range(0, fedMap.getFederatedRanges().length).forEach(counter -> {
+ FederatedRange fedRange = fedMap.getFederatedRanges()[counter];
+ fedRange.setBeginDim(reversed ? 1 : 0, 0);
+ fedRange.setEndDim(reversed ? 1 : 0, staticDim);
+ fedRange.setBeginDim(reversed ? 0 : 1, counter == 0 ? 0 : dims2[counter-1]);
+ fedRange.setEndDim(reversed ? 0 : 1, dims2[counter]);
});
return fedMap;
}
- private Long[] getOutputDimension(MatrixObject in, CPOperand inOp, CPOperand outOp, FederatedRange[] federatedRanges) {
+ /**
+ * Compute the output dimensions of the partial outputs according to the
+ * federated ranges.
+ */
+ private Long[] getOutputDimension(MatrixObject in, CPOperand inOp, CPOperand outOp,
+ FederatedRange[] federatedRanges) {
Long[] fedDims = new Long[federatedRanges.length];
if(!in.isFederated()) {
//slice
MatrixBlock mb = in.acquireReadAndRelease();
IntStream.range(0, federatedRanges.length).forEach(i -> {
- MatrixBlock sliced = mb
- .slice(federatedRanges[i].getBeginDimsInt()[0], federatedRanges[i].getEndDimsInt()[0] - 1);
+ MatrixBlock sliced = mb.slice(
+ federatedRanges[i].getBeginDimsInt()[0], federatedRanges[i].getEndDimsInt()[0] - 1,
+ federatedRanges[i].getBeginDimsInt()[1], federatedRanges[i].getEndDimsInt()[1] - 1);
fedDims[i] = (long) sliced.max();
});
return fedDims;
@@ -326,29 +388,79 @@ public class CtableFEDInstruction extends ComputationFEDInstruction {
return String.join(Lop.OPERAND_DELIMITOR, maxInstParts);
}
+ /**
+ * Static class which extends FederatedUDF to modify the partial outputs on
+ * the federated sites such that they can be bound without any local
+ * aggregation.
+ */
private static class SliceOutput extends FederatedUDF {
private static final long serialVersionUID = -2808597461054603816L;
- private final long _fedSize;
+ private final int _staticDim;
+ private final Long[] _fedDims;
+ private final boolean _reversed;
- protected SliceOutput(long input, long fedSize) {
+ protected SliceOutput(long input, long staticDim, Long[] fedDims, boolean reversed) {
super(new long[] {input});
- _fedSize = fedSize;
+ _staticDim = (int)staticDim;
+ _fedDims = fedDims;
+ _reversed = reversed;
}
+ /**
+ * Find the dimensions of the partial output matrix and expand it to the
+ * global static dimension along the non-partitioned axis and crop it
+ * along the paritioned axis.
+ *
+ * @param ec the execution context
+ * @param data
+ * @return FederatedResponse with status SUCCESS and an empty object
+ */
public FederatedResponse execute(ExecutionContext ec, Data... data) {
MatrixObject mo = (MatrixObject) data[0];
MatrixBlock mb = mo.acquireReadAndRelease();
- MatrixBlock sliced = mb.slice((int) (mb.getNumRows()-_fedSize), mb.getNumRows()-1);
+ int beginDim = 0;
+ int endDim = (_reversed ? mb.getNumRows() : mb.getNumColumns());
+ int localStaticDim = (_reversed ? mb.getNumColumns() : mb.getNumRows());
+ for(int counter = 0; counter < _fedDims.length; counter++) {
+ if(_fedDims[counter] == endDim) {
+ beginDim = (counter == 0 ? 0 : _fedDims[counter - 1].intValue());
+ break;
+ }
+ }
+
+ mb = expandMatrix(mb, localStaticDim);
+
+ // crop the output
+ MatrixBlock sliced = _reversed ? mb.slice(beginDim, endDim - 1, 0, _staticDim - 1)
+ : mb.slice(0, _staticDim - 1, beginDim, endDim - 1);
mo.acquireModify(sliced);
mo.release();
return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[] {});
}
+
+ /**
+ * Expand the matrix with zeros up to the specified static dimension.
+ *
+ * @param mb the matrix block of the partial output
+ * @param localStaticDim the static dimension of the output matrix block
+ * @return MatrixBlock the output matrix block expanded to the global static dimension
+ */
+ private MatrixBlock expandMatrix(MatrixBlock mb, int localStaticDim) {
+ int diff = _staticDim - localStaticDim;
+ if(diff > 0) {
+ MatrixBlock tmpMb = (_reversed ? new MatrixBlock(mb.getNumRows(), diff, (double) 0)
+ : new MatrixBlock(diff, mb.getNumColumns(), (double) 0));
+ mb = mb.append(tmpMb, null, _reversed);
+ }
+ return mb;
+ }
+
@Override
public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
return null;
}
}
-}
\ No newline at end of file
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
index a5793b5..9aeb776 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
@@ -85,6 +85,9 @@ public class FederatedCtableTest extends AutomatedTestBase {
@Test
public void federatedCtableMatrixInputSinglenode() { runCtable(Types.ExecMode.SINGLE_NODE, false, true); }
+ @Test
+ public void federatedCtableMatrixInputFedOutputSingleNode() { runCtable(Types.ExecMode.SINGLE_NODE, true, true); }
+
public void runCtable(Types.ExecMode execMode, boolean fedOutput, boolean matrixInput) {
String TEST_NAME = fedOutput ? TEST_NAME2 : TEST_NAME1;
@@ -108,7 +111,7 @@ public class FederatedCtableTest extends AutomatedTestBase {
loadTestConfiguration(config);
if(fedOutput)
- runFedCtable(HOME, TEST_NAME, port1, port2, port3, port4);
+ runFedCtable(HOME, TEST_NAME, matrixInput, port1, port2, port3, port4);
else
runNonFedCtable(HOME, TEST_NAME, matrixInput, port1, port2, port3, port4);
checkResults();
@@ -155,7 +158,7 @@ public class FederatedCtableTest extends AutomatedTestBase {
runTest(true, false, null, -1);
}
- private void runFedCtable(String HOME, String TEST_NAME, int port1, int port2, int port3, int port4) {
+ private void runFedCtable(String HOME, String TEST_NAME, boolean matrixInput, int port1, int port2, int port3, int port4) {
int r = rows / 4;
int c = cols;
@@ -174,7 +177,8 @@ public class FederatedCtableTest extends AutomatedTestBase {
fullDMLScriptName = HOME + TEST_NAME2 + "Reference.dml";
programArgs = new String[]{"-stats", "100", "-args",
input("X1"), input("X2"), input("X3"), input("X4"), Boolean.toString(reversedInputs).toUpperCase(),
- Boolean.toString(weighted).toUpperCase(), expected("F")};
+ Boolean.toString(weighted).toUpperCase(), Boolean.toString(matrixInput).toUpperCase(),
+ expected("F")};
runTest(true, false, null, -1);
// Run actual dml script with federated matrix
@@ -185,6 +189,7 @@ public class FederatedCtableTest extends AutomatedTestBase {
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
"in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
"rows=" + rows, "cols=" + cols, "revIn=" + Boolean.toString(reversedInputs).toUpperCase(),
+ "matrixInput=" + Boolean.toString(matrixInput).toUpperCase(),
"weighted=" + Boolean.toString(weighted).toUpperCase(), "out=" + output("F")
};
runTest(true, false, null, -1);
diff --git a/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml b/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
index 9c21ed5..a2eda9d 100644
--- a/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
+++ b/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
@@ -28,8 +28,14 @@ n = ncol(X);
# prepare offset vectors and one-hot encoded X
maxs = colMaxs(X);
-rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1);
-cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
+if($matrixInput) {
+ rix = matrix(seq(1,m)%*%matrix(1,1,n), m, n);
+ cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m, n);
+}
+else {
+ rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1);
+ cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
+}
W = rix + cix;
diff --git a/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml b/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
index e0721df..4fc6852 100644
--- a/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
+++ b/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
@@ -27,8 +27,14 @@ n = ncol(X);
# prepare offset vectors and one-hot encoded X
maxs = colMaxs(X);
-rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1)
-cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
+if($7) { # matrix input
+ rix = matrix(seq(1,m)%*%matrix(1,1,n), m, n);
+ cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m, n);
+}
+else {
+ rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1);
+ cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
+}
W = rix + cix;
@@ -43,4 +49,4 @@ else
else
X2 = table(rix, cix);
-write(X2, $7);
+write(X2, $8);