You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/11/15 20:04:36 UTC
[systemds] branch master updated: [SYSTEMDS-2733] Fix federated
append, right-indexing, as.frame, tests
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 4dedddc [SYSTEMDS-2733] Fix federated append, right-indexing, as.frame, tests
4dedddc is described below
commit 4dedddcbf626799d91507a84987ee7f9c9e6a2bb
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sun Nov 15 19:12:40 2020 +0100
[SYSTEMDS-2733] Fix federated append, right-indexing, as.frame, tests
This patch fixes various issues of federated instructions and meta data
handling in order to allow full federated ML pipelines.
* Append: fix for missing federated append (retain row-partitioning),
correct federated ranges, and only partial broadcasting
* Right Indexing: fix for correct federated ranges, and only indexing
for relevant ranges
* Transformencode: fix for consistent recode maps, and ranges
* Matrix-Frame casts: fix for consistent dimension meta data
* Tests: more reasonable fed worker startup time, configurable constant,
fixed output buffering for tests that process the buffer
---
.../sysds/runtime/controlprogram/ProgramBlock.java | 30 ++-
.../controlprogram/federated/FederationMap.java | 30 ++-
.../controlprogram/federated/FederationUtils.java | 2 +-
.../fed/AggregateBinaryFEDInstruction.java | 1 -
.../instructions/fed/AppendFEDInstruction.java | 72 ++++---
.../instructions/fed/FEDInstructionUtils.java | 3 +-
.../fed/MatrixIndexingFEDInstruction.java | 99 ++++------
...tiReturnParameterizedBuiltinFEDInstruction.java | 10 +-
.../instructions/fed/VariableFEDInstruction.java | 7 +-
.../instructions/spark/BinarySPInstruction.java | 14 +-
.../apache/sysds/runtime/meta/MetaDataUtils.java | 39 ++++
.../runtime/transform/encode/EncoderRecode.java | 18 ++
.../org/apache/sysds/test/AutomatedTestBase.java | 2 +
.../federated/algorithms/FederatedBivarTest.java | 6 +-
.../federated/algorithms/FederatedCorTest.java | 6 +-
.../federated/algorithms/FederatedGLMTest.java | 2 +-
.../federated/algorithms/FederatedKmeansTest.java | 2 +-
.../federated/algorithms/FederatedL2SVMTest.java | 2 +-
.../federated/algorithms/FederatedLmPipeline.java | 116 ++++++++++++
.../federated/algorithms/FederatedLogRegTest.java | 2 +-
.../federated/algorithms/FederatedPCATest.java | 6 +-
.../federated/algorithms/FederatedUnivarTest.java | 6 +-
.../federated/algorithms/FederatedVarTest.java | 6 +-
.../federated/algorithms/FederatedYL2SVMTest.java | 2 +-
.../federated/io/FederatedReaderTest.java | 2 +-
.../functions/federated/io/FederatedSSLTest.java | 2 +-
.../federated/io/FederatedWriterTest.java | 2 +-
.../paramserv/FederatedParamservTest.java | 2 +-
.../primitives/FederatedBinaryMatrixTest.java | 2 +-
.../primitives/FederatedBinaryVectorTest.java | 2 +-
.../primitives/FederatedCastToFrameTest.java | 5 +-
.../primitives/FederatedCastToMatrixTest.java | 5 +-
.../primitives/FederatedCentralMomentTest.java | 206 ++++++++++-----------
.../primitives/FederatedColAggregateTest.java | 6 +-
.../primitives/FederatedFullAggregateTest.java | 6 +-
.../primitives/FederatedMultiplyTest.java | 2 +-
.../federated/primitives/FederatedRCBindTest.java | 4 +-
.../primitives/FederatedRemoveEmptyTest.java | 6 +-
.../primitives/FederatedRightIndexTest.java | 6 +-
.../primitives/FederatedRowAggregateTest.java | 6 +-
.../federated/primitives/FederatedSplitTest.java | 187 ++++++++++---------
.../primitives/FederatedStatisticsTest.java | 2 +-
.../TransformFederatedEncodeApplyTest.java | 6 +-
.../TransformFederatedEncodeDecodeTest.java | 6 +-
.../functions/federated/FederatedLmPipeline.dml | 65 +++++++
.../federated/FederatedLmPipelineReference.dml | 64 +++++++
46 files changed, 712 insertions(+), 365 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
index bb34bf6..8bfdcc7 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
@@ -32,6 +32,7 @@ import org.apache.sysds.parser.ParseInfo;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.DMLScriptException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -56,7 +57,7 @@ public abstract class ProgramBlock implements ParseInfo
public static final String PRED_VAR = "__pred";
protected static final Log LOG = LogFactory.getLog(ProgramBlock.class.getName());
- private static final boolean CHECK_MATRIX_SPARSITY = false;
+ private static final boolean CHECK_MATRIX_PROPERTIES = false;
protected Program _prog; // pointer to Program this ProgramBlock is part of
@@ -273,8 +274,9 @@ public abstract class ProgramBlock implements ParseInfo
// optional check for correct nnz and sparse/dense representation of all
// variables in symbol table (for tracking source of wrong representation)
- if( CHECK_MATRIX_SPARSITY ) {
+ if( CHECK_MATRIX_PROPERTIES ) {
checkSparsity( tmp, ec.getVariables() );
+ checkFederated( tmp, ec.getVariables() );
}
}
catch (DMLScriptException e){
@@ -332,14 +334,11 @@ public abstract class ProgramBlock implements ParseInfo
private static void checkSparsity( Instruction lastInst, LocalVariableMap vars )
{
- for( String varname : vars.keySet() )
- {
+ for( String varname : vars.keySet() ) {
Data dat = vars.get(varname);
- if( dat instanceof MatrixObject )
- {
+ if( dat instanceof MatrixObject ) {
MatrixObject mo = (MatrixObject)dat;
- if( mo.isDirty() && !mo.isPartitioned() )
- {
+ if( mo.isDirty() && !mo.isPartitioned() ) {
MatrixBlock mb = mo.acquireRead();
boolean sparse1 = mb.isInSparseFormat();
long nnz1 = mb.getNonZeros();
@@ -368,6 +367,21 @@ public abstract class ProgramBlock implements ParseInfo
}
}
+ private static void checkFederated( Instruction lastInst, LocalVariableMap vars )
+ {
+ for( String varname : vars.keySet() ) {
+ Data dat = vars.get(varname);
+ if( !(dat instanceof CacheableData) )
+ continue;
+
+ CacheableData<?> mo = (CacheableData<?>)dat;
+ if( mo.isFederated() ) {
+ if( mo.getFedMapping().getFedMapping().isEmpty() )
+ throw new DMLRuntimeException("Invalid empty FederationMap for: "+mo);
+ }
+ }
+ }
+
///////////////////////////////////////////////////////////////////////////
// store position information for program blocks
///////////////////////////////////////////////////////////////////////////
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index f670c17..36e82a3 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.controlprogram.federated;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
@@ -40,6 +41,7 @@ import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;
+import org.apache.sysds.runtime.util.IndexRange;
public class FederationMap
{
@@ -83,6 +85,10 @@ public class FederationMap
_type = type;
}
+ public int getSize() {
+ return _fedMap.size();
+ }
+
public FederatedRange[] getFederatedRanges() {
return _fedMap.keySet().toArray(new FederatedRange[0]);
}
@@ -254,7 +260,7 @@ public class FederationMap
//TODO handling of file path, but no danger as never written
for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() )
map.put(new FederatedRange(e.getKey(), clen), e.getValue().copyWithNewID(id));
- return new FederationMap(id, map);
+ return new FederationMap(id, map, _type);
}
public FederationMap bind(long rOffset, long cOffset, FederationMap that) {
@@ -324,6 +330,23 @@ public class FederationMap
return fedMapCopy;
}
+ public FederationMap filter(IndexRange ixrange) {
+ FederationMap ret = this.clone(); //same ID
+
+ Iterator<Entry<FederatedRange, FederatedData>> iter = ret._fedMap.entrySet().iterator();
+ while( iter.hasNext() ) {
+ Entry<FederatedRange, FederatedData> e = iter.next();
+ FederatedRange range = e.getKey();
+ long rs = range.getBeginDims()[0], re = range.getEndDims()[0],
+ cs = range.getBeginDims()[1], ce = range.getEndDims()[1];
+ boolean overlap = ((ixrange.colStart <= ce) && (ixrange.colEnd >= cs)
+ && (ixrange.rowStart <= re) && (ixrange.rowEnd >= rs));
+ if( !overlap )
+ iter.remove();
+ }
+ return ret;
+ }
+
private static void setThreadID(long tid, FederatedRequest[]... frsets) {
for( FederatedRequest[] frset : frsets )
if( frset != null )
@@ -360,4 +383,9 @@ public class FederationMap
sb.append("\n"+ _fedMap);
return sb.toString();
}
+
+ @Override
+ public FederationMap clone() {
+ return copyWithNewID(getID());
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index 22f4e69..093ff30 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -347,7 +347,7 @@ public class FederationUtils {
throw new DMLRuntimeException("Unsupported aggregation operator: "
+ aop.aggOp.increOp.fn.getClass().getSimpleName());
}
-
+
public static FederationMap federateLocalData(CacheableData<?> data) {
long id = FederationUtils.getNextFedDataID();
FederatedLocalData federatedLocalData = new FederatedLocalData(id, data);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index bdb9784..9107a86 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -94,7 +94,6 @@ public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), (int)mo1.getBlocksize());
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID(), mo2.getNumColumns()));
- out.getFedMapping().setType(FType.ROW);
}
}
//#2 vector - federated matrix multiplication
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
index ee0d8aa..67425f1 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
@@ -22,7 +22,9 @@ package org.apache.sysds.runtime.instructions.fed;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.OffsetColumnIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -30,6 +32,7 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.meta.MetaDataUtils;
public class AppendFEDInstruction extends BinaryFEDInstruction {
protected boolean _cbind; // otherwise rbind
@@ -60,7 +63,7 @@ public class AppendFEDInstruction extends BinaryFEDInstruction {
MatrixObject mo1 = ec.getMatrixObject(input1.getName());
MatrixObject mo2 = ec.getMatrixObject(input2.getName());
DataCharacteristics dc1 = mo1.getDataCharacteristics();
- DataCharacteristics dc2 = mo1.getDataCharacteristics();
+ DataCharacteristics dc2 = mo2.getDataCharacteristics();
// check input dimensions
if(_cbind && mo1.getNumRows() != mo2.getNumRows()) {
@@ -80,32 +83,53 @@ public class AppendFEDInstruction extends BinaryFEDInstruction {
throw new DMLRuntimeException(sb.toString());
}
- FederationMap fm1;
- if(mo1.isFederated())
- fm1 = mo1.getFedMapping();
- else
- fm1 = FederationUtils.federateLocalData(mo1);
- FederationMap fm2;
- if(mo2.isFederated())
- fm2 = mo2.getFedMapping();
- else
- fm2 = FederationUtils.federateLocalData(mo2);
-
+ //prepare output
MatrixObject out = ec.getMatrixObject(output);
- long id = FederationUtils.getNextFedDataID();
- if(_cbind) {
- out.getDataCharacteristics().set(dc1.getRows(),
- dc1.getCols() + dc2.getCols(),
- dc1.getBlocksize(),
- dc1.getNonZeros() + dc2.getNonZeros());
- out.setFedMapping(fm1.identCopy(getTID(), id).bind(0, dc1.getCols(), fm2.identCopy(getTID(), id)));
+ MetaDataUtils.updateAppendDataCharacteristics(dc1, dc2, out.getDataCharacteristics(), _cbind);
+
+ // federated/federated
+ if( mo1.isFederated() && mo2.isFederated()
+ && mo1.getFedMapping().getType()==mo2.getFedMapping().getType() )
+ {
+ long id = FederationUtils.getNextFedDataID();
+ long roff = _cbind ? 0 : dc1.getRows();
+ long coff = _cbind ? dc1.getCols() : 0;
+ out.setFedMapping(mo1.getFedMapping().identCopy(getTID(), id)
+ .bind(roff, coff, mo2.getFedMapping().identCopy(getTID(), id)));
+ }
+ // federated/local, local/federated cbind
+ else if( (mo1.isFederated(FType.ROW) || mo2.isFederated(FType.ROW)) && _cbind ) {
+ MatrixObject moFed = mo1.isFederated(FType.ROW) ? mo1 : mo2;
+ MatrixObject moLoc = mo1.isFederated(FType.ROW) ? mo2 : mo1;
+
+ //construct commands: broadcast lhs, fed append, clean broadcast
+ FederatedRequest[] fr1 = moFed.getFedMapping().broadcastSliced(moLoc, false);
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2}, mo1.isFederated(FType.ROW) ?
+ new long[]{ moFed.getFedMapping().getID(), fr1[0].getID()} :
+ new long[]{ fr1[0].getID(), moFed.getFedMapping().getID()});
+ FederatedRequest fr3 = moFed.getFedMapping().cleanup(getTID(), fr1[0].getID());
+
+ //execute federated operations and set output
+ moFed.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+ out.setFedMapping(moFed.getFedMapping().copyWithNewID(fr2.getID(), out.getNumColumns()));
+ }
+ // federated/local, local/federated rbind
+ else if( (mo1.isFederated(FType.ROW) || mo2.isFederated(FType.ROW)) && !_cbind) {
+ long id = FederationUtils.getNextFedDataID();
+ long roff = _cbind ? 0 : dc1.getRows();
+ long coff = _cbind ? dc1.getCols() : 0;
+ FederationMap fed1 = mo1.isFederated(FType.ROW) ?
+ mo1.getFedMapping() : FederationUtils.federateLocalData(mo1);
+ FederationMap fed2 = mo2.isFederated(FType.ROW) ?
+ mo2.getFedMapping() : FederationUtils.federateLocalData(mo2);
+ out.setFedMapping(fed1.identCopy(getTID(), id)
+ .bind(roff, coff, fed2.identCopy(getTID(), id)));
}
else {
- out.getDataCharacteristics().set(dc1.getRows() + dc2.getRows(),
- dc1.getCols(),
- dc1.getBlocksize(),
- dc1.getNonZeros() + dc2.getNonZeros());
- out.setFedMapping(fm1.identCopy(getTID(), id).bind(dc1.getRows(), 0, fm2.identCopy(getTID(), id)));
+ throw new DMLRuntimeException("Unsupported federated append: "
+ + (mo1.isFederated() ? mo1.getFedMapping().getType().name():"LOCAL") + " "
+ + (mo2.isFederated() ? mo2.getFedMapping().getType().name():"LOCAL") + " " + _cbind);
}
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index d8af245..c0481ab 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -119,7 +119,7 @@ public class FEDInstructionUtils {
BinaryCPInstruction instruction = (BinaryCPInstruction) inst;
if( (instruction.input1.isMatrix() && ec.getMatrixObject(instruction.input1).isFederated())
|| (instruction.input2.isMatrix() && ec.getMatrixObject(instruction.input2).isFederated()) ) {
- if(instruction.getOpcode().equals("append"))
+ if(instruction.getOpcode().equals("append") )
fedinst = AppendFEDInstruction.parseInstruction(inst.getInstructionString());
else if(instruction.getOpcode().equals("qpick"))
fedinst = QuantilePickFEDInstruction.parseInstruction(inst.getInstructionString());
@@ -149,7 +149,6 @@ public class FEDInstructionUtils {
}
else if(inst instanceof MatrixIndexingCPInstruction) {
// matrix indexing
- LOG.info("Federated Indexing");
MatrixIndexingCPInstruction minst = (MatrixIndexingCPInstruction) inst;
if(inst.getOpcode().equalsIgnoreCase("rightIndex")
&& minst.input1.isMatrix() && ec.getCacheableData(minst.input1).isFederated()) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java
index 5c0a821..477379c 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java
@@ -50,80 +50,65 @@ public final class MatrixIndexingFEDInstruction extends IndexingFEDInstruction {
rightIndexing(ec);
}
- private void rightIndexing(ExecutionContext ec) {
+ private void rightIndexing(ExecutionContext ec)
+ {
+ //get input and requested index range
MatrixObject in = ec.getMatrixObject(input1);
- FederationMap fedMapping = in.getFedMapping();
IndexRange ixrange = getIndexRange(ec);
- // FederationMap.FType fedType;
+
+ //prepare output federation map (copy-on-write)
+ FederationMap fedMap = in.getFedMapping().filter(ixrange);
+
+ //modify federated ranges in place
Map<FederatedRange, IndexRange> ixs = new HashMap<>();
-
- for(int i = 0; i < fedMapping.getFederatedRanges().length; i++) {
- FederatedRange curFedRange = fedMapping.getFederatedRanges()[i];
- long rs = curFedRange.getBeginDims()[0], re = curFedRange.getEndDims()[0],
- cs = curFedRange.getBeginDims()[1], ce = curFedRange.getEndDims()[1];
-
- if((ixrange.colStart <= ce) && (ixrange.colEnd >= cs) && (ixrange.rowStart <= re) && (ixrange.rowEnd >= rs)) {
- // If the indexing range contains values that are within the specific federated range.
- // change the range.
- long rsn = (ixrange.rowStart >= rs) ? (ixrange.rowStart - rs) : 0;
- long ren = (ixrange.rowEnd >= rs && ixrange.rowEnd < re) ? (ixrange.rowEnd - rs) : (re - rs - 1);
- long csn = (ixrange.colStart >= cs) ? (ixrange.colStart - cs) : 0;
- long cen = (ixrange.colEnd >= cs && ixrange.colEnd < ce) ? (ixrange.colEnd - cs) : (ce - cs - 1);
- if(LOG.isDebugEnabled()) {
- LOG.debug("Ranges for fed location: " + rsn + " " + ren + " " + csn + " " + cen);
- LOG.debug("ixRange : " + ixrange);
- LOG.debug("Fed Mapping : " + curFedRange);
- }
- curFedRange.setBeginDim(0, Math.max(rs - ixrange.rowStart, 0));
- curFedRange.setBeginDim(1, Math.max(cs - ixrange.colStart, 0));
- curFedRange.setEndDim(0,
- (ixrange.rowEnd >= re ? re - ixrange.rowStart : ixrange.rowEnd - ixrange.rowStart + 1));
- curFedRange.setEndDim(1,
- (ixrange.colEnd >= ce ? ce - ixrange.colStart : ixrange.colEnd - ixrange.colStart + 1));
- if(LOG.isDebugEnabled()) {
- LOG.debug("Fed Mapping After : " + curFedRange);
- }
- ixs.put(curFedRange, new IndexRange(rsn, ren, csn, cen));
- }
- else {
- // If not within the range, change the range to become an 0 times 0 big range.
- // by setting the end dimensions to the same as the beginning dimensions.
- curFedRange.setBeginDim(0, 0);
- curFedRange.setBeginDim(1, 0);
- curFedRange.setEndDim(0, 0);
- curFedRange.setEndDim(1, 0);
+ for(FederatedRange range : fedMap.getFedMapping().keySet()) {
+ long rs = range.getBeginDims()[0], re = range.getEndDims()[0],
+ cs = range.getBeginDims()[1], ce = range.getEndDims()[1];
+ long rsn = (ixrange.rowStart >= rs) ? (ixrange.rowStart - rs) : 0;
+ long ren = (ixrange.rowEnd >= rs && ixrange.rowEnd < re) ? (ixrange.rowEnd - rs) : (re - rs - 1);
+ long csn = (ixrange.colStart >= cs) ? (ixrange.colStart - cs) : 0;
+ long cen = (ixrange.colEnd >= cs && ixrange.colEnd < ce) ? (ixrange.colEnd - cs) : (ce - cs - 1);
+ if(LOG.isDebugEnabled()) {
+ LOG.debug("Ranges for fed location: " + rsn + " " + ren + " " + csn + " " + cen);
+ LOG.debug("ixRange : " + ixrange);
+ LOG.debug("Fed Mapping : " + range);
}
-
+ range.setBeginDim(0, Math.max(rs - ixrange.rowStart, 0));
+ range.setBeginDim(1, Math.max(cs - ixrange.colStart, 0));
+ range.setEndDim(0, (ixrange.rowEnd >= re ? re-ixrange.rowStart : ixrange.rowEnd-ixrange.rowStart + 1));
+ range.setEndDim(1, (ixrange.colEnd >= ce ? ce-ixrange.colStart : ixrange.colEnd-ixrange.colStart + 1));
+ if(LOG.isDebugEnabled())
+ LOG.debug("Fed Mapping After : " + range);
+ ixs.put(range, new IndexRange(rsn, ren, csn, cen));
}
+ // execute slicing of valid range
long varID = FederationUtils.getNextFedDataID();
- FederationMap slicedMapping = fedMapping.mapParallel(varID, (range, data) -> {
+ FederationMap slicedFedMap = fedMap.mapParallel(varID, (range, data) -> {
try {
FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(
FederatedRequest.RequestType.EXEC_UDF, -1,
- new SliceMatrix(data.getVarID(), varID, ixs.getOrDefault(range, new IndexRange(-1, -1, -1, -1)))))
- .get();
+ new SliceMatrix(data.getVarID(), varID, ixs.get(range)))).get();
if(!response.isSuccessful())
response.throwExceptionFromResponse();
+ return null;
}
catch(Exception e) {
throw new DMLRuntimeException(e);
}
- return null;
});
+ //update output mapping and data characteristics
MatrixObject sliced = ec.getMatrixObject(output);
sliced.getDataCharacteristics()
- .set(fedMapping.getMaxIndexInRange(0), fedMapping.getMaxIndexInRange(1), (int) in.getBlocksize());
- if(ixrange.rowEnd - ixrange.rowStart == 0) {
- slicedMapping.setType(FederationMap.FType.COL);
- }
- else if(ixrange.colEnd - ixrange.colStart == 0) {
- slicedMapping.setType(FederationMap.FType.ROW);
- }
- sliced.setFedMapping(slicedMapping);
- LOG.debug(slicedMapping);
- LOG.debug(sliced);
+ .set(slicedFedMap.getMaxIndexInRange(0), slicedFedMap.getMaxIndexInRange(1), (int) in.getBlocksize());
+ sliced.setFedMapping(slicedFedMap);
+
+ //TODO is this really necessary
+ if(ixrange.rowEnd - ixrange.rowStart == 0)
+ slicedFedMap.setType(FederationMap.FType.COL);
+ else if(ixrange.colEnd - ixrange.colStart == 0)
+ slicedFedMap.setType(FederationMap.FType.ROW);
}
private static class SliceMatrix extends FederatedUDF {
@@ -141,11 +126,7 @@ public final class MatrixIndexingFEDInstruction extends IndexingFEDInstruction {
@Override
public FederatedResponse execute(ExecutionContext ec, Data... data) {
MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease();
- MatrixBlock res;
- if(_ixrange.rowStart != -1)
- res = mb.slice(_ixrange, new MatrixBlock());
- else
- res = new MatrixBlock();
+ MatrixBlock res = mb.slice(_ixrange, new MatrixBlock());
MatrixObject mout = ExecutionContext.createMatrixObject(res);
ec.setVariable(String.valueOf(_outputID), mout);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
index 05c1901..d02d0f5 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
@@ -128,10 +128,18 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
System.arraycopy(subRangeColNames, 0, colNames, (int) range.getBeginDims()[1], subRangeColNames.length);
}
catch(Exception e) {
- throw new DMLRuntimeException("Federated encoder creation failed: " + e.getMessage());
+ throw new DMLRuntimeException("Federated encoder creation failed: ", e);
}
return null;
});
+
+ //sort for consistent encoding in local and federated
+ if( EncoderRecode.SORT_RECODE_MAP ) {
+ for(Encoder encoder : globalEncoder.getEncoders())
+ if( encoder instanceof EncoderRecode )
+ ((EncoderRecode)encoder).sortCPRecodeMaps();
+ }
+
FrameBlock meta = new FrameBlock((int) fin.getNumColumns(), Types.ValueType.STRING);
meta.setColumnNames(colNames);
globalEncoder.getMetaData(meta);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
index 134a2e3..b45654d 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
@@ -126,14 +126,12 @@ public class VariableFEDInstruction extends FEDInstruction implements LineageTra
// execute function at federated site.
FederatedRequest fr1 = FederationUtils.callInstruction(_in.getInstructionString(),
- _in.getOutput(),
- new CPOperand[] {_in.getInput1()},
- new long[] {mo1.getFedMapping().getID()});
+ _in.getOutput(), new CPOperand[] {_in.getInput1()}, new long[] {mo1.getFedMapping().getID()});
mo1.getFedMapping().execute(getTID(), true, fr1);
// Construct output local.
FrameObject out = ec.getFrameObject(_in.getOutput());
- out.getDataCharacteristics().set(mo1.getNumColumns(), mo1.getNumRows(), (int) mo1.getBlocksize(), mo1.getNnz());
+ out.getDataCharacteristics().set(mo1.getNumRows(), mo1.getNumColumns(), (int) mo1.getBlocksize(), mo1.getNnz());
FederationMap outMap = mo1.getFedMapping().copyWithNewID(fr1.getID());
Map<FederatedRange, FederatedData> newMap = new HashMap<>();
for(Map.Entry<FederatedRange, FederatedData> pair : outMap.getFedMapping().entrySet()) {
@@ -152,5 +150,4 @@ public class VariableFEDInstruction extends FEDInstruction implements LineageTra
public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
return _in.getLineageItem(ec);
}
-
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
index f4f98dc..82a95a8 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
@@ -49,6 +49,7 @@ import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.meta.MetaDataUtils;
public abstract class BinarySPInstruction extends ComputationSPInstruction {
@@ -344,22 +345,13 @@ public abstract class BinarySPInstruction extends ComputationSPInstruction {
return mcOut;
}
- protected void updateBinaryAppendOutputDataCharacteristics(SparkExecutionContext sec, boolean cbind)
- {
+ protected void updateBinaryAppendOutputDataCharacteristics(SparkExecutionContext sec, boolean cbind) {
DataCharacteristics mc1 = sec.getDataCharacteristics(input1.getName());
DataCharacteristics mc2 = sec.getDataCharacteristics(input2.getName());
DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName());
//infer initially unknown dimensions from inputs
- if(!mcOut.dimsKnown()) {
- if( !mc1.dimsKnown() || !mc2.dimsKnown() )
- throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from inputs.");
-
- if( cbind )
- mcOut.set(mc1.getRows(), mc1.getCols()+mc2.getCols(), mc1.getBlocksize(), mc1.getBlocksize());
- else //rbind
- mcOut.set(mc1.getRows()+mc2.getRows(), mc1.getCols(), mc1.getBlocksize(), mc1.getBlocksize());
- }
+ MetaDataUtils.updateAppendDataCharacteristics(mc1, mc2, mcOut, cbind);
//infer initially unknown nnz from inputs
if( !mcOut.nnzKnown() && mc1.nnzKnown() && mc2.nnzKnown() ) {
diff --git a/src/main/java/org/apache/sysds/runtime/meta/MetaDataUtils.java b/src/main/java/org/apache/sysds/runtime/meta/MetaDataUtils.java
new file mode 100644
index 0000000..f372130
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/meta/MetaDataUtils.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.meta;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+
+public class MetaDataUtils {
+
+ public static void updateAppendDataCharacteristics(DataCharacteristics mc1,
+ DataCharacteristics mc2, DataCharacteristics mcOut, boolean cbind)
+ {
+ if(!mcOut.dimsKnown()) {
+ if( !mc1.dimsKnown() || !mc2.dimsKnown() )
+ throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from inputs.");
+
+ if( cbind )
+ mcOut.set(mc1.getRows(), mc1.getCols()+mc2.getCols(), mc1.getBlocksize());
+ else //rbind
+ mcOut.set(mc1.getRows()+mc2.getRows(), mc1.getCols(), mc1.getBlocksize());
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
index 6a1ea0b..1dc7bf2 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
@@ -20,6 +20,7 @@
package org.apache.sysds.runtime.transform.encode;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
@@ -40,6 +41,9 @@ public class EncoderRecode extends Encoder
{
private static final long serialVersionUID = 8213163881283341874L;
+ //test property to ensure consistent encoding for local and federated
+ public static boolean SORT_RECODE_MAP = false;
+
//recode maps and custom map for partial recode maps
private HashMap<Integer, HashMap<String, Long>> _rcdMaps = new HashMap<>();
private HashMap<Integer, HashSet<Object>> _rcdMapsPart = null;
@@ -72,6 +76,16 @@ public class EncoderRecode extends Encoder
return _rcdMapsPart;
}
+ public void sortCPRecodeMaps() {
+ for( HashMap<String,Long> map : _rcdMaps.values() ) {
+ String[] keys= map.keySet().toArray(new String[0]);
+ Arrays.sort(keys);
+ map.clear();
+ for(String key : keys)
+ putCode(map, key);
+ }
+ }
+
private long lookupRCDMap(int colID, String key) {
if( !_rcdMaps.containsKey(colID) )
return -1; //empty recode map
@@ -111,6 +125,10 @@ public class EncoderRecode extends Encoder
putCode(map, key);
}
}
+
+ if( SORT_RECODE_MAP ) {
+ sortCPRecodeMaps();
+ }
}
/**
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 3c3471e..1e62975 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -106,6 +106,8 @@ public abstract class AutomatedTestBase {
public static final double GPU_TOLERANCE = 1e-9;
public static final int FED_WORKER_WAIT = 1000; // in ms
+ public static final int FED_WORKER_WAIT_S = 30; // in ms
+
// With OpenJDK 8u242 on Windows, the new changes in JDK are not allowing
// to set the native library paths internally thus breaking the code.
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java
index ff811e0..e15c7d1 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java
@@ -114,9 +114,9 @@ public class FederatedBivarTest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
- Thread t2 = startLocalFedWorkerThread(port2, 10);
- Thread t3 = startLocalFedWorkerThread(port3, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
Thread t4 = startLocalFedWorkerThread(port4);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
index 82437b1..20b0147 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
@@ -102,9 +102,9 @@ public class FederatedCorTest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
- Thread t2 = startLocalFedWorkerThread(port2, 10);
- Thread t3 = startLocalFedWorkerThread(port3, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
Thread t4 = startLocalFedWorkerThread(port4);
rtplatform = execMode;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
index eb8aee8..37a7787 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
@@ -95,7 +95,7 @@ public class FederatedGLMTest extends AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
index f296b3a..e352b5a 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
@@ -104,7 +104,7 @@ public class FederatedKmeansTest extends AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
index f17754e..95e5ba4 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
@@ -99,7 +99,7 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java
new file mode 100644
index 0000000..a5a3d02
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.algorithms;
+
+import org.junit.Test;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.transform.encode.EncoderRecode;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+
+@net.jcip.annotations.NotThreadSafe
+public class FederatedLmPipeline extends AutomatedTestBase {
+
+ private final static String TEST_DIR = "functions/federated/";
+ private final static String TEST_NAME = "FederatedLmPipeline";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedLmPipeline.class.getSimpleName() + "/";
+
+ public int rows = 10000;
+ public int cols = 1000;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ }
+
+ @Test
+ public void federatedLmPipelineContinguous() {
+ federatedLmPipeline(Types.ExecMode.SINGLE_NODE, true);
+ }
+
+ @Test
+ public void federatedLmPipelineSampled() {
+ federatedLmPipeline(Types.ExecMode.SINGLE_NODE, false);
+ }
+
+ public void federatedLmPipeline(ExecMode execMode, boolean contSplits) {
+ ExecMode oldExec = setExecMode(execMode);
+ boolean oldSort = EncoderRecode.SORT_RECODE_MAP;
+ EncoderRecode.SORT_RECODE_MAP = true;
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ try {
+ // generated lm data
+ MatrixBlock X = MatrixBlock.randOperations(rows, cols, 1.0, 0, 1, "uniform", 7);
+ MatrixBlock w = MatrixBlock.randOperations(cols, 1, 1.0, 0, 1, "uniform", 3);
+ MatrixBlock y = new MatrixBlock(rows, 1, false).allocateBlock();
+ LibMatrixMult.matrixMult(X, w, y);
+ MatrixBlock c = MatrixBlock.randOperations(rows, 1, 1.0, 1, 50, "uniform", 23);
+ MatrixBlock rc = c.unaryOperations(InstructionUtils.parseUnaryOperator("round"), new MatrixBlock());
+ X = rc.append(X, new MatrixBlock(), true);
+
+ // We have two matrices handled by a single federated worker
+ int halfRows = rows / 2;
+ writeInputMatrixWithMTD("X1", X.slice(0, halfRows-1), false);
+ writeInputMatrixWithMTD("X2", X.slice(halfRows, rows-1), false);
+ writeInputMatrixWithMTD("Y", y, false);
+
+ // empty script name because we don't execute any script, just start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2);
+
+ TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-args", input("X1"), input("X2"), input("Y"),
+ String.valueOf(contSplits).toUpperCase(), expected("Z")};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + (cols+1),
+ "in_Y=" + input("Y"), "cont=" + String.valueOf(contSplits).toUpperCase(), "out=" + output("Z")};
+ runTest(true, false, null, -1);
+
+ // compare via files
+ compareResults(1e-2);
+ TestUtils.shutdownThreads(t1, t2);
+ }
+ finally {
+ resetExecMode(oldExec);
+ EncoderRecode.SORT_RECODE_MAP = oldSort;
+ }
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
index e7f1f80..fe67bc2 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
@@ -95,7 +95,7 @@ public class FederatedLogRegTest extends AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java
index 8438bb6..ae2f2fa 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java
@@ -102,9 +102,9 @@ public class FederatedPCATest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
- Thread t2 = startLocalFedWorkerThread(port2, 10);
- Thread t3 = startLocalFedWorkerThread(port3, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
Thread t4 = startLocalFedWorkerThread(port4);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java
index 7333533..a4a8236 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java
@@ -100,9 +100,9 @@ public class FederatedUnivarTest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
- Thread t2 = startLocalFedWorkerThread(port2, 10);
- Thread t3 = startLocalFedWorkerThread(port3, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
Thread t4 = startLocalFedWorkerThread(port4);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
index 46af1c9..348f157 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
@@ -109,9 +109,9 @@ public class FederatedVarTest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
- Thread t2 = startLocalFedWorkerThread(port2, 10);
- Thread t3 = startLocalFedWorkerThread(port3, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
Thread t4 = startLocalFedWorkerThread(port4);
rtplatform = execMode;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
index d0eaf87..1d58574 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
@@ -104,7 +104,7 @@ public class FederatedYL2SVMTest extends AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
index a5630e0..810b882 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
@@ -87,7 +87,7 @@ public class FederatedReaderTest extends AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2);
String host = "localhost";
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java
index 261daf6..fc2c1dd 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java
@@ -93,7 +93,7 @@ public class FederatedSSLTest extends AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2);
String host = "localhost";
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java
index a83fad3..d8bb743 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java
@@ -83,7 +83,7 @@ public class FederatedWriterTest extends AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2);
try {
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index 3015aaa..cc0af07 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -155,7 +155,7 @@ public class FederatedParamservTest extends AutomatedTestBase {
// start worker
ports.add(getRandomAvailablePort());
- threads.add(startLocalFedWorkerThread(ports.get(i), 10));
+ threads.add(startLocalFedWorkerThread(ports.get(i), FED_WORKER_WAIT_S));
// add worker to program args
programArgsList.add("X" + i + "=" + TestUtils.federatedAddress(ports.get(i), input("X" + i)));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java
index 958c09b..279f524 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java
@@ -95,7 +95,7 @@ public class FederatedBinaryMatrixTest extends AutomatedTestBase {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
index e8dd6f7..d3cea77 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
@@ -96,7 +96,7 @@ public class FederatedBinaryVectorTest extends AutomatedTestBase {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToFrameTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToFrameTest.java
index fe03906..3b05391 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToFrameTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToFrameTest.java
@@ -97,12 +97,13 @@ public class FederatedCastToFrameTest extends AutomatedTestBase {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
-
+ setOutputBuffering(true); //otherwise NPE
+
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2")};
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToMatrixTest.java
index fa51d89..4fb95c6 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToMatrixTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToMatrixTest.java
@@ -126,12 +126,13 @@ public class FederatedCastToMatrixTest extends AutomatedTestBase {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
-
+ setOutputBuffering(true); //otherwise NPE
+
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2")};
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCentralMomentTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCentralMomentTest.java
index 828718e..98b72a9 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCentralMomentTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCentralMomentTest.java
@@ -39,109 +39,109 @@ import org.junit.runners.Parameterized;
@net.jcip.annotations.NotThreadSafe
public class FederatedCentralMomentTest extends AutomatedTestBase {
- private final static String TEST_DIR = "functions/federated/";
- private final static String TEST_NAME = "FederatedCentralMomentTest";
- private final static String TEST_CLASS_DIR = TEST_DIR + FederatedCentralMomentTest.class.getSimpleName() + "/";
-
- private final static int blocksize = 1024;
- @Parameterized.Parameter()
- public int rows;
-
- @Parameterized.Parameter(1)
- public int k;
-
- @Parameterized.Parameters
- public static Collection<Object[]> data() {
- return Arrays.asList(new Object[][] {
- {1000, 2},
- {1000, 3},
- {1000, 4}
- });
- }
-
- @Override
- public void setUp() {
- TestUtils.clearAssertionInformation();
- addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S.scalar"}));
- }
-
- @Test
- public void federatedCentralMomentCP() { federatedCentralMoment(Types.ExecMode.SINGLE_NODE); }
-
- @Test
- @Ignore
- public void federatedCentralMomentSP() { federatedCentralMoment(Types.ExecMode.SPARK); }
-
- public void federatedCentralMoment(Types.ExecMode execMode) {
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- Types.ExecMode platformOld = rtplatform;
-
- getAndLoadTestConfiguration(TEST_NAME);
- String HOME = SCRIPT_DIR + TEST_DIR;
-
- int r = rows / 4;
-
- double[][] X1 = getRandomMatrix(r, 1, 1, 5, 1, 3);
- double[][] X2 = getRandomMatrix(r, 1, 1, 5, 1, 7);
- double[][] X3 = getRandomMatrix(r, 1, 1, 5, 1, 8);
- double[][] X4 = getRandomMatrix(r, 1, 1, 5, 1, 9);
-
- MatrixCharacteristics mc = new MatrixCharacteristics(r, 1, blocksize, r);
- writeInputMatrixWithMTD("X1", X1, false, mc);
- writeInputMatrixWithMTD("X2", X2, false, mc);
- writeInputMatrixWithMTD("X3", X3, false, mc);
- writeInputMatrixWithMTD("X4", X4, false, mc);
-
- // empty script name because we don't execute any script, just start the worker
- fullDMLScriptName = "";
- int port1 = getRandomAvailablePort();
- int port2 = getRandomAvailablePort();
- int port3 = getRandomAvailablePort();
- int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
- Thread t2 = startLocalFedWorkerThread(port2, 10);
- Thread t3 = startLocalFedWorkerThread(port3, 10);
+ private final static String TEST_DIR = "functions/federated/";
+ private final static String TEST_NAME = "FederatedCentralMomentTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedCentralMomentTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+
+ @Parameterized.Parameter(1)
+ public int k;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {1000, 2},
+ {1000, 3},
+ {1000, 4}
+ });
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S.scalar"}));
+ }
+
+ @Test
+ public void federatedCentralMomentCP() { federatedCentralMoment(Types.ExecMode.SINGLE_NODE); }
+
+ @Test
+ @Ignore
+ public void federatedCentralMomentSP() { federatedCentralMoment(Types.ExecMode.SPARK); }
+
+ public void federatedCentralMoment(Types.ExecMode execMode) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ int r = rows / 4;
+
+ double[][] X1 = getRandomMatrix(r, 1, 1, 5, 1, 3);
+ double[][] X2 = getRandomMatrix(r, 1, 1, 5, 1, 7);
+ double[][] X3 = getRandomMatrix(r, 1, 1, 5, 1, 8);
+ double[][] X4 = getRandomMatrix(r, 1, 1, 5, 1, 9);
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(r, 1, blocksize, r);
+ writeInputMatrixWithMTD("X1", X1, false, mc);
+ writeInputMatrixWithMTD("X2", X2, false, mc);
+ writeInputMatrixWithMTD("X3", X3, false, mc);
+ writeInputMatrixWithMTD("X4", X4, false, mc);
+
+ // empty script name because we don't execute any script, just start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ int port3 = getRandomAvailablePort();
+ int port4 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
Thread t4 = startLocalFedWorkerThread(port4);
- // reference file should not be written to hdfs, so we set platform here
- rtplatform = execMode;
- if(rtplatform == Types.ExecMode.SPARK) {
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
- }
- // Run reference dml script with normal matrix for Row/Col
- fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-stats", "100", "-args",
- input("X1"), input("X2"), input("X3"), input("X4"), expected("S"), String.valueOf(k)};
- runTest(null);
-
- TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
- loadTestConfiguration(config);
-
- fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-stats", "100", "-nvargs",
- "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
- "in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
- "in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
- "in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
- "rows=" + rows,
- "cols=" + 1,
- "out_S=" + output("S"),
- "k=" + k};
- runTest(null);
-
- // compare all sums via files
- compareResults(0.01);
-
- Assert.assertTrue(heavyHittersContainsString("fed_cm"));
-
- // check that federated input files are still existing
- Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
- Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
- Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
- Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
-
- TestUtils.shutdownThreads(t1, t2, t3, t4);
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
- }
+ // reference file should not be written to hdfs, so we set platform here
+ rtplatform = execMode;
+ if(rtplatform == Types.ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ // Run reference dml script with normal matrix for Row/Col
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-stats", "100", "-args",
+ input("X1"), input("X2"), input("X3"), input("X4"), expected("S"), String.valueOf(k)};
+ runTest(null);
+
+ TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
+ "in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
+ "in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
+ "rows=" + rows,
+ "cols=" + 1,
+ "out_S=" + output("S"),
+ "k=" + k};
+ runTest(null);
+
+ // compare all sums via files
+ compareResults(0.01);
+
+ Assert.assertTrue(heavyHittersContainsString("fed_cm"));
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+ TestUtils.shutdownThreads(t1, t2, t3, t4);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedColAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedColAggregateTest.java
index a8480e9..1bdcb5b 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedColAggregateTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedColAggregateTest.java
@@ -157,9 +157,9 @@ public class FederatedColAggregateTest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
- Thread t2 = startLocalFedWorkerThread(port2, 10);
- Thread t3 = startLocalFedWorkerThread(port3, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
Thread t4 = startLocalFedWorkerThread(port4);
rtplatform = execMode;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java
index d388913..9213620 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java
@@ -198,9 +198,9 @@ public class FederatedFullAggregateTest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
- Thread t2 = startLocalFedWorkerThread(port2, 10);
- Thread t3 = startLocalFedWorkerThread(port3, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
Thread t4 = startLocalFedWorkerThread(port4);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
index 3bc2649..8836203 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
@@ -103,7 +103,7 @@ public class FederatedMultiplyTest extends AutomatedTestBase {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
index 540b188..efde5b7 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
@@ -93,8 +93,8 @@ public class FederatedRCBindTest extends AutomatedTestBase {
writeInputMatrixWithMTD("B", B, false, new MatrixCharacteristics(rows, cols, blocksize, rows * cols));
int port1 = getRandomAvailablePort();
- int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
+ int port2 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2);
// we need the reference file to not be written to hdfs, so we get the correct format
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
index de1e6d5..a629270 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
@@ -113,9 +113,9 @@ public class FederatedRemoveEmptyTest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
- Thread t2 = startLocalFedWorkerThread(port2, 10);
- Thread t3 = startLocalFedWorkerThread(port3, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
Thread t4 = startLocalFedWorkerThread(port4);
rtplatform = execMode;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java
index b9e7f62..1401792 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java
@@ -149,9 +149,9 @@ public class FederatedRightIndexTest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
- Thread t2 = startLocalFedWorkerThread(port2, 10);
- Thread t3 = startLocalFedWorkerThread(port3, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
Thread t4 = startLocalFedWorkerThread(port4);
rtplatform = execMode;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java
index 49e692e..e0a3632 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java
@@ -156,9 +156,9 @@ public class FederatedRowAggregateTest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
- Thread t2 = startLocalFedWorkerThread(port2, 10);
- Thread t3 = startLocalFedWorkerThread(port3, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
Thread t4 = startLocalFedWorkerThread(port4);
rtplatform = execMode;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
index 9d37aff..3e640c0 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
@@ -39,100 +39,99 @@ import org.junit.runners.Parameterized;
@net.jcip.annotations.NotThreadSafe
public class FederatedSplitTest extends AutomatedTestBase {
- private static final Log LOG = LogFactory.getLog(FederatedSplitTest.class.getName());
- private final static String TEST_DIR = "functions/federated/";
- private final static String TEST_NAME = "FederatedSplitTest";
- private final static String TEST_CLASS_DIR = TEST_DIR + FederatedSplitTest.class.getSimpleName() + "/";
-
- private final static int blocksize = 1024;
- @Parameterized.Parameter()
- public int rows;
- @Parameterized.Parameter(1)
- public int cols;
- @Parameterized.Parameter(2)
- public String cont;
-
- @Parameterized.Parameters
- public static Collection<Object[]> data() {
- return Arrays.asList(new Object[][] {{152, 12, "TRUE"}, {132, 11, "FALSE"}});
- }
-
- @Override
- public void setUp() {
- TestUtils.clearAssertionInformation();
- addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
- }
-
- @Test
- public void federatedSplitCP() {
- federatedSplit(Types.ExecMode.SINGLE_NODE);
- }
-
- public void federatedSplit(Types.ExecMode execMode) {
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- Types.ExecMode platformOld = rtplatform;
- rtplatform = execMode;
- if(rtplatform == Types.ExecMode.SPARK) {
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
- }
-
- getAndLoadTestConfiguration(TEST_NAME);
- String HOME = SCRIPT_DIR + TEST_DIR;
-
- // write input matrices
- int halfRows = rows / 2;
- // We have two matrices handled by a single federated worker
- double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
- double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
- // And another two matrices handled by a single federated worker
- double[][] Y1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 44);
- double[][] Y2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 21);
-
- writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
- writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
- writeInputMatrixWithMTD("Y1", Y1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
- writeInputMatrixWithMTD("Y2", Y2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
-
- TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
- loadTestConfiguration(config);
-
- int port1 = getRandomAvailablePort();
- int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
+ private static final Log LOG = LogFactory.getLog(FederatedSplitTest.class.getName());
+ private final static String TEST_DIR = "functions/federated/";
+ private final static String TEST_NAME = "FederatedSplitTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedSplitTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+ @Parameterized.Parameter(2)
+ public String cont;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {{152, 12, "TRUE"}, {132, 11, "FALSE"}});
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ }
+
+ @Test
+ public void federatedSplitCP() {
+ federatedSplit(Types.ExecMode.SINGLE_NODE);
+ }
+
+ public void federatedSplit(Types.ExecMode execMode) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+ rtplatform = execMode;
+ if(rtplatform == Types.ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int halfRows = rows / 2;
+ // We have two matrices handled by a single federated worker
+ double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
+ double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+ // And another two matrices handled by a single federated worker
+ double[][] Y1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 44);
+ double[][] Y2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 21);
+
+ writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("Y1", Y1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("Y2", Y2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+
+ TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+ setOutputBuffering(true); //otherwise NPE
+
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2);
- // Run reference dml script with normal matrix
- fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
- "Y2=" + input("Y2"), "Z=" + expected("Z"), "Cont=" + cont};
- String out = runTest(null).toString();
-
- // Run actual dml script with federated matrix
- fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-stats", "100", "-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
- "X2=" + TestUtils.federatedAddress(port2, input("X2")),
- "Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
- "Y2=" + TestUtils.federatedAddress(port2, input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z"),
- "Cont=" + cont};
- String fedOut = runTest(null).toString();
-
- LOG.debug(out);
- LOG.debug(fedOut);
- // compare via files
- compareResults(1e-9);
-
- if(cont.equals("TRUE"))
- Assert.assertTrue(heavyHittersContainsString("fed_rightIndex"));
- else{
-
- Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
- // TODO add federated diag operator.
- // Assert.assertTrue(heavyHittersContainsString("fed_rdiag"));
-
- }
-
- TestUtils.shutdownThreads(t1, t2);
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
- }
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
+ "Y2=" + input("Y2"), "Z=" + expected("Z"), "Cont=" + cont};
+ String out = runTest(null).toString();
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100", "-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "X2=" + TestUtils.federatedAddress(port2, input("X2")),
+ "Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
+ "Y2=" + TestUtils.federatedAddress(port2, input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z"),
+ "Cont=" + cont};
+ String fedOut = runTest(null).toString();
+
+ LOG.debug(out);
+ LOG.debug(fedOut);
+ // compare via files
+ compareResults(1e-9);
+
+ if(cont.equals("TRUE"))
+ Assert.assertTrue(heavyHittersContainsString("fed_rightIndex"));
+ else {
+ Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
+ // TODO add federated diag operator.
+ // Assert.assertTrue(heavyHittersContainsString("fed_rdiag"));
+ }
+
+ TestUtils.shutdownThreads(t1, t2);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
index 865582d..09ca19e 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
@@ -99,7 +99,7 @@ public class FederatedStatisticsTest extends AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, 10);
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
index b7036d0..6c9b034 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
@@ -199,9 +199,9 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- t1 = startLocalFedWorkerThread(port1, 10);
- t2 = startLocalFedWorkerThread(port2, 10);
- t3 = startLocalFedWorkerThread(port3, 10);
+ t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
t4 = startLocalFedWorkerThread(port4);
FileFormatPropertiesCSV ffpCSV = new FileFormatPropertiesCSV(true, DataExpression.DEFAULT_DELIM_DELIMITER,
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java
index 458dbc1..71be21b 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java
@@ -134,9 +134,9 @@ public class TransformFederatedEncodeDecodeTest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- t1 = startLocalFedWorkerThread(port1, 10);
- t2 = startLocalFedWorkerThread(port2, 10);
- t3 = startLocalFedWorkerThread(port3, 10);
+ t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
t4 = startLocalFedWorkerThread(port4);
// schema
diff --git a/src/test/scripts/functions/federated/FederatedLmPipeline.dml b/src/test/scripts/functions/federated/FederatedLmPipeline.dml
new file mode 100644
index 0000000..323333d
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedLmPipeline.dml
@@ -0,0 +1,65 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+Fin = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), list($rows, $cols)))
+y = read($in_Y)
+
+# one hot encoding categorical, other passthrough
+Fall = as.frame(Fin)
+jspec = "{ ids:true, dummycode:[1] }"
+[X,M] = transformencode(target=Fall, spec=jspec)
+print("ncol(X) = "+ncol(X))
+
+# clipping out of value ranges
+colSD = colSds(X)
+colMean = (colMeans(X))
+upperBound = colMean + 1.5 * colSD
+lowerBound = colMean - 1.5 * colSD
+outFilter = (X < lowerBound) | (X > upperBound)
+X = X - outFilter*X + outFilter*colMeans(X);
+
+# normalization
+X = scale(X=X, center=TRUE, scale=TRUE);
+
+# split training and testing
+[Xtrain , Xtest, ytrain, ytest] = split(X=X, Y=y, cont=$cont, seed=7)
+
+# train regression model
+B = lm(X=Xtrain, y=ytrain, icpt=1, reg=1e-3, tol=1e-9, verbose=TRUE)
+
+# model evaluation on test split
+yhat = lmpredict(X=Xtest, w=B, icpt=1);
+y_residual = ytest - yhat;
+
+avg_res = sum(y_residual) / nrow(ytest);
+ss_res = sum(y_residual^2);
+ss_avg_res = ss_res - nrow(ytest) * avg_res^2;
+R2 = 1 - ss_res / (sum(y^2) - nrow(ytest) * (sum(y)/nrow(ytest))^2);
+print("\nAccuracy:" +
+ "\n--sum(ytest) = " + sum(ytest) +
+ "\n--sum(yhat) = " + sum(yhat) +
+ "\n--AVG_RES_Y: " + avg_res +
+ "\n--SS_AVG_RES_Y: " + ss_avg_res +
+ "\n--R2: " + R2 );
+
+# write trained model and meta data
+write(B, $out)
diff --git a/src/test/scripts/functions/federated/FederatedLmPipelineReference.dml b/src/test/scripts/functions/federated/FederatedLmPipelineReference.dml
new file mode 100644
index 0000000..72ca292
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedLmPipelineReference.dml
@@ -0,0 +1,64 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+Fin = rbind(read($1), read($2))
+y = read($3)
+
+# one hot encoding categorical, other passthrough
+Fall = as.frame(Fin)
+jspec = "{ ids:true, dummycode:[1] }"
+[X,M] = transformencode(target=Fall, spec=jspec)
+print("ncol(X) = "+ncol(X))
+
+# clipping out of value ranges
+colSD = colSds(X)
+colMean = (colMeans(X))
+upperBound = colMean + 1.5 * colSD
+lowerBound = colMean - 1.5 * colSD
+outFilter = (X < lowerBound) | (X > upperBound)
+X = X - outFilter*X + outFilter*colMeans(X);
+
+# normalization
+X = scale(X=X, center=TRUE, scale=TRUE);
+
+# split training and testing
+[Xtrain , Xtest, ytrain, ytest] = split(X=X, Y=y, cont=$4, seed=7)
+
+# train regression model
+B = lm(X=Xtrain, y=ytrain, icpt=1, reg=1e-3, tol=1e-9, verbose=TRUE)
+
+# model evaluation on test split
+yhat = lmpredict(X=Xtest, w=B, icpt=1);
+y_residual = ytest - yhat;
+
+avg_res = sum(y_residual) / nrow(ytest);
+ss_res = sum(y_residual^2);
+ss_avg_res = ss_res - nrow(ytest) * avg_res^2;
+R2 = 1 - ss_res / (sum(y^2) - nrow(ytest) * (sum(y)/nrow(ytest))^2);
+print("\nAccuracy:" +
+ "\n--sum(ytest) = " + sum(ytest) +
+ "\n--sum(yhat) = " + sum(yhat) +
+ "\n--AVG_RES_Y: " + avg_res +
+ "\n--SS_AVG_RES_Y: " + ss_avg_res +
+ "\n--R2: " + R2 );
+
+# write trained model and meta data
+write(B, $5)