You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/05/18 13:05:59 UTC
[systemds] branch master updated: [SYSTEMDS-2604] Extended
federation map/instructions (federated output)
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 77d8202 [SYSTEMDS-2604] Extended federation map/instructions (federated output)
77d8202 is described below
commit 77d82022e65eff0caf18efc73941e4374ac0c214
Author: sebwrede <sw...@know-center.at>
AuthorDate: Tue May 18 15:03:46 2021 +0200
[SYSTEMDS-2604] Extended federation map/instructions (federated output)
Closes #1237.
Co-authored-by: Matthias Boehm <mb...@gmail.com>
---
.../java/org/apache/sysds/hops/AggUnaryOp.java | 7 +-
src/main/java/org/apache/sysds/hops/BinaryOp.java | 5 +-
src/main/java/org/apache/sysds/hops/Hop.java | 42 +++--
src/main/java/org/apache/sysds/hops/ReorgOp.java | 1 -
src/main/java/org/apache/sysds/hops/TernaryOp.java | 2 -
src/main/java/org/apache/sysds/lops/Binary.java | 16 +-
src/main/java/org/apache/sysds/lops/Lop.java | 7 +-
src/main/java/org/apache/sysds/lops/MapMult.java | 41 ++---
src/main/java/org/apache/sysds/lops/MatMultCP.java | 14 +-
.../org/apache/sysds/lops/PartialAggregate.java | 18 +-
src/main/java/org/apache/sysds/lops/Ternary.java | 21 +--
src/main/java/org/apache/sysds/lops/Transform.java | 8 +-
src/main/java/org/apache/sysds/lops/Unary.java | 10 +-
.../controlprogram/caching/MatrixObject.java | 23 +--
.../controlprogram/federated/FederatedRange.java | 14 +-
.../controlprogram/federated/FederationMap.java | 194 +++++++++++++++------
.../controlprogram/federated/FederationUtils.java | 54 +++++-
.../paramserv/FederatedPSControlThread.java | 4 +-
.../paramserv/dp/BalanceToAvgFederatedScheme.java | 4 +-
.../paramserv/dp/DataPartitionFederatedScheme.java | 6 +-
.../dp/ReplicateToMaxFederatedScheme.java | 4 +-
.../paramserv/dp/ShuffleFederatedScheme.java | 4 +-
.../dp/SubsampleToMinFederatedScheme.java | 4 +-
.../runtime/instructions/InstructionUtils.java | 60 +------
.../cp/AggregateBinaryCPInstruction.java | 28 +--
.../fed/AggregateBinaryFEDInstruction.java | 167 ++++++++++++++----
.../fed/AggregateUnaryFEDInstruction.java | 32 ++--
.../instructions/fed/BinaryFEDInstruction.java | 15 +-
.../fed/BinaryMatrixMatrixFEDInstruction.java | 48 +++--
.../fed/BinaryMatrixScalarFEDInstruction.java | 6 +-
.../fed/ComputationFEDInstruction.java | 12 +-
.../runtime/instructions/fed/FEDInstruction.java | 20 ++-
.../instructions/fed/FEDInstructionUtils.java | 20 ++-
.../instructions/fed/IndexingFEDInstruction.java | 8 +-
.../instructions/fed/InitFEDInstruction.java | 20 +--
.../fed/QuantilePickFEDInstruction.java | 6 +-
.../fed/QuantileSortFEDInstruction.java | 28 ---
.../instructions/fed/ReorgFEDInstruction.java | 18 +-
.../instructions/fed/TernaryFEDInstruction.java | 14 +-
.../instructions/fed/TsmmFEDInstruction.java | 6 +-
.../instructions/fed/UnaryFEDInstruction.java | 30 +++-
.../instructions/fed/VariableFEDInstruction.java | 23 +--
.../sysds/runtime/io/ReaderWriterFederated.java | 13 +-
.../org/apache/sysds/test/AutomatedTestBase.java | 5 +-
.../primitives/FederatedNegativeTest.java | 9 +-
.../privacy/algorithms/FederatedL2SVMTest.java | 6 +-
.../fedplanning/FederatedMultiplyPlanningTest.java | 160 ++++++++++-------
.../FederatedMultiplyPlanningTest.dml | 0
.../FederatedMultiplyPlanningTest2.dml | 0
.../FederatedMultiplyPlanningTest2Reference.dml | 0
.../FederatedMultiplyPlanningTest3.dml | 4 +-
.../FederatedMultiplyPlanningTest3Reference.dml | 0
.../FederatedMultiplyPlanningTest4.dml} | 7 +-
.../FederatedMultiplyPlanningTest4Reference.dml} | 4 +-
.../FederatedMultiplyPlanningTest5.dml} | 9 +-
.../FederatedMultiplyPlanningTest5Reference.dml} | 8 +-
.../FederatedMultiplyPlanningTest6.dml} | 12 +-
.../FederatedMultiplyPlanningTest6Reference.dml} | 7 +-
.../FederatedMultiplyPlanningTestReference.dml | 0
59 files changed, 755 insertions(+), 553 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index 5d54535..118299c 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -157,11 +157,8 @@ public class AggUnaryOp extends MultiThreadedHop
setLineNumbers(agg1);
setLops(agg1);
- if (getDataType() == DataType.SCALAR) {
+ if (getDataType() == DataType.SCALAR)
agg1.getOutputParameters().setDimensions(1, 1, getBlocksize(), getNnz());
- } else {
- setFederatedOutput(agg1);
- }
}
else if( et == ExecType.SPARK )
{
@@ -380,7 +377,7 @@ public class AggUnaryOp extends MultiThreadedHop
&& !(getInput().get(0) instanceof DataOp) //input is not checkpoint
&& (getInput().get(0).getParent().size()==1 //uagg is only parent, or
|| !requiresAggregation(getInput().get(0), _direction)) //w/o agg
- && getInput().get(0).optFindExecType() == ExecType.SPARK )
+ && getInput().get(0).optFindExecType() == ExecType.SPARK )
{
//pull unary aggregate into spark
_etype = ExecType.SPARK;
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 897f707..d94f607 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -224,8 +224,6 @@ public class BinaryOp extends MultiThreadedHop
constructLopsBinaryDefault();
}
- setFederatedOutput(getLops());
-
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
@@ -418,7 +416,8 @@ public class BinaryOp extends MultiThreadedHop
Lop tmp = null;
if( ot != null ) {
tmp = new Unary(getInput(0).constructLops(), getInput(1).constructLops(),
- ot, getDataType(), getValueType(), et);
+ ot, getDataType(), getValueType(), et,
+ OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
}
else { //general case
tmp = new Binary(getInput(0).constructLops(), getInput(1).constructLops(),
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index c4c2b5e..1c86cea 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -47,6 +47,7 @@ import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
@@ -87,7 +88,7 @@ public abstract class Hop implements ParseInfo {
* If it is true, the output should be kept at federated sites.
* If it is false, the output should be retrieved by the coordinator.
*/
- protected boolean _federatedOutput = false;
+ protected FederatedOutput _federatedOutput = FederatedOutput.NONE;
// Estimated size for the output produced from this Hop
protected double _outputMemEstimate = OptimizerUtils.INVALID_SIZE;
@@ -288,6 +289,10 @@ public abstract class Hop implements ParseInfo {
}
public void constructAndSetLopsDataFlowProperties() {
+ //propagate federated output configuration to lops
+ if( isFederated() )
+ getLops().setFederatedOutput(_federatedOutput);
+
//Step 1: construct reblock lop if required (output of hop)
constructAndSetReblockLopIfRequired();
@@ -751,28 +756,25 @@ public abstract class Hop implements ParseInfo {
}
/**
- * Returns true if any input has federated ExecType and configures such input to keep the output federated.
+ * Returns true if any input has federated ExecType.
* This method can only return true if FedDecision is activated.
* @return true if any input has federated ExecType
*/
protected boolean inputIsFED(){
- if ( !OptimizerUtils.FEDERATED_COMPILATION ) return false;
- boolean fedFound = false;
- for ( Hop input : _input ){
- if ( input.isFederated() ){
- input._federatedOutput = true;
- fedFound = true;
- }
- }
- return fedFound;
+ if ( !OptimizerUtils.FEDERATED_COMPILATION )
+ return false;
+ for ( Hop input : _input )
+ if ( input.isFederated() || input.isFederatedOutput() )
+ return true;
+ return false;
}
-
- /**
- * Returns true if the execution is federated and/or if the output is federated.
- * @return true if federated
- */
+
public boolean isFederated(){
- return getExecType() == ExecType.FED || hasFederatedOutput();
+ return getExecType() == ExecType.FED;
+ }
+
+ public boolean isFederatedOutput(){
+ return _federatedOutput == FederatedOutput.FOUT;
}
public ArrayList<Hop> getParent() {
@@ -822,7 +824,7 @@ public abstract class Hop implements ParseInfo {
}
public boolean hasFederatedOutput(){
- return _federatedOutput;
+ return _federatedOutput == FederatedOutput.FOUT;
}
public void setUpdateType(UpdateType update){
@@ -1458,10 +1460,6 @@ public abstract class Hop implements ParseInfo {
lop.setPrivacyConstraint(getPrivacy());
}
- protected void setFederatedOutput(Lop lop){
- lop.setFederatedOutput(_federatedOutput);
- }
-
/**
* Set parse information.
*
diff --git a/src/main/java/org/apache/sysds/hops/ReorgOp.java b/src/main/java/org/apache/sysds/hops/ReorgOp.java
index badb057..9326d9d 100644
--- a/src/main/java/org/apache/sysds/hops/ReorgOp.java
+++ b/src/main/java/org/apache/sysds/hops/ReorgOp.java
@@ -162,7 +162,6 @@ public class ReorgOp extends MultiThreadedHop
else { //general case
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
Transform transform1 = new Transform(lin, _op, getDataType(), getValueType(), et, k);
- setFederatedOutput(transform1);
setOutputDimensions(transform1);
setLineNumbers(transform1);
setLops(transform1);
diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java b/src/main/java/org/apache/sysds/hops/TernaryOp.java
index 3a8d02b..ccbf746 100644
--- a/src/main/java/org/apache/sysds/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java
@@ -195,8 +195,6 @@ public class TernaryOp extends MultiThreadedHop
catch(LopsException e) {
throw new HopsException(this.printErrorLocation() + "error constructing Lops for TernaryOp Hop " , e);
}
-
- setFederatedOutput(getLops());
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
diff --git a/src/main/java/org/apache/sysds/lops/Binary.java b/src/main/java/org/apache/sysds/lops/Binary.java
index 5ba77bb..f8d4eb6 100644
--- a/src/main/java/org/apache/sysds/lops/Binary.java
+++ b/src/main/java/org/apache/sysds/lops/Binary.java
@@ -81,19 +81,17 @@ public class Binary extends Lop
@Override
public String getInstructions(String input1, String input2, String output) {
- InstructionUtils.concatBaseOperands(
+ String ret = InstructionUtils.concatOperands(
getExecType().name(), getOpcode(),
getInputs().get(0).prepInputOperand(input1),
getInputs().get(1).prepInputOperand(input2),
- prepOutputOperand(output)
- );
+ prepOutputOperand(output));
- if ( getExecType() == ExecType.CP || getExecType() == ExecType.FED){
- InstructionUtils.concatAdditionalOperand(String.valueOf(_numThreads));
- if ( federatedOutput )
- InstructionUtils.concatAdditionalOperand(String.valueOf(federatedOutput));
- }
+ if ( getExecType() == ExecType.CP )
+ ret = InstructionUtils.concatOperands(ret, String.valueOf(_numThreads));
+ else if( getExecType() == ExecType.FED )
+ ret = InstructionUtils.concatOperands(ret, String.valueOf(_numThreads), _fedOutput.name());
- return InstructionUtils.getInstructionString();
+ return ret;
}
}
diff --git a/src/main/java/org/apache/sysds/lops/Lop.java b/src/main/java/org/apache/sysds/lops/Lop.java
index a92609c..0433296 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -25,6 +25,7 @@ import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.lops.compile.Dag;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
import org.apache.sysds.runtime.privacy.PrivacyConstraint;
@@ -119,7 +120,7 @@ public abstract class Lop
* If it is true, the output should be kept at federated sites.
* If it is false, the output should be retrieved by the coordinator.
*/
- protected boolean federatedOutput = false;
+ protected FederatedOutput _fedOutput = null;
/**
* refers to #lops whose input is equal to the output produced by this lop.
@@ -294,8 +295,8 @@ public abstract class Lop
return privacyConstraint;
}
- public void setFederatedOutput(boolean federatedOutput){
- this.federatedOutput = federatedOutput;
+ public void setFederatedOutput(FederatedOutput fedOutput){
+ _fedOutput = fedOutput;
}
public void setConsumerCount(int cc) {
diff --git a/src/main/java/org/apache/sysds/lops/MapMult.java b/src/main/java/org/apache/sysds/lops/MapMult.java
index 3b30129..9429c10 100644
--- a/src/main/java/org/apache/sysds/lops/MapMult.java
+++ b/src/main/java/org/apache/sysds/lops/MapMult.java
@@ -22,7 +22,7 @@ package org.apache.sysds.lops;
import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
import org.apache.sysds.lops.LopProperties.ExecType;
-
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
@@ -94,35 +94,18 @@ public class MapMult extends Lop
}
@Override
- public String getInstructions(String input1, String input2, String output)
- {
- StringBuilder sb = new StringBuilder();
-
- sb.append(getExecType());
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(OPCODE);
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append( getInputs().get(0).prepInputOperand(input1));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append( getInputs().get(1).prepInputOperand(input2));
+ public String getInstructions(String input1, String input2, String output) {
+ String ret = InstructionUtils.concatOperands(
+ getExecType().name(), OPCODE,
+ getInputs().get(0).prepInputOperand(input1),
+ getInputs().get(1).prepInputOperand(input2),
+ prepOutputOperand(output),
+ _cacheType.name(),
+ String.valueOf(_outputEmptyBlocks));
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(prepOutputOperand(output));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(_cacheType);
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(_outputEmptyBlocks);
-
- if( getExecType() == ExecType.SPARK ) {
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(_aggtype.toString());
- }
+ if( getExecType() == ExecType.SPARK )
+ ret = InstructionUtils.concatOperands(ret, _aggtype.name());
- return sb.toString();
+ return ret;
}
}
diff --git a/src/main/java/org/apache/sysds/lops/MatMultCP.java b/src/main/java/org/apache/sysds/lops/MatMultCP.java
index 056a5c7..ac01c98 100644
--- a/src/main/java/org/apache/sysds/lops/MatMultCP.java
+++ b/src/main/java/org/apache/sysds/lops/MatMultCP.java
@@ -72,17 +72,18 @@ public class MatMultCP extends Lop {
@Override
public String getInstructions(String input1, String input2, String output) {
+ String ret = null;
if(!useTranspose) {
- return InstructionUtils.concatOperands(getExecType().name(),
- "ba+*",
+ ret = InstructionUtils.concatOperands(
+ getExecType().name(), "ba+*",
getInputs().get(0).prepInputOperand(input1),
getInputs().get(1).prepInputOperand(input2),
prepOutputOperand(output),
String.valueOf(numThreads));
}
else { // GPU or compressed
- return InstructionUtils.concatOperands(getExecType().name(),
- "ba+*",
+ ret = InstructionUtils.concatOperands(
+ getExecType().name(), "ba+*",
getInputs().get(0).prepInputOperand(input1),
getInputs().get(1).prepInputOperand(input2),
prepOutputOperand(output),
@@ -90,5 +91,10 @@ public class MatMultCP extends Lop {
String.valueOf(isLeftTransposed),
String.valueOf(isRightTransposed));
}
+
+ if ( getExecType() == ExecType.FED )
+ ret = InstructionUtils.concatOperands(ret, _fedOutput.name());
+
+ return ret;
}
}
diff --git a/src/main/java/org/apache/sysds/lops/PartialAggregate.java b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
index 118a804..7106114 100644
--- a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
@@ -217,22 +217,22 @@ public class PartialAggregate extends Lop
@Override
public String getInstructions(String input1, String output)
{
- InstructionUtils.concatBaseOperands(
- getExecType().name(),
- getOpcode(),
+ String ret = InstructionUtils.concatOperands(
+ getExecType().name(), getOpcode(),
getInputs().get(0).prepInputOperand(input1),
prepOutputOperand(output));
if ( getExecType() == ExecType.SPARK )
- InstructionUtils.concatAdditionalOperand(_aggtype.toString());
+ ret = InstructionUtils.concatOperands(ret, _aggtype.name());
else if ( getExecType() == ExecType.CP || getExecType() == ExecType.FED ){
- InstructionUtils.concatAdditionalOperand(Integer.toString(_numThreads));
+ ret = InstructionUtils.concatOperands(ret, Integer.toString(_numThreads));
if ( getOpcode().equalsIgnoreCase("uarimin") || getOpcode().equalsIgnoreCase("uarimax") )
- InstructionUtils.concatAdditionalOperand("1");
- if ( getExecType() == ExecType.FED && operation != AggOp.VAR )
- InstructionUtils.concatAdditionalOperand(String.valueOf(federatedOutput));
+ ret = InstructionUtils.concatOperands(ret, "1");
+ if ( getExecType() == ExecType.FED )
+ ret = InstructionUtils.concatOperands(ret, _fedOutput.name());
}
- return InstructionUtils.getInstructionString();
+
+ return ret;
}
public static String getOpcode(AggOp op, Direction dir)
diff --git a/src/main/java/org/apache/sysds/lops/Ternary.java b/src/main/java/org/apache/sysds/lops/Ternary.java
index a6ad9d2..b7cad7e 100644
--- a/src/main/java/org/apache/sysds/lops/Ternary.java
+++ b/src/main/java/org/apache/sysds/lops/Ternary.java
@@ -59,19 +59,20 @@ public class Ternary extends Lop
@Override
public String getInstructions(String input1, String input2, String input3, String output) {
- InstructionUtils.concatOperands(
- getExecType().name(),
- _op.toString(),
+ String ret = InstructionUtils.concatOperands(
+ getExecType().name(), _op.toString(),
getInputs().get(0).prepInputOperand(input1),
getInputs().get(1).prepInputOperand(input2),
getInputs().get(2).prepInputOperand(input3),
- prepOutputOperand(output)
- );
- if( (getExecType() == ExecType.CP || getExecType() == ExecType.FED ) && getDataType().isMatrix() ){
- InstructionUtils.concatAdditionalOperand(String.valueOf(_numThreads));
- if ( federatedOutput )
- InstructionUtils.concatAdditionalOperand(String.valueOf(federatedOutput));
+ prepOutputOperand(output));
+
+ if( getDataType().isMatrix() ) {
+ if( getExecType() == ExecType.CP )
+ ret = InstructionUtils.concatOperands(ret, String.valueOf(_numThreads));
+ else if( getExecType() == ExecType.FED )
+ ret = InstructionUtils.concatOperands(ret, String.valueOf(_numThreads), _fedOutput.name());
}
- return InstructionUtils.getInstructionString();
+
+ return ret;
}
}
diff --git a/src/main/java/org/apache/sysds/lops/Transform.java b/src/main/java/org/apache/sysds/lops/Transform.java
index 2c1df26..b0e5e3e 100644
--- a/src/main/java/org/apache/sysds/lops/Transform.java
+++ b/src/main/java/org/apache/sysds/lops/Transform.java
@@ -172,16 +172,16 @@ public class Transform extends Lop
&& (_operation == ReOrgOp.TRANS || _operation == ReOrgOp.SORT) ) {
sb.append( OPERAND_DELIMITOR );
sb.append( _numThreads );
- if ( federatedOutput ){
+ if ( getExecType()==ExecType.FED ) {
sb.append( OPERAND_DELIMITOR );
- sb.append( federatedOutput );
+ sb.append( _fedOutput.name() );
}
}
- if( getExecType()==ExecType.SPARK && _operation == ReOrgOp.RESHAPE ) {
+ else if( getExecType()==ExecType.SPARK && _operation == ReOrgOp.RESHAPE ) {
sb.append( OPERAND_DELIMITOR );
sb.append( _outputEmptyBlock );
}
- if( getExecType()==ExecType.SPARK && _operation == ReOrgOp.SORT ){
+ else if( getExecType()==ExecType.SPARK && _operation == ReOrgOp.SORT ){
sb.append( OPERAND_DELIMITOR );
sb.append( _bSortIndInMem );
}
diff --git a/src/main/java/org/apache/sysds/lops/Unary.java b/src/main/java/org/apache/sysds/lops/Unary.java
index 0e34ba2..bec7034 100644
--- a/src/main/java/org/apache/sysds/lops/Unary.java
+++ b/src/main/java/org/apache/sysds/lops/Unary.java
@@ -54,9 +54,10 @@ public class Unary extends Lop
* @param vt value type
* @param et execution type
*/
- public Unary(Lop input1, Lop input2, OpOp1 op, DataType dt, ValueType vt, ExecType et) {
+ public Unary(Lop input1, Lop input2, OpOp1 op, DataType dt, ValueType vt, ExecType et, int numThreads) {
super(Lop.Type.UNARY, dt, vt);
init(input1, input2, op, dt, vt, et);
+ _numThreads = numThreads;
}
private void init(Lop input1, Lop input2, OpOp1 op, DataType dt, ValueType vt, ExecType et) {
@@ -182,7 +183,12 @@ public class Unary extends Lop
sb.append( getInputs().get(1).prepInputOperand(input2));
sb.append( OPERAND_DELIMITOR );
- sb.append( this.prepOutputOperand(output));
+ sb.append( prepOutputOperand(output));
+
+ if( getExecType() == ExecType.CP ) {
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( String.valueOf(_numThreads) );
+ }
return sb.toString();
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index e55509b..3001e44 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -41,6 +41,7 @@ import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.io.FileFormatProperties;
@@ -548,28 +549,18 @@ public class MatrixObject extends CacheableData<MatrixBlock>
throws IOException
{
// TODO sparse optimization
- MatrixBlock ret = new MatrixBlock((int) dims[0], (int) dims[1], false);
List<Pair<FederatedRange, Future<FederatedResponse>>> readResponses = fedMap.requestFederatedData();
try {
- for (Pair<FederatedRange, Future<FederatedResponse>> readResponse : readResponses) {
- FederatedRange range = readResponse.getLeft();
- FederatedResponse response = readResponse.getRight().get();
- // add result
- int[] beginDimsInt = range.getBeginDimsInt();
- int[] endDimsInt = range.getEndDimsInt();
- MatrixBlock multRes = (MatrixBlock) response.getData()[0];
- ret.copy(beginDimsInt[0], endDimsInt[0] - 1,
- beginDimsInt[1], endDimsInt[1] - 1, multRes, false);
- ret.setNonZeros(ret.getNonZeros() + multRes.getNonZeros());
- }
+ if ( fedMap.getType() == FederationMap.FType.PART )
+ return FederationUtils.aggregateResponses(readResponses);
+ else
+ return FederationUtils.bindResponses(readResponses, dims);
}
- catch (Exception e) {
+ catch(Exception e) {
throw new DMLRuntimeException("Federated matrix read failed.", e);
}
-
- return ret;
}
-
+
/**
* Writes in-memory matrix to HDFS in a specified format.
*/
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
index 3bd5734..4948d27 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
@@ -24,8 +24,8 @@ import java.util.Arrays;
import org.apache.sysds.runtime.util.IndexRange;
public class FederatedRange implements Comparable<FederatedRange> {
- private long[] _beginDims;
- private long[] _endDims;
+ private final long[] _beginDims;
+ private final long[] _endDims;
/**
* Create a range with the indexes of each dimension between their respective <code>beginDims</code> and
@@ -81,6 +81,7 @@ public class FederatedRange implements Comparable<FederatedRange> {
size *= getSize(i);
return size;
}
+
public long getSize(int dim) {
return _endDims[dim] - _beginDims[dim];
@@ -102,7 +103,8 @@ public class FederatedRange implements Comparable<FederatedRange> {
return Arrays.toString(_beginDims) + " - " + Arrays.toString(_endDims);
}
- @Override public boolean equals(Object o) {
+ @Override
+ public boolean equals(Object o) {
if(this == o)
return true;
if(o == null || getClass() != o.getClass())
@@ -111,10 +113,10 @@ public class FederatedRange implements Comparable<FederatedRange> {
return Arrays.equals(_beginDims, range._beginDims) && Arrays.equals(_endDims, range._endDims);
}
- @Override public int hashCode() {
+ @Override
+ public int hashCode() {
int result = Arrays.hashCode(_beginDims);
- result = 31 * result + Arrays.hashCode(_endDims);
- return result;
+ return 31 * result + Arrays.hashCode(_endDims);
}
public FederatedRange shift(long rshift, long cshift) {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 9afa295..7a52b11 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -23,9 +23,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
-import java.util.Map;
import java.util.Map.Entry;
-import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
@@ -46,18 +44,44 @@ import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.IndexRange;
public class FederationMap {
+ public enum FPartitioning{
+ ROW, //row partitioned, groups of entire rows
+ COL, //column partitioned, groups of entire columns
+ MIXED, //arbitrary rectangles
+ NONE, //entire data in a location
+ }
+
+ public enum FReplication {
+ NONE, //every data item in a separate location
+ FULL, //every data item at every location
+ OVERLAP, //every data item partially at every location, w/ addition as aggregation method
+ }
+
public enum FType {
- ROW, // row partitioned, groups of rows
- COL, // column partitioned, groups of columns
- FULL, // Meaning both Row and Column indicating a single federated location and a full matrix
- OTHER;
-
+ ROW(FPartitioning.ROW, FReplication.NONE),
+ COL(FPartitioning.COL, FReplication.NONE),
+ FULL(FPartitioning.NONE, FReplication.NONE),
+ BROADCAST(FPartitioning.NONE, FReplication.FULL),
+ PART(FPartitioning.NONE, FReplication.OVERLAP),
+ OTHER(FPartitioning.MIXED, FReplication.NONE);
+
+ private final FPartitioning _partType;
+ @SuppressWarnings("unused") //not yet
+ private final FReplication _repType;
+
+ private FType(FPartitioning ptype, FReplication rtype) {
+ _partType = ptype;
+ _repType = rtype;
+ }
+
public boolean isRowPartitioned() {
- return this == ROW || this == FULL;
+ return _partType == FPartitioning.ROW
+ || _partType == FPartitioning.NONE;
}
public boolean isColPartitioned() {
- return this == COL || this == FULL;
+ return _partType == FPartitioning.COL
+ || _partType == FPartitioning.NONE;
}
public boolean isType(FType t) {
@@ -75,18 +99,18 @@ public class FederationMap {
}
private long _ID = -1;
- private final Map<FederatedRange, FederatedData> _fedMap;
+ private final List<Pair<FederatedRange, FederatedData>> _fedMap;
private FType _type;
- public FederationMap(Map<FederatedRange, FederatedData> fedMap) {
+ public FederationMap(List<Pair<FederatedRange, FederatedData>> fedMap) {
this(-1, fedMap);
}
- public FederationMap(long ID, Map<FederatedRange, FederatedData> fedMap) {
+ public FederationMap(long ID, List<Pair<FederatedRange, FederatedData>> fedMap) {
this(ID, fedMap, FType.OTHER);
}
- public FederationMap(long ID, Map<FederatedRange, FederatedData> fedMap, FType type) {
+ public FederationMap(long ID, List<Pair<FederatedRange, FederatedData>> fedMap, FType type) {
_ID = ID;
_fedMap = fedMap;
_type = type;
@@ -113,13 +137,31 @@ public class FederationMap {
}
public FederatedRange[] getFederatedRanges() {
- return _fedMap.keySet().toArray(new FederatedRange[0]);
+ return _fedMap.stream().map(e -> e.getKey()).toArray(FederatedRange[]::new);
+ }
+
+ public FederatedData[] getFederatedData() {
+ return _fedMap.stream().map(e -> e.getValue()).toArray(FederatedData[]::new);
+ }
+
+ private FederatedData getFederatedData(FederatedRange range) {
+ for( Pair<FederatedRange, FederatedData> e : _fedMap )
+ if( e.getKey().equals(range) )
+ return e.getValue();
+ return null;
+ }
+
+ private void removeFederatedData(FederatedRange range) {
+ Iterator<Pair<FederatedRange, FederatedData>> iter = _fedMap.iterator();
+ while( iter.hasNext() )
+ if( iter.next().getKey().equals(range) )
+ iter.remove();
}
- public Map<FederatedRange, FederatedData> getMap() {
+ public List<Pair<FederatedRange, FederatedData>> getMap() {
return _fedMap;
}
-
+
public FederatedRequest broadcast(CacheableData<?> data) {
// prepare single request for all federated data
long id = FederationUtils.getNextFedDataID();
@@ -152,7 +194,7 @@ public class FederationMap {
// prepare indexing ranges
int[][] ix = new int[_fedMap.size()][];
int pos = 0;
- for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet()) {
+ for(Pair<FederatedRange, FederatedData> e : _fedMap) {
int beg = e.getKey().getBeginDimsInt()[(_type == FType.ROW ? 0 : 1)];
int end = e.getKey().getEndDimsInt()[(_type == FType.ROW ? 0 : 1)];
int nr = _type == FType.ROW ? cb.getNumRows() : cb.getNumColumns();
@@ -189,18 +231,23 @@ public class FederationMap {
return ret;
}
+ /**
+ * Determines if the two federation maps are aligned row/column partitions
+ * at the same federated sites (which allows for purely federated operation)
+ * @param that FederationMap to check alignment with
+ * @param transposed true if that FederationMap should be transposed before checking alignment
+ * @return true if this and that FederationMap are aligned
+ */
public boolean isAligned(FederationMap that, boolean transposed) {
- // determines if the two federated data are aligned row/column partitions
- // at the same federated site (which allows for purely federated operation)
boolean ret = true;
- for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet()) {
+ for(Pair<FederatedRange, FederatedData> e : _fedMap) {
FederatedRange range = !transposed ? e.getKey() : new FederatedRange(e.getKey()).transpose();
- FederatedData dat2 = that._fedMap.get(range);
+ FederatedData dat2 = that.getFederatedData(range);
ret &= e.getValue().equalAddress(dat2);
}
return ret;
}
-
+
public Future<FederatedResponse>[] execute(long tid, FederatedRequest... fr) {
return execute(tid, false, fr);
}
@@ -220,7 +267,7 @@ public class FederationMap {
setThreadID(tid, frSlices, fr);
List<Future<FederatedResponse>> ret = new ArrayList<>();
int pos = 0;
- for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
+ for(Pair<FederatedRange, FederatedData> e : _fedMap)
ret.add(e.getValue().executeFederatedOperation((frSlices != null) ? addAll(frSlices[pos++], fr) : fr));
// prepare results (future federated responses), with optional wait to ensure the
@@ -237,7 +284,7 @@ public class FederationMap {
setThreadID(tid, frSlices2, fr);
List<Future<FederatedResponse>> ret = new ArrayList<>();
int pos = 0;
- for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet()) {
+ for(Pair<FederatedRange, FederatedData> e : _fedMap) {
if(Arrays.asList(fedRange1).contains(e.getKey())) {
FederatedRequest[] newFr = (frSlices1 != null) ? ((frSlices2 != null) ? (addAll(frSlices2[pos],
addAll(frSlices1[pos++], fr))) : addAll(frSlices1[pos++], fr)) : fr;
@@ -254,7 +301,9 @@ public class FederationMap {
}
public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRequest[] frSlices1, FederatedRequest[] frSlices2, FederatedRequest... fr) {
- return execute(tid, wait, Arrays.stream(_fedMap.keySet().toArray()).toArray(FederatedRange[]::new), null, frSlices1, frSlices2, fr);
+ return execute(tid, wait,
+ _fedMap.stream().map(e->e.getKey()).toArray(FederatedRange[]::new),
+ null, frSlices1, frSlices2, fr);
}
@SuppressWarnings("unchecked")
@@ -265,7 +314,7 @@ public class FederationMap {
setThreadID(tid, allSlices, fr);
List<Future<FederatedResponse>> ret = new ArrayList<>();
int pos = 0;
- for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet()) {
+ for(Pair<FederatedRange, FederatedData> e : _fedMap) {
FederatedRequest[] fedReq = fr;
for(FederatedRequest[] slice : frSlices)
fedReq = addAll(slice[pos], fedReq);
@@ -286,7 +335,7 @@ public class FederationMap {
List<Pair<FederatedRange, Future<FederatedResponse>>> readResponses = new ArrayList<>();
FederatedRequest request = new FederatedRequest(RequestType.GET_VAR, _ID);
- for(Map.Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
+ for(Pair<FederatedRange, FederatedData> e : _fedMap)
readResponses.add(new ImmutablePair<>(e.getKey(), e.getValue().executeFederatedOperation(request)));
return readResponses;
}
@@ -303,8 +352,8 @@ public class FederationMap {
VariableCPInstruction.prepareRemoveInstruction(id).toString());
request.setTID(tid);
List<Future<FederatedResponse>> tmp = new ArrayList<>();
- for(FederatedData fd : _fedMap.values())
- tmp.add(fd.executeFederatedOperation(request));
+ for(Pair<FederatedRange, FederatedData> fd : _fedMap)
+ tmp.add(fd.getValue().executeFederatedOperation(request));
// This cleaning is allowed to go in a separate thread, and finish on its own.
// The benefit is that the program is able to continue working on other things.
// The downside is that at the end of execution these threads can have executed
@@ -348,40 +397,80 @@ public class FederationMap {
return copyFederationMap;
}
+ /**
+ * Copy the federation map with the next available federated ID as reference to the federated data.
+ * This means that the federated map refers to the next federated data object on the workers.
+ * @return copied federation map with next federated ID
+ */
public FederationMap copyWithNewID() {
return copyWithNewID(FederationUtils.getNextFedDataID());
}
+ /**
+ * Copy the federation map with the given ID as reference to the federated data.
+ * This means that the federated map refers to the federated data object on the workers with the given ID.
+ * @param id federated data object ID
+ * @return copied federation map with given federated ID
+ */
public FederationMap copyWithNewID(long id) {
- Map<FederatedRange, FederatedData> map = new TreeMap<>();
+ List<Pair<FederatedRange, FederatedData>> map = new ArrayList<>();
// TODO handling of file path, but no danger as never written
- for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet()) {
+ for(Entry<FederatedRange, FederatedData> e : _fedMap) {
if(e.getKey().getSize() != 0)
- map.put(new FederatedRange(e.getKey()), e.getValue().copyWithNewID(id));
+ map.add(Pair.of(new FederatedRange(e.getKey()), e.getValue().copyWithNewID(id)));
}
return new FederationMap(id, map, _type);
}
+ /**
+ * Copy the federation map with the given ID as reference to the federated data
+ * and with given clen as end dimension for the columns in the range.
+ * This means that the federated map refers to the federated data object on the workers with the given ID.
+ * @param id federated data object ID
+ * @param clen column length of data objects on federated workers
+ * @return copied federation map with given federated ID and ranges adapted according to clen
+ */
public FederationMap copyWithNewID(long id, long clen) {
- Map<FederatedRange, FederatedData> map = new TreeMap<>();
+ List<Pair<FederatedRange, FederatedData>> map = new ArrayList<>();
// TODO handling of file path, but no danger as never written
- for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
- map.put(new FederatedRange(e.getKey(), clen), e.getValue().copyWithNewID(id));
+ for(Pair<FederatedRange, FederatedData> e : _fedMap)
+ map.add(Pair.of(new FederatedRange(e.getKey(), clen), e.getValue().copyWithNewID(id)));
return new FederationMap(id, map, _type);
}
+ /**
+ * Copy federated mapping while giving the federated data new IDs
+ * and setting the ranges from zero to row and column ends specified.
+ * The overlapping ranges are given an overlap number to separate the ranges when putting to the federated map.
+ * The federation map returned is of type FType.PART.
+ * @param rowRangeEnd end of range for the rows
+ * @param colRangeEnd end of range for the columns
+ * @param outputID ID given to the output
+ * @return new federation map with overlapping ranges with partially aggregated values
+ */
+ public FederationMap copyWithNewIDAndRange(long rowRangeEnd, long colRangeEnd, long outputID){
+ List<Pair<FederatedRange, FederatedData>> outputMap = new ArrayList<>();
+ for(Pair<FederatedRange, FederatedData> e : _fedMap) {
+ if(e.getKey().getSize() != 0)
+ outputMap.add(Pair.of(
+ new FederatedRange(new long[]{0,0}, new long[]{rowRangeEnd, colRangeEnd}),
+ e.getValue().copyWithNewID(outputID)));
+ }
+ return new FederationMap(outputID, outputMap, FType.PART);
+ }
+
public FederationMap bind(long rOffset, long cOffset, FederationMap that) {
- for(Entry<FederatedRange, FederatedData> e : that._fedMap.entrySet()) {
- _fedMap.put(new FederatedRange(e.getKey()).shift(rOffset, cOffset), e.getValue().copyWithNewID(_ID));
+ for(Entry<FederatedRange, FederatedData> e : that._fedMap) {
+ _fedMap.add(Pair.of(new FederatedRange(e.getKey()).shift(rOffset, cOffset), e.getValue().copyWithNewID(_ID)));
}
return this;
}
public FederationMap transpose() {
- Map<FederatedRange, FederatedData> tmp = new TreeMap<>(_fedMap);
+ List<Pair<FederatedRange, FederatedData>> tmp = new ArrayList<>(_fedMap);
_fedMap.clear();
- for(Entry<FederatedRange, FederatedData> e : tmp.entrySet()) {
- _fedMap.put(new FederatedRange(e.getKey()).transpose(), e.getValue().copyWithNewID(_ID));
+ for(Pair<FederatedRange, FederatedData> e : tmp) {
+ _fedMap.add(Pair.of(new FederatedRange(e.getKey()).transpose(), e.getValue().copyWithNewID(_ID)));
}
// derive output type
switch(_type) {
@@ -394,6 +483,8 @@ public class FederationMap {
case COL:
_type = FType.ROW;
break;
+ case PART:
+ _type = FType.PART;
default:
_type = FType.OTHER;
}
@@ -401,7 +492,7 @@ public class FederationMap {
}
public long getMaxIndexInRange(int dim) {
- return _fedMap.keySet().stream().mapToLong(range -> range.getEndDims()[dim]).max().orElse(-1L);
+ return _fedMap.stream().mapToLong(range -> range.getKey().getEndDims()[dim]).max().orElse(-1L);
}
/**
@@ -415,7 +506,7 @@ public class FederationMap {
ExecutorService pool = CommonThreadPool.get(_fedMap.size());
ArrayList<MappingTask> mappingTasks = new ArrayList<>();
- for(Map.Entry<FederatedRange, FederatedData> fedMap : _fedMap.entrySet())
+ for(Pair<FederatedRange, FederatedData> fedMap : _fedMap)
mappingTasks.add(new MappingTask(fedMap.getKey(), fedMap.getValue(), forEachFunction, _ID));
CommonThreadPool.invokeAndShutdown(pool, mappingTasks);
}
@@ -435,7 +526,7 @@ public class FederationMap {
FederationMap fedMapCopy = copyWithNewID(_ID);
ArrayList<MappingTask> mappingTasks = new ArrayList<>();
- for(Map.Entry<FederatedRange, FederatedData> fedMap : fedMapCopy._fedMap.entrySet())
+ for(Pair<FederatedRange, FederatedData> fedMap : fedMapCopy._fedMap)
mappingTasks.add(new MappingTask(fedMap.getKey(), fedMap.getValue(), mappingFunction, newVarID));
CommonThreadPool.invokeAndShutdown(pool, mappingTasks);
fedMapCopy._ID = newVarID;
@@ -445,7 +536,7 @@ public class FederationMap {
public FederationMap filter(IndexRange ixrange) {
FederationMap ret = this.clone(); // same ID
- Iterator<Entry<FederatedRange, FederatedData>> iter = ret._fedMap.entrySet().iterator();
+ Iterator<Pair<FederatedRange, FederatedData>> iter = ret._fedMap.iterator();
while(iter.hasNext()) {
Entry<FederatedRange, FederatedData> e = iter.next();
FederatedRange range = e.getKey();
@@ -466,19 +557,20 @@ public class FederationMap {
}
public void reverseFedMap() {
+ // TODO perf
// TODO: add a check if the map is sorted based on indexes before reversing.
// TODO: add a setup such that on construction the federated map is already sorted.
- FederatedRange[] fedRanges = this.getFederatedRanges();
+ FederatedRange[] fedRanges = getFederatedRanges();
for(int i = 0; i < Math.floor(fedRanges.length / 2.0); i++) {
- FederatedData data1 = _fedMap.get(fedRanges[i]);
- FederatedData data2 = _fedMap.get(fedRanges[fedRanges.length-1-i]);
+ FederatedData data1 = getFederatedData(fedRanges[i]);
+ FederatedData data2 = getFederatedData(fedRanges[fedRanges.length-1-i]);
- _fedMap.remove(fedRanges[i]);
- _fedMap.remove(fedRanges[fedRanges.length-1-i]);
+ removeFederatedData(fedRanges[i]);
+ removeFederatedData(fedRanges[fedRanges.length-1-i]);
- _fedMap.put(fedRanges[i], data2);
- _fedMap.put(fedRanges[fedRanges.length-1-i], data1);
+ _fedMap.add(Pair.of(fedRanges[i], data2));
+ _fedMap.add(Pair.of(fedRanges[fedRanges.length-1-i], data1));
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index 94fe0bd..ff91c35 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -19,13 +19,13 @@
package org.apache.sysds.runtime.controlprogram.federated;
+import java.util.ArrayList;
import java.util.Arrays;
-import java.util.HashMap;
import java.util.List;
-import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Future;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.log4j.Logger;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.lops.Lop;
@@ -65,9 +65,10 @@ public class FederationUtils {
return _idSeq.getNextID();
}
- public static FederatedRequest callInstruction(String inst, CPOperand varOldOut, CPOperand[] varOldIn, long[] varNewIn, boolean federatedOutput){
+ //TODO remove rmFedOutFlag, once all federated instructions have this flag, then unconditionally remove
+ public static FederatedRequest callInstruction(String inst, CPOperand varOldOut, CPOperand[] varOldIn, long[] varNewIn, boolean rmFedOutFlag){
long id = getNextFedDataID();
- String linst = InstructionUtils.instructionStringFEDPrepare(inst, varOldOut, id, varOldIn, varNewIn, federatedOutput);
+ String linst = InstructionUtils.instructionStringFEDPrepare(inst, varOldOut, id, varOldIn, varNewIn, rmFedOutFlag);
return new FederatedRequest(RequestType.EXEC_INST, id, linst);
}
@@ -460,9 +461,48 @@ public class FederationUtils {
public static FederationMap federateLocalData(CacheableData<?> data) {
long id = FederationUtils.getNextFedDataID();
FederatedLocalData federatedLocalData = new FederatedLocalData(id, data);
- Map<FederatedRange, FederatedData> fedMap = new HashMap<>();
- fedMap.put(new FederatedRange(new long[2], new long[] {data.getNumRows(), data.getNumColumns()}),
- federatedLocalData);
+ List<Pair<FederatedRange, FederatedData>> fedMap = new ArrayList<>();
+ fedMap.add(Pair.of(
+ new FederatedRange(new long[2], new long[] {data.getNumRows(), data.getNumColumns()}),
+ federatedLocalData));
return new FederationMap(id, fedMap);
}
+
+ /**
+ * Bind data from federated workers based on non-overlapping federated ranges.
+ * @param readResponses responses from federated workers containing the federated ranges and data
+ * @param dims dimensions of output MatrixBlock
+ * @return MatrixBlock of consolidated data
+ * @throws Exception in case of problems with getting data from responses
+ */
+ public static MatrixBlock bindResponses(List<Pair<FederatedRange, Future<FederatedResponse>>> readResponses, long[] dims)
+ throws Exception
+ {
+ MatrixBlock ret = new MatrixBlock((int) dims[0], (int) dims[1], false);
+ for(Pair<FederatedRange, Future<FederatedResponse>> readResponse : readResponses) {
+ FederatedRange range = readResponse.getLeft();
+ FederatedResponse response = readResponse.getRight().get();
+ // add result
+ int[] beginDimsInt = range.getBeginDimsInt();
+ int[] endDimsInt = range.getEndDimsInt();
+ MatrixBlock multRes = (MatrixBlock) response.getData()[0];
+ ret.copy(beginDimsInt[0], endDimsInt[0] - 1, beginDimsInt[1], endDimsInt[1] - 1, multRes, false);
+ ret.setNonZeros(ret.getNonZeros() + multRes.getNonZeros());
+ }
+ return ret;
+ }
+
+ /**
+ * Aggregate partially aggregated data from federated workers
+ * by adding values with the same index in different federated locations.
+ * @param readResponses responses from federated workers containing the federated data
+ * @return MatrixBlock of consolidated, aggregated data
+ */
+ @SuppressWarnings("unchecked")
+ public static MatrixBlock aggregateResponses(List<Pair<FederatedRange, Future<FederatedResponse>>> readResponses) {
+ List<Future<FederatedResponse>> dataParts = new ArrayList<>();
+ for ( Pair<FederatedRange, Future<FederatedResponse>> readResponse : readResponses )
+ dataParts.add(readResponse.getValue());
+ return FederationUtils.aggAdd(dataParts.toArray(new Future[0]));
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 79ce52c..d286c12 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -104,8 +104,8 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
incWorkerNumber();
// prepare features and labels
- _featuresData = (FederatedData) _features.getFedMapping().getMap().values().toArray()[0];
- _labelsData = (FederatedData) _labels.getFedMapping().getMap().values().toArray()[0];
+ _featuresData = _features.getFedMapping().getFederatedData()[0];
+ _labelsData = _labels.getFedMapping().getFederatedData()[0];
// weighting factor is always set, but only used when weighting is specified
_weightingFactor = weightingFactor;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
index 9c90767..02b81d0 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
@@ -58,8 +58,8 @@ public class BalanceToAvgFederatedScheme extends DataPartitionFederatedScheme {
for(int i = 0; i < pFeatures.size(); i++) {
// Works, because the map contains a single entry
- FederatedData featuresData = (FederatedData) pFeatures.get(i).getFedMapping().getMap().values().toArray()[0];
- FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
+ FederatedData featuresData = (FederatedData) pFeatures.get(i).getFedMapping().getFederatedData()[0];
+ FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getFederatedData()[0];
Future<FederatedResponse> udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
featuresData.getVarID(), new balanceDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed, average_num_rows)));
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
index c6429b4..e9bec6c 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
@@ -19,6 +19,7 @@
package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -35,7 +36,6 @@ import org.apache.sysds.runtime.meta.MetaDataFormat;
import java.util.ArrayList;
import java.util.Collections;
-import java.util.HashMap;
import java.util.List;
public abstract class DataPartitionFederatedScheme {
@@ -88,8 +88,8 @@ public abstract class DataPartitionFederatedScheme {
);
// Create new federation map
- HashMap<FederatedRange, FederatedData> newFedHashMap = new HashMap<>();
- newFedHashMap.put(range, data);
+ List<Pair<FederatedRange, FederatedData>> newFedHashMap = new ArrayList<>();
+ newFedHashMap.add(Pair.of(range, data));
slice.setFedMapping(new FederationMap(fedMatrix.getFedMapping().getID(), newFedHashMap));
slice.getFedMapping().setType(FederationMap.FType.ROW);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
index 77b2287..7348afd 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
@@ -61,8 +61,8 @@ public class ReplicateToMaxFederatedScheme extends DataPartitionFederatedScheme
for(int i = 0; i < pFeatures.size(); i++) {
// Works, because the map contains a single entry
- FederatedData featuresData = (FederatedData) pFeatures.get(i).getFedMapping().getMap().values().toArray()[0];
- FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
+ FederatedData featuresData = pFeatures.get(i).getFedMapping().getFederatedData()[0];
+ FederatedData labelsData = pLabels.get(i).getFedMapping().getFederatedData()[0];
Future<FederatedResponse> udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
featuresData.getVarID(), new replicateDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed, max_rows)));
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
index af95270..8037611 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
@@ -55,8 +55,8 @@ public class ShuffleFederatedScheme extends DataPartitionFederatedScheme {
for(int i = 0; i < pFeatures.size(); i++) {
// Works, because the map contains a single entry
- FederatedData featuresData = (FederatedData) pFeatures.get(i).getFedMapping().getMap().values().toArray()[0];
- FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
+ FederatedData featuresData = pFeatures.get(i).getFedMapping().getFederatedData()[0];
+ FederatedData labelsData = pLabels.get(i).getFedMapping().getFederatedData()[0];
Future<FederatedResponse> udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
featuresData.getVarID(), new shuffleDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed)));
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
index 369b3dd..4e92eec 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
@@ -61,8 +61,8 @@ public class SubsampleToMinFederatedScheme extends DataPartitionFederatedScheme
for(int i = 0; i < pFeatures.size(); i++) {
// Works, because the map contains a single entry
- FederatedData featuresData = (FederatedData) pFeatures.get(i).getFedMapping().getMap().values().toArray()[0];
- FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
+ FederatedData featuresData = pFeatures.get(i).getFedMapping().getFederatedData()[0];
+ FederatedData labelsData = pLabels.get(i).getFedMapping().getFederatedData()[0];
Future<FederatedResponse> udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
featuresData.getVarID(), new subsampleDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed, min_rows)));
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index eb4b4a6..51c7851 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -1061,54 +1061,6 @@ public class InstructionUtils
return _strBuilders.get().toString();
}
- /**
- * Concat the inputs as operands to generate the base instruction string.
- * The base instruction string can subsequently be extended with the
- * concatAdditional methods. The concatenation will be done using a
- * ThreadLocal StringBuilder, so the concatenation is local to the thread.
- * When all additional operands have been appended, the complete instruction
- * string can be retrieved by calling the getInstructionString method.
- * @param inputs operand inputs given as strings
- */
- public static void concatBaseOperands(String... inputs){
- concatBaseOperandsWithDelim(Lop.OPERAND_DELIMITOR, inputs);
- }
-
- /**
- * Concat input as an additional operand to the current thread-local base instruction string.
- * @param input operand input given as string
- */
- public static void concatAdditionalOperand(String input){
- StringBuilder sb = _strBuilders.get();
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(input);
- }
-
- /**
- * Concat inputs as additional operands to the current thread-local base instruction string.
- * @param inputs operand inputs given as strings
- */
- public static void concatAdditionalOperands(String... inputs){
- concatOperandsWithDelim(Lop.OPERAND_DELIMITOR, inputs);
- }
-
- /**
- * Returns the current thread-local instruction string.
- * This instruction string is built using the concat methods.
- * @return instruction string
- */
- public static String getInstructionString(){
- return _strBuilders.get().toString();
- }
-
- private static void concatOperandsWithDelim(String delim, String... inputs){
- StringBuilder sb = _strBuilders.get();
- for( int i=0; i<inputs.length; i++ ) {
- sb.append(delim);
- sb.append(inputs[i]);
- }
- }
-
private static void concatBaseOperandsWithDelim(String delim, String... inputs){
StringBuilder sb = _strBuilders.get();
sb.setLength(0); //reuse allocated space
@@ -1149,11 +1101,12 @@ public class InstructionUtils
* @param federatedOutput federated output flag
* @return instruction string prepared for federated request
*/
- public static String instructionStringFEDPrepare(String inst, CPOperand varOldOut, long id, CPOperand[] varOldIn, long[] varNewIn, boolean federatedOutput){
+ public static String instructionStringFEDPrepare(String inst, CPOperand varOldOut, long id, CPOperand[] varOldIn, long[] varNewIn, boolean rmFederatedOutput){
String linst = replaceExecTypeWithCP(inst);
linst = replaceOutputOperand(linst, varOldOut, id);
linst = replaceInputOperand(linst, varOldIn, varNewIn);
- linst = removeFEDOutputFlag(linst, federatedOutput);
+ if(rmFederatedOutput)
+ linst = removeFEDOutputFlag(linst);
return linst;
}
@@ -1175,11 +1128,8 @@ public class InstructionUtils
return linst;
}
- private static String removeFEDOutputFlag(String linst, boolean federatedOutput){
- if ( federatedOutput ){
- linst = linst.substring(0, linst.lastIndexOf(Lop.OPERAND_DELIMITOR));
- }
- return linst;
+ private static String removeFEDOutputFlag(String linst){
+ return linst.substring(0, linst.lastIndexOf(Lop.OPERAND_DELIMITOR));
}
private static String replaceOperand(String linst, CPOperand oldOperand, String newOperandName){
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java
index 7a1c42a..6981877 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java
@@ -54,30 +54,20 @@ public class AggregateBinaryCPInstruction extends BinaryCPInstruction {
if(!opcode.equalsIgnoreCase("ba+*")) {
throw new DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown opcode " + opcode);
}
- int numFields = parts.length - 1;
- if(numFields == 4) {
- CPOperand in1 = new CPOperand(parts[1]);
- CPOperand in2 = new CPOperand(parts[2]);
- CPOperand out = new CPOperand(parts[3]);
- int k = Integer.parseInt(parts[4]);
- AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(k);
- return new AggregateBinaryCPInstruction(aggbin, in1, in2, out, opcode, str);
- }
- else if(numFields == 6) {
- CPOperand in1 = new CPOperand(parts[1]);
- CPOperand in2 = new CPOperand(parts[2]);
- CPOperand out = new CPOperand(parts[3]);
- int k = Integer.parseInt(parts[4]);
+
+ int numFields = InstructionUtils.checkNumFields(parts, 4, 6);
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand out = new CPOperand(parts[3]);
+ int k = Integer.parseInt(parts[4]);
+ AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(k);
+ if ( numFields == 6 ){
boolean isLeftTransposed = Boolean.parseBoolean(parts[5]);
boolean isRightTransposed = Boolean.parseBoolean(parts[6]);
- AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(k);
return new AggregateBinaryCPInstruction(aggbin, in1, in2, out, opcode, str, isLeftTransposed,
isRightTransposed);
}
- else {
- throw new DMLRuntimeException("NumFields expected number (" + 4 + " or " + 6
- + ") != is not equal to actual number (" + numFields + ").");
- }
+ else return new AggregateBinaryCPInstruction(aggbin, in1, in2, out, opcode, str);
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 9822bef..c731ce0 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -27,6 +27,7 @@ import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -41,6 +42,11 @@ public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
CPOperand in2, CPOperand out, String opcode, String istr) {
super(FEDType.AggregateBinary, op, in1, in2, out, opcode, istr);
}
+
+ public AggregateBinaryFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out,
+ String opcode, String istr, FederatedOutput fedOut) {
+ super(FEDType.AggregateBinary, op, in1, in2, out, opcode, istr, fedOut);
+ }
public static AggregateBinaryFEDInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
@@ -48,13 +54,14 @@ public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
if(!opcode.equalsIgnoreCase("ba+*"))
throw new DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown opcode " + opcode);
- InstructionUtils.checkNumFields(parts, 4);
+ InstructionUtils.checkNumFields(parts, 5);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
int k = Integer.parseInt(parts[4]);
+ FederatedOutput fedOut = FederatedOutput.valueOf(parts[5]);
return new AggregateBinaryFEDInstruction(
- InstructionUtils.getMatMultOperator(k), in1, in2, out, opcode, str);
+ InstructionUtils.getMatMultOperator(k), in1, in2, out, opcode, str, fedOut);
}
@Override
@@ -62,39 +69,77 @@ public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
MatrixObject mo1 = ec.getMatrixObject(input1);
MatrixObject mo2 = ec.getMatrixObject(input2);
+ //TODO cleanup unnecessary redundancy
//#1 federated matrix-vector multiplication
if(mo1.isFederated(FType.COL) && mo2.isFederated(FType.ROW)
&& mo1.getFedMapping().isAligned(mo2.getFedMapping(), true) ) {
FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2},
- new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()});
- FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
- FederatedRequest fr3 = mo2.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
- //execute federated operations and aggregate
- Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
- MatrixBlock ret = FederationUtils.aggAdd(tmp);
- ec.setMatrixOutput(output.getName(), ret);
+ new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()}, true);
+
+ if ( _fedOut.isForcedFederated() ){
+ mo1.getFedMapping().execute(getTID(), fr1);
+ setPartialOutput(mo1.getFedMapping(), mo1, mo2, fr1.getID(), ec);
+ }
+ else {
+ FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
+ FederatedRequest fr3 = mo2.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+ MatrixBlock ret = FederationUtils.aggAdd(tmp);
+ ec.setMatrixOutput(output.getName(), ret);
+ }
}
- else if(mo1.isFederated(FType.ROW)) { // MV + MM
+ else if(mo1.isFederated(FType.ROW) || mo1.isFederated(FType.PART)) { // MV + MM
//construct commands: broadcast rhs, fed mv, retrieve results
FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
- new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()});
+ new CPOperand[]{input1, input2},
+ new long[]{mo1.getFedMapping().getID(), fr1.getID()}, true);
if( mo2.getNumColumns() == 1 && mo2.getNumRows() != mo1.getNumColumns()) { //MV
- FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
- FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
- //execute federated operations and aggregate
- Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
- MatrixBlock ret = FederationUtils.bind(tmp, false);
- ec.setMatrixOutput(output.getName(), ret);
+ if ( _fedOut.isForcedFederated() ){
+ FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1.getID());
+ mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+ if ( mo1.isFederated(FType.PART) )
+ setPartialOutput(mo1.getFedMapping(), mo1, mo2, fr2.getID(), ec);
+ else
+ setOutputFedMapping(mo1.getFedMapping(), mo1, mo2, fr2.getID(), ec);
+ }
+ else {
+ FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
+ MatrixBlock ret;
+ if ( mo1.isFederated(FType.PART) )
+ ret = FederationUtils.aggAdd(tmp);
+ else
+ ret = FederationUtils.bind(tmp, false);
+ ec.setMatrixOutput(output.getName(), ret);
+ }
}
else { //MM
//execute federated operations and aggregate
- FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1.getID());
- mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
- MatrixObject out = ec.getMatrixObject(output);
- out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), (int)mo1.getBlocksize());
- out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID(), mo2.getNumColumns()));
+ if ( !_fedOut.isForcedLocal() ){
+ FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1.getID());
+ mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+ if ( mo1.isFederated(FType.PART) || mo2.isFederated(FType.PART) )
+ setPartialOutput(mo1.getFedMapping(), mo1, mo2, fr2.getID(), ec);
+ else
+ setOutputFedMapping(mo1.getFedMapping(), mo1, mo2, fr2.getID(), ec);
+ }
+ else {
+ FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
+ MatrixBlock ret;
+ if ( mo1.isFederated(FType.PART) )
+ ret = FederationUtils.aggAdd(tmp);
+ else
+ ret = FederationUtils.bind(tmp, false);
+ ec.setMatrixOutput(output.getName(), ret);
+ }
}
}
//#2 vector - federated matrix multiplication
@@ -102,26 +147,44 @@ public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
//construct commands: broadcast rhs, fed mv, retrieve results
FederatedRequest[] fr1 = mo2.getFedMapping().broadcastSliced(mo1, true);
FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
- new CPOperand[]{input1, input2}, new long[]{fr1[0].getID(), mo2.getFedMapping().getID()});
- FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
- FederatedRequest fr4 = mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
- //execute federated operations and aggregate
- Future<FederatedResponse>[] tmp = mo2.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
- MatrixBlock ret = FederationUtils.aggAdd(tmp);
- ec.setMatrixOutput(output.getName(), ret);
+ new CPOperand[]{input1, input2},
+ new long[]{fr1[0].getID(), mo2.getFedMapping().getID()}, true);
+ if ( _fedOut.isForcedFederated() ){
+ // Partial aggregates (set fedmapping to the partial aggs)
+ FederatedRequest fr3 = mo2.getFedMapping().cleanup(getTID(), fr1[0].getID());
+ mo2.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+ setPartialOutput(mo2.getFedMapping(), mo1, mo2, fr2.getID(), ec);
+ }
+ else {
+ FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 = mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo2.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
+ MatrixBlock ret = FederationUtils.aggAdd(tmp);
+ ec.setMatrixOutput(output.getName(), ret);
+ }
}
//#3 col-federated matrix vector multiplication
else if (mo1.isFederated(FType.COL)) {// VM + MM
//construct commands: broadcast rhs, fed mv, retrieve results
FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, true);
FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
- new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1[0].getID()});
- FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
- FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
- //execute federated operations and aggregate
- Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
- MatrixBlock ret = FederationUtils.aggAdd(tmp);
- ec.setMatrixOutput(output.getName(), ret);
+ new CPOperand[]{input1, input2},
+ new long[]{mo1.getFedMapping().getID(), fr1[0].getID()}, true);
+ if ( _fedOut.isForcedFederated() ){
+ // Partial aggregates (set fedmapping to the partial aggs)
+ FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+ mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+ setPartialOutput(mo1.getFedMapping(), mo1, mo2, fr2.getID(), ec);
+ }
+ else {
+ FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
+ MatrixBlock ret = FederationUtils.aggAdd(tmp);
+ ec.setMatrixOutput(output.getName(), ret);
+ }
}
else { //other combinations
throw new DMLRuntimeException("Federated AggregateBinary not supported with the "
@@ -129,4 +192,36 @@ public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
+" "+mo2.isFederated()+":"+mo2.getFedMapping());
}
}
+
+ /**
+ * Sets the output with a federated mapping of overlapping partial aggregates.
+ * @param federationMap federated map from which the federated metadata is retrieved
+ * @param mo1 matrix object with number of rows used to set the number of rows of the output
+ * @param mo2 matrix object with number of columns used to set the number of columns of the output
+ * @param outputID ID of the output
+ * @param ec execution context
+ */
+ private void setPartialOutput(FederationMap federationMap, MatrixObject mo1, MatrixObject mo2,
+ long outputID, ExecutionContext ec){
+ MatrixObject out = ec.getMatrixObject(output);
+ out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), (int)mo1.getBlocksize());
+ FederationMap outputFedMap = federationMap
+ .copyWithNewIDAndRange(mo1.getNumRows(), mo2.getNumColumns(), outputID);
+ out.setFedMapping(outputFedMap);
+ }
+
+ /**
+ * Sets the output with a federated map copied from federationMap input given.
+ * @param federationMap federation map to be set in output
+ * @param mo1 matrix object with number of rows used to set the number of rows of the output
+ * @param mo2 matrix object with number of columns used to set the number of columns of the output
+ * @param outputID ID of the output
+ * @param ec execution context
+ */
+ private void setOutputFedMapping(FederationMap federationMap, MatrixObject mo1, MatrixObject mo2,
+ long outputID, ExecutionContext ec){
+ MatrixObject out = ec.getMatrixObject(output);
+ out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), (int)mo1.getBlocksize());
+ out.setFedMapping(federationMap.copyWithNewID(outputID, mo2.getNumColumns()));
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index da68d07..1a9f9f9 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -38,15 +38,15 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
private AggregateUnaryFEDInstruction(AggregateUnaryOperator auop,
- CPOperand in, CPOperand out, String opcode, String istr, boolean federatedOutput)
+ CPOperand in, CPOperand out, String opcode, String istr, FederatedOutput fedOut)
{
- super(FEDType.AggregateUnary, auop, in, out, opcode, istr, federatedOutput);
+ super(FEDType.AggregateUnary, auop, in, out, opcode, istr, fedOut);
}
protected AggregateUnaryFEDInstruction(Operator op,
- CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, boolean federatedOutput)
+ CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, FederatedOutput fedOut)
{
- super(FEDType.AggregateUnary, op, in1, in2, out, opcode, istr, federatedOutput);
+ super(FEDType.AggregateUnary, op, in1, in2, out, opcode, istr, fedOut);
}
protected AggregateUnaryFEDInstruction(Operator op,
@@ -76,12 +76,12 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
if(InstructionUtils.getExecType(str) == ExecType.SPARK)
str = InstructionUtils.replaceOperand(str, 4, "-1");
- boolean federatedOutput = false;
- if ( parts.length > 6 )
- federatedOutput = Boolean.parseBoolean(parts[5]);
- else if ( parts.length == 5 && !parts[4].equals("uarimin") && !parts[4].equals("uarimax") )
- federatedOutput = Boolean.parseBoolean(parts[4]);
- return new AggregateUnaryFEDInstruction(aggun, in1, out, opcode, str, federatedOutput);
+ FederatedOutput fedOut = null;
+ if ( parts.length == 5 && !parts[4].equals("uarimin") && !parts[4].equals("uarimax") )
+ fedOut = FederatedOutput.valueOf(parts[4]);
+ else
+ fedOut = FederatedOutput.valueOf(parts[5]);
+ return new AggregateUnaryFEDInstruction(aggun, in1, out, opcode, str, fedOut);
}
@Override
@@ -102,9 +102,9 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
if((instOpcode.equalsIgnoreCase("uarimax") || instOpcode.equalsIgnoreCase("uarimin")) && in.isFederated(FederationMap.FType.COL))
instString = InstructionUtils.replaceOperand(instString, 5, "2");
- //create federated commands for aggregation
-
- if ( _federatedOutput )
+ // create federated commands for aggregation
+ // (by default obtain output, even though unnecessary row aggregates)
+ if ( _fedOut.isForcedFederated() )
processFederatedOutput(map, in, ec);
else
processGetOutput(map, aop, ec, in);
@@ -121,7 +121,7 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
throw new DMLRuntimeException("Output of FED instruction, " + output.toString()
+ ", is a scalar and the output is set to be federated. Scalars cannot be federated. ");
FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
- new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()}, _federatedOutput);
+ new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()}, true);
map.execute(getTID(), fr1);
// derive new fed mapping for output
@@ -138,7 +138,7 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
*/
private void processGetOutput(FederationMap map, AggregateUnaryOperator aggUOptr, ExecutionContext ec, MatrixObject in){
FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
- new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()});
+ new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()}, true);
FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
FederatedRequest fr3 = map.cleanup(getTID(), fr1.getID());
@@ -151,7 +151,7 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
}
private void processVar(ExecutionContext ec){
- if ( _federatedOutput ){
+ if ( _fedOut.isForcedFederated() ){
throw new DMLRuntimeException("Output of " + toString() + " should not be federated "
+ "since the instruction requires consolidation of partial results to be computed.");
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
index 659281a..9c74ea3 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
@@ -31,13 +31,13 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
public abstract class BinaryFEDInstruction extends ComputationFEDInstruction {
protected BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op,
- CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, boolean federatedOutput) {
- super(type, op, in1, in2, out, opcode, istr, federatedOutput);
+ CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, FederatedOutput fedOut) {
+ super(type, op, in1, in2, out, opcode, istr, fedOut);
}
protected BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op,
CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
- this(type, op, in1, in2, out, opcode, istr, false);
+ this(type, op, in1, in2, out, opcode, istr, FederatedOutput.NONE);
}
public BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op,
@@ -57,7 +57,7 @@ public abstract class BinaryFEDInstruction extends ComputationFEDInstruction {
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
- boolean federatedOutput = parts.length > 5 && Boolean.parseBoolean(parts[5]);
+ FederatedOutput fedOut = FederatedOutput.valueOf(parts[parts.length-1]);
checkOutputDataType(in1, in2, out);
Operator operator = InstructionUtils.parseBinaryOrBuiltinOperator(opcode, in1, in2);
@@ -67,11 +67,11 @@ public abstract class BinaryFEDInstruction extends ComputationFEDInstruction {
if( in1.getDataType() == DataType.SCALAR && in2.getDataType() == DataType.SCALAR )
throw new DMLRuntimeException("Federated binary scalar scalar operations not yet supported");
else if( in1.getDataType() == DataType.MATRIX && in2.getDataType() == DataType.MATRIX )
- return new BinaryMatrixMatrixFEDInstruction(operator, in1, in2, out, opcode, str, federatedOutput);
+ return new BinaryMatrixMatrixFEDInstruction(operator, in1, in2, out, opcode, str, fedOut);
else if( in1.getDataType() == DataType.TENSOR && in2.getDataType() == DataType.TENSOR )
throw new DMLRuntimeException("Federated binary tensor tensor operations not yet supported");
else if( in1.isMatrix() && in2.isScalar() || in2.isMatrix() && in1.isScalar() )
- return new BinaryMatrixScalarFEDInstruction(operator, in1, in2, out, opcode, str, federatedOutput);
+ return new BinaryMatrixScalarFEDInstruction(operator, in1, in2, out, opcode, str, fedOut);
else
throw new DMLRuntimeException("Federated binary operations not yet supported:" + opcode);
}
@@ -89,13 +89,11 @@ public abstract class BinaryFEDInstruction extends ComputationFEDInstruction {
protected static String parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);
InstructionUtils.checkNumFields ( parts, 4 );
-
String opcode = parts[0];
in1.split(parts[1]);
in2.split(parts[2]);
in3.split(parts[3]);
out.split(parts[4]);
-
return opcode;
}
@@ -113,7 +111,6 @@ public abstract class BinaryFEDInstruction extends ComputationFEDInstruction {
inst_str = inst_str.replace(Lop.OPERAND_DELIMITOR + "RIGHT", "");
inst_str = inst_str.replace(Lop.OPERAND_DELIMITOR + VectorType.ROW_VECTOR.name(), "");
inst_str = inst_str.replace(Lop.OPERAND_DELIMITOR + VectorType.COL_VECTOR.name(), "");
-
return inst_str;
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index e67a353..db82bba 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -23,6 +23,7 @@ import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -32,8 +33,8 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
{
protected BinaryMatrixMatrixFEDInstruction(Operator op,
- CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, boolean federatedOutput) {
- super(FEDType.Binary, op, in1, in2, out, opcode, istr, federatedOutput);
+ CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, FederatedOutput fedOut) {
+ super(FEDType.Binary, op, in1, in2, out, opcode, istr, fedOut);
}
@Override
@@ -53,14 +54,16 @@ public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
FederatedRequest fr2 = null;
if( mo2.isFederated() ) {
if(mo1.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)) {
- fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
- new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()}, _federatedOutput);
+ fr2 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2},
+ new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()}, true);
mo1.getFedMapping().execute(getTID(), true, fr2);
}
else if ( !mo1.isFederated() ){
FederatedRequest[] fr1 = mo2.getFedMapping().broadcastSliced(mo1, false);
- fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
- new long[]{fr1[0].getID(), mo2.getFedMapping().getID()});
+ fr2 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2},
+ new long[]{fr1[0].getID(), mo2.getFedMapping().getID()}, true);
mo2.getFedMapping().execute(getTID(), true, fr1, fr2);
}
else {
@@ -75,7 +78,7 @@ public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
// only one partition (MM on a single fed worker)
FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
- new long[]{mo1.getFedMapping().getID(), fr1.getID()}, _federatedOutput);
+ new long[]{mo1.getFedMapping().getID(), fr1.getID()}, true);
FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1.getID());
//execute federated instruction and cleanup intermediates
mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
@@ -89,7 +92,7 @@ public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
// MV row partitioned row vector, MV col partitioned col vector
FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
- new long[]{mo1.getFedMapping().getID(), fr1.getID()}, _federatedOutput);
+ new long[]{mo1.getFedMapping().getID(), fr1.getID()}, true);
FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1.getID());
//execute federated instruction and cleanup intermediates
mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
@@ -99,17 +102,27 @@ public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
// row partitioned MM or col partitioned MM
FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
- new long[]{mo1.getFedMapping().getID(), fr1[0].getID()}, _federatedOutput);
+ new long[]{mo1.getFedMapping().getID(), fr1[0].getID()}, true);
FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
//execute federated instruction and cleanup intermediates
mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
}
+ else if ( mo1.isFederated(FType.PART) && !mo2.isFederated() ){
+ FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
+ fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
+ new long[]{mo1.getFedMapping().getID(), fr1.getID()}, true);
+ FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1.getID());
+ //execute federated instruction and cleanup intermediates
+ mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+ }
else {
throw new DMLRuntimeException("Matrix-matrix binary operations are only supported with a row partitioned or column partitioned federated input yet.");
}
}
- if ( mo1.isFederated() )
+ if ( mo1.isFederated(FType.PART) && !mo2.isFederated() )
+ setOutputFedMappingPart(mo1, mo2, fr2.getID(), ec);
+ else if ( mo1.isFederated() )
setOutputFedMapping(mo1, fr2.getID(), ec);
else if ( mo2.isFederated() )
setOutputFedMapping(mo2, fr2.getID(), ec);
@@ -117,6 +130,21 @@ public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
}
/**
+ * Sets the output with a federation map of overlapping partial aggregates with metadata copied from mo1.
+ * @param mo1 matrix object with number of rows used to set output number of rows and retrieve federated map
+ * @param mo2 matrix object with number of columns used to set output number of columns
+ * @param outputID ID of output
+ * @param ec execution context
+ */
+ private void setOutputFedMappingPart(MatrixObject mo1, MatrixObject mo2, long outputID, ExecutionContext ec){
+ MatrixObject out = ec.getMatrixObject(output);
+ out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), (int)mo1.getBlocksize());
+ FederationMap outputFedMap = mo1.getFedMapping()
+ .copyWithNewIDAndRange(mo1.getNumRows(), mo2.getNumColumns(), outputID);
+ out.setFedMapping(outputFedMap);
+ }
+
+ /**
* Set data characteristics and fed mapping for output.
* @param moFederated federated matrix object from which data characteristics and fed mapping are derived
* @param outputFedmappingID ID for the fed mapping of output
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
index 441a00b..5edcbd3 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
@@ -29,8 +29,8 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
public class BinaryMatrixScalarFEDInstruction extends BinaryFEDInstruction
{
protected BinaryMatrixScalarFEDInstruction(Operator op,
- CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, boolean federatedOutput) {
- super(FEDType.Binary, op, in1, in2, out, opcode, istr, federatedOutput);
+ CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, FederatedOutput fedOut) {
+ super(FEDType.Binary, op, in1, in2, out, opcode, istr, fedOut);
}
@Override
@@ -44,7 +44,7 @@ public class BinaryMatrixScalarFEDInstruction extends BinaryFEDInstruction
mo.getFedMapping().broadcast(ec.getScalarInput(scalar)) : null;
FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{matrix, (fr1 != null)?scalar:null},
- new long[]{mo.getFedMapping().getID(), (fr1 != null)?fr1.getID():-1}, _federatedOutput);
+ new long[]{mo.getFedMapping().getID(), (fr1 != null)?fr1.getID():-1}, true);
//execute federated matrix-scalar operation and cleanups
if( fr1 != null ) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java
index 692455c..2ce7480 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java
@@ -39,17 +39,17 @@ public abstract class ComputationFEDInstruction extends FEDInstruction implement
protected ComputationFEDInstruction(FEDType type, Operator op,
CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
- this(type, op, in1, in2, null, out, opcode, istr, false);
+ this(type, op, in1, in2, null, out, opcode, istr, FederatedOutput.NONE);
}
protected ComputationFEDInstruction(FEDType type, Operator op,
- CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, boolean federatedOutput) {
- this(type, op, in1, in2, null, out, opcode, istr, federatedOutput);
+ CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr,FederatedOutput fedOut) {
+ this(type, op, in1, in2, null, out, opcode, istr, fedOut);
}
protected ComputationFEDInstruction(FEDType type, Operator op,
- CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, boolean federatedOutput){
- super(type, op, opcode, istr, federatedOutput);
+ CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, FederatedOutput fedOut){
+ super(type, op, opcode, istr, fedOut);
input1 = in1;
input2 = in2;
input3 = in3;
@@ -58,7 +58,7 @@ public abstract class ComputationFEDInstruction extends FEDInstruction implement
protected ComputationFEDInstruction(FEDType type, Operator op,
CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr) {
- this(type, op, in1, in2, in3, out, opcode, istr, false);
+ this(type, op, in1, in2, in3, out, opcode, istr, FederatedOutput.NONE);
}
public String getOutputVariableName() {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 871bdb1..f35030f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -49,25 +49,37 @@ public abstract class FEDInstruction extends Instruction {
SpoofFused,
Unary
}
+
+ public enum FederatedOutput {
+ FOUT, // forced federated output
+ LOUT, // forced local output (consolidated in CP)
+ NONE; // runtime heuristics
+ public boolean isForcedFederated() {
+ return this == FOUT;
+ }
+ public boolean isForcedLocal() {
+ return this == LOUT;
+ }
+ }
protected final FEDType _fedType;
protected long _tid = -1; //main
- protected boolean _federatedOutput = false;
+ protected FederatedOutput _fedOut = FederatedOutput.NONE;
protected FEDInstruction(FEDType type, String opcode, String istr) {
this(type, null, opcode, istr);
}
protected FEDInstruction(FEDType type, Operator op, String opcode, String istr) {
- this(type, op, opcode, istr, false);
+ this(type, op, opcode, istr, FederatedOutput.NONE);
}
- protected FEDInstruction(FEDType type, Operator op, String opcode, String istr, boolean federatedOutput) {
+ protected FEDInstruction(FEDType type, Operator op, String opcode, String istr, FederatedOutput fedOut) {
super(op);
_fedType = type;
instString = istr;
instOpcode = opcode;
- _federatedOutput = federatedOutput;
+ _fedOut = fedOut;
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 8f22539..4b38250 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -28,6 +28,7 @@ import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.AggregateTernaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
@@ -47,6 +48,7 @@ import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryMatrixCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction.VariableOperationCode;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
@@ -87,7 +89,8 @@ public class FEDInstructionUtils {
MatrixObject mo1 = ec.getMatrixObject(instruction.input1);
MatrixObject mo2 = ec.getMatrixObject(instruction.input2);
if (mo1.isFederated(FType.ROW) || mo2.isFederated(FType.ROW) || mo1.isFederated(FType.COL)) {
- fedinst = AggregateBinaryFEDInstruction.parseInstruction(inst.getInstructionString());
+ fedinst = AggregateBinaryFEDInstruction.parseInstruction(
+ InstructionUtils.concatOperands(inst.getInstructionString(), FederatedOutput.NONE.name()));
}
}
}
@@ -111,7 +114,8 @@ public class FEDInstructionUtils {
CacheableData<?> mo = ec.getCacheableData(rinst.input1);
if((mo instanceof MatrixObject || mo instanceof FrameObject) && mo.isFederated() )
- fedinst = ReorgFEDInstruction.parseInstruction(rinst.getInstructionString());
+ fedinst = ReorgFEDInstruction.parseInstruction(
+ InstructionUtils.concatOperands(rinst.getInstructionString(),FederatedOutput.NONE.name()));
}
else if(instruction.input1 != null && instruction.input1.isMatrix()
&& ec.containsVariable(instruction.input1)) {
@@ -127,7 +131,8 @@ public class FEDInstructionUtils {
fedinst = ReshapeFEDInstruction.parseInstruction(inst.getInstructionString());
else if(inst instanceof AggregateUnaryCPInstruction && mo1.isFederated() &&
((AggregateUnaryCPInstruction) instruction).getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT)
- fedinst = AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
+ fedinst = AggregateUnaryFEDInstruction.parseInstruction(
+ InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
else if(inst instanceof UnaryMatrixCPInstruction && mo1.isFederated()) {
if(UnaryMatrixFEDInstruction.isValidOpcode(inst.getOpcode()) &&
!(inst.getOpcode().equalsIgnoreCase("ucumk+*") && mo1.isFederated(FType.COL)))
@@ -147,7 +152,8 @@ public class FEDInstructionUtils {
ec.getMatrixObject(instruction.input2).isFederated(FType.ROW)))
fedinst = CovarianceFEDInstruction.parseInstruction(inst.getInstructionString());
else
- fedinst = BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
+ fedinst = BinaryFEDInstruction.parseInstruction(
+ InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
}
}
else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
@@ -273,7 +279,8 @@ public class FEDInstructionUtils {
AggregateUnarySPInstruction instruction = (AggregateUnarySPInstruction) inst;
Data data = ec.getVariable(instruction.input1);
if(data instanceof MatrixObject && ((MatrixObject) data).isFederated())
- fedinst = AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
+ fedinst = AggregateUnaryFEDInstruction.parseInstruction(
+ InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
}
}
else if(inst instanceof BinarySPInstruction) {
@@ -307,7 +314,8 @@ public class FEDInstructionUtils {
Data data = ec.getVariable(instruction.input1);
if((data instanceof MatrixObject && ((MatrixObject)data).isFederated())
|| (data instanceof TensorObject && ((TensorObject)data).isFederated())) {
- fedinst = BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
+ fedinst = BinaryFEDInstruction.parseInstruction(
+ InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
}
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
index 90f93de..ef0223d 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
@@ -35,6 +35,7 @@ import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
@@ -144,10 +145,10 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
//create new frame schema
List<Types.ValueType> schema = new ArrayList<>();
-
// replace old reshape values for each worker
int i = 0;
- for(FederatedRange range : fedMap.getMap().keySet()) {
+ for(org.apache.commons.lang3.tuple.Pair<FederatedRange, FederatedData> e : fedMap.getMap()) {
+ FederatedRange range = e.getKey();
long rs = range.getBeginDims()[0], re = range.getEndDims()[0],
cs = range.getBeginDims()[1], ce = range.getEndDims()[1];
long rsn = (ixrange.rowStart >= rs) ? (ixrange.rowStart - rs) : 0;
@@ -219,7 +220,8 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
// replace old reshape values for each worker
int i = 0, prev = 0, from = fedMap.getSize();
- for(FederatedRange range : fedMap.getMap().keySet()) {
+ for(org.apache.commons.lang3.tuple.Pair<FederatedRange, FederatedData> e : fedMap.getMap()) {
+ FederatedRange range = e.getKey();
long rs = range.getBeginDims()[0], re = range.getEndDims()[0],
cs = range.getBeginDims()[1], ce = range.getEndDims()[1];
long rsn = (ixrange.rowStart >= rs) ? (ixrange.rowStart - rs) : 0;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index bc16149..9b6d3f0 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -27,8 +27,6 @@ import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
-import java.util.Map;
-import java.util.TreeMap;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
@@ -214,15 +212,14 @@ public class InitFEDInstruction extends FEDInstruction implements LineageTraceab
public static void federateMatrix(CacheableData<?> output, List<Pair<FederatedRange, FederatedData>> workers) {
- Map<FederatedRange, FederatedData> fedMapping = new TreeMap<>();
- for(Pair<FederatedRange, FederatedData> t : workers) {
- fedMapping.put(t.getLeft(), t.getRight());
- }
+ List<Pair<FederatedRange, FederatedData>> fedMapping = new ArrayList<>();
+ for(Pair<FederatedRange, FederatedData> e : workers)
+ fedMapping.add(e);
List<Pair<FederatedData, Future<FederatedResponse>>> idResponses = new ArrayList<>();
long id = FederationUtils.getNextFedDataID();
boolean rowPartitioned = true;
boolean colPartitioned = true;
- for(Map.Entry<FederatedRange, FederatedData> entry : fedMapping.entrySet()) {
+ for(Pair<FederatedRange, FederatedData> entry : fedMapping) {
FederatedRange range = entry.getKey();
FederatedData value = entry.getValue();
if(!value.isInitialized()) {
@@ -268,10 +265,9 @@ public class InitFEDInstruction extends FEDInstruction implements LineageTraceab
}
public static void federateFrame(FrameObject output, List<Pair<FederatedRange, FederatedData>> workers) {
- Map<FederatedRange, FederatedData> fedMapping = new TreeMap<>();
- for(Pair<FederatedRange, FederatedData> t : workers) {
- fedMapping.put(t.getLeft(), t.getRight());
- }
+ List<Pair<FederatedRange, FederatedData>> fedMapping = new ArrayList<>();
+ for(Pair<FederatedRange, FederatedData> e : workers)
+ fedMapping.add(e);
// we want to wait for the futures with the response containing varIDs and the schemas of the frames
// on the distributed workers. We need the FederatedData, the starting column of the sub frame (for the schema)
// and the future for the response
@@ -279,7 +275,7 @@ public class InitFEDInstruction extends FEDInstruction implements LineageTraceab
long id = FederationUtils.getNextFedDataID();
boolean rowPartitioned = true;
boolean colPartitioned = true;
- for(Map.Entry<FederatedRange, FederatedData> entry : fedMapping.entrySet()) {
+ for(Pair<FederatedRange, FederatedData> entry : fedMapping) {
FederatedRange range = entry.getKey();
FederatedData value = entry.getValue();
if(!value.isInitialized()) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
index 04b50ac..f984967 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
@@ -51,14 +51,14 @@ public class QuantilePickFEDInstruction extends BinaryFEDInstruction {
}
private QuantilePickFEDInstruction(Operator op, CPOperand in, CPOperand in2, CPOperand out, OperationTypes type,
- boolean inmem, String opcode, String istr, boolean federatedOutput) {
- super(FEDType.QPick, op, in, in2, out, opcode, istr, federatedOutput);
+ boolean inmem, String opcode, String istr, FederatedOutput fedOut) {
+ super(FEDType.QPick, op, in, in2, out, opcode, istr, fedOut);
_type = type;
}
private QuantilePickFEDInstruction(Operator op, CPOperand in, CPOperand in2, CPOperand out, OperationTypes type,
boolean inmem, String opcode, String istr) {
- this(op, in, in2, out, type, inmem, opcode, istr, false);
+ this(op, in, in2, out, type, inmem, opcode, istr, FederatedOutput.NONE);
}
public static QuantilePickFEDInstruction parseInstruction ( String str ) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
index c1aa37d..0a545bc 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
@@ -76,34 +76,6 @@ public class QuantileSortFEDInstruction extends UnaryFEDInstruction{
}
}
-
-// @Override
-// public void processInstruction(ExecutionContext ec) {
-// MatrixObject in = ec.getMatrixObject(input1.getName());
-// FederationMap map = in.getFedMapping();
-//
-// //create federated commands for aggregation
-// FederatedRequest fr1 = FederationUtils
-// .callInstruction(instString, output, new CPOperand[] {input1}, new long[] {in.getFedMapping().getID()});
-// FederatedRequest fr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
-// FederatedRequest fr3 = map.cleanup(getTID(), fr1.getID());
-//
-// Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, fr2, fr3);
-//
-// try {
-// Object d = tmp[0].get().getData()[0];
-// System.out.println(1);
-// }
-// catch(Exception e) {
-// e.printStackTrace();
-// }
-//
-// MatrixObject out = ec.getMatrixObject(output);
-// out.getDataCharacteristics().set(in.getDataCharacteristics());
-// out.setFedMapping(in.getFedMapping().copyWithNewID(fr2.getID()));
-// }
-
-
@Override
public void processInstruction(ExecutionContext ec) {
MatrixObject in = ec.getMatrixObject(input1);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index f4999f8..19847a2 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -53,8 +53,8 @@ import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
public class ReorgFEDInstruction extends UnaryFEDInstruction {
- public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opcode, String istr, boolean federatedOutput) {
- super(FEDType.Reorg, op, in1, out, opcode, istr, federatedOutput);
+ public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opcode, String istr, FederatedOutput fedOut) {
+ super(FEDType.Reorg, op, in1, out, opcode, istr, fedOut);
}
public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opcode, String istr) {
@@ -72,8 +72,8 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
in.split(parts[1]);
out.split(parts[2]);
int k = Integer.parseInt(parts[3]);
- boolean federatedOutput = parts.length > 4 && Boolean.parseBoolean(parts[4]);
- return new ReorgFEDInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str, federatedOutput);
+ FederatedOutput fedOut = FederatedOutput.valueOf(parts[4]);
+ return new ReorgFEDInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str, fedOut);
}
else if ( opcode.equalsIgnoreCase("rdiag") ) {
parseUnaryInstruction(str, in, out); //max 2 operands
@@ -100,9 +100,8 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
if(instOpcode.equals("r'")) {
//execute transpose at federated site
FederatedRequest fr1 = FederationUtils.callInstruction(instString,
- output,
- new CPOperand[] {input1},
- new long[] {mo1.getFedMapping().getID()}, _federatedOutput);
+ output, new CPOperand[] {input1},
+ new long[] {mo1.getFedMapping().getID()}, true);
mo1.getFedMapping().execute(getTID(), true, fr1);
//drive output federated mapping
@@ -113,9 +112,8 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
else if(instOpcode.equalsIgnoreCase("rev")) {
//execute transpose at federated site
FederatedRequest fr1 = FederationUtils.callInstruction(instString,
- output,
- new CPOperand[] {input1},
- new long[] {mo1.getFedMapping().getID()});
+ output, new CPOperand[] {input1},
+ new long[] {mo1.getFedMapping().getID()}, true);
mo1.getFedMapping().execute(getTID(), true, fr1);
if(mo1.isFederated(FederationMap.FType.ROW))
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
index 2805107..c7dd8b6 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
@@ -37,8 +37,8 @@ import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
public class TernaryFEDInstruction extends ComputationFEDInstruction {
private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out,
- String opcode, String str, boolean federatedOutput) {
- super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str, federatedOutput);
+ String opcode, String str, FederatedOutput fedOut) {
+ super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str, fedOut);
}
public static TernaryFEDInstruction parseInstruction(String str) {
@@ -49,9 +49,9 @@ public class TernaryFEDInstruction extends ComputationFEDInstruction {
CPOperand operand3 = new CPOperand(parts[3]);
CPOperand outOperand = new CPOperand(parts[4]);
int numThreads = parts.length>5 ? Integer.parseInt(parts[5]) : 1;
- boolean federatedOutput = parts.length > 6 && parts[6].equals("true");
+ FederatedOutput fedOut = parts.length>7 ? FederatedOutput.valueOf(parts[6]) : FederatedOutput.NONE;
TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode, numThreads);
- return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str, federatedOutput);
+ return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str, fedOut);
}
@Override
@@ -166,12 +166,12 @@ public class TernaryFEDInstruction extends ComputationFEDInstruction {
*/
private void sendFederatedRequests(ExecutionContext ec, MatrixObject fedMapObj, long fedOutputID,
FederatedRequest[] federatedSlices1, FederatedRequest[] federatedSlices2, FederatedRequest... federatedRequests){
- if ( _federatedOutput ){
+ if ( !_fedOut.isForcedLocal() ){
fedMapObj.getFedMapping().execute(getTID(), true, federatedSlices1, federatedSlices2, federatedRequests);
setOutputFedMapping(ec, fedMapObj, fedOutputID);
- } else {
- processAndRetrieve(ec, fedMapObj, fedOutputID, federatedSlices1, federatedSlices2, federatedRequests);
}
+ else
+ processAndRetrieve(ec, fedMapObj, fedOutputID, federatedSlices1, federatedSlices2, federatedRequests);
}
/**
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
index bb14774..c0fc942 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -39,14 +39,14 @@ public class TsmmFEDInstruction extends BinaryFEDInstruction {
@SuppressWarnings("unused")
private final int _numThreads;
- public TsmmFEDInstruction(CPOperand in, CPOperand out, MMTSJType type, int k, String opcode, String istr, boolean federatedOutput) {
- super(FEDType.Tsmm, null, in, null, out, opcode, istr, federatedOutput);
+ public TsmmFEDInstruction(CPOperand in, CPOperand out, MMTSJType type, int k, String opcode, String istr, FederatedOutput fedOut) {
+ super(FEDType.Tsmm, null, in, null, out, opcode, istr, fedOut);
_type = type;
_numThreads = k;
}
public TsmmFEDInstruction(CPOperand in, CPOperand out, MMTSJType type, int k, String opcode, String istr) {
- this(in, out, type, k, opcode, istr, false);
+ this(in, out, type, k, opcode, istr, FederatedOutput.NONE);
}
public static TsmmFEDInstruction parseInstruction(String str) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
index fa0754f..0ae3178 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
@@ -31,8 +31,8 @@ public abstract class UnaryFEDInstruction extends ComputationFEDInstruction {
}
protected UnaryFEDInstruction(FEDType type, Operator op, CPOperand in, CPOperand out, String opcode, String instr,
- boolean federatedOutput) {
- this(type, op, in, null, null, out, opcode, instr, federatedOutput);
+ FederatedOutput fedOut) {
+ this(type, op, in, null, null, out, opcode, instr, fedOut);
}
protected UnaryFEDInstruction(FEDType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode,
@@ -41,23 +41,27 @@ public abstract class UnaryFEDInstruction extends ComputationFEDInstruction {
}
protected UnaryFEDInstruction(FEDType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode,
- String instr, boolean federatedOutput) {
- this(type, op, in1, in2, null, out, opcode, instr, federatedOutput);
+ String instr, FederatedOutput fedOut) {
+ this(type, op, in1, in2, null, out, opcode, instr, fedOut);
}
protected UnaryFEDInstruction(FEDType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out,
String opcode, String instr) {
- this(type, op, in1, in2, in3, out, opcode, instr, false);
+ this(type, op, in1, in2, in3, out, opcode, instr, FederatedOutput.NONE);
}
protected UnaryFEDInstruction(FEDType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out,
- String opcode, String instr, boolean federatedOutput) {
- super(type, op, in1, in2, in3, out, opcode, instr, federatedOutput);
+ String opcode, String instr, FederatedOutput fedOut) {
+ super(type, op, in1, in2, in3, out, opcode, instr, fedOut);
}
static String parseUnaryInstruction(String instr, CPOperand in, CPOperand out) {
- InstructionUtils.checkNumFields(instr, 2);
- return parse(instr, in, null, null, out);
+ //TODO: simplify once all fed instructions have consistent flags
+ int num = InstructionUtils.checkNumFields(instr, 2, 3);
+ if(num == 2)
+ return parse(instr, in, null, null, out);
+ else
+ return parseWithFedOutFlag(instr, in, out);
}
static String parseUnaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand out) {
@@ -98,4 +102,12 @@ public abstract class UnaryFEDInstruction extends ComputationFEDInstruction {
}
return opcode;
}
+
+ private static String parseWithFedOutFlag(String instr, CPOperand in1, CPOperand out) {
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);
+ String opcode = parts[0];
+ in1.split(parts[1]);
+ out.split(parts[parts.length - 2]);
+ return opcode;
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
index a479be7..197ba43 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
@@ -19,8 +19,9 @@
package org.apache.sysds.runtime.instructions.fed;
+import java.util.ArrayList;
import java.util.Arrays;
-import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.tuple.Pair;
@@ -103,12 +104,12 @@ public class VariableFEDInstruction extends FEDInstruction implements LineageTra
MatrixObject out = ec.getMatrixObject(_in.getOutput());
FederationMap outMap = mo1.getFedMapping().copyWithNewID(fr1.getID());
- Map<FederatedRange, FederatedData> newMap = new HashMap<>();
- for(Map.Entry<FederatedRange, FederatedData> pair : outMap.getMap().entrySet()) {
+ List<Pair<FederatedRange, FederatedData>> newMap = new ArrayList<>();
+ for(Pair<FederatedRange, FederatedData> pair : outMap.getMap()) {
FederatedData om = pair.getValue();
- FederatedData nf = new FederatedData(Types.DataType.MATRIX, om.getAddress(), om.getFilepath(),
- om.getVarID());
- newMap.put(pair.getKey(), nf);
+ FederatedData nf = new FederatedData(Types.DataType.MATRIX,
+ om.getAddress(), om.getFilepath(), om.getVarID());
+ newMap.add(Pair.of(pair.getKey(), nf));
}
out.setFedMapping(outMap);
}
@@ -130,12 +131,12 @@ public class VariableFEDInstruction extends FEDInstruction implements LineageTra
FrameObject out = ec.getFrameObject(_in.getOutput());
out.getDataCharacteristics().set(mo1.getNumRows(), mo1.getNumColumns(), (int) mo1.getBlocksize(), mo1.getNnz());
FederationMap outMap = mo1.getFedMapping().copyWithNewID(fr1.getID());
- Map<FederatedRange, FederatedData> newMap = new HashMap<>();
- for(Map.Entry<FederatedRange, FederatedData> pair : outMap.getMap().entrySet()) {
+ List<Pair<FederatedRange, FederatedData>> newMap = new ArrayList<>();
+ for(Map.Entry<FederatedRange, FederatedData> pair : outMap.getMap()) {
FederatedData om = pair.getValue();
- FederatedData nf = new FederatedData(Types.DataType.FRAME, om.getAddress(), om.getFilepath(),
- om.getVarID());
- newMap.put(pair.getKey(), nf);
+ FederatedData nf = new FederatedData(Types.DataType.FRAME,
+ om.getAddress(), om.getFilepath(), om.getVarID());
+ newMap.add(Pair.of(pair.getKey(), nf));
}
ValueType[] schema = new ValueType[(int) mo1.getDataCharacteristics().getCols()];
Arrays.fill(schema, ValueType.FP64);
diff --git a/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java b/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java
index 5252c5e..070c534 100644
--- a/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java
+++ b/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java
@@ -25,8 +25,6 @@ import java.io.OutputStreamWriter;
import java.net.InetSocketAddress;
import java.util.Arrays;
import java.util.List;
-import java.util.Map;
-import java.util.Map.Entry;
import java.util.stream.Collectors;
import com.fasterxml.jackson.core.type.TypeReference;
@@ -119,13 +117,10 @@ public class ReaderWriterFederated {
}
}
- private static FederatedDataAddress[] parseMap(Map<FederatedRange, FederatedData> map) {
- FederatedDataAddress[] res = new FederatedDataAddress[map.size()];
- int i = 0;
- for(Entry<FederatedRange, FederatedData> ent : map.entrySet()) {
- res[i++] = new FederatedDataAddress(ent.getKey(), ent.getValue());
- }
- return res;
+ private static FederatedDataAddress[] parseMap(List<Pair<FederatedRange, FederatedData>> map) {
+ return map.stream()
+ .map(e -> new FederatedDataAddress(e.getKey(), e.getValue()))
+ .toArray(FederatedDataAddress[]::new);
}
/**
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index c7821af..e22d0e2 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -40,6 +40,7 @@ import java.util.stream.Collectors;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
@@ -643,7 +644,7 @@ public abstract class AutomatedTestBase {
new MatrixCharacteristics(nrows, ncol), Types.FileFormat.BINARY));
// write parts and generate FederationMap
- HashMap<FederatedRange, FederatedData> fedHashMap = new HashMap<>();
+ List<Pair<FederatedRange, FederatedData>> fedHashMap = new ArrayList<>();
for(int i = 0; i < numFederatedWorkers; i++) {
double lowerBound = ranges[i][0];
double upperBound = ranges[i][1];
@@ -658,7 +659,7 @@ public abstract class AutomatedTestBase {
// generate fedmap entry
FederatedRange range = new FederatedRange(new long[]{(long) lowerBound, 0}, new long[]{(long) upperBound, ncol});
FederatedData data = new FederatedData(DataType.MATRIX, new InetSocketAddress(ports.get(i)), input(path));
- fedHashMap.put(range, data);
+ fedHashMap.add(Pair.of(range, data));
}
federatedMatrixObject.setFedMapping(new FederationMap(FederationUtils.getNextFedDataID(), fedHashMap));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedNegativeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedNegativeTest.java
index f5fba16..86e75eb 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedNegativeTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedNegativeTest.java
@@ -23,10 +23,11 @@ import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import java.net.InetSocketAddress;
-import java.util.HashMap;
-import java.util.Map;
+import java.util.ArrayList;
+import java.util.List;
import java.util.concurrent.Future;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
@@ -52,11 +53,11 @@ public class FederatedNegativeTest {
NegativeTest1();
}
FederationUtils.resetFedDataID(); //ensure expected ID when tests run in single JVM
- Map<FederatedRange, FederatedData> fedMap = new HashMap<>();
+ List<Pair<FederatedRange, FederatedData>> fedMap = new ArrayList<>();
FederatedRange r = new FederatedRange(new long[]{0,0}, new long[]{1,1});
FederatedData d = new FederatedData(Types.DataType.SCALAR,
new InetSocketAddress("localhost", port), "Nowhere");
- fedMap.put(r,d);
+ fedMap.add(Pair.of(r,d));
FederationMap fedM = new FederationMap(fedMap);
FederatedRequest fr = new FederatedRequest(FederatedRequest.RequestType.GET_VAR);
Future<FederatedResponse>[] res = fedM.execute(0, fr);
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
index bc4dec4..3994baa 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
@@ -20,6 +20,8 @@
package org.apache.sysds.test.functions.privacy.algorithms;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.junit.Assert;
+import org.junit.Ignore;
import org.junit.Test;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
@@ -54,6 +56,7 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
// PrivateAggregation Single Input
@Test
+ @Ignore
public void federatedL2SVMCPPrivateAggregationX1() throws JSONException {
Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
privacyConstraints.put("X1", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
@@ -61,6 +64,7 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
}
@Test
+ @Ignore
public void federatedL2SVMCPPrivateAggregationX2() throws JSONException {
Map<String, PrivacyConstraint> privacyConstraints = new HashMap<>();
privacyConstraints.put("X2", new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
@@ -399,7 +403,7 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
}
if ( expectedPrivacyLevel != null)
- assert(checkedPrivacyConstraintsContains(expectedPrivacyLevel));
+ Assert.assertTrue(checkedPrivacyConstraintsContains(expectedPrivacyLevel));
}
finally {
TestUtils.shutdownThreads(t1, t2);
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index 7134347..79fe54f 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -21,6 +21,8 @@ package org.apache.sysds.test.functions.privacy.fedplanning;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
+import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -34,13 +36,18 @@ import org.apache.sysds.test.TestUtils;
import java.util.Arrays;
import java.util.Collection;
+import static org.junit.Assert.assertTrue;
+
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
- private final static String TEST_DIR = "functions/privacy/";
+ private final static String TEST_DIR = "functions/privacy/fedplanning/";
private final static String TEST_NAME = "FederatedMultiplyPlanningTest";
private final static String TEST_NAME_2 = "FederatedMultiplyPlanningTest2";
private final static String TEST_NAME_3 = "FederatedMultiplyPlanningTest3";
+ private final static String TEST_NAME_4 = "FederatedMultiplyPlanningTest4";
+ private final static String TEST_NAME_5 = "FederatedMultiplyPlanningTest5";
+ private final static String TEST_NAME_6 = "FederatedMultiplyPlanningTest6";
private final static String TEST_CLASS_DIR = TEST_DIR + FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
private final static int blocksize = 1024;
@@ -55,6 +62,9 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
addTestConfiguration(TEST_NAME_2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2, new String[] {"Z"}));
addTestConfiguration(TEST_NAME_3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_3, new String[] {"Z.scalar"}));
+ addTestConfiguration(TEST_NAME_4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_4, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME_5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_5, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME_6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_6, new String[] {"Z"}));
}
@Parameterized.Parameters
@@ -67,86 +77,105 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
@Test
public void federatedMultiplyCP() {
- OptimizerUtils.FEDERATED_COMPILATION = true;
federatedTwoMatricesSingleNodeTest(TEST_NAME);
}
@Test
+ @Ignore
public void federatedRowSum(){
- OptimizerUtils.FEDERATED_COMPILATION = true;
federatedTwoMatricesSingleNodeTest(TEST_NAME_2);
}
@Test
public void federatedTernarySequence(){
- OptimizerUtils.FEDERATED_COMPILATION = true;
federatedTwoMatricesSingleNodeTest(TEST_NAME_3);
}
- private void writeStandardMatrix(String matrixName, long seed){
- int halfRows = rows/2;
- double[][] matrix = getRandomMatrix(halfRows, cols, 0, 1, 1, seed);
- writeInputMatrixWithMTD(matrixName, matrix, false,
- new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols),
- new PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation));
+ @Test
+ public void federatedAggregateBinarySequence(){
+ cols = rows;
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_4);
}
- public void federatedTwoMatricesSingleNodeTest(String testName){
- federatedTwoMatricesTest(Types.ExecMode.SINGLE_NODE, testName);
+ @Test
+ @Ignore
+ public void federatedAggregateBinaryColFedSequence(){
+ cols = rows;
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_5);
}
- public void federatedTwoMatricesTest(Types.ExecMode execMode, String testName) {
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- Types.ExecMode platformOld = rtplatform;
- rtplatform = execMode;
- if(rtplatform == Types.ExecMode.SPARK) {
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
- }
-
- getAndLoadTestConfiguration(testName);
- String HOME = SCRIPT_DIR + TEST_DIR;
-
- // Write input matrices
- writeStandardMatrix("X1", 42);
- writeStandardMatrix("X2", 1340);
- writeStandardMatrix("Y1", 44);
- writeStandardMatrix("Y2", 21);
+ @Test
+ @Ignore
+ public void federatedAggregateBinarySequence2(){
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_6);
+ }
- int port1 = getRandomAvailablePort();
- int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ private void writeStandardMatrix(String matrixName, long seed){
+ writeStandardMatrix(matrixName, seed, new PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation));
+ }
- TestConfiguration config = availableTestConfigurations.get(testName);
- loadTestConfiguration(config);
+ private void writeStandardMatrix(String matrixName, long seed, PrivacyConstraint privacyConstraint){
+ int halfRows = rows/2;
+ double[][] matrix = getRandomMatrix(halfRows, cols, 0, 1, 1, seed);
+ MatrixCharacteristics mc = new MatrixCharacteristics(halfRows, cols, blocksize, (long) halfRows * cols);
+ writeInputMatrixWithMTD(matrixName, matrix, false, mc, privacyConstraint);
+ }
- // Run actual dml script with federated matrix
- fullDMLScriptName = HOME + testName + ".dml";
- programArgs = new String[] {"-explain", "-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
- "X2=" + TestUtils.federatedAddress(port2, input("X2")),
- "Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
- "Y2=" + TestUtils.federatedAddress(port2, input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z")};
- runTest(true, false, null, -1);
+ private void writeColStandardMatrix(String matrixName, long seed){
+ writeColStandardMatrix(matrixName, seed, new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ }
- OptimizerUtils.FEDERATED_COMPILATION = false;
+ private void writeColStandardMatrix(String matrixName, long seed, PrivacyConstraint privacyConstraint){
+ int halfCols = cols/2;
+ double[][] matrix = getRandomMatrix(rows, halfCols, 0, 1, 1, seed);
+ MatrixCharacteristics mc = new MatrixCharacteristics(rows, halfCols, blocksize, (long) halfCols *rows);
+ writeInputMatrixWithMTD(matrixName, matrix, false, mc, privacyConstraint);
+ }
- // Run reference dml script with normal matrix
- fullDMLScriptName = HOME + testName + "Reference.dml";
- programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
- "Y2=" + input("Y2"), "Z=" + expected("Z")};
- runTest(true, false, null, -1);
+ private void writeRowFederatedVector(String matrixName, long seed){
+ writeRowFederatedVector(matrixName, seed, new PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+ }
- // compare via files
- compareResults(1e-9);
- heavyHittersContainsString("fed_*", "fed_ba+*");
+ private void writeRowFederatedVector(String matrixName, long seed, PrivacyConstraint privacyConstraint){
+ int halfCols = cols / 2;
+ double[][] matrix = getRandomMatrix(halfCols, 1, 0, 1, 1, seed);
+ MatrixCharacteristics mc = new MatrixCharacteristics(halfCols, 1, blocksize, (long) halfCols *rows);
+ writeInputMatrixWithMTD(matrixName, matrix, false, mc, privacyConstraint);
+ }
- TestUtils.shutdownThreads(t1, t2);
+ private void writeInputMatrices(String testName){
+ if ( testName.equals(TEST_NAME_5) ){
+ writeColStandardMatrix("X1", 42);
+ writeColStandardMatrix("X2", 1340);
+ writeColStandardMatrix("Y1", 44);
+ writeColStandardMatrix("Y2", 21);
+ }
+ else if ( testName.equals(TEST_NAME_6) ){
+ writeColStandardMatrix("X1", 42);
+ writeColStandardMatrix("X2", 1340);
+ writeRowFederatedVector("Y1", 44);
+ writeRowFederatedVector("Y2", 21);
+ }
+ else {
+ writeStandardMatrix("X1", 42);
+ writeStandardMatrix("X2", 1340);
+ if ( testName.equals(TEST_NAME_4) ){
+ writeStandardMatrix("Y1", 44, null);
+ writeStandardMatrix("Y2", 21, null);
+ }
+ else {
+ writeStandardMatrix("Y1", 44);
+ writeStandardMatrix("Y2", 21);
+ }
+ }
+ }
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ private void federatedTwoMatricesSingleNodeTest(String testName){
+ federatedTwoMatricesTest(Types.ExecMode.SINGLE_NODE, testName);
}
- public void federatedThreeMatricesTest(Types.ExecMode execMode, String testName) {
+ private void federatedTwoMatricesTest(Types.ExecMode execMode, String testName) {
+ OptimizerUtils.FEDERATED_COMPILATION = true;
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
rtplatform = execMode;
@@ -157,27 +186,25 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
getAndLoadTestConfiguration(testName);
String HOME = SCRIPT_DIR + TEST_DIR;
- // Write input matrices
- writeStandardMatrix("X1", 42);
- writeStandardMatrix("X2", 1340);
- writeStandardMatrix("Y1", 44);
- writeStandardMatrix("Y2", 21);
- writeStandardMatrix("W1", 55);
+ writeInputMatrices(testName);
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
Thread t2 = startLocalFedWorkerThread(port2);
- TestConfiguration config = availableTestConfigurations.get(testName);
- loadTestConfiguration(config);
-
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + testName + ".dml";
- programArgs = new String[] {"-explain", "-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ programArgs = new String[] {"-stats", "-explain", "-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
"X2=" + TestUtils.federatedAddress(port2, input("X2")),
"Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
"Y2=" + TestUtils.federatedAddress(port2, input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z")};
+ if ( testName.equals(TEST_NAME_4) || testName.equals(TEST_NAME_5) ){
+ programArgs = new String[] {"-stats","-explain", "-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "X2=" + TestUtils.federatedAddress(port2, input("X2")),
+ "Y1=" + input("Y1"),
+ "Y2=" + input("Y2"), "r=" + rows, "c=" + cols, "Z=" + output("Z")};
+ }
runTest(true, false, null, -1);
OptimizerUtils.FEDERATED_COMPILATION = false;
@@ -190,7 +217,10 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
// compare via files
compareResults(1e-9);
- heavyHittersContainsString("fed_*", "fed_ba+*");
+ if ( testName.equals(TEST_NAME_3) )
+ assertTrue(heavyHittersContainsString("fed_+*", "fed_1-*"));
+ else
+ assertTrue(heavyHittersContainsString("fed_*", "fed_ba+*"));
TestUtils.shutdownThreads(t1, t2);
diff --git a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.dml
similarity index 100%
copy from src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest.dml
copy to src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.dml
diff --git a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest2.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest2.dml
similarity index 100%
rename from src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest2.dml
rename to src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest2.dml
diff --git a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest2Reference.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest2Reference.dml
similarity index 100%
rename from src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest2Reference.dml
rename to src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest2Reference.dml
diff --git a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest3.dml
similarity index 86%
copy from src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3.dml
copy to src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest3.dml
index 8e39a89..148d8ef 100644
--- a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3.dml
+++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest3.dml
@@ -20,9 +20,9 @@
#-------------------------------------------------------------
X = federated(addresses=list($X1, $X2),
-ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
+ ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
Y = federated(addresses=list($Y1, $Y2),
-ranges=list(list(0, 0), list($r/2, $c), list($r / 2, 0), list($r, $c)))
+ ranges=list(list(0, 0), list($r/2, $c), list($r / 2, 0), list($r, $c)))
W = rand(rows=$r, cols=$c, min=0, max=1, pdf='uniform', seed=5)
s = 3.5
Z0 = W + s * X
diff --git a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3Reference.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest3Reference.dml
similarity index 100%
rename from src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3Reference.dml
rename to src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest3Reference.dml
diff --git a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest4.dml
similarity index 87%
rename from src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest.dml
rename to src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest4.dml
index 04b3804..a798ed2 100644
--- a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest.dml
+++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest4.dml
@@ -21,8 +21,7 @@
X = federated(addresses=list($X1, $X2),
ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
-Y = federated(addresses=list($Y1, $Y2),
- ranges=list(list(0, 0), list($r/2, $c), list($r / 2, 0), list($r, $c)))
-Z0 = X * Y
-Z = t(Z0) %*% X
+Y = rbind(read($Y1), read($Y2))
+Z0 = X %*% Y
+Z = Z0 * Y
write(Z, $Z)
diff --git a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTestReference.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest4Reference.dml
similarity index 97%
copy from src/test/scripts/functions/privacy/FederatedMultiplyPlanningTestReference.dml
copy to src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest4Reference.dml
index ee595d7..21c5990 100644
--- a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTestReference.dml
+++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest4Reference.dml
@@ -21,6 +21,6 @@
X = rbind(read($X1), read($X2))
Y = rbind(read($Y1), read($Y2))
-Z0 = X * Y
-Z = t(Z0) %*% X
+Z0 = X %*% Y
+Z = Z0 * Y
write(Z, $Z)
diff --git a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTestReference.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest5.dml
similarity index 83%
copy from src/test/scripts/functions/privacy/FederatedMultiplyPlanningTestReference.dml
copy to src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest5.dml
index ee595d7..76da220 100644
--- a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTestReference.dml
+++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest5.dml
@@ -19,8 +19,9 @@
#
#-------------------------------------------------------------
-X = rbind(read($X1), read($X2))
-Y = rbind(read($Y1), read($Y2))
-Z0 = X * Y
-Z = t(Z0) %*% X
+X = federated(addresses=list($X1, $X2),
+ ranges=list(list(0, 0), list($r, $c / 2), list(0, $c / 2), list($r, $c)))
+Y = cbind(read($Y1), read($Y2))
+Z0 = X %*% t(Y)
+Z = Z0 * Y
write(Z, $Z)
diff --git a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTestReference.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest5Reference.dml
similarity index 91%
copy from src/test/scripts/functions/privacy/FederatedMultiplyPlanningTestReference.dml
copy to src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest5Reference.dml
index ee595d7..6d3131d 100644
--- a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTestReference.dml
+++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest5Reference.dml
@@ -19,8 +19,8 @@
#
#-------------------------------------------------------------
-X = rbind(read($X1), read($X2))
-Y = rbind(read($Y1), read($Y2))
-Z0 = X * Y
-Z = t(Z0) %*% X
+X = cbind(read($X1), read($X2))
+Y = cbind(read($Y1), read($Y2))
+Z0 = X %*% t(Y)
+Z = Z0 * Y
write(Z, $Z)
diff --git a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest6.dml
similarity index 78%
rename from src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3.dml
rename to src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest6.dml
index 8e39a89..08b68bc 100644
--- a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest3.dml
+++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest6.dml
@@ -20,12 +20,10 @@
#-------------------------------------------------------------
X = federated(addresses=list($X1, $X2),
-ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
+ ranges=list(list(0, 0), list($r, $c / 2), list(0, $c / 2), list($r, $c)))
Y = federated(addresses=list($Y1, $Y2),
-ranges=list(list(0, 0), list($r/2, $c), list($r / 2, 0), list($r, $c)))
-W = rand(rows=$r, cols=$c, min=0, max=1, pdf='uniform', seed=5)
-s = 3.5
-Z0 = W + s * X
-Z1 = 1 - Y * Z0
-Z = sum(Z1)
+ ranges=list(list(0, 0), list($c / 2, 1), list($c / 2, 0), list($c, 1)))
+P = rand(rows = 1, cols = ncol(X), min=0, max=1, pdf='uniform', seed=6)
+Z0 = X %*% Y
+Z = Z0 %*% P
write(Z, $Z)
diff --git a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTestReference.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest6Reference.dml
similarity index 88%
copy from src/test/scripts/functions/privacy/FederatedMultiplyPlanningTestReference.dml
copy to src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest6Reference.dml
index ee595d7..77e0825 100644
--- a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTestReference.dml
+++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest6Reference.dml
@@ -19,8 +19,9 @@
#
#-------------------------------------------------------------
-X = rbind(read($X1), read($X2))
+X = cbind(read($X1), read($X2))
Y = rbind(read($Y1), read($Y2))
-Z0 = X * Y
-Z = t(Z0) %*% X
+P = rand(rows = 1, cols = ncol(X), min=0, max=1, pdf='uniform', seed=6)
+Z0 = X %*% Y
+Z = Z0 %*% P
write(Z, $Z)
diff --git a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTestReference.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTestReference.dml
similarity index 100%
rename from src/test/scripts/functions/privacy/FederatedMultiplyPlanningTestReference.dml
rename to src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTestReference.dml