You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/09/18 21:40:06 UTC
[systemds] branch master updated: [SYSTEMDS-3093] Support for spark
instructions on federated inputs
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 0d3bb35 [SYSTEMDS-3093] Support for spark instructions on federated inputs
0d3bb35 is described below
commit 0d3bb359ae894252e2509bde1cbe56008f609f5a
Author: OlgaOvcharenko <ov...@gmail.com>
AuthorDate: Sat Sep 18 23:39:18 2021 +0200
[SYSTEMDS-3093] Support for spark instructions on federated inputs
Execute spark instructions at federated workers (cross-datacenter
federation). Now we no longer convert spark instructions to local
federated instruction but spark federated instructions.
Closes #1363.
---
.../hops/rewrite/RewriteBlockSizeAndReblock.java | 1 +
.../controlprogram/caching/CacheableData.java | 2 +-
.../controlprogram/context/ExecutionContext.java | 19 ++
.../federated/FederatedWorkerHandler.java | 6 +-
.../controlprogram/federated/FederationUtils.java | 40 ++-
.../runtime/instructions/InstructionUtils.java | 14 +-
.../fed/AggregateTernaryFEDInstruction.java | 80 +++---
.../fed/AggregateUnaryFEDInstruction.java | 111 +++++++--
.../instructions/fed/AppendFEDInstruction.java | 11 +-
.../instructions/fed/BinaryFEDInstruction.java | 2 +-
.../instructions/fed/CastFEDInstruction.java | 133 ++++++++++
...on.java => CumulativeOffsetFEDInstruction.java} | 183 +++++++-------
.../runtime/instructions/fed/FEDInstruction.java | 6 +
.../instructions/fed/FEDInstructionUtils.java | 256 +++++++++++++++-----
.../instructions/fed/IndexingFEDInstruction.java | 24 +-
.../instructions/fed/MapmmFEDInstruction.java | 269 +++++++++++++++++++++
.../instructions/fed/ReblockFEDInstruction.java | 93 +++++++
.../instructions/fed/ReorgFEDInstruction.java | 29 ++-
.../instructions/fed/ReshapeFEDInstruction.java | 10 +-
.../instructions/fed/TernaryFEDInstruction.java | 36 ++-
.../fed/UnaryMatrixFEDInstruction.java | 11 +-
.../spark/MatrixAppendMSPInstruction.java | 2 +-
.../spark/ParameterizedBuiltinSPInstruction.java | 5 +
.../matrix/operators/AggregateUnaryOperator.java | 5 +
.../org/apache/sysds/test/AutomatedTestBase.java | 21 +-
.../primitives/FederatedFullAggregateTest.java | 13 +-
.../primitives/FederatedFullCumulativeTest.java | 33 +--
.../federated/primitives/FederatedIfelseTest.java | 10 +-
.../primitives/FederatedLeftIndexTest.java | 12 +-
.../federated/primitives/FederatedMMChainTest.java | 8 +-
.../primitives/FederatedMultiplyTest.java | 12 +-
.../federated/primitives/FederatedProdTest.java | 3 +
.../primitives/FederatedQuantileTest.java | 13 +-
.../federated/primitives/FederatedRdiagTest.java | 4 +-
.../primitives/FederatedRemoveEmptyTest.java | 7 +-
.../federated/primitives/FederatedReshapeTest.java | 5 +-
.../federated/primitives/FederatedRevTest.java | 7 +-
.../primitives/FederatedRightIndexTest.java | 12 +-
.../primitives/FederatedRowAggregateTest.java | 39 ++-
.../primitives/FederatedRowIndexTest.java | 9 +-
.../federated/primitives/FederatedSplitTest.java | 9 +-
.../federated/primitives/FederatedSumTest.java | 7 +-
.../primitives/FederatedTokenizeTest.java | 2 +-
.../federated/primitives/FederatedTriTest.java | 5 +-
.../aggregate/FederatedMeanTestReference.dml | 6 +-
45 files changed, 1276 insertions(+), 309 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java
index ba40730..3a20283 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java
@@ -112,6 +112,7 @@ public class RewriteBlockSizeAndReblock extends HopRewriteRule
dop.setBlocksize(blocksize);
}
else if (dop.getOp() == OpOpData.FEDERATED) {
+ dop.setRequiresReblock(true);
dop.setBlocksize(blocksize);
}
else {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index d2b38a9..e44d061 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -547,7 +547,7 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
if( DMLScript.STATISTICS )
CacheStatistics.incrementLinHits();
}
- else if( isFederated() ) {
+ else if( isFederatedExcept(FType.BROADCAST) ) {
_data = readBlobFromFederated(_fedMapping);
//mark for initial local write despite read operation
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index 75591b6..8735d05 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -606,6 +606,25 @@ public class ExecutionContext {
return ret;
}
+ public static MatrixObject createMatrixObject(DataCharacteristics dc) {
+ MatrixObject ret = new MatrixObject(Types.ValueType.FP64,
+ OptimizerUtils.getUniqueTempFileName());
+ ret.setMetaData(new MetaDataFormat(new MatrixCharacteristics(
+ dc.getRows(), dc.getCols()), FileFormat.BINARY));
+ ret.getMetaData().getDataCharacteristics()
+ .setBlocksize(ConfigurationManager.getBlocksize());
+ return ret;
+ }
+
+ public static FrameObject createFrameObject(DataCharacteristics dc) {
+ FrameObject ret = new FrameObject(OptimizerUtils.getUniqueTempFileName());
+ ret.setMetaData(new MetaDataFormat(new MatrixCharacteristics(
+ dc.getRows(), dc.getCols()), FileFormat.BINARY));
+ ret.getMetaData().getDataCharacteristics()
+ .setBlocksize(ConfigurationManager.getBlocksize());
+ return ret;
+ }
+
public static FrameObject createFrameObject(FrameBlock fb) {
FrameObject ret = new FrameObject(OptimizerUtils.getUniqueTempFileName());
ret.acquireModify(fb);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 062cfa0..a35b736 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -249,7 +249,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
}
private FederatedResponse putVariable(FederatedRequest request) {
- checkNumParams(request.getNumParams(), 1);
+ checkNumParams(request.getNumParams(), 1, 2);
String varname = String.valueOf(request.getID());
ExecutionContext ec = _ecm.get(request.getTID());
@@ -265,6 +265,10 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
data = (ScalarObject) request.getParam(0);
else if(request.getParam(0) instanceof ListObject)
data = (ListObject) request.getParam(0);
+ else if(request.getNumParams() == 2)
+ data = request.getParam(1) == DataType.MATRIX ?
+ ExecutionContext.createMatrixObject((MatrixCharacteristics) request.getParam(0)) :
+ ExecutionContext.createFrameObject((MatrixCharacteristics) request.getParam(0));
else
throw new DMLRuntimeException(
"FederatedWorkerHandler: Unsupported object type, has to be of type CacheBlock or ScalarObject");
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index f9fb881..e2430bb 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -101,7 +101,6 @@ public class FederationUtils {
FederatedRequest[] fr = new FederatedRequest[inst.length];
for(int j=0; j<inst.length; j++) {
for(int i = 0; i < varOldIn.length; i++) {
- linst[j] = linst[j].replace(ExecType.SPARK.name(), ExecType.CP.name());
linst[j] = linst[j].replace(
Lop.OPERAND_DELIMITOR + varOldOut.getName() + Lop.DATATYPE_PREFIX,
Lop.OPERAND_DELIMITOR + String.valueOf(id) + Lop.DATATYPE_PREFIX);
@@ -118,8 +117,30 @@ public class FederationUtils {
return fr;
}
- public static FederatedRequest callInstruction(String inst, CPOperand varOldOut, long outputId, CPOperand[] varOldIn, long[] varNewIn) {
- String linst = InstructionUtils.replaceOperand(inst, 0, ExecType.CP.name());
+ public static FederatedRequest[] callInstruction(String[] inst, CPOperand varOldOut, long outputId, CPOperand[] varOldIn, long[] varNewIn, ExecType type) {
+ String[] linst = inst;
+ FederatedRequest[] fr = new FederatedRequest[inst.length];
+ for(int j=0; j<inst.length; j++) {
+ for(int i = 0; i < varOldIn.length; i++) {
+ linst[j] = InstructionUtils.replaceOperand(linst[j], 0, type == null ? InstructionUtils.getExecType(linst[j]).name() : type.name());
+ linst[j] = linst[j].replace(
+ Lop.OPERAND_DELIMITOR + varOldOut.getName() + Lop.DATATYPE_PREFIX,
+ Lop.OPERAND_DELIMITOR + String.valueOf(outputId) + Lop.DATATYPE_PREFIX);
+
+ if(varOldIn[i] != null) {
+ linst[j] = linst[j].replace(
+ Lop.OPERAND_DELIMITOR + varOldIn[i].getName() + Lop.DATATYPE_PREFIX,
+ Lop.OPERAND_DELIMITOR + String.valueOf(varNewIn[i]) + Lop.DATATYPE_PREFIX);
+ linst[j] = linst[j].replace("=" + varOldIn[i].getName(), "=" + String.valueOf(varNewIn[i])); //parameterized
+ }
+ }
+ fr[j] = new FederatedRequest(RequestType.EXEC_INST, outputId, (Object) linst[j]);
+ }
+ return fr;
+ }
+
+ public static FederatedRequest callInstruction(String inst, CPOperand varOldOut, long outputId, CPOperand[] varOldIn, long[] varNewIn, ExecType type, boolean rmFedOutputFlag) {
+ String linst = InstructionUtils.replaceOperand(inst, 0, type.name());
linst = linst.replace(Lop.OPERAND_DELIMITOR+varOldOut.getName()+Lop.DATATYPE_PREFIX, Lop.OPERAND_DELIMITOR+outputId+Lop.DATATYPE_PREFIX);
for(int i=0; i<varOldIn.length; i++)
if( varOldIn[i] != null ) {
@@ -128,6 +149,8 @@ public class FederationUtils {
Lop.OPERAND_DELIMITOR+(varNewIn[i])+Lop.DATATYPE_PREFIX);
linst = linst.replace("="+varOldIn[i].getName(), "="+(varNewIn[i])); //parameterized
}
+ if(rmFedOutputFlag)
+ linst = InstructionUtils.removeFEDOutputFlag(linst);
return new FederatedRequest(RequestType.EXEC_INST, outputId, linst);
}
@@ -226,12 +249,15 @@ public class FederationUtils {
public static MatrixBlock aggProd(Future<FederatedResponse>[] ffr, FederationMap fedMap, AggregateUnaryOperator aop) {
try {
boolean rowFed = fedMap.getType() == FederationMap.FType.ROW;
- MatrixBlock ret = rowFed ?
+ MatrixBlock ret = aop.isFullAggregate() ? (rowFed ?
+ new MatrixBlock(ffr.length, 1, 1.0) : new MatrixBlock(1, ffr.length, 1.0)) :
+ (rowFed ?
new MatrixBlock(ffr.length, (int) fedMap.getFederatedRanges()[0].getEndDims()[1], 1.0) :
- new MatrixBlock((int) fedMap.getFederatedRanges()[0].getEndDims()[0], ffr.length, 1.0);
- MatrixBlock res = rowFed ?
+ new MatrixBlock((int) fedMap.getFederatedRanges()[0].getEndDims()[0], ffr.length, 1.0));
+ MatrixBlock res = aop.isFullAggregate() ? new MatrixBlock(1, 1, 1.0) :
+ (rowFed ?
new MatrixBlock(1, (int) fedMap.getFederatedRanges()[0].getEndDims()[1], 1.0) :
- new MatrixBlock((int) fedMap.getFederatedRanges()[0].getEndDims()[0], 1, 1.0);
+ new MatrixBlock((int) fedMap.getFederatedRanges()[0].getEndDims()[0], 1, 1.0));
for(int i = 0; i < ffr.length; i++) {
MatrixBlock tmp = (MatrixBlock) ffr[i].get().getData()[0];
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index 828eb0a..246ed87 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions;
import java.util.Arrays;
import java.util.StringTokenizer;
+import org.apache.commons.lang.ArrayUtils;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.CorrectionLocationType;
@@ -1029,6 +1030,16 @@ public class InstructionUtils
return concatOperands(parts);
}
+ public static String removeOperand(String instStr, int operand) {
+ //split instruction and check for correctness
+ String[] parts = instStr.split(Lop.OPERAND_DELIMITOR);
+ if( operand >= parts.length )
+ throw new DMLRuntimeException("Operand position "
+ + operand + " exceeds the length of the instruction.");
+ //remove and reconstruct string
+ return concatOperands((String[]) ArrayUtils.remove(parts, operand));
+ }
+
public static String replaceOperandName(String instStr) {
String[] parts = instStr.split(Lop.OPERAND_DELIMITOR);
String oldName = parts[parts.length-1];
@@ -1115,7 +1126,8 @@ public class InstructionUtils
}
private static String replaceExecTypeWithCP(String inst){
- return inst.replace(Types.ExecType.SPARK.name(), Types.ExecType.CP.name())
+ return inst
+// .replace(Types.ExecType.SPARK.name(), Types.ExecType.CP.name())
.replace(Types.ExecType.FED.name(), Types.ExecType.CP.name());
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
index 0fead6b..1b0c0bf 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
@@ -24,69 +24,84 @@ import java.util.concurrent.Future;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
-import org.apache.sysds.runtime.instructions.cp.AggregateTernaryCPInstruction;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
-import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.matrix.operators.Operator;
-public class AggregateTernaryFEDInstruction extends FEDInstruction {
+public class AggregateTernaryFEDInstruction extends ComputationFEDInstruction {
// private static final Log LOG = LogFactory.getLog(AggregateTernaryFEDInstruction.class.getName());
- public final AggregateTernaryCPInstruction _ins;
-
- protected AggregateTernaryFEDInstruction(AggregateTernaryCPInstruction ins) {
- super(FEDType.AggregateTernary, ins.getOperator(), ins.getOpcode(), ins.getInstructionString());
- _ins = ins;
+ private AggregateTernaryFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out,
+ String opcode, String istr) {
+ super(FEDType.AggregateTernary, op, in1, in2, in3, out, opcode, istr);
}
- public static AggregateTernaryFEDInstruction parseInstruction(AggregateTernaryCPInstruction ins) {
- return new AggregateTernaryFEDInstruction(ins);
+ public static AggregateTernaryFEDInstruction parseInstruction(String str) {
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode = parts[0];
+
+ if(opcode.equalsIgnoreCase("tak+*") || opcode.equalsIgnoreCase("tack+*")) {
+ InstructionUtils.checkNumFields(parts, 5);
+
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand in3 = new CPOperand(parts[3]);
+ CPOperand out = new CPOperand(parts[4]);
+ int numThreads = Integer.parseInt(parts[5]);
+
+ AggregateTernaryOperator op = InstructionUtils.parseAggregateTernaryOperator(opcode, numThreads);
+ return new AggregateTernaryFEDInstruction(op, in1, in2, in3, out, opcode, str);
+ }
+ else {
+ throw new DMLRuntimeException("AggregateTernaryInstruction.parseInstruction():: Unknown opcode " + opcode);
+ }
}
@Override
public void processInstruction(ExecutionContext ec) {
- MatrixObject mo1 = ec.getMatrixObject(_ins.input1);
- MatrixObject mo2 = ec.getMatrixObject(_ins.input2);
- MatrixObject mo3 = _ins.input3.isLiteral() ? null : ec.getMatrixObject(_ins.input3);
+ MatrixObject mo1 = ec.getMatrixObject(input1);
+ MatrixObject mo2 = ec.getMatrixObject(input2);
+ MatrixObject mo3 = input3.isLiteral() ? null : ec.getMatrixObject(input3);
if(mo3 != null && mo1.isFederated() && mo2.isFederated() && mo3.isFederated()
&& mo1.getFedMapping().isAligned(mo2.getFedMapping(), mo1.isFederated(FType.ROW) ? AlignType.ROW : AlignType.COL)
&& mo2.getFedMapping().isAligned(mo3.getFedMapping(), mo1.isFederated(FType.ROW) ? AlignType.ROW : AlignType.COL)) {
- FederatedRequest fr1 = FederationUtils.callInstruction(_ins.getInstructionString(), _ins.getOutput(),
- new CPOperand[] {_ins.input1, _ins.input2, _ins.input3},
+ FederatedRequest fr1 = FederationUtils.callInstruction(getInstructionString(), output,
+ new CPOperand[] {input1, input2, input3},
new long[] {mo1.getFedMapping().getID(), mo2.getFedMapping().getID(), mo3.getFedMapping().getID()});
FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1.getID());
Future<FederatedResponse>[] response = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
- if(_ins.output.getDataType().isScalar()) {
+ if(output.getDataType().isScalar()) {
AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
- ec.setScalarOutput(_ins.output.getName(), FederationUtils.aggScalar(aop, response, mo1.getFedMapping()));
+ ec.setScalarOutput(output.getName(), FederationUtils.aggScalar(aop, response, mo1.getFedMapping()));
}
else {
- AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator(_ins.getOpcode().equals("fed_tak+*") ? "uak+" : "uack+");
- ec.setMatrixOutput(_ins.output.getName(), FederationUtils.aggMatrix(aop, response, mo1.getFedMapping()));
+ AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator(getOpcode().equals("fed_tak+*") ? "uak+" : "uack+");
+ ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, response, mo1.getFedMapping()));
}
}
else if(mo1.isFederated() && mo2.isFederated()
&& mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) && mo3 == null) {
- FederatedRequest fr1 = mo1.getFedMapping().broadcast(ec.getScalarInput(_ins.input3));
- FederatedRequest fr2 = FederationUtils.callInstruction(_ins.getInstructionString(),
- _ins.getOutput(),
- new CPOperand[] {_ins.input1, _ins.input2, _ins.input3},
+ FederatedRequest fr1 = mo1.getFedMapping().broadcast(ec.getScalarInput(input3));
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[] {input1, input2, input3},
new long[] {mo1.getFedMapping().getID(), mo2.getFedMapping().getID(), fr1.getID()});
FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
FederatedRequest fr4 = mo2.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
- if(_ins.output.getDataType().isScalar()) {
+ if(output.getDataType().isScalar()) {
double sum = 0;
for(Future<FederatedResponse> fr : tmp)
try {
@@ -96,22 +111,21 @@ public class AggregateTernaryFEDInstruction extends FEDInstruction {
throw new DMLRuntimeException("Federated Get data failed with exception on TernaryFedInstruction", e);
}
- ec.setScalarOutput(_ins.output.getName(), new DoubleObject(sum));
+ ec.setScalarOutput(output.getName(), new DoubleObject(sum));
}
else {
throw new DMLRuntimeException("Not Implemented Federated Ternary Variation");
}
- } else if(mo1.isFederatedExcept(FType.BROADCAST) && _ins.input3.isMatrix() && mo3 != null) {
+ } else if(mo1.isFederatedExcept(FType.BROADCAST) && input3.isMatrix() && mo3 != null) {
FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo3, false);
FederatedRequest[] fr2 = mo1.getFedMapping().broadcastSliced(mo2, false);
- FederatedRequest fr3 = FederationUtils.callInstruction(_ins.getInstructionString(),
- _ins.getOutput(),
- new CPOperand[] {_ins.input1, _ins.input2, _ins.input3},
+ FederatedRequest fr3 = FederationUtils.callInstruction(getInstructionString(), output,
+ new CPOperand[] {input1, input2, input3},
new long[] {mo1.getFedMapping().getID(), fr2[0].getID(), fr1[0].getID()});
FederatedRequest fr4 = new FederatedRequest(RequestType.GET_VAR, fr3.getID());
Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2[0], fr3, fr4);
- if(_ins.output.getDataType().isScalar()) {
+ if(output.getDataType().isScalar()) {
double sum = 0;
for(Future<FederatedResponse> fr : tmp)
try {
@@ -121,7 +135,7 @@ public class AggregateTernaryFEDInstruction extends FEDInstruction {
throw new DMLRuntimeException("Federated Get data failed with exception on TernaryFedInstruction", e);
}
- ec.setScalarOutput(_ins.output.getName(), new DoubleObject(sum));
+ ec.setScalarOutput(output.getName(), new DoubleObject(sum));
}
else {
throw new DMLRuntimeException("Not Implemented Federated Ternary Variation");
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index 2e5366e..3183194 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -33,8 +33,11 @@ import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
@@ -70,13 +73,13 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
AggregateUnaryOperator aggun = null;
if(opcode.equalsIgnoreCase("uarimax") || opcode.equalsIgnoreCase("uarimin"))
- aggun = InstructionUtils.parseAggregateUnaryRowIndexOperator(opcode, Integer.parseInt(parts[4]), 1);
+ if(InstructionUtils.getExecType(str) == ExecType.SPARK)
+ aggun = InstructionUtils.parseAggregateUnaryRowIndexOperator(opcode, 1, 1);
+ else
+ aggun = InstructionUtils.parseAggregateUnaryRowIndexOperator(opcode, Integer.parseInt(parts[4]), 1);
else
aggun = InstructionUtils.parseBasicAggregateUnaryOperator(opcode);
- if(InstructionUtils.getExecType(str) == ExecType.SPARK)
- str = InstructionUtils.replaceOperand(str, 4, "-1");
-
FederatedOutput fedOut = null;
if ( parts.length == 5 && !parts[4].equals("uarimin") && !parts[4].equals("uarimax") )
fedOut = FederatedOutput.valueOf(parts[4]);
@@ -89,10 +92,9 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
public void processInstruction(ExecutionContext ec) {
if (getOpcode().contains("var")) {
processVar(ec);
- }else{
+ } else {
processDefault(ec);
}
-
}
private void processDefault(ExecutionContext ec){
@@ -106,9 +108,16 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
// create federated commands for aggregation
// (by default obtain output, even though unnecessary row aggregates)
if ( _fedOut.isForcedFederated() )
- processFederatedOutput(map, in, ec);
- else
- processGetOutput(map, aop, ec, in);
+ if(instString.startsWith("SPARK"))
+ processFederatedSPOutput(map, in, ec, aop);
+ else
+ processFederatedOutput(map, in, ec);
+ else {
+ if(instString.startsWith("SPARK"))
+ processGetSPOutput(map, in, ec, aop);
+ else
+ processGetOutput(map, aop, ec, in);
+ }
}
/**
@@ -207,33 +216,105 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
throw new DMLRuntimeException("Output of " + toString() + " should not be federated "
+ "since the instruction requires consolidation of partial results to be computed.");
}
+
+ boolean isSpark = instString.startsWith("SPARK");
+
AggregateUnaryOperator aop = (AggregateUnaryOperator) _optr;
MatrixObject in = ec.getMatrixObject(input1);
FederationMap map = in.getFedMapping();
+ long id = FederationUtils.getNextFedDataID();
+ FederatedRequest tmpRequest = null;
+ if(isSpark) {
+ if ( output.isScalar() ) {
+ ScalarObject scalarOut = ec.getScalarInput(output);
+ tmpRequest = map.broadcast(scalarOut);
+ id = tmpRequest.getID();
+ }
+ else {
+ if((map.getType() == FederationMap.FType.COL && aop.isColAggregate()) || (map.getType() == FederationMap.FType.ROW && aop.isRowAggregate()))
+ tmpRequest = new FederatedRequest(RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1), in.getDataType());
+ else {
+ DataCharacteristics dc = ec.getDataCharacteristics(output.getName());
+ tmpRequest = new FederatedRequest(RequestType.PUT_VAR, id, dc, in.getDataType());
+ }
+ }
+ }
+
// federated ranges mean for variance
Future<FederatedResponse>[] meanTmp = null;
if (getOpcode().contains("var")) {
String meanInstr = instString.replace(getOpcode(), getOpcode().replace("var", "mean"));
+
//create federated commands for aggregation
- FederatedRequest meanFr1 = FederationUtils.callInstruction(meanInstr, output,
- new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()});
+ FederatedRequest meanFr1 = FederationUtils.callInstruction(meanInstr, output, id,
+ new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()}, isSpark ? ExecType.SPARK : ExecType.CP, isSpark);
FederatedRequest meanFr2 = new FederatedRequest(RequestType.GET_VAR, meanFr1.getID());
FederatedRequest meanFr3 = map.cleanup(getTID(), meanFr1.getID());
- meanTmp = map.execute(getTID(), meanFr1, meanFr2, meanFr3);
+ meanTmp = map.execute(getTID(), isSpark ? new FederatedRequest[] {tmpRequest, meanFr1, meanFr2, meanFr3} : new FederatedRequest[] {meanFr1, meanFr2, meanFr3});
}
//create federated commands for aggregation
- FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
- new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()});
+ FederatedRequest fr1 = FederationUtils.callInstruction(instString, output, id,
+ new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()}, isSpark ? ExecType.SPARK : ExecType.CP, isSpark);
FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
FederatedRequest fr3 = map.cleanup(getTID(), fr1.getID());
//execute federated commands and cleanups
- Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, fr2, fr3);
+ Future<FederatedResponse>[] tmp = map.execute(getTID(), isSpark ? new FederatedRequest[] {tmpRequest, fr1, fr2, fr3} : new FederatedRequest[] { fr1, fr2, fr3});
if( output.isScalar() )
ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, tmp, meanTmp, map));
else
ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, tmp, meanTmp, map));
}
+
+ private void processFederatedSPOutput(FederationMap map, MatrixObject in, ExecutionContext ec, AggregateUnaryOperator aop) {
+ DataCharacteristics dc = ec.getDataCharacteristics(output.getName());
+ FederatedRequest fr1;
+ long id = FederationUtils.getNextFedDataID();
+
+ if((map.getType() == FederationMap.FType.COL && aop.isColAggregate()) ||
+ (map.getType() == FederationMap.FType.ROW && aop.isRowAggregate()))
+ fr1 = new FederatedRequest(RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1), in.getDataType());
+ else
+ fr1 = new FederatedRequest(RequestType.PUT_VAR, id, dc, in.getDataType());
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output, id,
+ new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()}, ExecType.SPARK, true);
+
+ map.execute(getTID(), fr1, fr2);
+ // derive new fed mapping for output
+ MatrixObject out = ec.getMatrixObject(output);
+ out.setFedMapping(in.getFedMapping().copyWithNewID(fr2.getID()));
+ }
+
+ private void processGetSPOutput(FederationMap map, MatrixObject in, ExecutionContext ec, AggregateUnaryOperator aop) {
+ DataCharacteristics dc = ec.getDataCharacteristics(output.getName());
+ FederatedRequest fr1;
+ long id = FederationUtils.getNextFedDataID();
+
+ if ( output.isScalar() ) {
+ ScalarObject scalarOut = ec.getScalarInput(output);
+ fr1 = map.broadcast(scalarOut);
+ id = fr1.getID();
+ }
+ else {
+
+ if((map.getType() == FederationMap.FType.COL && aop.isColAggregate()) || (map.getType() == FederationMap.FType.ROW && aop.isRowAggregate()))
+ fr1 = new FederatedRequest(RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1), in.getDataType());
+ else
+ fr1 = new FederatedRequest(RequestType.PUT_VAR, id, dc, in.getDataType());
+ }
+
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output, id,
+ new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()}, ExecType.SPARK, true);
+ FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
+ FederatedRequest fr4 = map.cleanup(getTID(), fr2.getID());
+
+ //execute federated commands and cleanups
+ Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, fr2, fr3, fr4);
+ if( output.isScalar() )
+ ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, tmp, map));
+ else
+ ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, tmp, map));
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
index c2e7ab1..15ba1f1 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
@@ -32,6 +32,7 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataUtils;
public class AppendFEDInstruction extends BinaryFEDInstruction {
@@ -101,7 +102,8 @@ public class AppendFEDInstruction extends BinaryFEDInstruction {
}
// federated/local, local/federated cbind
else if( (mo1.isFederated(FType.ROW) || mo2.isFederated(FType.ROW)) && _cbind ) {
- boolean isFed = mo1.isFederated(FType.ROW);
+ boolean isFed = mo1.isFederated(FType.ROW) && mo1.isFederatedExcept(FType.BROADCAST);
+ boolean isSpark = instString.contains("SPARK");
MatrixObject moFed = isFed ? mo1 : mo2;
MatrixObject moLoc = isFed ? mo2 : mo1;
@@ -113,7 +115,12 @@ public class AppendFEDInstruction extends BinaryFEDInstruction {
new long[]{ fr1[0].getID(), moFed.getFedMapping().getID()});
//execute federated operations and set output
- moFed.getFedMapping().execute(getTID(), true, fr1, fr2);
+ if(isSpark) {
+ FederatedRequest tmp = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, fr2.getID(), new MatrixCharacteristics(-1, -1), mo1.getDataType());
+ moFed.getFedMapping().execute(getTID(), true, fr1, tmp, fr2);
+ } else {
+ moFed.getFedMapping().execute(getTID(), true, fr1, fr2);
+ }
out.setFedMapping(moFed.getFedMapping().copyWithNewID(fr2.getID(), out.getNumColumns()));
}
// federated/local, local/federated rbind
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
index 0c37a89..c2f07f7 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
@@ -52,7 +52,7 @@ public abstract class BinaryFEDInstruction extends ComputationFEDInstruction {
}
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
- InstructionUtils.checkNumFields(parts, 3, 4, 5);
+ InstructionUtils.checkNumFields(parts, 3, 4, 5, 6);
String opcode = parts[0];
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/CastFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/CastFEDInstruction.java
new file mode 100644
index 0000000..5fd53e0
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/CastFEDInstruction.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.lops.UnaryCP;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+
+public class CastFEDInstruction extends UnaryFEDInstruction {
+
+ private CastFEDInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String istr) {
+ super(FEDInstruction.FEDType.Cast, op, in, out, opcode, istr);
+ }
+
+ public static CastFEDInstruction parseInstruction ( String str ) {
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+ InstructionUtils.checkNumFields(parts, 2);
+ String opcode = parts[0];
+ CPOperand in = new CPOperand(parts[1]);
+ CPOperand out = new CPOperand(parts[2]);
+ return new CastFEDInstruction(null, in, out, opcode, str);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ if(getOpcode().equals(UnaryCP.CAST_AS_MATRIX_OPCODE))
+ processCastAsMatrixVariableInstruction(ec);
+ else if(getOpcode().equals(UnaryCP.CAST_AS_FRAME_OPCODE))
+ processCastAsFrameVariableInstruction(ec);
+ else
+ throw new DMLRuntimeException("Unsupported Opcode for federated Variable Instruction : " + getOpcode());
+ }
+
+ private void processCastAsMatrixVariableInstruction(ExecutionContext ec) {
+
+ FrameObject mo1 = ec.getFrameObject(input1);
+
+ if(!mo1.isFederated())
+ throw new DMLRuntimeException(
+ "Federated Cast: " + "Federated input expected, but invoked w/ " + mo1.isFederated());
+
+ long id = FederationUtils.getNextFedDataID();
+ FederatedRequest fr1 = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1), Types.DataType.MATRIX);
+
+ // execute function at federated site.
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output, id,
+ new CPOperand[] {input1}, new long[] {mo1.getFedMapping().getID()}, Types.ExecType.SPARK, false);
+ mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
+
+ // Construct output local.
+
+ MatrixObject out = ec.getMatrixObject(output);
+ FederationMap outMap = mo1.getFedMapping().copyWithNewID(fr1.getID());
+ List<Pair<FederatedRange, FederatedData>> newMap = new ArrayList<>();
+ for(Pair<FederatedRange, FederatedData> pair : outMap.getMap()) {
+ FederatedData om = pair.getValue();
+ FederatedData nf = new FederatedData(Types.DataType.MATRIX,
+ om.getAddress(), om.getFilepath(), om.getVarID());
+ newMap.add(Pair.of(pair.getKey(), nf));
+ }
+ out.setFedMapping(outMap);
+ }
+
+ private void processCastAsFrameVariableInstruction(ExecutionContext ec) {
+
+ MatrixObject mo1 = ec.getMatrixObject(input1);
+
+ if(!mo1.isFederated())
+ throw new DMLRuntimeException(
+ "Federated Reorg: " + "Federated input expected, but invoked w/ " + mo1.isFederated());
+
+ long id = FederationUtils.getNextFedDataID();
+ FederatedRequest fr1 = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1), Types.DataType.FRAME);
+
+ // execute function at federated site.
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output, id,
+ new CPOperand[] {input1}, new long[] {mo1.getFedMapping().getID()}, Types.ExecType.SPARK, false);
+ mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
+
+ // Construct output local.
+ FrameObject out = ec.getFrameObject(output);
+ out.getDataCharacteristics().set(mo1.getNumRows(), mo1.getNumColumns(), (int) mo1.getBlocksize(), mo1.getNnz());
+ FederationMap outMap = mo1.getFedMapping().copyWithNewID(fr2.getID());
+ List<Pair<FederatedRange, FederatedData>> newMap = new ArrayList<>();
+ for(Map.Entry<FederatedRange, FederatedData> pair : outMap.getMap()) {
+ FederatedData om = pair.getValue();
+ FederatedData nf = new FederatedData(Types.DataType.FRAME,
+ om.getAddress(), om.getFilepath(), om.getVarID());
+ newMap.add(Pair.of(pair.getKey(), nf));
+ }
+ ValueType[] schema = new ValueType[(int) mo1.getDataCharacteristics().getCols()];
+ Arrays.fill(schema, ValueType.FP64);
+ out.setSchema(schema);
+ out.setFedMapping(outMap);
+ }
+
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/CumulativeOffsetFEDInstruction.java
similarity index 62%
copy from src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
copy to src/main/java/org/apache/sysds/runtime/instructions/fed/CumulativeOffsetFEDInstruction.java
index 24a850f..67288c3 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/CumulativeOffsetFEDInstruction.java
@@ -19,11 +19,10 @@
package org.apache.sysds.runtime.instructions.fed;
-import java.util.Arrays;
import java.util.concurrent.Future;
-import org.apache.sysds.common.Types.DataType;
-import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -32,112 +31,129 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.Builtin;
-import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
-import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
-
-public class UnaryMatrixFEDInstruction extends UnaryFEDInstruction {
-
- protected UnaryMatrixFEDInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String instr) {
- super(FEDType.Unary, op, in, out, opcode, instr);
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+
+public class CumulativeOffsetFEDInstruction extends BinaryFEDInstruction
+{
+ private UnaryOperator _uop = null;
+
+ private CumulativeOffsetFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, double init, boolean broadcast, String opcode, String istr) {
+ super(FEDType.CumsumOffset, op, in1, in2, out, opcode, istr);
+
+ if ("bcumoffk+".equals(opcode))
+ _uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucumk+"));
+ else if ("bcumoff*".equals(opcode))
+ _uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucum*"));
+ else if ("bcumoff+*".equals(opcode))
+ _uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucumk+*"));
+ else if ("bcumoffmin".equals(opcode))
+ _uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucummin"));
+ else if ("bcumoffmax".equals(opcode))
+ _uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucummax"));
}
-
- public static boolean isValidOpcode(String opcode) {
- return !LibCommonsMath.isSupportedUnaryOperation(opcode);
- }
-
- public static UnaryMatrixFEDInstruction parseInstruction(String str) {
- CPOperand in = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
- CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
- String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+ public static CumulativeOffsetFEDInstruction parseInstruction ( String str ) {
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType( str );
+ InstructionUtils.checkNumFields(parts, 5);
String opcode = parts[0];
-
- if(parts.length == 5 && (opcode.equalsIgnoreCase("exp") || opcode.equalsIgnoreCase("log") || opcode.startsWith("ucum"))) {
- in.split(parts[1]);
- out.split(parts[2]);
- ValueFunction func = Builtin.getBuiltinFnObject(opcode);
- if( Arrays.asList(new String[]{"ucumk+","ucum*","ucumk+*","ucummin","ucummax","exp","log","sigmoid"}).contains(opcode) ){
- UnaryOperator op = new UnaryOperator(func,Integer.parseInt(parts[3]),Boolean.parseBoolean(parts[4]));
- return new UnaryMatrixFEDInstruction(op, in, out, opcode, str);
- }
- else
- return new UnaryMatrixFEDInstruction(null, in, out, opcode, str);
- }
- opcode = parseUnaryInstruction(str, in, out);
- return new UnaryMatrixFEDInstruction(InstructionUtils.parseUnaryOperator(opcode), in, out, opcode, str);
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand out = new CPOperand(parts[3]);
+ double init = Double.parseDouble(parts[4]);
+ boolean broadcast = Boolean.parseBoolean(parts[5]);
+ return new CumulativeOffsetFEDInstruction(null, in1, in2, out, init, broadcast, opcode, str);
}
-
- @Override
+
+ @Override
public void processInstruction(ExecutionContext ec) {
MatrixObject mo1 = ec.getMatrixObject(input1);
- if(getOpcode().startsWith("ucum") && mo1.isFederated(FederationMap.FType.ROW))
- processCumulativeInstruction(ec, mo1);
+ MatrixObject mo2 = ec.getMatrixObject(input2);
+ if(getOpcode().startsWith("bcumoff") && mo1.isFederated(FederationMap.FType.ROW))
+ processCumulativeInstruction(ec);
else {
//federated execution on arbitrary row/column partitions
//(only assumption for sparse-unsafe: fed mapping covers entire matrix)
- FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
- new CPOperand[] {input1}, new long[] {mo1.getFedMapping().getID()});
- mo1.getFedMapping().execute(getTID(), true, fr1);
+ FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[] {input1, input2}, new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+ FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, fr2.getID(), mo1.getDataCharacteristics(), mo1.getDataType());
+ mo1.getFedMapping().execute(getTID(), true, fr1, fr3, fr2);
- setOutputFedMapping(ec, mo1, fr1.getID());
+ setOutputFedMapping(ec, mo1, fr2.getID());
}
}
- public void processCumulativeInstruction(ExecutionContext ec, MatrixObject mo1) {
+ public void processCumulativeInstruction(ExecutionContext ec) {
+ MatrixObject mo1 = ec.getMatrixObject(input1.getName());
+ MatrixObject mo2 = ec.getMatrixObject(input2.getName());
+ DataCharacteristics mcOut = ec.getDataCharacteristics(output.getName());
+
+ long id = FederationUtils.getNextFedDataID();
+
String opcode = getOpcode();
MatrixObject out;
- if(opcode.equalsIgnoreCase("ucumk+*")) {
- FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
- new CPOperand[] {input1}, new long[] {mo1.getFedMapping().getID()});
+
+ if(opcode.equalsIgnoreCase("bcumoff+*")) {
+ FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, mcOut, mo1.getDataType());
+ FederatedRequest fr4 = mo1.getFedMapping().broadcast(mo2);
+ FederatedRequest fr1 = FederationUtils.callInstruction(instString, output, id,
+ new CPOperand[] {input1, input2}, new long[] {mo1.getFedMapping().getID(), fr4.getID()}, Types.ExecType.SPARK, false);
FederatedRequest fr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
- Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), true, fr3, fr4, fr1, fr2);
out = setOutputFedMapping(ec, mo1, fr1.getID());
- MatrixBlock scalingValues = getScalars(mo1, tmp);
+ MatrixBlock scalingValues = getScalars(mo1, mo2, tmp);
setScalingValues(ec, mo1, out, scalingValues);
}
else {
- String colAgg = opcode.replace("ucum", "uac");
- String agg2 = opcode.replace(opcode.contains("ucumk")? "ucumk" :"ucum", "");
+ String colAgg = opcode.replace("bcumoff", "uac");
+ String agg2 = opcode.replace(opcode.contains("bcumoffk")? "bcumoffk" :"bcumoff", "");
- double init = opcode.equalsIgnoreCase("ucumk+") ? 0.0:
- opcode.equalsIgnoreCase("ucum*") ? 1.0 :
- opcode.equalsIgnoreCase("ucummin") ? Double.MAX_VALUE : -Double.MAX_VALUE;
+ double init = opcode.equalsIgnoreCase("bcumoffk+") ? 0.0:
+ opcode.equalsIgnoreCase("bcumoff*") ? 1.0 :
+ opcode.equalsIgnoreCase("bcumoffmin") ? Double.MAX_VALUE : -Double.MAX_VALUE;
- Future<FederatedResponse>[] tmp = modifyAndGetInstruction(colAgg, mo1);
- MatrixBlock scalingValues = getResultBlock(tmp, (int)mo1.getNumColumns(), opcode, init);
+ Future<FederatedResponse>[] tmp = modifyAndGetInstruction(colAgg, mo1, mo2);
+ MatrixBlock scalingValues = getResultBlock(tmp, (int)mo1.getNumColumns(), opcode, init, _uop);
out = ec.getMatrixObject(output);
setScalingValues(agg2, ec, mo1, out, scalingValues, init);
}
- processCumulative(out);
+ processCumulative(out, mo2);
}
- private Future<FederatedResponse>[] modifyAndGetInstruction(String newInst, MatrixObject mo1) {
+ private Future<FederatedResponse>[] modifyAndGetInstruction(String newInst, MatrixObject mo1, MatrixObject mo2) {
String modifiedInstString = InstructionUtils.replaceOperand(instString, 1, newInst);
-
- FederatedRequest fr1 = FederationUtils.callInstruction(modifiedInstString, output,
- new CPOperand[] {input1}, new long[] {mo1.getFedMapping().getID()});
+ modifiedInstString = InstructionUtils.removeOperand(modifiedInstString, 3);
+ modifiedInstString = InstructionUtils.removeOperand(modifiedInstString, 4);
+ modifiedInstString = InstructionUtils.removeOperand(modifiedInstString, 4);
+ modifiedInstString = InstructionUtils.concatOperands(modifiedInstString, AggBinaryOp.SparkAggType.SINGLE_BLOCK.name());
+
+ long id = FederationUtils.getNextFedDataID();
+ FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1), mo1.getDataType());
+ FederatedRequest fr1 = FederationUtils.callInstruction(modifiedInstString, output, id,
+ new CPOperand[] {input1}, new long[] {mo1.getFedMapping().getID()}, Types.ExecType.SPARK, false);
FederatedRequest fr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
- return mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
+ return mo1.getFedMapping().execute(getTID(), true, fr3, fr1, fr2);
}
- private void processCumulative(MatrixObject out) {
+ private void processCumulative(MatrixObject out, MatrixObject mo2) {
String modifiedInstString = InstructionUtils.replaceOperand(instString, 2, InstructionUtils.createOperand(output));
+ FederatedRequest fr3 = out.getFedMapping().broadcast(mo2);
FederatedRequest fr4 = FederationUtils.callInstruction(modifiedInstString, output, out.getFedMapping().getID(),
- new CPOperand[] {output}, new long[] {out.getFedMapping().getID()});
- out.getFedMapping().execute(getTID(), true, fr4);
-
+ new CPOperand[] {output, input2}, new long[] {out.getFedMapping().getID(), fr3.getID()}, Types.ExecType.SPARK, false);
+ out.getFedMapping().execute(getTID(), true, fr3, fr4);
out.setFedMapping(out.getFedMapping().copyWithNewID(fr4.getID()));
// modify fed ranges since ucumk+* output is always nx1
- if(getOpcode().equalsIgnoreCase("ucumk+*")) {
+ if(getOpcode().equalsIgnoreCase("bcumoff+*")) {
out.getDataCharacteristics().set(out.getNumRows(), 1L, (int) out.getBlocksize());
for(int i = 0; i < out.getFedMapping().getFederatedRanges().length; i++)
out.getFedMapping().getFederatedRanges()[i].setEndDim(1, 1);
@@ -146,9 +162,9 @@ public class UnaryMatrixFEDInstruction extends UnaryFEDInstruction {
}
}
- private static MatrixBlock getResultBlock(Future<FederatedResponse>[] tmp, int cols, String opcode, double init) {
+ private static MatrixBlock getResultBlock(Future<FederatedResponse>[] tmp, int cols, String opcode, double init, UnaryOperator uop) {
//TODO perf simple rbind, as the first row (init) is anyway not transferred
-
+
//collect row vectors into local matrix
MatrixBlock res = new MatrixBlock(tmp.length, cols, init);
for(int i = 0; i < tmp.length-1; i++)
@@ -156,17 +172,17 @@ public class UnaryMatrixFEDInstruction extends UnaryFEDInstruction {
res.copy(i+1, i+1, 0, cols-1, ((MatrixBlock) tmp[i].get().getData()[0]), true);
}
catch(Exception e) {
- throw new DMLRuntimeException("Federated Get data failed with exception on UnaryMatrixFEDInstruction", e);
+ throw new DMLRuntimeException("Federated Get data failed with exception on CumulativeOffsetFEDInstruction", e);
}
//local cumulative aggregate
return res.unaryOperations(
- new UnaryOperator(Builtin.getBuiltinFnObject(opcode)),
+ uop,
new MatrixBlock());
}
- private MatrixBlock getScalars(MatrixObject mo1, Future<FederatedResponse>[] tmp) {
- MatrixBlock[] aggRes = getAggMatrices(mo1);
+ private MatrixBlock getScalars(MatrixObject mo1, MatrixObject mo2, Future<FederatedResponse>[] tmp) {
+ MatrixBlock[] aggRes = getAggMatrices(mo1, mo2);
MatrixBlock prod = aggRes[0];
MatrixBlock firstValues = aggRes[1];
for(int i = 0; i < tmp.length; i++)
@@ -175,7 +191,7 @@ public class UnaryMatrixFEDInstruction extends UnaryFEDInstruction {
prod.setValue(i, 0, curr.getValue(curr.getNumRows()-1, 0));
}
catch(Exception e) {
- throw new DMLRuntimeException("Federated Get data failed with exception on UnaryMatrixFEDInstruction", e);
+ throw new DMLRuntimeException("Federated Get data failed with exception on CumulativeOffsetFEDInstruction", e);
}
// aggregate sumprod to get scalars
@@ -190,8 +206,8 @@ public class UnaryMatrixFEDInstruction extends UnaryFEDInstruction {
return B.binaryOperationsInPlace(InstructionUtils.parseBinaryOperator("+"), firstValues.slice(0,firstValues.getNumRows()-1,0,0));
}
- private MatrixBlock[] getAggMatrices(MatrixObject mo1) {
- Future<FederatedResponse>[] tmp = modifyAndGetInstruction("ucum*", mo1);
+ private MatrixBlock[] getAggMatrices(MatrixObject mo1, MatrixObject mo2) {
+ Future<FederatedResponse>[] tmp = modifyAndGetInstruction("ucum*", mo1, mo2);
// slice and return prod and first value
MatrixBlock prod = new MatrixBlock(tmp.length, 2, 0.0);
@@ -203,7 +219,7 @@ public class UnaryMatrixFEDInstruction extends UnaryFEDInstruction {
firstValues.copy(i, i, 0,1, curr.slice(0, 0), true);
}
catch(Exception e) {
- throw new DMLRuntimeException("Federated Get data failed with exception on UnaryMatrixFEDInstruction", e);
+ throw new DMLRuntimeException("Federated Get data failed with exception on CumulativeOffsetFEDInstruction", e);
}
return new MatrixBlock[] {prod, firstValues};
}
@@ -226,8 +242,8 @@ public class UnaryMatrixFEDInstruction extends UnaryFEDInstruction {
long varID2 = FederationUtils.getNextFedDataID();
ec.setVariable(String.valueOf(varID2), mo2);
- CPOperand opCond = new CPOperand(String.valueOf(condID), ValueType.FP64, DataType.MATRIX);
- CPOperand op2 = new CPOperand(String.valueOf(varID2), ValueType.FP64, DataType.MATRIX);
+ CPOperand opCond = new CPOperand(String.valueOf(condID), Types.ValueType.FP64, Types.DataType.MATRIX);
+ CPOperand op2 = new CPOperand(String.valueOf(varID2), Types.ValueType.FP64, Types.DataType.MATRIX);
String ternaryInstString = InstructionUtils.constructTernaryString(instString, opCond, input1, op2, output);
@@ -247,7 +263,7 @@ public class UnaryMatrixFEDInstruction extends UnaryFEDInstruction {
private void setScalingValues(String opcode, ExecutionContext ec, MatrixObject mo1, MatrixObject out, MatrixBlock scalingValues, double init) {
//TODO perf improvement (currently this creates a sliced broadcast in the size of the original matrix
//but sparse w/ strategically placed offsets, but would need to be dense for dense prod/cumsum)
-
+
//allocated large matrix of init value and placed offset rows in first row of every partition
MatrixBlock mb2 = new MatrixBlock((int) mo1.getNumRows(), (int) mo1.getNumColumns(), init);
for(int i = 1; i < scalingValues.getNumRows(); i++) {
@@ -258,17 +274,18 @@ public class UnaryMatrixFEDInstruction extends UnaryFEDInstruction {
MatrixObject mo2 = ExecutionContext.createMatrixObject(mb2);
long varID2 = FederationUtils.getNextFedDataID();
ec.setVariable(String.valueOf(varID2), mo2);
- CPOperand op2 = new CPOperand(String.valueOf(varID2), ValueType.FP64, DataType.MATRIX);
+ CPOperand op2 = new CPOperand(String.valueOf(varID2), Types.ValueType.FP64, Types.DataType.MATRIX);
String modifiedInstString = InstructionUtils.constructBinaryInstString(instString, opcode, input1, op2, output);
+ long id = FederationUtils.getNextFedDataID();
+ FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1), Types.DataType.MATRIX);
FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
- FederatedRequest fr2 = FederationUtils.callInstruction(modifiedInstString, output,
- new CPOperand[] {input1, op2}, new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
- mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
+ FederatedRequest fr2 = FederationUtils.callInstruction(modifiedInstString, output, id,
+ new CPOperand[] {input1, op2}, new long[] {mo1.getFedMapping().getID(), fr1[0].getID()}, Types.ExecType.SPARK, false);
+ mo1.getFedMapping().execute(getTID(), true, fr1, fr3, fr2);
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID()));
-
ec.removeVariable(op2.getName());
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 0e00faa..010b8b2 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -32,11 +32,16 @@ public abstract class FEDInstruction extends Instruction {
AggregateTernary,
Append,
Binary,
+ Cast,
+ Checkpoint,
+ CSVReblock,
Ctable,
CumulativeAggregate,
+ CumsumOffset,
Init,
MultiReturnParameterizedBuiltin,
MMChain,
+ MAPMM,
MatrixIndexing,
Ternary,
Tsmm,
@@ -44,6 +49,7 @@ public abstract class FEDInstruction extends Instruction {
Quaternary,
QSort,
QPick,
+ Reblock,
Reorg,
Reshape,
SpoofFused,
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 38c3b8b..107edac 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -20,7 +20,7 @@
package org.apache.sysds.runtime.instructions.fed;
import org.apache.commons.lang3.ArrayUtils;
-
+import org.apache.sysds.lops.UnaryCP;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
import org.apache.sysds.runtime.codegen.SpoofMultiAggregate;
import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
@@ -47,28 +47,42 @@ import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinC
import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.TernaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.SpoofCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.TernaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryMatrixCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction.VariableOperationCode;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+import org.apache.sysds.runtime.instructions.spark.AggregateTernarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AppendMSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AppendRSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.BinaryFrameScalarSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinaryMatrixBVectorSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinaryMatrixMatrixSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinaryMatrixScalarSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinaryTensorTensorBroadcastSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinaryTensorTensorSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CastSPInstruction;
import org.apache.sysds.runtime.instructions.spark.CentralMomentSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CtableSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CumulativeOffsetSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.IndexingSPInstruction;
import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.MultiReturnParameterizedBuiltinSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.ParameterizedBuiltinSPInstruction;
import org.apache.sysds.runtime.instructions.spark.QuantilePickSPInstruction;
import org.apache.sysds.runtime.instructions.spark.QuantileSortSPInstruction;
import org.apache.sysds.runtime.instructions.spark.QuaternarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.ReblockSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.ReorgSPInstruction;
import org.apache.sysds.runtime.instructions.spark.SpoofSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.TernarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.UnaryMatrixSPInstruction;
import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.WriteSPInstruction;
@@ -125,7 +139,7 @@ public class FEDInstructionUtils {
ReorgCPInstruction rinst = (ReorgCPInstruction) inst;
CacheableData<?> mo = ec.getCacheableData(rinst.input1);
- if((mo instanceof MatrixObject || mo instanceof FrameObject)
+ if((mo instanceof MatrixObject || mo instanceof FrameObject)
&& mo.isFederatedExcept(FType.BROADCAST) )
fedinst = ReorgFEDInstruction.parseInstruction(
InstructionUtils.concatOperands(rinst.getInstructionString(),FederatedOutput.NONE.name()));
@@ -225,9 +239,9 @@ public class FEDInstructionUtils {
}
else if(inst instanceof AggregateTernaryCPInstruction){
AggregateTernaryCPInstruction ins = (AggregateTernaryCPInstruction) inst;
- if(ins.input1.isMatrix() && ec.getCacheableData(ins.input1).isFederatedExcept(FType.BROADCAST)
+ if(ins.input1.isMatrix() && ec.getCacheableData(ins.input1).isFederatedExcept(FType.BROADCAST)
&& ins.input2.isMatrix() && ec.getCacheableData(ins.input2).isFederatedExcept(FType.BROADCAST)) {
- fedinst = AggregateTernaryFEDInstruction.parseInstruction(ins);
+ fedinst = AggregateTernaryFEDInstruction.parseInstruction(ins.getInstructionString());
}
}
else if(inst instanceof QuaternaryCPInstruction) {
@@ -266,59 +280,144 @@ public class FEDInstructionUtils {
public static Instruction checkAndReplaceSP(Instruction inst, ExecutionContext ec) {
FEDInstruction fedinst = null;
if (inst instanceof MapmmSPInstruction) {
- // FIXME does not yet work for MV multiplication. SPARK execution mode not supported for federated l2svm
MapmmSPInstruction instruction = (MapmmSPInstruction) inst;
Data data = ec.getVariable(instruction.input1);
- if (data instanceof MatrixObject && ((MatrixObject) data).isFederated()) {
- String[] instParts = inst.getInstructionString().split(Instruction.OPERAND_DELIM);
- instParts[1] = "ba+*";
- instParts[5] = "16";
- instParts[6] = instParts[7];
- String instString = InstructionUtils.concatOperands(instParts[0], instParts[1], instParts[2],
- instParts[3], instParts[4], instParts[5], instParts[6]);
- fedinst = AggregateBinaryFEDInstruction.parseInstruction(instString);
+ if (data instanceof MatrixObject && ((MatrixObject) data).isFederatedExcept(FType.BROADCAST)) {
+ fedinst = MapmmFEDInstruction.parseInstruction(instruction.getInstructionString());
+ }
+ }
+ else if(inst instanceof CastSPInstruction){
+ CastSPInstruction ins = (CastSPInstruction) inst;
+ if((ins.getOpcode().equalsIgnoreCase(UnaryCP.CAST_AS_FRAME_OPCODE) || ins.getOpcode().equalsIgnoreCase(UnaryCP.CAST_AS_MATRIX_OPCODE))
+ && ins.input1.isMatrix() && ec.getCacheableData(ins.input1).isFederatedExcept(FType.BROADCAST)){
+ fedinst = CastFEDInstruction.parseInstruction(ins.getInstructionString());
+ }
+ }
+ else if (inst instanceof WriteSPInstruction) {
+ WriteSPInstruction instruction = (WriteSPInstruction) inst;
+ Data data = ec.getVariable(instruction.input1);
+ if (data instanceof CacheableData && ((CacheableData<?>) data).isFederated()) {
+ // Write spark instruction can not be executed for federated matrix objects (tries to get rdds which do
+ // not exist), therefore we replace the instruction with the VariableCPInstruction.
+ return VariableCPInstruction.parseInstruction(instruction.getInstructionString());
+ }
+ }
+ else if(inst instanceof QuaternarySPInstruction) {
+ QuaternarySPInstruction instruction = (QuaternarySPInstruction) inst;
+ Data data = ec.getVariable(instruction.input1);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederated())
+ fedinst = QuaternaryFEDInstruction.parseInstruction(instruction.getInstructionString());
+ }
+ else if(inst instanceof SpoofSPInstruction) {
+ SpoofSPInstruction ins = (SpoofSPInstruction) inst;
+ Class<?> scla = ins.getOperatorClass().getSuperclass();
+ if(((scla == SpoofCellwise.class || scla == SpoofMultiAggregate.class || scla == SpoofOuterProduct.class)
+ && SpoofFEDInstruction.isFederated(ec, ins.getInputs(), scla))
+ || (scla == SpoofRowwise.class && SpoofFEDInstruction.isFederated(ec, FType.ROW, ins.getInputs(), scla))) {
+ fedinst = SpoofFEDInstruction.parseInstruction(inst.getInstructionString());
}
}
- else if (inst instanceof UnarySPInstruction) {
+ else if (inst instanceof UnarySPInstruction && ! (inst instanceof IndexingSPInstruction)) {
+ UnarySPInstruction instruction = (UnarySPInstruction) inst;
if (inst instanceof CentralMomentSPInstruction) {
- CentralMomentSPInstruction instruction = (CentralMomentSPInstruction) inst;
- Data data = ec.getVariable(instruction.input1);
- if (data instanceof MatrixObject && ((MatrixObject) data).isFederated())
+ CentralMomentSPInstruction cinstruction = (CentralMomentSPInstruction) inst;
+ Data data = ec.getVariable(cinstruction.input1);
+ if (data instanceof MatrixObject && ((MatrixObject) data).isFederated() && ((MatrixObject) data).isFederatedExcept(FType.BROADCAST))
fedinst = CentralMomentFEDInstruction.parseInstruction(inst.getInstructionString());
} else if (inst instanceof QuantileSortSPInstruction) {
- QuantileSortSPInstruction instruction = (QuantileSortSPInstruction) inst;
- Data data = ec.getVariable(instruction.input1);
- if (data instanceof MatrixObject && ((MatrixObject) data).isFederated())
+ QuantileSortSPInstruction qinstruction = (QuantileSortSPInstruction) inst;
+ Data data = ec.getVariable(qinstruction.input1);
+ if (data instanceof MatrixObject && ((MatrixObject) data).isFederated() && ((MatrixObject) data).isFederatedExcept(FType.BROADCAST))
fedinst = QuantileSortFEDInstruction.parseInstruction(inst.getInstructionString());
}
else if (inst instanceof AggregateUnarySPInstruction) {
- AggregateUnarySPInstruction instruction = (AggregateUnarySPInstruction) inst;
- Data data = ec.getVariable(instruction.input1);
- if(data instanceof MatrixObject && ((MatrixObject) data).isFederated())
- fedinst = AggregateUnaryFEDInstruction.parseInstruction(
- InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
+ AggregateUnarySPInstruction auinstruction = (AggregateUnarySPInstruction) inst;
+ Data data = ec.getVariable(auinstruction.input1);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederated() && ((MatrixObject) data).isFederatedExcept(FType.BROADCAST))
+ if(ArrayUtils.contains(new String[]{"uarimin", "uarimax"}, auinstruction.getOpcode())) {
+ if(((MatrixObject) data).getFedMapping().getType() == FType.ROW)
+ fedinst = AggregateUnaryFEDInstruction.parseInstruction(
+ InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
+ }
+ else
+ fedinst = AggregateUnaryFEDInstruction.parseInstruction(InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
+ }
+ else if(inst instanceof ReorgSPInstruction && (inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag")
+ || inst.getOpcode().equals("rev"))) {
+ ReorgSPInstruction rinst = (ReorgSPInstruction) inst;
+ CacheableData<?> mo = ec.getCacheableData(rinst.input1);
+ if((mo instanceof MatrixObject || mo instanceof FrameObject) && mo.isFederated() && ((MatrixObject) mo).isFederatedExcept(FType.BROADCAST))
+ fedinst = ReorgFEDInstruction.parseInstruction(
+ InstructionUtils.concatOperands(rinst.getInstructionString(), FederatedOutput.NONE.name()));
+ }
+ else if(inst instanceof ReblockSPInstruction && instruction.input1 != null && (instruction.input1.isFrame() || instruction.input1.isMatrix())) {
+ ReblockSPInstruction rinst = (ReblockSPInstruction) instruction;
+ CacheableData<?> data = ec.getCacheableData(rinst.input1);
+ if(data.isFederatedExcept(FType.BROADCAST))
+ fedinst = ReblockFEDInstruction.parseInstruction(inst.getInstructionString());
+ }
+ else if(instruction.input1 != null && instruction.input1.isMatrix() && ec.containsVariable(instruction.input1)) {
+ MatrixObject mo1 = ec.getMatrixObject(instruction.input1);
+ if(mo1.isFederatedExcept(FType.BROADCAST)) {
+ if(instruction.getOpcode().equalsIgnoreCase("cm"))
+ fedinst = CentralMomentFEDInstruction.parseInstruction(inst.getInstructionString());
+ else if(inst.getOpcode().equalsIgnoreCase("qsort")) {
+ if(mo1.getFedMapping().getFederatedRanges().length == 1)
+ fedinst = QuantileSortFEDInstruction.parseInstruction(inst.getInstructionString());
+ }
+ else if(inst.getOpcode().equalsIgnoreCase("rshape")) {
+ fedinst = ReshapeFEDInstruction.parseInstruction(inst.getInstructionString());
+ }
+ else if(inst instanceof UnaryMatrixSPInstruction) {
+ if(UnaryMatrixFEDInstruction.isValidOpcode(inst.getOpcode()))
+ fedinst = UnaryMatrixFEDInstruction.parseInstruction(inst.getInstructionString());
+ }
+ }
}
}
- else if(inst instanceof BinarySPInstruction) {
+ else if (inst instanceof BinarySPInstruction) {
+ BinarySPInstruction instruction = (BinarySPInstruction) inst;
+
if(inst instanceof QuantilePickSPInstruction) {
- QuantilePickSPInstruction instruction = (QuantilePickSPInstruction) inst;
- Data data = ec.getVariable(instruction.input1);
- if(data instanceof MatrixObject && ((MatrixObject) data).isFederated())
+ QuantilePickSPInstruction qinstruction = (QuantilePickSPInstruction) inst;
+ Data data = ec.getVariable(qinstruction.input1);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederatedExcept(FType.BROADCAST))
fedinst = QuantilePickFEDInstruction.parseInstruction(inst.getInstructionString());
}
else if (inst instanceof AppendGAlignedSPInstruction) {
- // TODO other Append Spark instructions
- AppendGAlignedSPInstruction instruction = (AppendGAlignedSPInstruction) inst;
- Data data = ec.getVariable(instruction.input1);
- if (data instanceof MatrixObject && ((MatrixObject) data).isFederated()) {
+ AppendGAlignedSPInstruction ainstruction = (AppendGAlignedSPInstruction) inst;
+ Data data1 = ec.getVariable(ainstruction.input1);
+ Data data2 = ec.getVariable(ainstruction.input2);
+ if (data1 instanceof MatrixObject && ((MatrixObject) data1).isFederatedExcept(FType.BROADCAST)
+ && (! ((CacheableData<?>)data2).isFederated() || ((CacheableData<?>)data2).isFederatedExcept(FType.BROADCAST))) {
fedinst = AppendFEDInstruction.parseInstruction(instruction.getInstructionString());
}
}
else if (inst instanceof AppendGSPInstruction) {
- AppendGSPInstruction instruction = (AppendGSPInstruction) inst;
- Data data = ec.getVariable(instruction.input1);
- if(data instanceof MatrixObject && ((MatrixObject) data).isFederated()) {
- fedinst = AppendFEDInstruction.parseInstruction(instruction.getInstructionString());
+ AppendGSPInstruction ainstruction = (AppendGSPInstruction) inst;
+ Data data1 = ec.getVariable(ainstruction.input1);
+ Data data2 = ec.getVariable(ainstruction.input2);
+ if(data1 instanceof MatrixObject && ((MatrixObject) data1).isFederatedExcept(FType.BROADCAST)
+ && (! ((CacheableData<?>)data2).isFederated() || ((CacheableData<?>)data2).isFederatedExcept(FType.BROADCAST))) {
+ fedinst = AppendFEDInstruction.parseInstruction(ainstruction.getInstructionString());
+ }
+ }
+ else if (inst instanceof AppendMSPInstruction) {
+ AppendMSPInstruction ainstruction = (AppendMSPInstruction) inst;
+ Data data1 = ec.getVariable(ainstruction.input1);
+ Data data2 = ec.getVariable(ainstruction.input2);
+ if(((CacheableData<?>) data1).isFederatedExcept(FType.BROADCAST) && (! ((CacheableData<?>)data2).isFederated()
+ || ((CacheableData<?>)data2).isFederatedExcept(FType.BROADCAST))) {
+ fedinst = AppendFEDInstruction.parseInstruction(ainstruction.getInstructionString());
+ }
+ }
+ else if (inst instanceof AppendRSPInstruction) {
+ AppendRSPInstruction ainstruction = (AppendRSPInstruction) inst;
+ Data data1 = ec.getVariable(ainstruction.input1);
+ Data data2 = ec.getVariable(ainstruction.input2);
+ if(((CacheableData<?>) data1).isFederatedExcept(FType.BROADCAST) && (! ((CacheableData<?>)data2).isFederated()
+ || ((CacheableData<?>)data2).isFederatedExcept(FType.BROADCAST))) {
+ fedinst = AppendFEDInstruction.parseInstruction(ainstruction.getInstructionString());
}
}
else if (inst instanceof BinaryMatrixScalarSPInstruction
@@ -326,39 +425,76 @@ public class FEDInstructionUtils {
|| inst instanceof BinaryMatrixBVectorSPInstruction
|| inst instanceof BinaryTensorTensorSPInstruction
|| inst instanceof BinaryTensorTensorBroadcastSPInstruction) {
- BinarySPInstruction instruction = (BinarySPInstruction) inst;
Data data = ec.getVariable(instruction.input1);
- if((data instanceof MatrixObject && ((MatrixObject)data).isFederated())
- || (data instanceof TensorObject && ((TensorObject)data).isFederated())) {
+ if((data instanceof MatrixObject && ((MatrixObject)data).isFederatedExcept(FType.BROADCAST))
+ || (data instanceof TensorObject && ((TensorObject)data).isFederatedExcept(FType.BROADCAST))) {
fedinst = BinaryFEDInstruction.parseInstruction(
InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
}
}
+ else if(inst.getOpcode().equals("_map") && inst instanceof BinaryFrameScalarSPInstruction && !inst.getInstructionString().contains("UtilFunctions")
+ && instruction.input1.isFrame() && ec.getFrameObject(instruction.input1).isFederated()) {
+ fedinst = BinaryFrameScalarFEDInstruction.parseInstruction(InstructionUtils
+ .concatOperands(inst.getInstructionString(), FederatedOutput.NONE.name()));
+ }
+ else if( (instruction.input1.isMatrix() && ec.getCacheableData(instruction.input1).isFederatedExcept(FType.BROADCAST))
+ || (instruction.input2.isMatrix() && ec.getMatrixObject(instruction.input2).isFederatedExcept(FType.BROADCAST))) {
+ if("cov".equals(instruction.getOpcode()) && (ec.getMatrixObject(instruction.input1)
+ .isFederated(FType.ROW) || ec.getMatrixObject(instruction.input2).isFederated(FType.ROW)))
+ fedinst = CovarianceFEDInstruction.parseInstruction(inst.getInstructionString());
+ else if(inst instanceof CumulativeOffsetSPInstruction) {
+ fedinst = CumulativeOffsetFEDInstruction.parseInstruction(inst.getInstructionString());
+ }
+ else
+ fedinst = BinaryFEDInstruction.parseInstruction(InstructionUtils.concatOperands(inst.getInstructionString(), FederatedOutput.NONE.name()));
+ }
}
- else if (inst instanceof WriteSPInstruction) {
- WriteSPInstruction instruction = (WriteSPInstruction) inst;
- Data data = ec.getVariable(instruction.input1);
- if (data instanceof MatrixObject && ((MatrixObject) data).isFederated()) {
- // Write spark instruction can not be executed for federated matrix objects (tries to get rdds which do
- // not exist), therefore we replace the instruction with the VariableCPInstruction.
- return VariableCPInstruction.parseInstruction(instruction.getInstructionString());
+ else if( inst instanceof ParameterizedBuiltinSPInstruction) {
+ ParameterizedBuiltinSPInstruction pinst = (ParameterizedBuiltinSPInstruction) inst;
+ if( pinst.getOpcode().equalsIgnoreCase("replace") && pinst.getTarget(ec).isFederatedExcept(FType.BROADCAST) )
+ fedinst = ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
+ }
+ else if (inst instanceof MultiReturnParameterizedBuiltinSPInstruction) {
+ MultiReturnParameterizedBuiltinSPInstruction minst = (MultiReturnParameterizedBuiltinSPInstruction) inst;
+ if(minst.getOpcode().equals("transformencode") && minst.input1.isFrame()) {
+ CacheableData<?> fo = ec.getCacheableData(minst.input1);
+ if(fo.isFederatedExcept(FType.BROADCAST)) {
+ fedinst = MultiReturnParameterizedBuiltinFEDInstruction.parseInstruction(minst.getInstructionString());
+ }
}
}
- else if(inst instanceof QuaternarySPInstruction) {
- QuaternarySPInstruction instruction = (QuaternarySPInstruction) inst;
- Data data = ec.getVariable(instruction.input1);
- if(data instanceof MatrixObject && ((MatrixObject) data).isFederated())
- fedinst = QuaternaryFEDInstruction.parseInstruction(instruction.getInstructionString());
+ else if(inst instanceof IndexingSPInstruction) {
+ // matrix and frame indexing
+ IndexingSPInstruction minst = (IndexingSPInstruction) inst;
+ if((minst.input1.isMatrix() || minst.input1.isFrame())
+ && ec.getCacheableData(minst.input1).isFederatedExcept(FType.BROADCAST)) {
+ fedinst = IndexingFEDInstruction.parseInstruction(minst.getInstructionString());
+ }
}
- else if(inst instanceof SpoofSPInstruction) {
- SpoofSPInstruction ins = (SpoofSPInstruction) inst;
- Class<?> scla = ins.getOperatorClass().getSuperclass();
- if(((scla == SpoofCellwise.class || scla == SpoofMultiAggregate.class || scla == SpoofOuterProduct.class)
- && SpoofFEDInstruction.isFederated(ec, ins.getInputs(), scla))
- || (scla == SpoofRowwise.class && SpoofFEDInstruction.isFederated(ec, FType.ROW, ins.getInputs(), scla))) {
- fedinst = SpoofFEDInstruction.parseInstruction(inst.getInstructionString());
+ else if(inst instanceof TernarySPInstruction) {
+ TernarySPInstruction tinst = (TernarySPInstruction) inst;
+ if((tinst.input1.isMatrix() && ec.getCacheableData(tinst.input1).isFederatedExcept(FType.BROADCAST))
+ || (tinst.input2.isMatrix() && ec.getCacheableData(tinst.input2).isFederatedExcept(FType.BROADCAST))
+ || (tinst.input3.isMatrix() && ec.getCacheableData(tinst.input3).isFederatedExcept(FType.BROADCAST))) {
+ fedinst = TernaryFEDInstruction.parseInstruction(tinst.getInstructionString());
}
}
+ else if(inst instanceof AggregateTernarySPInstruction){
+ AggregateTernarySPInstruction ins = (AggregateTernarySPInstruction) inst;
+ if(ins.input1.isMatrix() && ec.getCacheableData(ins.input1).isFederatedExcept(FType.BROADCAST) && ins.input2.isMatrix() &&
+ ec.getCacheableData(ins.input2).isFederatedExcept(FType.BROADCAST)) {
+ fedinst = AggregateTernaryFEDInstruction.parseInstruction(ins.getInstructionString());
+ }
+ }
+ else if(inst instanceof CtableSPInstruction) {
+ CtableSPInstruction cinst = (CtableSPInstruction) inst;
+ if(inst.getOpcode().equalsIgnoreCase("ctable")
+ && ( ec.getCacheableData(cinst.input1).isFederated(FType.ROW)
+ || (cinst.input2.isMatrix() && ec.getCacheableData(cinst.input2).isFederated(FType.ROW))
+ || (cinst.input3.isMatrix() && ec.getCacheableData(cinst.input3).isFederated(FType.ROW))))
+ fedinst = CtableFEDInstruction.parseInstruction(cinst.getInstructionString());
+ }
+
//set thread id for federated context management
if( fedinst != null ) {
fedinst.setTID(ec.getTID());
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
index ef0223d..92968d8 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
@@ -43,6 +43,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.IndexRange;
public final class IndexingFEDInstruction extends UnaryFEDInstruction {
@@ -79,7 +80,7 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
String opcode = parts[0];
if(opcode.equalsIgnoreCase(RightIndex.OPCODE)) {
- if(parts.length == 7) {
+ if(parts.length == 7 || parts.length == 8) {
CPOperand in, rl, ru, cl, cu, out;
in = new CPOperand(parts[1]);
rl = new CPOperand(parts[2]);
@@ -97,8 +98,8 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
throw new DMLRuntimeException("Invalid number of operands in instruction: " + str);
}
}
- else if(opcode.equalsIgnoreCase(LeftIndex.OPCODE)) {
- if ( parts.length == 8 ) {
+ else if(opcode.equalsIgnoreCase(LeftIndex.OPCODE) || opcode.equalsIgnoreCase("mapLeftIndex")) {
+ if ( parts.length == 8 || parts.length == 9) {
CPOperand lhsInput, rhsInput, rl, ru, cl, cu, out;
lhsInput = new CPOperand(parts[1]);
rhsInput = new CPOperand(parts[2]);
@@ -175,8 +176,13 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
}
i++;
}
- FederatedRequest[] fr1 = FederationUtils.callInstruction(instStrings,
- output, new CPOperand[] {input1}, new long[] {fedMap.getID()});
+
+ long id = FederationUtils.getNextFedDataID();
+ FederatedRequest tmp = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, in.getMetaData().getDataCharacteristics(), in.getDataType());
+
+ FederatedRequest[] fr1 = FederationUtils.callInstruction(instStrings, output, id,
+ new CPOperand[] {input1}, new long[] {fedMap.getID()}, InstructionUtils.getExecType(instString));
+ fedMap.execute(getTID(), true, tmp);
fedMap.execute(getTID(), true, fr1, new FederatedRequest[0]);
if(input1.isFrame()) {
@@ -260,9 +266,13 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
sliceIxs = Arrays.stream(sliceIxs).filter(Objects::nonNull).toArray(int[][] :: new);
+ long id = FederationUtils.getNextFedDataID();
+ FederatedRequest tmp = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1), in1.getDataType());
+ fedMap.execute(getTID(), true, tmp);
+
FederatedRequest[] fr1 = fedMap.broadcastSliced(in2, input2.isFrame(), sliceIxs);
- FederatedRequest[] fr2 = FederationUtils.callInstruction(instStrings, output, new CPOperand[]{input1, input2},
- new long[]{fedMap.getID(), fr1[0].getID()});
+ FederatedRequest[] fr2 = FederationUtils.callInstruction(instStrings, output, id, new CPOperand[]{input1, input2},
+ new long[]{fedMap.getID(), fr1[0].getID()}, null);
FederatedRequest fr3 = fedMap.cleanup(getTID(), fr1[0].getID());
//execute federated instruction and cleanup intermediates
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MapmmFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MapmmFEDInstruction.java
new file mode 100644
index 0000000..0d09411
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MapmmFEDInstruction.java
@@ -0,0 +1,269 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.concurrent.Future;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.lops.MapMult;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+
+public class MapmmFEDInstruction extends BinaryFEDInstruction
+{
+ private MapmmFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, MapMult.CacheType type,
+ boolean outputEmpty, AggBinaryOp.SparkAggType aggtype, String opcode, String istr) {
+ super(FEDType.MAPMM, op, in1, in2, out, opcode, istr);
+ }
+
+ public static MapmmFEDInstruction parseInstruction( String str ) {
+ String parts[] = InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode = parts[0];
+
+ if(!opcode.equalsIgnoreCase(MapMult.OPCODE))
+ throw new DMLRuntimeException("MapmmSPInstruction.parseInstruction():: Unknown opcode " + opcode);
+
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand out = new CPOperand(parts[3]);
+ MapMult.CacheType type = MapMult.CacheType.valueOf(parts[4]);
+ boolean outputEmpty = Boolean.parseBoolean(parts[5]);
+ AggBinaryOp.SparkAggType aggtype = AggBinaryOp.SparkAggType.valueOf(parts[6]);
+
+ AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(1);
+ return new MapmmFEDInstruction(aggbin, in1, in2, out, type, outputEmpty, aggtype, opcode, str);
+ }
+
+ public void processInstruction(ExecutionContext ec) {
+ MatrixObject mo1 = ec.getMatrixObject(input1);
+ MatrixObject mo2 = ec.getMatrixObject(input2);
+
+ long id = FederationUtils.getNextFedDataID();
+ FederatedRequest frEmpty = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id,
+ new MatrixCharacteristics(-1, -1), Types.DataType.MATRIX);
+
+ //TODO cleanup unnecessary redundancy
+ //#1 federated matrix-vector multiplication
+ if(mo1.isFederated(FType.COL) && mo2.isFederated(FType.ROW)
+ && mo1.getFedMapping().isAligned(mo2.getFedMapping(), AlignType.COL_T) ) {
+ FederatedRequest fr1 = FederationUtils.callInstruction(instString, output, id,
+ new CPOperand[]{input1, input2},
+ new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()}, Types.ExecType.SPARK,false);
+
+ if ( _fedOut.isForcedFederated() ){
+ mo1.getFedMapping().execute(getTID(), frEmpty, fr1);
+ setPartialOutput(mo1.getFedMapping(), mo1, mo2, fr1.getID(), ec);
+ }
+ else {
+ FederatedRequest fr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
+ FederatedRequest fr3 = mo2.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), frEmpty, fr1, fr2, fr3);
+ MatrixBlock ret = FederationUtils.aggAdd(tmp);
+ ec.setMatrixOutput(output.getName(), ret);
+ }
+ }
+ else if(mo1.isFederated(FType.ROW) || mo1.isFederated(FType.PART)) { // MV + MM
+ //construct commands: broadcast rhs, fed mv, retrieve results
+ FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output, id,
+ new CPOperand[]{input1, input2},
+ new long[]{mo1.getFedMapping().getID(), fr1.getID()}, Types.ExecType.SPARK, false);
+ if( mo2.getNumColumns() == 1 ) { //MV
+ if ( _fedOut.isForcedFederated() ){
+ mo1.getFedMapping().execute(getTID(), frEmpty, fr1, fr2);
+ if ( mo1.isFederated(FType.PART) )
+ setPartialOutput(mo1.getFedMapping(), mo1, mo2, fr2.getID(), ec);
+ else
+ setOutputFedMapping(mo1.getFedMapping(), mo1, mo2, fr2.getID(), ec);
+ }
+ else {
+ FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(), fr2.getID());
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), frEmpty, fr1, fr2, fr3, fr4);
+ MatrixBlock ret;
+ if ( mo1.isFederated(FType.PART) )
+ ret = FederationUtils.aggAdd(tmp);
+ else
+ ret = FederationUtils.bind(tmp, false);
+ ec.setMatrixOutput(output.getName(), ret);
+ }
+ }
+ else { //MM
+ //execute federated operations and aggregate
+ if ( !_fedOut.isForcedLocal() ){
+ mo1.getFedMapping().execute(getTID(), true, frEmpty, fr1, fr2);
+ if ( mo1.isFederated(FType.PART) || mo2.isFederated(FType.PART) )
+ setPartialOutput(mo1.getFedMapping(), mo1, mo2, fr2.getID(), ec);
+ else
+ setOutputFedMapping(mo1.getFedMapping(), mo1, mo2, fr2.getID(), ec);
+ }
+ else {
+ FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(), fr2.getID());
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), frEmpty, fr1, fr2, fr3, fr4);
+ MatrixBlock ret;
+ if ( mo1.isFederated(FType.PART) )
+ ret = FederationUtils.aggAdd(tmp);
+ else
+ ret = FederationUtils.bind(tmp, false);
+ ec.setMatrixOutput(output.getName(), ret);
+ }
+ }
+ }
+ //#2 vector - federated matrix multiplication
+ else if (mo2.isFederated(FType.ROW)) {// VM + MM
+ if ( mo1.isFederated(FType.COL) && isAggBinaryFedAligned(mo1,mo2) ){
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2},
+ new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()}, true);
+ if ( _fedOut.isForcedFederated() ){
+ // Partial aggregates (set fedmapping to the partial aggs)
+ mo2.getFedMapping().execute(getTID(), true, fr2);
+ setPartialOutput(mo2.getFedMapping(), mo1, mo2, fr2.getID(), ec);
+ }
+ else {
+ FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo2.getFedMapping().execute(getTID(), fr2, fr3);
+ MatrixBlock ret = FederationUtils.aggAdd(tmp);
+ ec.setMatrixOutput(output.getName(), ret);
+ }
+ }
+ else {
+ //construct commands: broadcast rhs, fed mv, retrieve results
+ FederatedRequest[] fr1 = mo2.getFedMapping().broadcastSliced(mo1, true);
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2},
+ new long[]{fr1[0].getID(), mo2.getFedMapping().getID()}, true);
+ if ( _fedOut.isForcedFederated() ){
+ // Partial aggregates (set fedmapping to the partial aggs)
+ mo2.getFedMapping().execute(getTID(), true, fr1, fr2);
+ setPartialOutput(mo2.getFedMapping(), mo1, mo2, fr2.getID(), ec);
+ }
+ else {
+ FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 = mo2.getFedMapping().cleanup(getTID(), fr2.getID());
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo2.getFedMapping().execute(getTID(), true, fr1, fr2, fr3, fr4);
+ MatrixBlock ret = FederationUtils.aggAdd(tmp);
+ ec.setMatrixOutput(output.getName(), ret);
+ }
+ }
+ }
+ //#3 col-federated matrix vector multiplication
+ else if (mo1.isFederated(FType.COL)) {// VM + MM
+ //construct commands: broadcast rhs, fed mv, retrieve results
+ FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, true);
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2},
+ new long[]{mo1.getFedMapping().getID(), fr1[0].getID()}, true);
+ if ( _fedOut.isForcedFederated() ){
+ // Partial aggregates (set fedmapping to the partial aggs)
+ mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
+ setPartialOutput(mo1.getFedMapping(), mo1, mo2, fr2.getID(), ec);
+ }
+ else {
+ FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(), fr2.getID());
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
+ MatrixBlock ret = FederationUtils.aggAdd(tmp);
+ ec.setMatrixOutput(output.getName(), ret);
+ }
+ }
+ else { //other combinations
+ throw new DMLRuntimeException("Federated AggregateBinary not supported with the "
+ + "following federated objects: "+mo1.isFederated()+":"+mo1.getFedMapping()
+ +" "+mo2.isFederated()+":"+mo2.getFedMapping());
+ }
+ }
+
+ /**
+ * Checks alignment of dimensions for the federated aggregate binary processing without broadcast.
+ * If the begin and end ranges of mo1 has cols equal to the rows of the begin and end ranges of mo2,
+ * the two inputs are aligned for the processing of the federated aggregate binary instruction without broadcasting.
+ * @param mo1 input matrix object 1
+ * @param mo2 input matrix object 2
+ * @return true if the two inputs are aligned for aggregate binary processing without broadcasting
+ */
+ private static boolean isAggBinaryFedAligned(MatrixObject mo1, MatrixObject mo2){
+ FederatedRange[] mo1FederatedRanges = mo1.getFedMapping().getFederatedRanges();
+ FederatedRange[] mo2FederatedRanges = mo2.getFedMapping().getFederatedRanges();
+ for ( int i = 0; i < mo1FederatedRanges.length; i++ ){
+ FederatedRange mo1FedRange = mo1FederatedRanges[i];
+ FederatedRange mo2FedRange = mo2FederatedRanges[i];
+
+ if ( mo1FedRange.getBeginDims()[1] != mo2FedRange.getBeginDims()[0]
+ || mo1FedRange.getEndDims()[1] != mo2FedRange.getEndDims()[0])
+ return false;
+ }
+ return true;
+ }
+
+ /**
+ * Sets the output with a federated mapping of overlapping partial aggregates.
+ * @param federationMap federated map from which the federated metadata is retrieved
+ * @param mo1 matrix object with number of rows used to set the number of rows of the output
+ * @param mo2 matrix object with number of columns used to set the number of columns of the output
+ * @param outputID ID of the output
+ * @param ec execution context
+ */
+ private void setPartialOutput(FederationMap federationMap, MatrixObject mo1, MatrixObject mo2,
+ long outputID, ExecutionContext ec){
+ MatrixObject out = ec.getMatrixObject(output);
+ out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), (int)mo1.getBlocksize());
+ FederationMap outputFedMap = federationMap
+ .copyWithNewIDAndRange(mo1.getNumRows(), mo2.getNumColumns(), outputID);
+ out.setFedMapping(outputFedMap);
+ }
+
+ /**
+ * Sets the output with a federated map copied from federationMap input given.
+ * @param federationMap federation map to be set in output
+ * @param mo1 matrix object with number of rows used to set the number of rows of the output
+ * @param mo2 matrix object with number of columns used to set the number of columns of the output
+ * @param outputID ID of the output
+ * @param ec execution context
+ */
+ private void setOutputFedMapping(FederationMap federationMap, MatrixObject mo1, MatrixObject mo2,
+ long outputID, ExecutionContext ec){
+ MatrixObject out = ec.getMatrixObject(output);
+ out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), (int)mo1.getBlocksize());
+ out.setFedMapping(federationMap.copyWithNewID(outputID, mo2.getNumColumns()));
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReblockFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReblockFEDInstruction.java
new file mode 100644
index 0000000..40feddb
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReblockFEDInstruction.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.meta.MetaDataFormat;
+
+public class ReblockFEDInstruction extends UnaryFEDInstruction {
+ private int blen;
+
+ private ReblockFEDInstruction(Operator op, CPOperand in, CPOperand out, int br, int bc, boolean emptyBlocks,
+ String opcode, String instr) {
+ super(FEDInstruction.FEDType.Reblock, op, in, out, opcode, instr);
+ blen = br;
+ blen = bc;
+ }
+
+ public static ReblockFEDInstruction parseInstruction(String str) {
+ String parts[] = InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode = parts[0];
+
+ if(!opcode.equals("rblk")) {
+ throw new DMLRuntimeException("Incorrect opcode for ReblockFEDInstruction:" + opcode);
+ }
+
+ CPOperand in = new CPOperand(parts[1]);
+ CPOperand out = new CPOperand(parts[2]);
+ int blen=Integer.parseInt(parts[3]);
+ boolean outputEmptyBlocks = Boolean.parseBoolean(parts[4]);
+
+ Operator op = null; // no operator for ReblockFEDInstruction
+ return new ReblockFEDInstruction(op, in, out, blen, blen, outputEmptyBlocks, opcode, str);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ //set the output characteristics
+ CacheableData<?> obj = ec.getCacheableData(input1.getName());
+ DataCharacteristics mc = ec.getDataCharacteristics(input1.getName());
+ DataCharacteristics mcOut = ec.getDataCharacteristics(output.getName());
+ mcOut.set(mc.getRows(), mc.getCols(), blen, mc.getNonZeros());
+
+ //get the source format from the meta data
+ MetaDataFormat iimd = (MetaDataFormat) obj.getMetaData();
+ if(iimd == null)
+ throw new DMLRuntimeException("Error ReblockFEDInstruction: Metadata not found");
+
+ long id = FederationUtils.getNextFedDataID();
+ FederatedRequest[] fr1 = new FederatedRequest[obj.getFedMapping().getSize()];
+ int i = 0;
+ for(FederatedRange range : obj.getFedMapping().getFederatedRanges()) {
+ fr1[i] = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id,
+ new MatrixCharacteristics(range.getSize(0), range.getSize(1)), obj.getDataType());
+ i++;
+ }
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output, id,
+ new CPOperand[]{input1}, new long[]{ obj.getFedMapping().getID()}, Types.ExecType.SPARK, false);
+
+ //execute federated operations and set output
+ obj.getFedMapping().execute(getTID(), true, fr1, fr2);
+ CacheableData<?> out = ec.getCacheableData(output);
+ out.setFedMapping(obj.getFedMapping().copyWithNewID(fr2.getID()));
+ out.getDataCharacteristics().set(mcOut);
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index cb3074f..c32b15b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -51,8 +51,10 @@ import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
public class ReorgFEDInstruction extends UnaryFEDInstruction {
+ @SuppressWarnings("unused")
private static boolean fedoutFlagInString = false;
public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opcode, String istr, FederatedOutput fedOut) {
@@ -73,8 +75,9 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
InstructionUtils.checkNumFields(str, 2, 3, 4);
in.split(parts[1]);
out.split(parts[2]);
- int k = Integer.parseInt(parts[3]);
- FederatedOutput fedOut = FederatedOutput.valueOf(parts[4]);
+ int k = str.startsWith(Types.ExecMode.SPARK.name()) ? 0 : Integer.parseInt(parts[3]);
+ FederatedOutput fedOut = str.startsWith(Types.ExecMode.SPARK.name()) ? FederatedOutput.valueOf(parts[3]) :
+ FederatedOutput.valueOf(parts[4]);
return new ReorgFEDInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str, fedOut);
}
else if ( opcode.equalsIgnoreCase("rdiag") ) {
@@ -95,6 +98,7 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
public void processInstruction(ExecutionContext ec) {
MatrixObject mo1 = ec.getMatrixObject(input1);
ReorgOperator r_op = (ReorgOperator) _optr;
+ boolean isSpark = instString.startsWith("SPARK");
if( !mo1.isFederated() )
throw new DMLRuntimeException("Federated Reorg: "
@@ -105,12 +109,15 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
if(instOpcode.equals("r'")) {
//execute transpose at federated site
- FederatedRequest fr1 = FederationUtils.callInstruction(instString,
- output, new CPOperand[] {input1},
- new long[] {mo1.getFedMapping().getID()}, true);
+ long id = FederationUtils.getNextFedDataID();
+ FederatedRequest fr = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1), mo1.getDataType());
+
+ FederatedRequest fr1 = FederationUtils.callInstruction(instString, output, id, new CPOperand[] {input1},
+ new long[] {mo1.getFedMapping().getID()}, isSpark ? Types.ExecType.SPARK : Types.ExecType.CP, true);
+ mo1.getFedMapping().execute(getTID(), true, fr, fr1);
+
if (_fedOut != null && !_fedOut.isForcedLocal()){
mo1.getFedMapping().execute(getTID(), true, fr1);
-
//drive output federated mapping
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(mo1.getNumColumns(), mo1.getNumRows(), (int) mo1.getBlocksize(), mo1.getNnz());
@@ -123,11 +130,13 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
}
}
else if(instOpcode.equalsIgnoreCase("rev")) {
+ long id = FederationUtils.getNextFedDataID();
+ FederatedRequest fr = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1), mo1.getDataType());
+
//execute transpose at federated site
- FederatedRequest fr1 = FederationUtils.callInstruction(instString,
- output, new CPOperand[] {input1},
- new long[] {mo1.getFedMapping().getID()}, fedoutFlagInString);
- mo1.getFedMapping().execute(getTID(), true, fr1);
+ FederatedRequest fr1 = FederationUtils.callInstruction(instString, output, id, new CPOperand[] {input1},
+ new long[] {mo1.getFedMapping().getID()}, isSpark ? Types.ExecType.SPARK : Types.ExecType.CP, true);
+ mo1.getFedMapping().execute(getTID(), true, fr, fr1);
if(mo1.isFederated(FederationMap.FType.ROW))
mo1.getFedMapping().reverseFedMap();
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java
index b66f72a..ef7ac4c 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java
@@ -56,7 +56,7 @@ public class ReshapeFEDInstruction extends UnaryFEDInstruction {
public static ReshapeFEDInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
- InstructionUtils.checkNumFields(parts, 6);
+ InstructionUtils.checkNumFields(parts, 6, 7);
String opcode = parts[0];
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
@@ -96,9 +96,13 @@ public class ReshapeFEDInstruction extends UnaryFEDInstruction {
String[] newInstString = getNewInstString(mo1, instString, rows, cols, byRow.getBooleanValue());
+ long id = FederationUtils.getNextFedDataID();
+ FederatedRequest tmp = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, mo1.getMetaData().getDataCharacteristics(), mo1.getDataType());
+
//execute at federated site
- FederatedRequest[] fr1 = FederationUtils.callInstruction(newInstString,
- output, new CPOperand[] {input1}, new long[] {mo1.getFedMapping().getID()});
+ FederatedRequest[] fr1 = FederationUtils.callInstruction(newInstString, output, id,
+ new CPOperand[] {input1}, new long[] {mo1.getFedMapping().getID()}, InstructionUtils.getExecType(instString));
+ mo1.getFedMapping().execute(getTID(), true, tmp);
mo1.getFedMapping().execute(getTID(), true, fr1, new FederatedRequest[0]);
// set new fed map
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
index b334775..7c15c6e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
@@ -33,6 +33,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
public class TernaryFEDInstruction extends ComputationFEDInstruction {
@@ -92,8 +93,12 @@ public class TernaryFEDInstruction extends ComputationFEDInstruction {
}
private void processMatrixScalarInput(ExecutionContext ec, MatrixObject mo1, CPOperand in) {
- FederatedRequest fr1 = FederationUtils.callInstruction(instString, output, new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()});
- sendFederatedRequests(ec, mo1, fr1.getID(), fr1);
+ long id = FederationUtils.getNextFedDataID();
+ FederatedRequest fr1 = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1), mo1.getDataType());
+
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output, id, new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()},
+ InstructionUtils.getExecType(instString), false);
+ sendFederatedRequests(ec, mo1, fr1.getID(), fr1, fr2);
}
private void process2MatrixScalarInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, CPOperand in1, CPOperand in2) {
@@ -113,13 +118,16 @@ public class TernaryFEDInstruction extends ComputationFEDInstruction {
fr1 = mo1.getFedMapping().broadcastSliced(ec.getMatrixObject(in1), false);
varNewIn = new long[]{fr1[0].getID(), mo1.getFedMapping().getID()};
}
- FederatedRequest fr2 = FederationUtils.callInstruction(instString, output, varOldIn, varNewIn);
+ long id = FederationUtils.getNextFedDataID();
+ Types.ExecType execType = InstructionUtils.getExecType(instString) == Types.ExecType.SPARK ? Types.ExecType.SPARK : Types.ExecType.CP;
+ FederatedRequest fr2 = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1), mo1.getDataType());
+ FederatedRequest fr3 = FederationUtils.callInstruction(instString, output, id, varOldIn, varNewIn, execType, false);
// 2 aligned inputs
if(fr1 == null)
- sendFederatedRequests(ec, mo1, fr2.getID(), fr2);
+ sendFederatedRequests(ec, mo1, fr3.getID(), fr2, fr3);
else
- sendFederatedRequests(ec, mo1, fr2.getID(), fr1, fr2);
+ sendFederatedRequests(ec, mo1, fr3.getID(), fr1, fr2, fr3);
}
/**
@@ -207,17 +215,21 @@ public class TernaryFEDInstruction extends ComputationFEDInstruction {
FederatedRequest[] fr2;
FederatedRequest fr3, fr4;
+ long id = FederationUtils.getNextFedDataID();
+ FederatedRequest fr5 = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1), mo1.getDataType());
+ Types.ExecType execType = InstructionUtils.getExecType(instString);
+
// all 3 inputs fed aligned on the one worker
if(retAlignedValues._allAligned) {
- fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
- new long[] {mo1.getFedMapping().getID(), mo2.getFedMapping().getID(), mo3.getFedMapping().getID()});
- sendFederatedRequests(ec, mo1, fr3.getID(), fr3);
+ fr3 = FederationUtils.callInstruction(instString, output, id, new CPOperand[] {input1, input2, input3},
+ new long[] {mo1.getFedMapping().getID(), mo2.getFedMapping().getID(), mo3.getFedMapping().getID()}, execType, false);
+ sendFederatedRequests(ec, mo1, fr3.getID(), fr5, fr3);
}
// 2 fed aligned inputs
else if(retAlignedValues._twoAligned) {
- fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3}, retAlignedValues._vars);
+ fr3 = FederationUtils.callInstruction(instString, output, id, new CPOperand[] {input1, input2, input3}, retAlignedValues._vars, execType, false);
fr4 = mo1.getFedMapping().cleanup(getTID(), retAlignedValues._fr[0].getID());
- sendFederatedRequests(ec, mo1, fr3.getID(), retAlignedValues._fr, fr3, fr4);
+ sendFederatedRequests(ec, mo1, fr3.getID(), retAlignedValues._fr, fr5, fr3, fr4);
}
// 1 fed input or not aligned
else {
@@ -239,8 +251,8 @@ public class TernaryFEDInstruction extends ComputationFEDInstruction {
vars = ec.getMatrixObject(input2).isFederated() ? new long[] {fr1[0].getID(), mo1.getFedMapping().getID(), fr2[0].getID()} : new long[] {fr1[0].getID(), fr2[0].getID(),
mo1.getFedMapping().getID()};
- fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3}, vars);
- sendFederatedRequests(ec, mo1, fr3.getID(), fr1, fr2, fr3);
+ fr3 = FederationUtils.callInstruction(instString, output, id, new CPOperand[] {input1, input2, input3}, vars, execType, false);
+ sendFederatedRequests(ec, mo1, fr3.getID(), fr5, fr1[0], fr2[0], fr3);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
index 24a850f..1aac3cc 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.Arrays;
import java.util.concurrent.Future;
+import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -45,7 +46,7 @@ public class UnaryMatrixFEDInstruction extends UnaryFEDInstruction {
protected UnaryMatrixFEDInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String instr) {
super(FEDType.Unary, op, in, out, opcode, instr);
}
-
+
public static boolean isValidOpcode(String opcode) {
return !LibCommonsMath.isSupportedUnaryOperation(opcode);
}
@@ -56,7 +57,7 @@ public class UnaryMatrixFEDInstruction extends UnaryFEDInstruction {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
-
+
if(parts.length == 5 && (opcode.equalsIgnoreCase("exp") || opcode.equalsIgnoreCase("log") || opcode.startsWith("ucum"))) {
in.split(parts[1]);
out.split(parts[2]);
@@ -71,8 +72,8 @@ public class UnaryMatrixFEDInstruction extends UnaryFEDInstruction {
opcode = parseUnaryInstruction(str, in, out);
return new UnaryMatrixFEDInstruction(InstructionUtils.parseUnaryOperator(opcode), in, out, opcode, str);
}
-
- @Override
+
+ @Override
public void processInstruction(ExecutionContext ec) {
MatrixObject mo1 = ec.getMatrixObject(input1);
if(getOpcode().startsWith("ucum") && mo1.isFederated(FederationMap.FType.ROW))
@@ -131,7 +132,7 @@ public class UnaryMatrixFEDInstruction extends UnaryFEDInstruction {
String modifiedInstString = InstructionUtils.replaceOperand(instString, 2, InstructionUtils.createOperand(output));
FederatedRequest fr4 = FederationUtils.callInstruction(modifiedInstString, output, out.getFedMapping().getID(),
- new CPOperand[] {output}, new long[] {out.getFedMapping().getID()});
+ new CPOperand[] {output}, new long[] {out.getFedMapping().getID()}, Types.ExecType.CP, false);
out.getFedMapping().execute(getTID(), true, fr4);
out.setFedMapping(out.getFedMapping().copyWithNewID(fr4.getID()));
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixAppendMSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixAppendMSPInstruction.java
index a36a467..bdc98f2 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixAppendMSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixAppendMSPInstruction.java
@@ -68,7 +68,7 @@ public class MatrixAppendMSPInstruction extends AppendMSPInstruction {
out = in1.flatMapToPair(
new MapSideAppendFunction(in2, _cbind, off, blen));
}
-
+
//put output RDD handle into symbol table
updateBinaryAppendOutputDataCharacteristics(sec, _cbind);
sec.setRDDHandleForVariable(output.getName(), out);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index 40e152f..1478c5f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -36,6 +36,7 @@ import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -98,6 +99,10 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
return params;
}
+ public CacheableData<?> getTarget(ExecutionContext ec) {
+ return ec.getCacheableData(params.get("target"));
+ }
+
public static HashMap<String, String> constructParameterMap(String[] params) {
// process all elements in "params" except first(opcode) and last(output)
HashMap<String, String> paramMap = new HashMap<>();
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateUnaryOperator.java b/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateUnaryOperator.java
index 532567b..a88b68f 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateUnaryOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateUnaryOperator.java
@@ -26,6 +26,7 @@ import org.apache.sysds.runtime.functionobjects.KahanPlusSq;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Or;
import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.functionobjects.ReduceAll;
import org.apache.sysds.runtime.functionobjects.ReduceCol;
import org.apache.sysds.runtime.functionobjects.ReduceRow;
@@ -67,4 +68,8 @@ public class AggregateUnaryOperator extends Operator {
public boolean isColAggregate() {
return indexFn instanceof ReduceRow;
}
+
+ public boolean isFullAggregate() {
+ return indexFn instanceof ReduceAll;
+ }
}
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 3ac462c..6b21cd6 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -19,11 +19,6 @@
package org.apache.sysds.test;
-import static java.lang.Math.ceil;
-import static java.lang.Thread.sleep;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.fail;
-
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
@@ -39,6 +34,10 @@ import java.util.Properties;
import java.util.Set;
import java.util.stream.Collectors;
+import static java.lang.Math.ceil;
+import static java.lang.Thread.sleep;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.tuple.Pair;
@@ -52,13 +51,13 @@ import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Lop;
-import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.parser.ParseException;
@@ -77,6 +76,7 @@ import org.apache.sysds.runtime.io.FrameReaderFactory;
import org.apache.sysds.runtime.io.ReaderWriterFederated;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataAll;
@@ -1780,6 +1780,15 @@ public abstract class AutomatedTestBase {
}
}
+ protected void compareResults(double epsilon, String name1, String name2) {
+ for(int i = 0; i < comparisonFiles.length; i++) {
+ HashMap<MatrixValue.CellIndex, Double> expected = TestUtils.readDMLMatrixFromHDFS(comparisonFiles[i]);
+ HashMap<MatrixValue.CellIndex, Double> output = TestUtils.readDMLMatrixFromHDFS(outputDirectories[i]);
+ TestUtils.compareMatrices(expected, output, epsilon, name1, name2);
+
+ }
+ }
+
/**
* <p>
* Compares the results of the computation of the frame with the expected ones.
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java
index abf7df2..0ee20a8 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java
@@ -31,7 +31,6 @@ import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
-import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -115,31 +114,26 @@ public class FederatedFullAggregateTest extends AutomatedTestBase {
}
@Test
- @Ignore
public void testSumDenseMatrixSP() {
runAggregateOperationTest(OpType.SUM, ExecType.SPARK);
}
@Test
- @Ignore
public void testMeanDenseMatrixSP() {
runAggregateOperationTest(OpType.MEAN, ExecType.SPARK);
}
@Test
- @Ignore
public void testMaxDenseMatrixSP() {
runAggregateOperationTest(OpType.MAX, ExecType.SPARK);
}
@Test
- @Ignore
public void testMinDenseMatrixSP() {
runAggregateOperationTest(OpType.MIN, ExecType.SPARK);
}
@Test
- @Ignore
public void testVarDenseMatrixSP() {
runAggregateOperationTest(OpType.VAR, ExecType.SPARK);
}
@@ -150,6 +144,9 @@ public class FederatedFullAggregateTest extends AutomatedTestBase {
case SPARK:
rtplatform = ExecMode.SPARK;
break;
+ case CP:
+ rtplatform = ExecMode.SINGLE_NODE;
+ break;
default:
rtplatform = ExecMode.HYBRID;
break;
@@ -218,14 +215,14 @@ public class FederatedFullAggregateTest extends AutomatedTestBase {
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"),
+ programArgs = new String[] {"-explain", "-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"),
expected("S"), Boolean.toString(rowPartitioned).toUpperCase()};
runTest(true, false, null, -1);
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-stats", "100", "-nvargs",
+ programArgs = new String[] {"-explain","-stats", "100", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullCumulativeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullCumulativeTest.java
index 075dd0c..aded5fe 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullCumulativeTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullCumulativeTest.java
@@ -80,11 +80,8 @@ public class FederatedFullCumulativeTest extends AutomatedTestBase {
@Test
public void testSumDenseMatrixCP() { runCumOperationTest(OpType.SUM, ExecType.CP); }
-// FIXME offset handling has some remaining issues
-// @Test
-// public void testProdDenseMatrixCP() {
-// runCumOperationTest(OpType.PROD, ExecType.CP);
-// }
+ @Test
+ public void testSumDenseMatrixSP() { runCumOperationTest(OpType.SUM, ExecType.SPARK); }
@Test
public void testMaxDenseMatrixCP() {
@@ -96,11 +93,15 @@ public class FederatedFullCumulativeTest extends AutomatedTestBase {
runCumOperationTest(OpType.MIN, ExecType.CP);
}
-// FIXME offset handling has some remaining issues
-// @Test
-// public void testSumprodDenseMatrixCP() {
-// runCumOperationTest(OpType.SUMPROD, ExecType.CP);
-// }
+ @Test
+ public void testMaxDenseMatrixSP() {
+ runCumOperationTest(OpType.MAX, ExecType.SPARK);
+ }
+
+ @Test
+ public void testMinDenseMatrixSP() {
+ runCumOperationTest(OpType.MIN, ExecType.SPARK);
+ }
private void runCumOperationTest(OpType type, ExecType instType) {
ExecMode platformOld = setExecMode(instType);
@@ -179,23 +180,23 @@ public class FederatedFullCumulativeTest extends AutomatedTestBase {
runTest(true, false, null, -1);
// compare via files
- compareResults(1e-6);
+ compareResults(1e-6, "DML1", "DML2");
switch(type) {
case SUM:
- Assert.assertTrue(heavyHittersContainsString("fed_ucumk+"));
+ Assert.assertTrue(heavyHittersContainsString(instType == ExecType.SPARK ? "fed_bcumoffk+" : "fed_ucumk+"));
break;
case PROD:
- Assert.assertTrue(heavyHittersContainsString("fed_ucum*"));
+ Assert.assertTrue(heavyHittersContainsString(instType == ExecType.SPARK ? "fed_bcumoff*" : "fed_ucum*"));
break;
case MAX:
- Assert.assertTrue(heavyHittersContainsString("fed_ucummax"));
+ Assert.assertTrue(heavyHittersContainsString(instType == ExecType.SPARK ? "fed_bcumoffmax" : "fed_ucummax"));
break;
case MIN:
- Assert.assertTrue(heavyHittersContainsString("fed_ucummin"));
+ Assert.assertTrue(heavyHittersContainsString(instType == ExecType.SPARK ? "fed_bcumoffmin" : "fed_ucummin"));
break;
case SUMPROD:
- Assert.assertTrue(heavyHittersContainsString("ucumk+*"));
+ Assert.assertTrue(heavyHittersContainsString(instType == ExecType.SPARK ? "fed_bcumoff+*" : "ucumk+*"));
break;
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedIfelseTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedIfelseTest.java
index 11234e3..ee8fabf 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedIfelseTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedIfelseTest.java
@@ -79,6 +79,14 @@ public class FederatedIfelseTest extends AutomatedTestBase {
runTernaryTest(ExecMode.SINGLE_NODE, true);
}
+ @Test
+ public void testIfelseDiffWorkersSP() {
+ runTernaryTest(ExecMode.SPARK, false);
+ }
+
+ @Test
+ public void testIfelseAlignedSP() { runTernaryTest(ExecMode.SPARK, true); }
+
private void runTernaryTest(ExecMode execMode, boolean aligned) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
@@ -144,7 +152,7 @@ public class FederatedIfelseTest extends AutomatedTestBase {
runTernary(HOME, TEST_NAME, port1, port2, port3, port4);
// compare via files
- compareResults(1e-9);
+ compareResults(1e-9, "DML1", "DML2");
Assert.assertTrue(heavyHittersContainsString("fed_ifelse"));
// check that federated input files are still existing
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java
index e6c3c85..3c33728 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java
@@ -101,6 +101,14 @@ public class FederatedLeftIndexTest extends AutomatedTestBase {
runAggregateOperationTest(DataType.FRAME, ExecMode.SINGLE_NODE);
}
+ @Test
+ public void testLeftIndexFullDenseMatrixSP() { runAggregateOperationTest(DataType.MATRIX, ExecMode.SPARK); }
+
+ @Test
+ public void testLeftIndexFullDenseFrameSP() {
+ runAggregateOperationTest(DataType.FRAME, ExecMode.SPARK);
+ }
+
private void runAggregateOperationTest(DataType dataType, ExecMode execMode) {
setExecMode(execMode);
@@ -186,9 +194,9 @@ public class FederatedLeftIndexTest extends AutomatedTestBase {
runTest(null);
// compare via files
- compareResults(1e-9);
+ compareResults(1e-9, "Stat-DML1", "Stat-DML2");
- Assert.assertTrue(heavyHittersContainsString("fed_leftIndex"));
+ Assert.assertTrue(rtplatform ==ExecMode.SPARK ? heavyHittersContainsString("fed_mapLeftIndex") : heavyHittersContainsString("fed_leftIndex"));
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMMChainTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMMChainTest.java
index f223cf2..d9f28b7 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMMChainTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMMChainTest.java
@@ -75,6 +75,12 @@ public class FederatedMMChainTest extends AutomatedTestBase {
public void testMMChainWeightsCP() { runMMChainTest(ExecMode.SINGLE_NODE, TEST_NAME2); }
@Test
public void testMMChainWeights2CP() { runMMChainTest(ExecMode.SINGLE_NODE, TEST_NAME3); }
+ @Test
+ public void testMMChainSP() { runMMChainTest(ExecMode.SPARK, TEST_NAME1); }
+ @Test
+ public void testMMChainWeightsSP() { runMMChainTest(ExecMode.SPARK, TEST_NAME2); }
+ @Test
+ public void testMMChainWeights2SP() { runMMChainTest(ExecMode.SPARK, TEST_NAME3); }
private void runMMChainTest(ExecMode execMode, String TEST_NAME) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
@@ -149,7 +155,7 @@ public class FederatedMMChainTest extends AutomatedTestBase {
runTest(null);
// compare via files
- compareResults(1e-9);
+ compareResults(1e-9, "Stat-DML1", "Stat-DML2");
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
index 7fc192d..3acc53b 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
@@ -20,7 +20,6 @@
package org.apache.sysds.test.functions.federated.primitives;
import org.apache.sysds.hops.OptimizerUtils;
-import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -74,13 +73,16 @@ public class FederatedMultiplyTest extends AutomatedTestBase {
federatedMultiply(Types.ExecMode.SINGLE_NODE, true);
}
- @Test
- @Ignore
+ @Test
public void federatedMultiplySP() {
- // TODO Fix me Spark execution error
federatedMultiply(Types.ExecMode.SPARK);
}
+ @Test
+ public void federatedMultiplySPCompileToFED() {
+ federatedMultiply(Types.ExecMode.SPARK, true);
+ }
+
private void federatedMultiply(Types.ExecMode execMode){
federatedMultiply(execMode,false);
}
@@ -134,7 +136,7 @@ public class FederatedMultiplyTest extends AutomatedTestBase {
runTest(true, false, null, -1);
// compare via files
- compareResults(1e-9);
+ compareResults(1e-9, "Stat-DML1", "Stat-DML2");
TestUtils.shutdownThreads(t1, t2);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedProdTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedProdTest.java
index 70859d4..51d7852 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedProdTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedProdTest.java
@@ -69,6 +69,9 @@ public class FederatedProdTest extends AutomatedTestBase {
@Test
public void testProdCP() { runProdTest(ExecMode.SINGLE_NODE); }
+ @Test
+ public void testProdSP() { runProdTest(ExecMode.SPARK); }
+
private void runProdTest(ExecMode execMode) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
index 1a71279..226ad53 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
@@ -30,7 +30,6 @@ import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
-import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -83,27 +82,27 @@ public class FederatedQuantileTest extends AutomatedTestBase {
public void federatedQuantilesCP() { federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME1, -1); }
@Test
- @Ignore
+// @Ignore
public void federatedQuantile1SP() { federatedQuartile(Types.ExecMode.SPARK, TEST_NAME1, 0.25); }
@Test
- @Ignore
+// @Ignore
public void federatedQuantile2SP() { federatedQuartile(Types.ExecMode.SPARK, TEST_NAME1, 0.5); }
@Test
- @Ignore
+// @Ignore
public void federatedQuantile3SP() { federatedQuartile(Types.ExecMode.SPARK, TEST_NAME1, 0.75); }
@Test
- @Ignore
+// @Ignore
public void federatedMedianSP() { federatedQuartile(Types.ExecMode.SPARK, TEST_NAME2, -1); }
@Test
- @Ignore
+// @Ignore
public void federatedIQMSP() { federatedQuartile(Types.ExecMode.SPARK, TEST_NAME1, -1); }
@Test
- @Ignore
+// @Ignore
public void federatedQuantilesSP() { federatedQuartile(Types.ExecMode.SPARK, TEST_NAME1, -1); }
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java
index eedbbbb..cda9966 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java
@@ -30,7 +30,6 @@ import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
-import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -68,7 +67,6 @@ public class FederatedRdiagTest extends AutomatedTestBase {
public void federatedRdiagCP() { federatedRdiag(Types.ExecMode.SINGLE_NODE); }
@Test
- @Ignore
public void federatedRdiagSP() { federatedRdiag(Types.ExecMode.SPARK); }
public void federatedRdiag(Types.ExecMode execMode) {
@@ -128,7 +126,7 @@ public class FederatedRdiagTest extends AutomatedTestBase {
runTest(null);
// compare all sums via files
- compareResults(0.01);
+ compareResults(0.01, "DML1", "DML2");
Assert.assertTrue(heavyHittersContainsString("fed_rdiag"));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
index 89f67b2..af2b9bb 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
@@ -71,6 +71,11 @@ public class FederatedRemoveEmptyTest extends AutomatedTestBase {
runAggregateOperationTest(ExecMode.SINGLE_NODE);
}
+ @Test
+ public void testRemoveEmptySP() {
+ runAggregateOperationTest(ExecMode.SPARK);
+ }
+
private void runAggregateOperationTest(ExecMode execMode) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
@@ -141,7 +146,7 @@ public class FederatedRemoveEmptyTest extends AutomatedTestBase {
runTest(null);
// compare via files
- compareResults(1e-9);
+ compareResults(1e-9, "DML1", "DML2");
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedReshapeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedReshapeTest.java
index c32a4a9..a9f917a 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedReshapeTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedReshapeTest.java
@@ -30,7 +30,6 @@ import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
-import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -76,7 +75,7 @@ public class FederatedReshapeTest extends AutomatedTestBase {
}
@Test
- @Ignore
+// @Ignore
public void federatedReshapeSP() {
federatedReshape(Types.ExecMode.SPARK);
}
@@ -139,7 +138,7 @@ public class FederatedReshapeTest extends AutomatedTestBase {
runTest(null);
// compare all sums via files
- compareResults(0.01);
+ compareResults(0.01, "DML1", "DML2");
Assert.assertTrue(heavyHittersContainsString("fed_rshape"));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java
index 36996db..847f351 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java
@@ -72,6 +72,11 @@ public class FederatedRevTest extends AutomatedTestBase {
runRevTest(ExecMode.SINGLE_NODE);
}
+ @Test
+ public void testRevSP() {
+ runRevTest(ExecMode.SPARK);
+ }
+
private void runRevTest(ExecMode execMode) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
@@ -141,7 +146,7 @@ public class FederatedRevTest extends AutomatedTestBase {
runTest(null);
// compare via files
- compareResults(0.01);
+ compareResults(0.01, "Stat-DML1", "Stat-DML2");
Assert.assertTrue(heavyHittersContainsString("fed_rev"));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java
index e950c20..0139137 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java
@@ -108,6 +108,16 @@ public class FederatedRightIndexTest extends AutomatedTestBase {
runAggregateOperationTest(IndexType.FULL, DataType.FRAME, ExecMode.SINGLE_NODE);
}
+ @Test
+ public void testRightIndexFullDenseMatrixSP() {
+ runAggregateOperationTest(IndexType.FULL, DataType.MATRIX, ExecMode.SPARK);
+ }
+
+ @Test
+ public void testRightIndexFullDenseFrameSP() {
+ runAggregateOperationTest(IndexType.FULL, DataType.FRAME, ExecMode.SPARK);
+ }
+
private void runAggregateOperationTest(IndexType indexType, DataType dataType, ExecMode execMode) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
@@ -200,7 +210,7 @@ public class FederatedRightIndexTest extends AutomatedTestBase {
LOG.debug(runTest(null));
// compare via files
- compareResults(1e-9);
+ compareResults(1e-9, "Stat-DML1", "Stat-DML2");
Assert.assertTrue(heavyHittersContainsString("fed_rightIndex"));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java
index 91452eb..ab69143 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java
@@ -116,6 +116,41 @@ public class FederatedRowAggregateTest extends AutomatedTestBase {
runAggregateOperationTest(OpType.MM, ExecMode.SINGLE_NODE);
}
+ @Test
+ public void testRowSumDenseMatrixSP() {
+ runAggregateOperationTest(OpType.SUM, ExecMode.SPARK);
+ }
+
+ @Test
+ public void testRowMeanDenseMatrixSP() {
+ runAggregateOperationTest(OpType.MEAN, ExecMode.SPARK);
+ }
+
+ @Test
+ public void testRowMaxDenseMatrixSP() {
+ runAggregateOperationTest(OpType.MAX, ExecMode.SPARK);
+ }
+
+ @Test
+ public void testRowMinDenseMatrixSP() {
+ runAggregateOperationTest(OpType.MIN, ExecMode.SPARK);
+ }
+
+ @Test
+ public void testRowVarDenseMatrixSP() {
+ runAggregateOperationTest(OpType.VAR, ExecMode.SPARK);
+ }
+
+ @Test
+ public void testRowProdDenseMatrixSP() {
+ runAggregateOperationTest(OpType.PROD, ExecMode.SPARK);
+ }
+
+ @Test
+ public void testMMDenseMatrixSP() {
+ runAggregateOperationTest(OpType.MM, ExecMode.SPARK);
+ }
+
private void runAggregateOperationTest(OpType type, ExecMode execMode) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
@@ -208,7 +243,7 @@ public class FederatedRowAggregateTest extends AutomatedTestBase {
runTest(true, false, null, -1);
// compare via files
- compareResults(type == FederatedRowAggregateTest.OpType.VAR ? 1e-2 : 1e-9);
+ compareResults(type == FederatedRowAggregateTest.OpType.VAR ? 1e-2 : 1e-9, "Stat-DML1", "Stat-DML2");
String fedInst = "fed_uar";
@@ -232,7 +267,7 @@ public class FederatedRowAggregateTest extends AutomatedTestBase {
Assert.assertTrue(heavyHittersContainsString(fedInst.concat("*")));
break;
case MM:
- Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
+ Assert.assertTrue(heavyHittersContainsString(rtplatform == ExecMode.SPARK ? "fed_mapmm" : "fed_ba+*"));
break;
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowIndexTest.java
index e4c7534..eb908ed 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowIndexTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowIndexTest.java
@@ -71,6 +71,11 @@ public class FederatedRowIndexTest extends AutomatedTestBase {
runRowIndexTest(ExecMode.SINGLE_NODE);
}
+ @Test
+ public void testRowIndexSP() {
+ runRowIndexTest(ExecMode.SPARK);
+ }
+
private void runRowIndexTest(ExecMode execMode) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
@@ -137,9 +142,9 @@ public class FederatedRowIndexTest extends AutomatedTestBase {
runTest(null);
// compare via files
- compareResults(1e-9);
+ compareResults(1e-9, "Stat-DML1", "Stat-DML2");
- Assert.assertTrue(heavyHittersContainsString("fed_uarimax"));
+ Assert.assertTrue(heavyHittersContainsString("fed_uarimax") || (!rowPartitioned && execMode == ExecMode.SPARK));
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
index afd2ffe..333f515 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
@@ -70,6 +70,11 @@ public class FederatedSplitTest extends AutomatedTestBase {
federatedSplit(Types.ExecMode.SINGLE_NODE);
}
+ @Test
+ public void federatedSplitSP() {
+ federatedSplit(Types.ExecMode.SPARK);
+ }
+
public void federatedSplit(Types.ExecMode execMode) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
@@ -122,11 +127,11 @@ public class FederatedSplitTest extends AutomatedTestBase {
LOG.debug(out);
LOG.debug(fedOut);
// compare via files
- compareResults(1e-9);
+ compareResults(1e-9, "Stat-DML1", "Stat-DML2");
if(cont.equals("TRUE"))
Assert.assertTrue(heavyHittersContainsString("fed_rightIndex"));
- else {
+ else if(execMode != Types.ExecMode.SPARK){
Assert.assertTrue(heavyHittersContainsString("fed_rmempty"));
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
index 6262e03..82ac6eb 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
@@ -29,7 +29,6 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -75,11 +74,15 @@ public class FederatedSumTest extends AutomatedTestBase {
}
@Test
- @Ignore
public void federatedSumSP() {
federatedSum(Types.ExecMode.SPARK);
}
+ @Test
+ public void federatedSumSPToFED() {
+ federatedSum(Types.ExecMode.SPARK, true);
+ }
+
public void federatedSum(Types.ExecMode execMode){
federatedSum(execMode, false);
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTokenizeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTokenizeTest.java
index 6c2096d..02dd496 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTokenizeTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTokenizeTest.java
@@ -127,8 +127,8 @@ public class FederatedTokenizeTest extends AutomatedTestBase {
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {"-explain", "-args", input("AH"), HOME + TEST_NAME + ".json", expected("S")};
runTest(null);
- // Run actual dml script with federated matrix
+ // Run actual dml script with federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-stats", "100", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1, input("AH")),
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTriTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTriTest.java
index e7ae2d4..0ea200f 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTriTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTriTest.java
@@ -68,6 +68,9 @@ public class FederatedTriTest extends AutomatedTestBase {
@Test
public void testTriCP() { runTriTest(ExecMode.SINGLE_NODE); }
+ @Test
+ public void testTriSP() { runTriTest(ExecMode.SPARK); }
+
private void runTriTest(ExecMode execMode) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
@@ -133,7 +136,7 @@ public class FederatedTriTest extends AutomatedTestBase {
runTest(null);
// compare via files
- compareResults(1e-9);
+ compareResults(1e-9, "Stat-DML1", "Stat-DML2");
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
diff --git a/src/test/scripts/functions/federated/aggregate/FederatedMeanTestReference.dml b/src/test/scripts/functions/federated/aggregate/FederatedMeanTestReference.dml
index 8566399..cb17fc3 100644
--- a/src/test/scripts/functions/federated/aggregate/FederatedMeanTestReference.dml
+++ b/src/test/scripts/functions/federated/aggregate/FederatedMeanTestReference.dml
@@ -19,9 +19,11 @@
#
#-------------------------------------------------------------
-if($6) { A = rbind(read($1), read($2), read($3), read($4)); }
+if($6) {
+ A = rbind(read($1), read($2), read($3), read($4));
+}
else { A = cbind(read($1), read($2), read($3), read($4)); }
-
+#A = read($1)
s = mean(A);
write(s, $5);