You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2022/06/05 16:09:11 UTC
[systemds] branch main updated: [SYSTEMDS-3374] Federation primitive for local to federated data
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 641949da67 [SYSTEMDS-3374] Federation primitive for local to federated data
641949da67 is described below
commit 641949da67a2abfdbbdab0164359f9b6e387622a
Author: OlgaOvcharenko <ov...@gmail.com>
AuthorDate: Sun Jun 5 17:11:55 2022 +0200
[SYSTEMDS-3374] Federation primitive for local to federated data
Closes #1609.
---
src/main/java/org/apache/sysds/lops/Federated.java | 27 ++-
.../org/apache/sysds/parser/DataExpression.java | 24 ++-
.../controlprogram/federated/FederatedData.java | 14 ++
.../federated/FederatedWorkerHandler.java | 35 ++--
.../instructions/fed/InitFEDInstruction.java | 212 +++++++++++++++++++--
.../sysds/runtime/matrix/data/MatrixBlock.java | 9 +
src/test/java/org/apache/sysds/test/TestUtils.java | 5 +-
.../primitives/FederatedTransferLocalDataTest.java | 125 ++++++++++++
.../federated/FederatedTransferLocalDataTest.dml | 34 ++++
.../FederatedTransferLocalDataTestReference.dml | 23 +++
10 files changed, 467 insertions(+), 41 deletions(-)
diff --git a/src/main/java/org/apache/sysds/lops/Federated.java b/src/main/java/org/apache/sysds/lops/Federated.java
index 52b52be544..2ed1de2fdb 100644
--- a/src/main/java/org/apache/sysds/lops/Federated.java
+++ b/src/main/java/org/apache/sysds/lops/Federated.java
@@ -25,11 +25,12 @@ import java.util.HashMap;
import static org.apache.sysds.common.Types.DataType;
import static org.apache.sysds.common.Types.ValueType;
import static org.apache.sysds.parser.DataExpression.FED_ADDRESSES;
+import static org.apache.sysds.parser.DataExpression.FED_LOCAL_OBJECT;
import static org.apache.sysds.parser.DataExpression.FED_RANGES;
import static org.apache.sysds.parser.DataExpression.FED_TYPE;
public class Federated extends Lop {
- private Lop _type, _addresses, _ranges;
+ private Lop _type, _addresses, _ranges, _localObject;
public Federated(HashMap<String, Lop> inputLops, DataType dataType, ValueType valueType) {
super(Type.Federated, dataType, valueType);
@@ -43,6 +44,12 @@ public class Federated extends Lop {
_addresses.addOutput(this);
addInput(_ranges);
_ranges.addOutput(this);
+
+ if(inputLops.size() == 4) {
+ _localObject = inputLops.get(FED_LOCAL_OBJECT);
+ addInput(_localObject);
+ _localObject.addOutput(this);
+ }
}
@Override
@@ -60,6 +67,24 @@ public class Federated extends Lop {
sb.append(prepOutputOperand(output));
return sb.toString();
}
+
+ @Override
+ public String getInstructions(String type, String addresses, String ranges, String object, String output) {
+ StringBuilder sb = new StringBuilder("FED");
+ sb.append(OPERAND_DELIMITOR);
+ sb.append("fedinit");
+ sb.append(OPERAND_DELIMITOR);
+ sb.append(_type.prepScalarInputOperand(type));
+ sb.append(OPERAND_DELIMITOR);
+ sb.append(_addresses.prepScalarInputOperand(addresses));
+ sb.append(OPERAND_DELIMITOR);
+ sb.append(_ranges.prepScalarInputOperand(ranges));
+ sb.append(OPERAND_DELIMITOR);
+ sb.append(_localObject.prepScalarInputOperand(object));
+ sb.append(OPERAND_DELIMITOR);
+ sb.append(prepOutputOperand(output));
+ return sb.toString();
+ }
@Override
public String toString() {
diff --git a/src/main/java/org/apache/sysds/parser/DataExpression.java b/src/main/java/org/apache/sysds/parser/DataExpression.java
index 2f25809762..e2e3996cea 100644
--- a/src/main/java/org/apache/sysds/parser/DataExpression.java
+++ b/src/main/java/org/apache/sysds/parser/DataExpression.java
@@ -87,6 +87,7 @@ public class DataExpression extends DataIdentifier
public static final String FED_ADDRESSES = "addresses";
public static final String FED_RANGES = "ranges";
public static final String FED_TYPE = "type";
+ public static final String FED_LOCAL_OBJECT = "local_matrix";
public static final String FORMAT_TYPE = "format";
@@ -132,7 +133,7 @@ public class DataExpression extends DataIdentifier
Arrays.asList(SQL_CONN, SQL_USER, SQL_PASS, SQL_QUERY));
public static final Set<String> FEDERATED_VALID_PARAM_NAMES = new HashSet<>(
- Arrays.asList(FED_ADDRESSES, FED_RANGES, FED_TYPE));
+ Arrays.asList(FED_ADDRESSES, FED_RANGES, FED_TYPE, FED_LOCAL_OBJECT));
/** Valid parameter names in metadata file */
public static final Set<String> READ_VALID_MTD_PARAM_NAMES =new HashSet<>(
@@ -540,6 +541,16 @@ public class DataExpression extends DataIdentifier
param = passedParamExprs.get(2);
dataExpr.addFederatedExprParam(DataExpression.FED_TYPE, param.getExpr());
}
+ else if(unnamedParamCount == 4) {
+ ParameterExpression param = passedParamExprs.get(0);
+ dataExpr.addFederatedExprParam(DataExpression.FED_LOCAL_OBJECT, param.getExpr());
+ param = passedParamExprs.get(1);
+ dataExpr.addFederatedExprParam(DataExpression.FED_ADDRESSES, param.getExpr());
+ param = passedParamExprs.get(2);
+ dataExpr.addFederatedExprParam(DataExpression.FED_RANGES, param.getExpr());
+ param = passedParamExprs.get(3);
+ dataExpr.addFederatedExprParam(DataExpression.FED_TYPE, param.getExpr());
+ }
else {
errorListener.validationError(parseInfo,
"for federated statement, at most 3 arguments are supported: addresses, ranges, type");
@@ -888,7 +899,7 @@ public class DataExpression extends DataIdentifier
raiseValidateError("UDF function call not supported as parameter to built-in function call", false,LanguageErrorCodes.INVALID_PARAMETERS);
}
inputParamExpr.validateExpression(ids, currConstVars, conditional);
- if (s != null && !s.equals(RAND_DATA) && !s.equals(RAND_DIMS) && !s.equals(FED_ADDRESSES) && !s.equals(FED_RANGES)
+ if (s != null && !s.equals(RAND_DATA) && !s.equals(RAND_DIMS) && !s.equals(FED_ADDRESSES) && !s.equals(FED_RANGES) && !s.equals(FED_LOCAL_OBJECT)
&& !s.equals(DELIM_NA_STRINGS) && !s.equals(SCHEMAPARAM) && getVarParam(s).getOutput().getDataType() != DataType.SCALAR ) {
raiseValidateError("Non-scalar data types are not supported for data expression.", conditional,LanguageErrorCodes.INVALID_PARAMETERS);
}
@@ -2195,7 +2206,16 @@ public class DataExpression extends DataIdentifier
else if(fedType.getValue().equalsIgnoreCase(FED_FRAME_IDENTIFIER)) {
getOutput().setDataType(DataType.FRAME);
}
+
+ if(_varParams.size() == 4) {
+ exp = getVarParam(FED_LOCAL_OBJECT);
+ if( !(exp instanceof DataIdentifier) ) {
+ raiseValidateError("for federated statement " + FED_LOCAL_OBJECT + " has incorrect value type", conditional);
+ }
+ getVarParam(FED_LOCAL_OBJECT).validateExpression(ids, currConstVars, conditional);
+ }
getOutput().setDimensions(-1, -1);
+
break;
default:
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
index 1fb1e8b1ec..370163aaf2 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
@@ -34,6 +34,7 @@ import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.paramserv.NetworkTrafficCounter;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -150,6 +151,19 @@ public class FederatedData {
return executeFederatedOperation(request);
}
+ public synchronized Future<FederatedResponse> initFederatedDataFromLocal(long id, CacheBlock block) {
+ if(isInitialized())
+ throw new DMLRuntimeException("Tried to init already initialized data");
+ if(!_dataType.isMatrix() && !_dataType.isFrame())
+ throw new DMLRuntimeException("Federated datatype \"" + _dataType.toString() + "\" is not supported.");
+ _varID = id;
+ FederatedRequest request = new FederatedRequest(RequestType.READ_VAR, id);
+ request.appendParam(_filepath);
+ request.appendParam(_dataType.name());
+ request.appendParam(block);
+ return executeFederatedOperation(request);
+ }
+
public Future<FederatedResponse> executeFederatedOperation(FederatedRequest... request) {
return executeFederatedOperation(_address, request);
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index d0865df120..592f77ccce 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -89,15 +89,14 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
private static final Log LOG = LogFactory.getLog(FederatedWorkerHandler.class.getName());
/** The Federated Lookup Table of the current Federated Worker. */
- private FederatedLookupTable _flt;
+ private final FederatedLookupTable _flt;
/** Read cache shared by all worker handlers */
- private FederatedReadCache _frc;
+ private final FederatedReadCache _frc;
private Timing _timing = null;
-
-
+
/** Federated workload analyzer */
- private FederatedWorkloadAnalyzer _fan;
+ private final FederatedWorkloadAnalyzer _fan;
/**
* Create a Federated Worker Handler.
@@ -114,12 +113,12 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
_frc = frc;
_fan = fan;
}
-
+
public FederatedWorkerHandler(FederatedLookupTable flt, FederatedReadCache frc, FederatedWorkloadAnalyzer fan, Timing timing) {
this(flt, frc, fan);
_timing = timing;
}
-
+
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
ctx.writeAndFlush(createResponse(msg, ctx.channel().remoteAddress()))
@@ -138,7 +137,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
} catch (RuntimeException ignored) {
// ignore timing if it wasn't started yet
}
-
+
String host;
if(remoteAddress instanceof InetSocketAddress) {
host = ((InetSocketAddress) remoteAddress).getHostString();
@@ -183,7 +182,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
private FederatedResponse createResponse(FederatedRequest[] requests, String remoteHost)
throws DMLPrivacyException, FederatedWorkerHandlerException, Exception {
-
+
FederatedResponse response = null; // last response
boolean containsCLEAR = false;
for(int i = 0; i < requests.length; i++) {
@@ -272,14 +271,15 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
}
private FederatedResponse readData(FederatedRequest request, ExecutionContextMap ecm) {
- checkNumParams(request.getNumParams(), 2);
+ checkNumParams(request.getNumParams(), 2, 3);
String filename = (String) request.getParam(0);
DataType dt = DataType.valueOf((String) request.getParam(1));
- return readData(filename, dt, request.getID(), request.getTID(), ecm);
+ return readData(filename, dt, request.getID(), request.getTID(), ecm,
+ request.getNumParams() == 2 ? null : (CacheBlock)request.getParam(2));
}
private FederatedResponse readData(String filename, DataType dataType,
- long id, long tid, ExecutionContextMap ecm) {
+ long id, long tid, ExecutionContextMap ecm, CacheBlock localBlock) {
MatrixCharacteristics mc = new MatrixCharacteristics();
mc.setBlocksize(ConfigurationManager.getBlocksize());
@@ -299,7 +299,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
cd = _frc.get(filename, !linReuse);
try {
if(cd == null) { // data is neither in lineage cache nor in read cache
- cd = readDataNoReuse(filename, dataType, mc); // actual read of the data
+ cd = localBlock == null ? readDataNoReuse(filename, dataType, mc) : ExecutionContext.createCacheableData(localBlock); // actual read of the data
if(linReuse) // put the object into the lineage cache
LineageCache.putFedReadObject(cd, linItem, ec);
else
@@ -315,7 +315,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
throw ex;
}
}
-
+
if(shouldTryAsyncCompress()) // TODO: replace the reused object
CompressedMatrixBlockFactory.compressAsync(ec, sId);
@@ -426,7 +426,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
throw new FederatedWorkerHandlerException(
"Unsupported object type, has to be of type CacheBlock or ScalarObject");
-
+
// set variable and construct empty response
ec.setVariable(varName, data);
@@ -450,12 +450,13 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
private FederatedResponse getVariable(FederatedRequest request, ExecutionContextMap ecm) {
try{
+
checkNumParams(request.getNumParams(), 0);
ExecutionContext ec = ecm.get(request.getTID());
if(!ec.containsVariable(String.valueOf(request.getID())))
throw new FederatedWorkerHandlerException(
"Variable " + request.getID() + " does not exist at federated worker.");
-
+
// get variable and construct response
Data dataObject = ec.getVariable(String.valueOf(request.getID()));
dataObject = PrivacyMonitor.handlePrivacy(dataObject);
@@ -487,7 +488,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
adaptToWorkload(ec, _fan, tid, ins);
return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
}
-
+
private static ExecutionContext getContextForInstruction(long id, Instruction ins, ExecutionContextMap ecm){
final ExecutionContext ec = ecm.get(id);
//handle missing spark execution context
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 6e18115835..3e648bbe3b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -33,24 +33,25 @@ import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
-import org.apache.sysds.api.DMLScript;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -60,6 +61,8 @@ import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.StringObject;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageTraceable;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
public class InitFEDInstruction extends FEDInstruction implements LineageTraceable {
@@ -69,7 +72,7 @@ public class InitFEDInstruction extends FEDInstruction implements LineageTraceab
public static final String FED_MATRIX_IDENTIFIER = "matrix";
public static final String FED_FRAME_IDENTIFIER = "frame";
- private CPOperand _type, _addresses, _ranges, _output;
+ private CPOperand _type, _addresses, _ranges, _localObject, _output;
public InitFEDInstruction(CPOperand type, CPOperand addresses, CPOperand ranges, CPOperand out, String opcode,
String instr) {
@@ -80,44 +83,66 @@ public class InitFEDInstruction extends FEDInstruction implements LineageTraceab
_output = out;
}
+ public InitFEDInstruction(CPOperand type, CPOperand addresses, CPOperand ranges, CPOperand object, CPOperand out, String opcode,
+ String instr) {
+ this(type, addresses, ranges, out, opcode, instr);
+ _localObject = object;
+ }
+
public static InitFEDInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
// We need 5 parts: Opcode, Type (Frame/Matrix), Addresses (list of Strings with
// url/ip:port/filepath), ranges and the output Operand
- if(parts.length != 5)
+ if(parts.length != 5 && parts.length != 6)
throw new DMLRuntimeException("Invalid number of operands in federated instruction: " + str);
String opcode = parts[0];
- CPOperand type, addresses, ranges, out;
- type = new CPOperand(parts[1]);
- addresses = new CPOperand(parts[2]);
- ranges = new CPOperand(parts[3]);
- out = new CPOperand(parts[4]);
- return new InitFEDInstruction(type, addresses, ranges, out, opcode, str);
+ if(parts.length == 5) {
+ CPOperand type, addresses, ranges, out;
+ type = new CPOperand(parts[1]);
+ addresses = new CPOperand(parts[2]);
+ ranges = new CPOperand(parts[3]);
+ out = new CPOperand(parts[4]);
+ return new InitFEDInstruction(type, addresses, ranges, out, opcode, str);
+ } else {
+ CPOperand type, addresses, object, ranges, out;
+ type = new CPOperand(parts[1]);
+ addresses = new CPOperand(parts[2]);
+ ranges = new CPOperand(parts[3]);
+ object = new CPOperand(parts[4]);
+ out = new CPOperand(parts[5]);
+ return new InitFEDInstruction(type, addresses, ranges, object, out, opcode, str);
+ }
}
@Override
public void processInstruction(ExecutionContext ec) {
+ if(_localObject == null)
+ processFedInit(ec);
+ else
+ processFromLocalFedInit(ec);
+ }
+
+ private void processFedInit(ExecutionContext ec){
String type = ec.getScalarInput(_type).getStringValue();
ListObject addresses = ec.getListObject(_addresses.getName());
ListObject ranges = ec.getListObject(_ranges.getName());
List<Pair<FederatedRange, FederatedData>> feds = new ArrayList<>();
if(addresses.getLength() * 2 != ranges.getLength())
- throw new DMLRuntimeException("Federated read needs twice the amount of addresses as ranges "
- + "(begin and end): addresses=" + addresses.getLength() + " ranges=" + ranges.getLength());
+ throw new DMLRuntimeException("Federated read needs twice the amount of addresses as ranges " + "(begin and end): addresses=" + addresses.getLength() + " ranges=" + ranges.getLength());
//check for duplicate addresses (would lead to overwrite with common variable names)
// TODO relax requirement by using different execution contexts per federated data?
Set<String> addCheck = new HashSet<>();
for( Data dat : addresses.getData() )
if( dat instanceof StringObject ) {
- String address = ((StringObject)dat).getStringValue();
+ String address = ((StringObject) dat).getStringValue();
if(addCheck.contains(address))
LOG.warn("Federated data contains address duplicates: " + addresses);
addCheck.add(address);
}
-
+
Types.DataType fedDataType;
if(type.equalsIgnoreCase(FED_MATRIX_IDENTIFIER))
fedDataType = Types.DataType.MATRIX;
@@ -136,6 +161,103 @@ public class InitFEDInstruction extends FEDInstruction implements LineageTraceab
int port = Integer.parseInt(parsedValues[1]);
String filePath = parsedValues[2];
+ if(DMLScript.FED_STATISTICS)
+ // register the federated worker for federated statistics creation
+ FederatedStatistics.registerFedWorker(host, port);
+
+ // get beginning and end of data ranges
+ List<Data> rangesData = ranges.getData();
+ Data beginData = rangesData.get(i * 2);
+ Data endData = rangesData.get(i * 2 + 1);
+ if(beginData.getDataType() != Types.DataType.LIST || endData.getDataType() != Types.DataType.LIST)
+ throw new DMLRuntimeException("Federated read ranges (lower, upper) have to be lists of dimensions");
+ List<Data> beginDimsData = ((ListObject) beginData).getData();
+ List<Data> endDimsData = ((ListObject) endData).getData();
+
+ // fill begin and end dims
+ long[] beginDims = new long[beginDimsData.size()];
+ long[] endDims = new long[beginDims.length];
+ for(int d = 0; d < beginDims.length; d++) {
+ beginDims[d] = ((ScalarObject) beginDimsData.get(d)).getLongValue();
+ endDims[d] = ((ScalarObject) endDimsData.get(d)).getLongValue();
+ }
+
+ usedDims[0] = Math.max(usedDims[0], endDims[0]);
+ usedDims[1] = Math.max(usedDims[1], endDims[1]);
+ try {
+ FederatedData federatedData = new FederatedData(fedDataType,
+ new InetSocketAddress(InetAddress.getByName(host), port), filePath);
+ feds.add(new ImmutablePair<>(new FederatedRange(beginDims, endDims), federatedData));
+ }
+ catch(UnknownHostException e) {
+ throw new DMLRuntimeException("federated host was unknown: " + host);
+ }
+ }
+ else {
+ throw new DMLRuntimeException("federated instruction only takes strings as addresses");
+ }
+ }
+
+ if(type.equalsIgnoreCase(FED_MATRIX_IDENTIFIER)) {
+ CacheableData<?> output = ec.getCacheableData(_output);
+ output.getDataCharacteristics().setRows(usedDims[0]).setCols(usedDims[1]);
+ federateMatrix(output, feds, null);
+ }
+ else if(type.equalsIgnoreCase(FED_FRAME_IDENTIFIER)) {
+ if(usedDims[1] > Integer.MAX_VALUE)
+ throw new DMLRuntimeException("federated Frame can not have more than max int columns, because the "
+ + "schema can only be max int length");
+ FrameObject output = ec.getFrameObject(_output);
+ output.getDataCharacteristics().setRows(usedDims[0]).setCols(usedDims[1]);
+ federateFrame(output, feds, null);
+ }
+ else {
+ throw new DMLRuntimeException("type \"" + type + "\" non valid federated type");
+ }
+ }
+
+ public void processFromLocalFedInit(ExecutionContext ec) {
+ String type = ec.getScalarInput(_type).getStringValue();
+ ListObject addresses = ec.getListObject(_addresses.getName());
+ ListObject ranges = ec.getListObject(_ranges.getName());
+ List<Pair<FederatedRange, FederatedData>> feds = new ArrayList<>();
+
+ CacheableData<?> co = ec.getCacheableData(_localObject);
+ CacheBlock cb = co.acquireReadAndRelease();
+
+ if(addresses.getLength() * 2 != ranges.getLength())
+ throw new DMLRuntimeException("Federated read needs twice the amount of addresses as ranges "
+ + "(begin and end): addresses=" + addresses.getLength() + " ranges=" + ranges.getLength());
+
+ //check for duplicate addresses (would lead to overwrite with common variable names)
+ Set<String> addCheck = new HashSet<>();
+ for(Data dat : addresses.getData())
+ if(dat instanceof StringObject) {
+ String address = ((StringObject) dat).getStringValue();
+ if(addCheck.contains(address))
+ LOG.warn("Federated data contains address duplicates: " + addresses);
+ addCheck.add(address);
+ }
+
+ Types.DataType fedDataType;
+ if(type.equalsIgnoreCase(FED_MATRIX_IDENTIFIER))
+ fedDataType = Types.DataType.MATRIX;
+ else if(type.equalsIgnoreCase(FED_FRAME_IDENTIFIER))
+ fedDataType = Types.DataType.FRAME;
+ else
+ throw new DMLRuntimeException("type \"" + type + "\" non valid federated type");
+
+ long[] usedDims = new long[] {0, 0};
+ CacheBlock[] cbs = new CacheBlock[addresses.getLength()];
+ for(int i = 0; i < addresses.getLength(); i++) {
+ Data addressData = addresses.getData().get(i);
+ if(addressData instanceof StringObject) {
+ // We split address into url/ip, the port and file path of file to read
+ String[] parsedValues = parseURLNoFilePath(((StringObject) addressData).getStringValue());
+ String host = parsedValues[0];
+ int port = Integer.parseInt(parsedValues[1]);
+ String filePath = co.getFileName();
+
if(DMLScript.FED_STATISTICS)
// register the federated worker for federated statistics creation
FederatedStatistics.registerFedWorker(host, port);
@@ -159,6 +281,11 @@ public class InitFEDInstruction extends FEDInstruction implements LineageTraceab
}
usedDims[0] = Math.max(usedDims[0], endDims[0]);
usedDims[1] = Math.max(usedDims[1], endDims[1]);
+
+ CacheBlock slice = cb instanceof MatrixBlock ? ((MatrixBlock)cb).slice((int) beginDims[0], (int) endDims[0]-1, (int) beginDims[1], (int) endDims[1]-1, true) :
+ ((FrameBlock)cb).slice((int) beginDims[0], (int) endDims[0]-1, (int) beginDims[1], (int) endDims[1]-1, true, new FrameBlock());
+ cbs[i] = slice;
+
try {
FederatedData federatedData = new FederatedData(fedDataType,
new InetSocketAddress(InetAddress.getByName(host), port), filePath);
@@ -172,10 +299,11 @@ public class InitFEDInstruction extends FEDInstruction implements LineageTraceab
throw new DMLRuntimeException("federated instruction only takes strings as addresses");
}
}
+
if(type.equalsIgnoreCase(FED_MATRIX_IDENTIFIER)) {
CacheableData<?> output = ec.getCacheableData(_output);
output.getDataCharacteristics().setRows(usedDims[0]).setCols(usedDims[1]);
- federateMatrix(output, feds);
+ federateMatrix(output, feds, cbs);
}
else if(type.equalsIgnoreCase(FED_FRAME_IDENTIFIER)) {
if(usedDims[1] > Integer.MAX_VALUE)
@@ -183,13 +311,44 @@ public class InitFEDInstruction extends FEDInstruction implements LineageTraceab
+ "schema can only be max int length");
FrameObject output = ec.getFrameObject(_output);
output.getDataCharacteristics().setRows(usedDims[0]).setCols(usedDims[1]);
- federateFrame(output, feds);
+ federateFrame(output, feds, cbs);
}
else {
throw new DMLRuntimeException("type \"" + type + "\" non valid federated type");
}
}
+ public static String[] parseURLNoFilePath(String input) {
+ try {
+ // Artificially making it http protocol.
+ // This is to avoid malformed address error in the URL passing.
+ // TODO: Construct new protocol name for Federated communication
+ URL address = new URL("http://" + input);
+ String host = address.getHost();
+ if(host.length() == 0)
+ throw new IllegalArgumentException("Missing Host name for federated address");
+ // The current system does not support ipv6, only ipv4.
+ // TODO: Support IPV6 address for Federated communication
+ String ipRegex = "^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$";
+ if(host.matches("^\\d+\\.\\d+\\.\\d+\\.\\d+$") && !host.matches(ipRegex))
+ throw new IllegalArgumentException("Input Host address looks like an IP address but is outside range");
+ int port = address.getPort();
+ if(port == -1)
+ port = DMLConfig.DEFAULT_FEDERATED_PORT;
+ if(address.getQuery() != null)
+ throw new IllegalArgumentException("Query is not supported");
+
+ if(address.getRef() != null)
+ throw new IllegalArgumentException("Reference is not supported");
+
+ return new String[] {host, String.valueOf(port)};
+ }
+ catch(MalformedURLException e) {
+ throw new IllegalArgumentException(
+ "federated address `" + input + "` does not fit required URL pattern of \"host:port/directory\"", e);
+ }
+ }
+
public static String[] parseURL(String input) {
try {
// Artificially making it http protocol.
@@ -231,6 +390,10 @@ public class InitFEDInstruction extends FEDInstruction implements LineageTraceab
}
public static void federateMatrix(CacheableData<?> output, List<Pair<FederatedRange, FederatedData>> workers) {
+ federateMatrix(output, workers, null);
+ }
+
+ public static void federateMatrix(CacheableData<?> output, List<Pair<FederatedRange, FederatedData>> workers, CacheBlock[] blocks) {
List<Pair<FederatedRange, FederatedData>> fedMapping = new ArrayList<>();
for(Pair<FederatedRange, FederatedData> e : workers)
@@ -239,6 +402,7 @@ public class InitFEDInstruction extends FEDInstruction implements LineageTraceab
long id = FederationUtils.getNextFedDataID();
boolean rowPartitioned = true;
boolean colPartitioned = true;
+ int k = 0;
for(Pair<FederatedRange, FederatedData> entry : fedMapping) {
FederatedRange range = entry.getKey();
FederatedData value = entry.getValue();
@@ -248,7 +412,10 @@ public class InitFEDInstruction extends FEDInstruction implements LineageTraceab
long[] dims = output.getDataCharacteristics().getDims();
for(int i = 0; i < dims.length; i++)
dims[i] = endDims[i] - beginDims[i];
- idResponses.add(new ImmutablePair<>(value, value.initFederatedData(id)));
+ if(blocks == null || blocks.length == 0)
+ idResponses.add(new ImmutablePair<>(value, value.initFederatedData(id)));
+ else
+ idResponses.add(new ImmutablePair<>(value, value.initFederatedDataFromLocal(id, blocks[k++])));
}
rowPartitioned &= (range.getSize(1) == output.getNumColumns());
colPartitioned &= (range.getSize(0) == output.getNumRows());
@@ -284,7 +451,7 @@ public class InitFEDInstruction extends FEDInstruction implements LineageTraceab
LOG.debug("Fed map Inited:" + output.getFedMapping());
}
- public static void federateFrame(FrameObject output, List<Pair<FederatedRange, FederatedData>> workers) {
+ public static void federateFrame(FrameObject output, List<Pair<FederatedRange, FederatedData>> workers, CacheBlock[] blocks) {
List<Pair<FederatedRange, FederatedData>> fedMapping = new ArrayList<>();
for(Pair<FederatedRange, FederatedData> e : workers)
fedMapping.add(e);
@@ -295,6 +462,7 @@ public class InitFEDInstruction extends FEDInstruction implements LineageTraceab
long id = FederationUtils.getNextFedDataID();
boolean rowPartitioned = true;
boolean colPartitioned = true;
+ int k = 0;
for(Pair<FederatedRange, FederatedData> entry : fedMapping) {
FederatedRange range = entry.getKey();
FederatedData value = entry.getValue();
@@ -305,8 +473,12 @@ public class InitFEDInstruction extends FEDInstruction implements LineageTraceab
for(int i = 0; i < dims.length; i++) {
dims[i] = endDims[i] - beginDims[i];
}
- idResponses.add(
- new ImmutablePair<>(value, new ImmutablePair<>((int) beginDims[1], value.initFederatedData(id))));
+ if(blocks == null || blocks.length == 0)
+ idResponses.add(
+ new ImmutablePair<>(value, new ImmutablePair<>((int) beginDims[1], value.initFederatedData(id))));
+ else
+ idResponses.add(
+ new ImmutablePair<>(value, new ImmutablePair<>((int) beginDims[1], value.initFederatedDataFromLocal(id, blocks[k++]))));
}
rowPartitioned &= (range.getSize(1) == output.getNumColumns());
colPartitioned &= (range.getSize(0) == output.getNumRows());
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 315871ef50..ee4e668c55 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -4147,6 +4147,15 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
return slice(rl, ru, cl, cu, true, ret);
}
+ /**
+ * Slice out a row block
+ * @param rl The row lower to start from
+ * @param ru The row lower to end at
+ * @param cl The col lower to start from
+ * @param cu The col lower to end at
+ * @param deep Deep copy or not
+ * @return The sliced out matrix block.
+ */
public final MatrixBlock slice(int rl, int ru, int cl, int cu, boolean deep){
return slice(rl, ru, cl, cu, deep, null);
}
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java
index 64785c4ed8..1096a5147e 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -1839,7 +1839,6 @@ public class TestUtils
public static double[][] generateTestMatrix(int rows, int cols, double min, double max, double sparsity, long seed) {
double[][] matrix = new double[rows][cols];
Random random = (seed == -1) ? TestUtils.random : new Random(seed);
-
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
if (random.nextDouble() > sparsity)
@@ -3022,6 +3021,10 @@ public class TestUtils
return host + ':' + port + '/' + input;
}
+ public static String federatedAddressNoInput(String host, int port) {
+ return host + ':' + port;
+ }
+
public static double gaussian_probability (double point)
// "Handbook of Mathematical Functions", ed. by M. Abramowitz and I.A. Stegun,
// U.S. Nat-l Bureau of Standards, 10th print (Dec 1972), Sec. 7.1.26, p. 299
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTransferLocalDataTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTransferLocalDataTest.java
new file mode 100644
index 0000000000..f612ca14e4
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTransferLocalDataTest.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedTransferLocalDataTest extends AutomatedTestBase {
+ private final static String TEST_DIR = "functions/federated/";
+ private final static String TEST_NAME1 = "FederatedTransferLocalDataTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedTransferLocalDataTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+ @Parameterized.Parameter(2)
+ public boolean rowPartitioned;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"}));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {12, 4, true}, {12, 4, false},
+ });
+ }
+
+ @Test
+ public void federatedTransferCP() { runTransferTest(Types.ExecMode.SINGLE_NODE); }
+
+ @Test
+ public void federatedTransferSP() { runTransferTest(Types.ExecMode.SPARK); }
+
+ private void runTransferTest(Types.ExecMode execMode) {
+ String TEST_NAME = TEST_NAME1;
+ ExecMode platformOld = setExecMode(execMode);
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ double[][] X = getRandomMatrix(rows, cols, 1, 5, 1, 3);
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(rows, cols, blocksize, (long) rows * cols);
+ writeInputMatrixWithMTD("X", X, false, mc);
+
+ // empty script name because we don't execute any script, just start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ int port3 = getRandomAvailablePort();
+ int port4 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
+ Thread t4 = startLocalFedWorkerThread(port4);
+
+ rtplatform = execMode;
+ if(rtplatform == Types.ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-stats", "100", "-args", input("X"), expected("S")};
+
+ runTest(null);
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100", "-nvargs",
+ "in_X=" + input("X"),
+ "in_X1=" + TestUtils.federatedAddressNoInput("localhost", port1),
+ "in_X2=" + TestUtils.federatedAddressNoInput("localhost", port2),
+ "in_X3=" + TestUtils.federatedAddressNoInput("localhost", port3),
+ "in_X4=" + TestUtils.federatedAddressNoInput("localhost", port4), "rows=" + rows, "cols=" + cols,
+ "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};
+
+ runTest(null);
+
+ // compare via files
+ compareResults(1e-9, "Stat-DML1", "Stat-DML2");
+
+ TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+ resetExecMode(platformOld);
+ }
+}
diff --git a/src/test/scripts/functions/federated/FederatedTransferLocalDataTest.dml b/src/test/scripts/functions/federated/FederatedTransferLocalDataTest.dml
new file mode 100644
index 0000000000..e4dd5180db
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedTransferLocalDataTest.dml
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A1 = read($in_X);
+
+if ($rP) {
+ A = federated(local_matrix=A1, addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
+} else {
+ A = federated(local_matrix=A1, addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols)));
+}
+print(toString(A))
+write(A, $out_S);
diff --git a/src/test/scripts/functions/federated/FederatedTransferLocalDataTestReference.dml b/src/test/scripts/functions/federated/FederatedTransferLocalDataTestReference.dml
new file mode 100644
index 0000000000..dcce9747e8
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedTransferLocalDataTestReference.dml
@@ -0,0 +1,23 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+write(A, $2);