You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/05/29 23:44:20 UTC
[systemds] branch master updated: [SYSTEMDS-2982] Federated codegen
w/ aligned federated inputs
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new bd6c642 [SYSTEMDS-2982] Federated codegen w/ aligned federated inputs
bd6c642 is described below
commit bd6c64265ff25e923141166a8f9cd4ecea16caf6
Author: ywcb00 <yw...@ywcb.org>
AuthorDate: Sun May 30 01:29:00 2021 +0200
[SYSTEMDS-2982] Federated codegen w/ aligned federated inputs
Closes #1287.
---
.../controlprogram/caching/CacheableData.java | 2 +-
.../controlprogram/federated/FederatedRange.java | 4 +
.../controlprogram/federated/FederationMap.java | 37 ++-
.../instructions/cp/SpoofCPInstruction.java | 41 ++-
.../instructions/fed/SpoofFEDInstruction.java | 75 +++--
.../instructions/spark/SpoofSPInstruction.java | 43 ++-
.../org/apache/sysds/test/AutomatedTestBase.java | 28 ++
.../codegen/FederatedCodegenMultipleFedMOTest.java | 269 +++++++++++++++++
.../codegen/FederatedCellwiseTmplTest.dml | 4 +-
.../codegen/FederatedCellwiseTmplTestReference.dml | 4 +-
.../codegen/FederatedCodegenMultipleFedMOTest.dml | 333 +++++++++++++++++++++
.../FederatedCodegenMultipleFedMOTestReference.dml | 329 ++++++++++++++++++++
.../FederatedOuterProductTmplTestReference.dml | 2 +-
13 files changed, 1108 insertions(+), 63 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index f7970ff..c4d04d7 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -373,7 +373,7 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
}
public boolean isFederated(FType type) {
- return isFederated() && _fedMapping.getType().isType(type);
+ return isFederated() && (type == null || _fedMapping.getType().isType(type));
}
/**
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
index 4948d27..73636d8 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
@@ -94,6 +94,10 @@ public class FederatedRange implements Comparable<FederatedRange> {
return -1;
if ( _beginDims[i] > o._beginDims[i])
return 1;
+ if ( _endDims[i] < o._endDims[i])
+ return -1;
+ if ( _endDims[i] > o._endDims[i])
+ return 1;
}
return 0;
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 7a52b11..c77ff79 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -247,7 +247,42 @@ public class FederationMap {
}
return ret;
}
-
+
+ /**
+ * determines if the two federated data are aligned row/column partitions (depending on parameters equalRows/equalCols)
+ * at the same federated site (which often allows for purely federated operations)
+ * @param that FederationMap to check alignment with
+ * @param transposed true if that FederationMap should be transposed before checking alignment
+ * @param equalRows true to indicate that the row dimension should be checked for alignment
+ * @param equalCols true to indicate that the col dimension should be checked for alignment
+ * @return true if this and that FederationMap are aligned
+ */
+ public boolean isAligned(FederationMap that, boolean transposed, boolean equalRows, boolean equalCols) {
+ boolean ret = true;
+ final int ROW_IX = transposed ? 1 : 0; // swapping row and col dimension index of "that" if transposed
+ final int COL_IX = transposed ? 0 : 1;
+
+ for(Pair<FederatedRange, FederatedData> e : _fedMap) {
+ boolean rangeFound = false; // to indicate if at least one matching range has been found
+ for(FederatedRange r : that.getFederatedRanges()) {
+ long[] rbd = r.getBeginDims();
+ long[] red = r.getEndDims();
+ long[] ebd = e.getKey().getBeginDims();
+ long[] eed = e.getKey().getEndDims();
+ // searching for the matching federated range of "that"
+ if((!equalRows || (rbd[ROW_IX] == ebd[0] && red[ROW_IX] == eed[0]))
+ && (!equalCols || (rbd[COL_IX] == ebd[1] && red[COL_IX] == eed[1]))) {
+ rangeFound = true;
+ FederatedData dat2 = that.getFederatedData(r);
+ ret &= e.getValue().equalAddress(dat2); // both paritions must be located on the same fed worker
+ }
+ }
+ if(!(ret &= rangeFound)) // setting ret to false if no matching range has been found
+ break; // directly returning if not ret to skip further checks
+ }
+ return ret;
+ }
+
public Future<FederatedResponse>[] execute(long tid, FederatedRequest... fr) {
return execute(tid, false, fr);
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java
index e9bacd3..38fd8d7 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java
@@ -27,8 +27,11 @@ import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.codegen.SpoofOperator;
+import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.lineage.LineageCodegenItem;
@@ -131,16 +134,38 @@ public class SpoofCPInstruction extends ComputationCPInstruction {
}
public boolean isFederated(ExecutionContext ec) {
- for(CPOperand input : _in)
- if( ec.isFederated(input) )
- return true;
- return false;
+ return isFederated(ec, null);
}
public boolean isFederated(ExecutionContext ec, FType type) {
- for(CPOperand input : _in)
- if( ec.isFederated(input, type) )
- return true;
- return false;
+ FederationMap fedMap = null;
+ boolean retVal = false;
+
+ // flags for alignment check
+ boolean equalRows = false;
+ boolean equalCols = false;
+ boolean transposed = false; // flag indicates to check for transposed alignment
+
+ for(CPOperand input : _in) {
+ Data data = ec.getVariable(input);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederated(type)) {
+ MatrixObject mo = ((MatrixObject) data);
+ if(fedMap == null) { // first federated matrix
+ fedMap = mo.getFedMapping();
+ retVal = true;
+
+ // setting the constraints for alignment check on further federated matrices
+ equalRows = mo.isFederated(FType.ROW);
+ equalCols = mo.isFederated(FType.COL);
+ transposed = (getOperatorClass().getSuperclass() == SpoofOuterProduct.class);
+ }
+ else if(!fedMap.isAligned(mo.getFedMapping(), false, equalRows, equalCols)
+ && (!transposed || !(fedMap.isAligned(mo.getFedMapping(), true, equalRows, equalCols)
+ || mo.getFedMapping().isAligned(fedMap, true, equalRows, equalCols)))) {
+ retVal = false; // multiple federated matrices must be aligned
+ }
+ }
+ }
+ return retVal;
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
index 13b1785..918ec00 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
@@ -94,60 +94,51 @@ public class SpoofFEDInstruction extends FEDInstruction
throw new DMLRuntimeException("Federated code generation only supported" +
" for cellwise, rowwise, multiaggregate, and outerproduct templates.");
- ArrayList<CPOperand> inCpoMat = new ArrayList<>();
- ArrayList<CPOperand> inCpoScal = new ArrayList<>();
- ArrayList<MatrixObject> inMo = new ArrayList<>();
- ArrayList<ScalarObject> inSo = new ArrayList<>();
+
FederationMap fedMap = null;
- for(CPOperand cpo : _inputs) {
+ for(CPOperand cpo : _inputs) { // searching for the first federated matrix to obtain the federation map
Data tmpData = ec.getVariable(cpo);
- if(tmpData instanceof MatrixObject) {
- MatrixObject tmp = (MatrixObject) tmpData;
- if(fedMap == null & tmp.isFederated()) { //take first
- inCpoMat.add(0, cpo); // insert federated CPO at the beginning
- fedMap = tmp.getFedMapping();
- }
- else {
- inCpoMat.add(cpo);
- inMo.add(tmp);
- }
- }
- else if(tmpData instanceof ScalarObject) {
- ScalarObject tmp = (ScalarObject) tmpData;
- inCpoScal.add(cpo);
- inSo.add(tmp);
+ if(tmpData instanceof MatrixObject && ((MatrixObject)tmpData).isFederated()) {
+ fedMap = ((MatrixObject)tmpData).getFedMapping();
+ break;
}
}
ArrayList<FederatedRequest> frBroadcast = new ArrayList<>();
ArrayList<FederatedRequest[]> frBroadcastSliced = new ArrayList<>();
- long[] frIds = new long[1 + inMo.size() + inSo.size()];
+ long[] frIds = new long[_inputs.length];
int index = 0;
- frIds[index++] = fedMap.getID(); // insert federation map id at the beginning
- for(MatrixObject mo : inMo) {
- if(spoofType.needsBroadcastSliced(fedMap, mo.getNumRows(), mo.getNumColumns(), index)) {
- FederatedRequest[] tmpFr = spoofType.broadcastSliced(mo, fedMap);
- frIds[index++] = tmpFr[0].getID();
- frBroadcastSliced.add(tmpFr);
+
+ for(CPOperand cpo : _inputs) {
+ Data tmpData = ec.getVariable(cpo);
+ if(tmpData instanceof MatrixObject) {
+ MatrixObject mo = (MatrixObject) tmpData;
+ if(mo.isFederated()) {
+ frIds[index++] = mo.getFedMapping().getID();
+ }
+ else if(spoofType.needsBroadcastSliced(fedMap, mo.getNumRows(), mo.getNumColumns(), index)) {
+ FederatedRequest[] tmpFr = spoofType.broadcastSliced(mo, fedMap);
+ frIds[index++] = tmpFr[0].getID();
+ frBroadcastSliced.add(tmpFr);
+ }
+ else {
+ FederatedRequest tmpFr = fedMap.broadcast(mo);
+ frIds[index++] = tmpFr.getID();
+ frBroadcast.add(tmpFr);
+ }
}
- else {
- FederatedRequest tmpFr = fedMap.broadcast(mo);
+ else if(tmpData instanceof ScalarObject) {
+ ScalarObject so = (ScalarObject) tmpData;
+ FederatedRequest tmpFr = fedMap.broadcast(so);
frIds[index++] = tmpFr.getID();
frBroadcast.add(tmpFr);
}
}
- for(ScalarObject so : inSo) {
- FederatedRequest tmpFr = fedMap.broadcast(so);
- frIds[index++] = tmpFr.getID();
- frBroadcast.add(tmpFr);
- }
// change the is_literal flag from true to false because when broadcasted it is not a literal anymore
instString = instString.replace("true", "false");
- CPOperand[] inCpo = ArrayUtils.addAll(inCpoMat.toArray(new CPOperand[0]),
- inCpoScal.toArray(new CPOperand[0]));
- FederatedRequest frCompute = FederationUtils.callInstruction(instString, _output, inCpo, frIds);
+ FederatedRequest frCompute = FederationUtils.callInstruction(instString, _output, _inputs, frIds);
// get partial results from federated workers
FederatedRequest frGet = new FederatedRequest(RequestType.GET_VAR, frCompute.getID());
@@ -184,14 +175,18 @@ public class SpoofFEDInstruction extends FEDInstruction
protected boolean needsBroadcastSliced(FederationMap fedMap, long rowNum, long colNum, int inputIndex) {
FType fedType = fedMap.getType();
+
boolean retVal = (rowNum == fedMap.getMaxIndexInRange(0) && colNum == fedMap.getMaxIndexInRange(1));
if(fedType == FType.ROW)
- retVal |= (rowNum == fedMap.getMaxIndexInRange(0) && (colNum == 1 || colNum == fedMap.getSize()));
+ retVal |= (rowNum == fedMap.getMaxIndexInRange(0)
+ && (colNum == 1 || colNum == fedMap.getSize() || fedMap.getMaxIndexInRange(1) == 1));
else if(fedType == FType.COL)
- retVal |= ((rowNum == 1 || rowNum == fedMap.getSize()) && colNum == fedMap.getMaxIndexInRange(1));
- else
+ retVal |= (colNum == fedMap.getMaxIndexInRange(1)
+ && (rowNum == 1 || rowNum == fedMap.getSize() || fedMap.getMaxIndexInRange(0) == 1));
+ else {
throw new DMLRuntimeException("Only row partitioned or column" +
" partitioned federated input supported yet.");
+ }
return retVal;
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java
index f76d74f..cc6d7bc 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java
@@ -39,14 +39,17 @@ import org.apache.sysds.runtime.codegen.SpoofOuterProduct.OutProdType;
import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
@@ -678,16 +681,40 @@ public class SpoofSPInstruction extends SPInstruction {
}
public boolean isFederated(ExecutionContext ec) {
- for(CPOperand input : _in)
- if( ec.isFederated(input) )
- return true;
- return false;
+ return isFederated(ec, null);
}
public boolean isFederated(ExecutionContext ec, FType type) {
- for(CPOperand input : _in)
- if( ec.isFederated(input, type) )
- return true;
- return false;
+ //FIXME remove redundancy with SpoofCPInstruction
+
+ FederationMap fedMap = null;
+ boolean retVal = false;
+
+ // flags for alignment check
+ boolean equalRows = false;
+ boolean equalCols = false;
+ boolean transposed = false; // flag indicates to check for transposed alignment
+
+ for(CPOperand input : _in) {
+ Data data = ec.getVariable(input);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederated(type)) {
+ MatrixObject mo = ((MatrixObject) data);
+ if(fedMap == null) { // first federated matrix
+ fedMap = mo.getFedMapping();
+ retVal = true;
+
+ // setting the constraints for alignment check on further federated matrices
+ equalRows = mo.isFederated(FType.ROW);
+ equalCols = mo.isFederated(FType.COL);
+ transposed = (getOperatorClass().getSuperclass() == SpoofOuterProduct.class);
+ }
+ else if(!fedMap.isAligned(mo.getFedMapping(), false, equalRows, equalCols)
+ && (!transposed || !(fedMap.isAligned(mo.getFedMapping(), true, equalRows, equalCols)
+ || mo.getFedMapping().isAligned(fedMap, true, equalRows, equalCols)))) {
+ retVal = false; // multiple federated matrices must be aligned
+ }
+ }
+ }
+ return retVal;
}
}
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index c90892b..40cf34d 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -2080,6 +2080,20 @@ public abstract class AutomatedTestBase {
return(count >= minCount);
}
+ protected boolean heavyHittersContainsString(String str, int minCount, long minCallCount) {
+ int count = 0;
+ long callCount = Long.MAX_VALUE;
+ for(String opcode : Statistics.getCPHeavyHitterOpCodes()) {
+ if(opcode.equals(str)) {
+ count++;
+ long tmpCallCount = Statistics.getCPHeavyHitterCount(opcode);
+ if(tmpCallCount < callCount)
+ callCount = tmpCallCount;
+ }
+ }
+ return (count >= minCount && callCount >= minCallCount);
+ }
+
protected boolean heavyHittersContainsSubString(String... str) {
for(String opcode : Statistics.getCPHeavyHitterOpCodes())
for(String s : str)
@@ -2095,6 +2109,20 @@ public abstract class AutomatedTestBase {
return(count >= minCount);
}
+ protected boolean heavyHittersContainsSubString(String str, int minCount, long minCallCount) {
+ int count = 0;
+ long callCount = Long.MAX_VALUE;
+ for(String opcode : Statistics.getCPHeavyHitterOpCodes()) {
+ if(opcode.contains(str)) {
+ count++;
+ long tmpCallCount = Statistics.getCPHeavyHitterCount(opcode);
+ if(tmpCallCount < callCount)
+ callCount = tmpCallCount;
+ }
+ }
+ return (count >= minCount && callCount >= minCallCount);
+ }
+
protected boolean checkedPrivacyConstraintsContains(PrivacyLevel... levels) {
for(PrivacyLevel level : levels)
if(!(CheckedConstraintsLog.getCheckedConstraints().containsKey(level)))
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
new file mode 100644
index 0000000..65f1728
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
@@ -0,0 +1,269 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.codegen;
+
+import java.io.File;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Ignore;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedCodegenMultipleFedMOTest extends AutomatedTestBase
+{
+ private final static String TEST_NAME = "FederatedCodegenMultipleFedMOTest";
+
+ private final static String TEST_DIR = "functions/federated/codegen/";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedCodegenMultipleFedMOTest.class.getSimpleName() + "/";
+
+ private final static String TEST_CONF = "SystemDS-config-codegen.xml";
+
+ private final static String OUTPUT_NAME = "Z";
+ private final static double TOLERANCE = 1e-7;
+ private final static int BLOCKSIZE = 1024;
+
+ @Parameterized.Parameter()
+ public int test_num;
+ @Parameterized.Parameter(1)
+ public int rows_x;
+ @Parameterized.Parameter(2)
+ public int cols_x;
+ @Parameterized.Parameter(3)
+ public int rows_y;
+ @Parameterized.Parameter(4)
+ public int cols_y;
+ @Parameterized.Parameter(5)
+ public boolean row_partitioned;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{OUTPUT_NAME}));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ // rows must be even for row partitioned X and Y
+ // cols must be even for col partitioned X and Y
+ return Arrays.asList(new Object[][] {
+ // {test_num, rows_x, cols_x, rows_y, cols_y row_partitioned}
+
+ // cellwise
+ // row partitioned
+ {1, 4, 4, 4, 4, true},
+ // {2, 4, 4, 4, 1, true},
+ {3, 4, 1, 4, 1, true},
+ {4, 1000, 1, 1000, 1, true},
+ // {5, 500, 2, 500, 2, true},
+ {6, 2, 500, 2, 500, true},
+ {7, 2, 4, 2, 4, true},
+ // column partitioned
+ // {1, 4, 4, 4, 4, false},
+ {2, 4, 4, 1, 4, false},
+ {5, 500, 2, 500, 2, false},
+ // {6, 2, 500, 2, 500, false},
+ {7, 2, 4, 2, 4, false},
+
+ // rowwise
+ // {101, 6, 2, 6, 2, true},
+ {102, 6, 1, 6, 4, true},
+ // {103, 6, 4, 6, 2, true},
+ {104, 150, 10, 150, 10, true},
+
+ // multi aggregate
+ // row partitioned
+ // {201, 6, 4, 6, 4, true},
+ {202, 6, 4, 6, 4, true},
+ // {203, 20, 1, 20, 1, true},
+ // col partitioned
+ {201, 6, 4, 6, 4, false},
+ {202, 6, 4, 6, 4, false},
+
+ // outer product
+ // row partitioned
+ // {301, 1500, 1500, 1500, 10, true},
+ {303, 4000, 2000, 4000, 10, true},
+ // {305, 4000, 2000, 4000, 10, true},
+ // {307, 1000, 2000, 1000, 10, true},
+ // {309, 1000, 2000, 1000, 10, true},
+ // col partitioned
+ // {302, 2000, 2000, 10, 2000, false},
+ // {304, 4000, 2000, 10, 2000, false},
+ // {306, 4000, 2000, 10, 2000, false},
+ {308, 1000, 2000, 10, 2000, false},
+ // {310, 1000, 2000, 10, 2000, false},
+ // row and col partitioned
+ // {311, 1000, 2000, 1000, 10, true}, // not working yet - ArrayIndexOutOfBoundsException in dotProduct
+ {312, 1000, 2000, 10, 2000, false},
+ // {313, 4000, 2000, 4000, 10, true}, // not working yet - ArrayIndexOutOfBoundsException in dotProduct
+ {314, 4000, 2000, 10, 2000, false},
+
+ // combined tests
+ {401, 20, 10, 20, 6, true}, // cellwise, rowwise, multiaggregate
+ {402, 2000, 2000, 2000, 10, true}, // outerproduct
+
+ });
+ }
+
+ @BeforeClass
+ public static void init() {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+
+ @Test
+ @Ignore
+ public void federatedCodegenMultipleFedMOSingleNode() {
+ testFederatedCodegenMultipleFedMO(ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ @Ignore
+ public void federatedCodegenMultipleFedMOSpark() {
+ testFederatedCodegenMultipleFedMO(ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedCodegenMultipleFedMOHybrid() {
+ testFederatedCodegenMultipleFedMO(ExecMode.HYBRID);
+ }
+
+ private void testFederatedCodegenMultipleFedMO(ExecMode exec_mode) {
+ // store the previous platform config to restore it after the test
+ ExecMode platform_old = setExecMode(exec_mode);
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ int fed_rows_x = rows_x;
+ int fed_cols_x = cols_x;
+ int fed_rows_y = rows_y;
+ int fed_cols_y = cols_y;
+ if(row_partitioned) {
+ fed_rows_x /= 2;
+ fed_rows_y /= 2;
+ }
+ else {
+ fed_cols_x /= 2;
+ fed_cols_y /= 2;
+ }
+
+ // generate dataset
+ // matrix handled by two federated workers
+ double[][] X1 = getRandomMatrix(fed_rows_x, fed_cols_x, 0, 1, 0.1, 3);
+ double[][] X2 = getRandomMatrix(fed_rows_x, fed_cols_x, 0, 1, 0.1, 23);
+ // matrix handled by two federated workers
+ double[][] Y1 = getRandomMatrix(fed_rows_y, fed_cols_y, 0, 1, 0.1, 64);
+ double[][] Y2 = getRandomMatrix(fed_rows_y, fed_cols_y, 0, 1, 0.1, 135);
+
+ writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(fed_rows_x, fed_cols_x, BLOCKSIZE, fed_rows_x * fed_cols_x));
+ writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(fed_rows_x, fed_cols_x, BLOCKSIZE, fed_rows_x * fed_cols_x));
+ writeInputMatrixWithMTD("Y1", Y1, false, new MatrixCharacteristics(fed_rows_y, fed_cols_y, BLOCKSIZE, fed_rows_y * fed_cols_y));
+ writeInputMatrixWithMTD("Y2", Y2, false, new MatrixCharacteristics(fed_rows_y, fed_cols_y, BLOCKSIZE, fed_rows_y * fed_cols_y));
+
+ // empty script name because we don't execute any script, just start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread thread2 = startLocalFedWorkerThread(port2);
+
+ getAndLoadTestConfiguration(TEST_NAME);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-stats", "-nvargs",
+ "in_X1=" + input("X1"), "in_X2=" + input("X2"),
+ "in_Y1=" + input("Y1"), "in_Y2=" + input("Y2"),
+ "in_rp=" + Boolean.toString(row_partitioned).toUpperCase(),
+ "in_test_num=" + Integer.toString(test_num),
+ "out_Z=" + expected(OUTPUT_NAME)};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
+ "in_Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
+ "in_Y2=" + TestUtils.federatedAddress(port2, input("Y2")),
+ "in_rp=" + Boolean.toString(row_partitioned).toUpperCase(),
+ "in_test_num=" + Integer.toString(test_num),
+ "rows_x=" + rows_x, "cols_x=" + cols_x,
+ "rows_y=" + rows_y, "cols_y=" + cols_y,
+ "out_Z=" + output(OUTPUT_NAME)};
+ runTest(true, false, null, -1);
+
+ // compare the results via files
+ HashMap<CellIndex, Double> refResults = readDMLMatrixFromExpectedDir(OUTPUT_NAME);
+ HashMap<CellIndex, Double> fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME);
+ TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref");
+
+ TestUtils.shutdownThreads(thread1, thread2);
+
+ // check for federated operations
+ if(test_num >= 0 && test_num < 100)
+ Assert.assertTrue(heavyHittersContainsSubString("fed_spoofCell"));
+ else if(test_num < 200)
+ Assert.assertTrue(heavyHittersContainsSubString("fed_spoofRA"));
+ else if(test_num < 300)
+ Assert.assertTrue(heavyHittersContainsSubString("fed_spoofMA"));
+ else if(test_num < 400)
+ Assert.assertTrue(heavyHittersContainsSubString("fed_spoofOP"));
+ else if(test_num == 401) {
+ Assert.assertTrue(heavyHittersContainsSubString("fed_spoofRA"));
+ Assert.assertTrue(heavyHittersContainsSubString("fed_spoofCell"));
+ Assert.assertTrue(heavyHittersContainsSubString("fed_spoofMA", exec_mode == ExecMode.SPARK ? 0 : 1));
+ }
+ else if(test_num == 402)
+ Assert.assertTrue(heavyHittersContainsSubString("fed_spoofOP", 3, exec_mode == ExecMode.SPARK? 1 :2));
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("Y1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("Y2")));
+
+ resetExecMode(platform_old);
+ }
+
+ /**
+ * Override default configuration with custom test configuration to ensure
+ * scratch space and local temporary directory locations are also updated.
+ */
+ @Override
+ protected File getConfigTemplateFile() {
+ // Instrumentation in this test's output log to show custom configuration file used for template.
+ File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, TEST_CONF);
+ return TEST_CONF_FILE;
+ }
+}
diff --git a/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTest.dml b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTest.dml
index 3f91385..45f790b 100644
--- a/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTest.dml
+++ b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTest.dml
@@ -87,7 +87,7 @@ else if(test_num == 9) {
Y = matrix(seq(6, 1005), 500, 2);
U = X + 7 * Y;
- Z = as.matrix(sum(log(U)))
+ Z = as.matrix(sum(log(U)));
}
else if(test_num == 10) {
# X ... 500x2 matrix
@@ -106,7 +106,7 @@ else if(test_num == 12) {
Y = matrix(seq(6, 1005), 2, 500);
U = X + 7 * Y;
- Z = as.matrix(sum(sqrt(U)))
+ Z = as.matrix(sum(sqrt(U)));
}
else if(test_num == 13) {
# X ... 2x4 matrix
diff --git a/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTestReference.dml b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTestReference.dml
index 2c13e6a..e2e3b4b 100644
--- a/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTestReference.dml
+++ b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTestReference.dml
@@ -85,7 +85,7 @@ else if(test_num == 9) {
Y = matrix(seq(6, 1005), 500, 2);
U = X + 7 * Y;
- Z = as.matrix(sum(log(U)))
+ Z = as.matrix(sum(log(U)));
}
else if(test_num == 10) {
while(FALSE){} #TODO
@@ -104,7 +104,7 @@ else if(test_num == 12) {
Y = matrix(seq(6, 1005), 2, 500);
U = X + 7 * Y;
- Z = as.matrix(sum(sqrt(U)))
+ Z = as.matrix(sum(sqrt(U)));
}
else if(test_num == 13) {
# X ... 2x4 matrix
diff --git a/src/test/scripts/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.dml b/src/test/scripts/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.dml
new file mode 100644
index 0000000..5e2b796
--- /dev/null
+++ b/src/test/scripts/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.dml
@@ -0,0 +1,333 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+test_num = $in_test_num;
+row_part = $in_rp;
+
+if(row_part) {
+ X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows_x / 2, $cols_x), list($rows_x / 2, 0), list($rows_x, $cols_x)));
+ Y = federated(addresses=list($in_Y1, $in_Y2),
+ ranges=list(list(0, 0), list($rows_y / 2, $cols_y), list($rows_y / 2, 0), list($rows_y, $cols_y)));
+}
+else {
+ X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows_x, $cols_x / 2), list(0, $cols_x / 2), list($rows_x, $cols_x)));
+ Y = federated(addresses=list($in_Y1, $in_Y2),
+ ranges=list(list(0, 0), list($rows_y, $cols_y / 2), list(0, $cols_y / 2), list($rows_y, $cols_y)));
+}
+
+if(test_num == 1) { # cellwise #4
+ # X ... 4x4 matrix
+ # Y ... 4x4 matrix
+ w = matrix(3, rows=4, cols=4);
+ Z = test1(X, Y, w);
+}
+else if(test_num == 2) { # cellwise #5
+ # X ... 4x4 matrix
+ # Y ... 4x1 / 1x4 vector
+ U = matrix( "1 2 3 4", rows=4, cols=1);
+ Z = test2(X, Y, U);
+}
+else if(test_num == 3) { # cellwise #6
+ # X ... 4x1 vector
+ # Y ... 4x1 vector
+ v = matrix("3 3 3 3", rows=4, cols=1);
+ Z = test3(X, Y, v);
+}
+else if(test_num == 4) { # cellwise #7
+ # X ... 1000x1 vector
+ # Y ... 1000x1 vector
+ Z = test4(X, Y);
+}
+else if(test_num == 5) { # cellwise #9
+ # X ... 500x2 matrix
+ # Y ... 500x2 matrix
+ Z = test5(X, Y);
+}
+else if(test_num == 6) { # cellwise #12
+ # X ... 2x500 matrix
+ # Y ... 2x500 matrix
+ Z = test6(X, Y);
+}
+else if(test_num == 7) { # cellwise #13
+ # X ... 2x4 matrix
+ # Y ... 2x4 matrix
+ w = matrix(seq(1,8), rows=2, cols=4);
+ Z = test1(X, Y, w);
+}
+else if(test_num == 101) { # rowwise #2
+ # X ... 6x2 matrix
+ # Y ... 6x2 matrix
+ U = matrix(1, rows=2, cols=1);
+ Z = test101(X, Y, U);
+}
+else if(test_num == 102) { # rowwise #3
+ # X ... 6x1 vector
+ # Y ... 6x4 vector
+ U = matrix( "1 2 3 4 5 6", rows=6, cols=1);
+ V = matrix(1,rows=4,cols=1);
+ Z = test102(X, Y, U, V);
+}
+else if(test_num == 103) { # rowwise #4
+ # X ... 6x4 matrix
+ # Y ... 6x2 matrix
+ Z = test103(X, Y);
+}
+else if(test_num == 104) { # rowwise #10
+ # X ... 150x10 matrix
+ # Y ... 150x10 matrix
+ Z = test104(X, Y);
+}
+else if(test_num == 201) { # multiagg #4
+ # X ... 6x4 matrix
+ # Y ... 6x4 matrix
+ Z = test201(X, Y);
+}
+else if(test_num == 202) { # multiagg #5
+ # X ... 6x4 matrix
+ # Y ... 6x4 matrix
+ U = matrix(seq(0,23), rows=6, cols=4);
+ V = matrix(seq(2,25), rows=6, cols=4);
+ Z = test202(X, Y, U, V);
+}
+else if(test_num == 203) { # multiagg #7
+ # X ... 20x1 vector
+ # Y ... 20x1 vector
+ Z = test203(X, Y);
+}
+else if(test_num == 301) { # outerproduct #1
+ # X ... 1500x1500 matrix
+ # Y ... 1500x10 matrix
+ V = matrix(seq(1,15000), rows=1500, cols=10);
+ Z = test301(X, Y, V);
+}
+else if(test_num == 302) { # outerproduct #1
+ # X ... 2000x2000 matrix
+ # Y ... 10x2000 matrix
+ U = matrix(seq(1,20000), rows=2000, cols=10);
+ Z = test301(X, U, t(Y));
+}
+else if(test_num == 303) { # outerproduct #2
+ # X ... 4000x2000 matrix
+ # Y ... 4000x10 matrix
+ V = matrix(seq(51, 20050), rows=2000, cols=10);
+ Z = test303(X, Y, V);
+}
+else if(test_num == 304) { # outerproduct #2
+ # X ... 4000x2000 matrix
+ # Y ... 10x2000 matrix
+ U = matrix(seq(51, 40050), rows=4000, cols=10);
+ Z = test303(X, U, t(Y));
+}
+else if(test_num == 305) { # outerproduct #6
+ # X ... 4000x2000 matrix
+ # Y ... 4000x10 matrix
+ V = matrix(seq(-1, 19998), rows=2000, cols=10);
+ Z = test305(X, Y, V);
+}
+else if(test_num == 306) { # outerproduct #6
+ # X ... 4000x2000 matrix
+ # Y ... 10x2000 matrix
+ U = matrix(seq(1, 40000), rows=4000, cols=10);
+ Z = test305(X, U, t(Y));
+}
+else if(test_num == 307) { # outerproduct #8
+ # X ... 1000x2000 matrix
+ # Y ... 1000x10 matrix
+ V = matrix(seq(1, 20000), rows=2000, cols=10);
+ Z = test307(X, Y, V);
+}
+else if(test_num == 308) { # outerproduct #8
+ # X ... 1000x2000 matrix
+ # Y ... 10x2000 matrix
+ U = matrix(seq(1, 10000), rows=1000, cols=10);
+ Z = test307(X, U, t(Y));
+}
+else if(test_num == 309) { # outerproduct #9
+ # X ... 1000x2000 matrix
+ # Y ... 1000x10 matrix
+ V = matrix(seq(1, 20000), rows=2000, cols=10);
+ Z = test309(X, Y, V);
+}
+else if(test_num == 310) { # outerproduct #9
+ # X ... 1000x2000 matrix
+ # Y ... 10x2000 matrix
+ U = matrix(seq(1, 10000), rows=1000, cols=10);
+ Z = test309(X, U, t(Y));
+}
+else if(test_num == 311) { # outerproduct #8
+ # X ... 1000x2000 matrix
+ # Y ... 1000x10 matrix
+ Y = t(Y); # col partitioned Y
+ while(FALSE) { }
+ # Y ... 10x1000 matrix
+ V = matrix(seq(1, 20000), rows=2000, cols=10);
+ Z = test307(X, t(Y), V);
+}
+else if(test_num == 312) { # outerproduct #8
+ # X ... 1000x2000 matrix
+ Y = t(Y); # row partitioned Y
+ while(FALSE) { }
+ # Y ... 2000x10 matrix
+ U = matrix(seq(1, 10000), rows=1000, cols=10);
+ Z = test307(X, U, Y);
+}
+else if(test_num == 313) {
+ # X ... 4000x2000 matrix
+ # Y ... 4000x10 matrix
+ Y = t(Y); # col partitioned Y
+ while(FALSE) { }
+ # Y ... 10x4000 matrix
+ V = matrix(seq(51, 20050), rows=2000, cols=10);
+ Z = test303(X, t(Y), V);
+}
+else if(test_num == 314) {
+ # X ... 4000x2000 matrix
+ # Y ... 10x2000 matrix
+ Y = t(Y); # row partitioned Y
+ while(FALSE) { }
+ # Y ... 2000x10 matrix
+ U = matrix(seq(51, 40050), rows=4000, cols=10);
+ Z = test303(X, U, Y);
+}
+else if(test_num == 401) { # combined tests
+ # X ... 20x10 matrix
+ # Y ... 20x6 matrix
+
+ A = test103(X, Y); # not federated output
+ B = test2(X, Y[, 1], t(cbind(A, A)));
+ while(FALSE) { }
+ U = X[6:13, 7:10];
+ V = B[6:13, 3:6];
+ while(FALSE) { }
+ C = test201(U, V);
+ while(FALSE) { }
+ Z = B - C;
+}
+else if(test_num == 402) { # combined outerproduct tests
+ # X ... 2000x2000 matrix
+ # Y ... 2000x10 matrix
+
+ V = matrix(seq(1,20000), rows=2000, cols=10);
+ A = test301(X, Y, V);
+ while(FALSE) { }
+ B = test305(X, Y, V);
+ while(FALSE) { }
+ C = test309(X, Y, V);
+ while(FALSE) { }
+ X = t(X); # col partitioned X and Y
+ Y = t(Y);
+ while(FALSE) { }
+ U = matrix(seq(1, 20000), rows=2000, cols=10);
+ D = test301(X, U, t(Y));
+ while(FALSE) { }
+ E = test305(X, U, t(Y));
+ while(FALSE) { }
+ F = test309(X, U, t(Y));
+ while(FALSE) { }
+ Z = as.scalar(A) - B + C - as.scalar(D) + E - F;
+}
+
+write(Z, $out_Z);
+
+# ************** Tests defined in functions for reusability **************
+test1 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] w) return(Matrix[Double] Z) {
+ Z = 10 + floor(round(abs((X + w) * Y)));
+}
+test2 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] U) return(Matrix[Double] Z) {
+ G = abs(exp(X));
+ V = 10 + floor(round(abs((X / Y) + U)));
+ Z = G + V;
+}
+test3 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] v) return(Matrix[Double] Z) {
+ Z = as.matrix(sum(X * Y * v));
+}
+test4 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+ U = X + Y - 7 + abs(X);
+ Z = t(U) %*% U;
+}
+test5 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+ U = X + 7 * Y;
+ Z = as.matrix(sum(log(U)));
+}
+test6 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+ U = X + 7 * Y;
+ Z = as.matrix(sum(sqrt(U)));
+}
+
+test101 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] U) return(Matrix[Double] Z) {
+ lambda = sum(Y);
+ Z = t(X) %*% (lambda * (X %*% U));
+}
+test102 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] U, Matrix[Double] V) return(Matrix[Double] Z) {
+ Z = t(Y) %*% (U + (2 - (X * (Y %*% V))));
+}
+test103 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+ Z = colSums(X / rowSums(Y));
+}
+test104 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+ Y = Y + (X <= rowMins(X));
+ U = (Y / rowSums(Y));
+ Z = colSums(U);
+}
+
+test201 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+ #disjoint partitions with partial shared reads
+ r1 = sum(X * Y);
+ r2 = sum(X ^ 2);
+ r3 = sum(Y ^ 2);
+ Z = as.matrix(r1 + r2 + r3);
+}
+test202 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] U, Matrix[Double] V) return(Matrix[Double] Z) {
+ #disjoint partitions with transitive partial shared reads
+ r1 = sum(X * U);
+ r2 = sum(V * Y);
+ r3 = sum(X * V * Y);
+ Z = as.matrix(r1 + r2 + r3);
+}
+test203 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+ r1 = t(X) %*% X;
+ r2 = t(X) %*% Y;
+ r3 = t(Y) %*% Y;
+ Z = r1 + r2 + r3;
+}
+
+test301 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) return(Matrix[Double] Z) {
+ eps = 0.1;
+ Z = as.matrix(sum(X * log(U %*% t(V) + eps)));
+}
+test303 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) return(Matrix[Double] Z) {
+ eps = 0.1;
+ Z = t(t(U) %*% (X / (U %*% t(V) + eps)));
+}
+test305 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) return(Matrix[Double] Z) {
+ eps = 0.1;
+ Z = (X / ((U %*% t(V)) + eps)) %*% V;
+}
+test307 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) return(Matrix[Double] Z) {
+ eps = 0.1;
+ Z = X * (1 / (1 + exp(-(U %*% t(V)))));
+}
+test309 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) return(Matrix[Double] Z) {
+ eps = 0.4;
+ Z = t(t(U) %*% (X / (U %*% t(V) + eps)));
+}
diff --git a/src/test/scripts/functions/federated/codegen/FederatedCodegenMultipleFedMOTestReference.dml b/src/test/scripts/functions/federated/codegen/FederatedCodegenMultipleFedMOTestReference.dml
new file mode 100644
index 0000000..6279e3d
--- /dev/null
+++ b/src/test/scripts/functions/federated/codegen/FederatedCodegenMultipleFedMOTestReference.dml
@@ -0,0 +1,329 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+test_num = $in_test_num;
+row_part = $in_rp;
+
+if(row_part) {
+ X = rbind(read($in_X1), read($in_X2));
+ Y = rbind(read($in_Y1), read($in_Y2));
+}
+else {
+ X = cbind(read($in_X1), read($in_X2));
+ Y = cbind(read($in_Y1), read($in_Y2));
+}
+
+if(test_num == 1) { # cellwise #4
+ # X ... 4x4 matrix
+ # Y ... 4x4 matrix
+ w = matrix(3, rows=4, cols=4);
+ Z = test1(X, Y, w);
+}
+else if(test_num == 2) { # cellwise #5
+ # X ... 4x4 matrix
+ # Y ... 4x1 / 1x4 vector
+ U = matrix( "1 2 3 4", rows=4, cols=1);
+ Z = test2(X, Y, U);
+}
+else if(test_num == 3) { # cellwise #6
+ # X ... 4x1 vector
+ # Y ... 4x1 vector
+ v = matrix("3 3 3 3", rows=4, cols=1);
+ Z = test3(X, Y, v);
+}
+else if(test_num == 4) { # cellwise #7
+ # X ... 1000x1 vector
+ # Y ... 1000x1 vector
+ Z = test4(X, Y);
+}
+else if(test_num == 5) { # cellwise #9
+ # X ... 500x2 matrix
+ # Y ... 500x2 matrix
+ Z = test5(X, Y);
+}
+else if(test_num == 6) { # cellwise #12
+ # X ... 2x500 matrix
+ # Y ... 2x500 matrix
+ Z = test6(X, Y);
+}
+else if(test_num == 7) { # cellwise #13
+ # X ... 2x4 matrix
+ # Y ... 2x4 matrix
+ w = matrix(seq(1,8), rows=2, cols=4);
+ Z = test1(X, Y, w);
+}
+else if(test_num == 101) { # rowwise #2
+ # X ... 6x2 matrix
+ # Y ... 6x2 matrix
+ U = matrix(1, rows=2, cols=1);
+ Z = test101(X, Y, U);
+}
+else if(test_num == 102) { # rowwise #3
+ # X ... 6x1 vector
+ # Y ... 6x4 vector
+ U = matrix( "1 2 3 4 5 6", rows=6, cols=1);
+ V = matrix(1,rows=4,cols=1);
+ Z = test102(X, Y, U, V);
+}
+else if(test_num == 103) { # rowwise #4
+ # X ... 6x4 matrix
+ # Y ... 6x2 matrix
+ Z = test103(X, Y);
+}
+else if(test_num == 104) { # rowwise #10
+ # X ... 150x10 matrix
+ # Y ... 150x10 matrix
+ Z = test104(X, Y);
+}
+else if(test_num == 201) { # multiagg #4
+ # X ... 6x4 matrix
+ # Y ... 6x4 matrix
+ Z = test201(X, Y);
+}
+else if(test_num == 202) { # multiagg #5
+ # X ... 6x4 matrix
+ # Y ... 6x4 matrix
+ U = matrix(seq(0,23), rows=6, cols=4);
+ V = matrix(seq(2,25), rows=6, cols=4);
+ Z = test202(X, Y, U, V);
+}
+else if(test_num == 203) { # multiagg #7
+ # X ... 20x1 vector
+ # Y ... 20x1 vector
+ Z = test203(X, Y);
+}
+else if(test_num == 301) { # outerproduct #1
+ # X ... 1500x1500 matrix
+ # Y ... 1500x10 matrix
+ V = matrix(seq(1,15000), rows=1500, cols=10);
+ Z = test301(X, Y, V);
+}
+else if(test_num == 302) { # outerproduct #1
+ # X ... 2000x2000 matrix
+ # Y ... 10x2000 matrix
+ U = matrix(seq(1,20000), rows=2000, cols=10);
+ Z = test301(X, U, t(Y));
+}
+else if(test_num == 303) { # outerproduct #2
+ # X ... 4000x2000 matrix
+ # Y ... 4000x10 matrix
+ V = matrix(seq(51, 20050), rows=2000, cols=10);
+ Z = test303(X, Y, V);
+}
+else if(test_num == 304) { # outerproduct #2
+ # X ... 4000x2000 matrix
+ # Y ... 10x2000 matrix
+ U = matrix(seq(51, 40050), rows=4000, cols=10);
+ Z = test303(X, U, t(Y));
+}
+else if(test_num == 305) { # outerproduct #6
+ # X ... 4000x2000 matrix
+ # Y ... 4000x10 matrix
+ V = matrix(seq(-1, 19998), rows=2000, cols=10);
+ Z = test305(X, Y, V);
+}
+else if(test_num == 306) { # outerproduct #6
+ # X ... 4000x2000 matrix
+ # Y ... 10x2000 matrix
+ U = matrix(seq(1, 40000), rows=4000, cols=10);
+ Z = test305(X, U, t(Y));
+}
+else if(test_num == 307) { # outerproduct #8
+ # X ... 1000x2000 matrix
+ # Y ... 1000x10 matrix
+ V = matrix(seq(1, 20000), rows=2000, cols=10);
+ Z = test307(X, Y, V);
+}
+else if(test_num == 308) { # outerproduct #8
+ # X ... 1000x2000 matrix
+ # Y ... 10x2000 matrix
+ U = matrix(seq(1, 10000), rows=1000, cols=10);
+ Z = test307(X, U, t(Y));
+}
+else if(test_num == 309) { # outerproduct #9
+ # X ... 1000x2000 matrix
+ # Y ... 1000x10 matrix
+ V = matrix(seq(1, 20000), rows=2000, cols=10);
+ Z = test309(X, Y, V);
+}
+else if(test_num == 310) { # outerproduct #9
+ # X ... 1000x2000 matrix
+ # Y ... 10x2000 matrix
+ U = matrix(seq(1, 10000), rows=1000, cols=10);
+ Z = test309(X, U, t(Y));
+}
+else if(test_num == 311) { # outerproduct #8
+ # X ... 1000x2000 matrix
+ # Y ... 1000x10 matrix
+ Y = t(Y); # col partitioned Y
+ while(FALSE) { }
+ # Y ... 10x1000 matrix
+ V = matrix(seq(1, 20000), rows=2000, cols=10);
+ Z = test307(X, t(Y), V);
+}
+else if(test_num == 312) { # outerproduct #8
+ # X ... 1000x2000 matrix
+ Y = t(Y); # row partitioned Y
+ while(FALSE) { }
+ # Y ... 2000x10 matrix
+ U = matrix(seq(1, 10000), rows=1000, cols=10);
+ Z = test307(X, U, Y);
+}
+else if(test_num == 313) {
+ # X ... 4000x2000 matrix
+ # Y ... 4000x10 matrix
+ Y = t(Y); # col partitioned Y
+ while(FALSE) { }
+ # Y ... 10x4000 matrix
+ V = matrix(seq(51, 20050), rows=2000, cols=10);
+ Z = test303(X, t(Y), V);
+}
+else if(test_num == 314) {
+ # X ... 4000x2000 matrix
+ # Y ... 10x2000 matrix
+ Y = t(Y); # row partitioned Y
+ while(FALSE) { }
+ # Y ... 2000x10 matrix
+ U = matrix(seq(51, 40050), rows=4000, cols=10);
+ Z = test303(X, U, Y);
+}
+else if(test_num == 401) { # combined tests
+ # X ... 20x10 matrix
+ # Y ... 20x6 matrix
+
+ A = test103(X, Y); # not federated output
+ B = test2(X, Y[, 1], t(cbind(A, A)));
+ while(FALSE) { }
+ U = X[6:13, 7:10];
+ V = B[6:13, 3:6];
+ while(FALSE) { }
+ C = test201(U, V);
+ while(FALSE) { }
+ Z = B - C;
+}
+else if(test_num == 402) { # combined outerproduct tests
+ # X ... 2000x2000 matrix
+ # Y ... 2000x10 matrix
+
+ V = matrix(seq(1,20000), rows=2000, cols=10);
+ A = test301(X, Y, V);
+ while(FALSE) { }
+ B = test305(X, Y, V);
+ while(FALSE) { }
+ C = test309(X, Y, V);
+ while(FALSE) { }
+ X = t(X); # col partitioned X and Y
+ Y = t(Y);
+ while(FALSE) { }
+ U = matrix(seq(1, 20000), rows=2000, cols=10);
+ D = test301(X, U, t(Y));
+ while(FALSE) { }
+ E = test305(X, U, t(Y));
+ while(FALSE) { }
+ F = test309(X, U, t(Y));
+ while(FALSE) { }
+ Z = as.scalar(A) - B + C - as.scalar(D) + E - F;
+}
+
+write(Z, $out_Z);
+
+# ************** Tests defined in functions for reusability **************
+test1 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] w) return(Matrix[Double] Z) {
+ Z = 10 + floor(round(abs((X + w) * Y)));
+}
+test2 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] U) return(Matrix[Double] Z) {
+ G = abs(exp(X));
+ V = 10 + floor(round(abs((X / Y) + U)));
+ Z = G + V;
+}
+test3 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] v) return(Matrix[Double] Z) {
+ Z = as.matrix(sum(X * Y * v));
+}
+test4 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+ U = X + Y - 7 + abs(X);
+ Z = t(U) %*% U;
+}
+test5 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+ U = X + 7 * Y;
+ Z = as.matrix(sum(log(U)));
+}
+test6 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+ U = X + 7 * Y;
+ Z = as.matrix(sum(sqrt(U)));
+}
+
+test101 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] U) return(Matrix[Double] Z) {
+ lambda = sum(Y);
+ Z = t(X) %*% (lambda * (X %*% U));
+}
+test102 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] U, Matrix[Double] V) return(Matrix[Double] Z) {
+ Z = t(Y) %*% (U + (2 - (X * (Y %*% V))));
+}
+test103 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+ Z = colSums(X / rowSums(Y));
+}
+test104 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+ Y = Y + (X <= rowMins(X));
+ U = (Y / rowSums(Y));
+ Z = colSums(U);
+}
+
+test201 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+ #disjoint partitions with partial shared reads
+ r1 = sum(X * Y);
+ r2 = sum(X ^ 2);
+ r3 = sum(Y ^ 2);
+ Z = as.matrix(r1 + r2 + r3);
+}
+test202 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] U, Matrix[Double] V) return(Matrix[Double] Z) {
+ #disjoint partitions with transitive partial shared reads
+ r1 = sum(X * U);
+ r2 = sum(V * Y);
+ r3 = sum(X * V * Y);
+ Z = as.matrix(r1 + r2 + r3);
+}
+test203 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+ r1 = t(X) %*% X;
+ r2 = t(X) %*% Y;
+ r3 = t(Y) %*% Y;
+ Z = r1 + r2 + r3;
+}
+
+test301 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) return(Matrix[Double] Z) {
+ eps = 0.1;
+ Z = as.matrix(sum(X * log(U %*% t(V) + eps)));
+}
+test303 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) return(Matrix[Double] Z) {
+ eps = 0.1;
+ Z = t(t(U) %*% (X / (U %*% t(V) + eps)));
+}
+test305 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) return(Matrix[Double] Z) {
+ eps = 0.1;
+ Z = (X / ((U %*% t(V)) + eps)) %*% V;
+}
+test307 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) return(Matrix[Double] Z) {
+ eps = 0.1;
+ Z = X * (1 / (1 + exp(-(U %*% t(V)))));
+}
+test309 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) return(Matrix[Double] Z) {
+ eps = 0.4;
+ Z = t(t(U) %*% (X / (U %*% t(V) + eps)));
+}
diff --git a/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTestReference.dml b/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTestReference.dml
index e592dcb..242305c 100644
--- a/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTestReference.dml
+++ b/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTestReference.dml
@@ -31,7 +31,7 @@ else {
if(test_num == 1) { # wcemm
# X ... 2000x2000 matrix
-
+
U = matrix(seq(1, 20000), rows=2000, cols=10);
V = matrix(seq(20001, 40000), rows=2000, cols=10);
eps = 0.1;