You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2020/09/29 12:35:02 UTC
[systemds] branch master updated: [SYSTEMDS-2546,
2547] Fix Federated rbind/cbind
This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 7fb86b9 [SYSTEMDS-2546,2547] Fix Federated rbind/cbind
7fb86b9 is described below
commit 7fb86b9e403f9031c3dad094daaa5c411a139aac
Author: Kevin Innerebner <ke...@yahoo.com>
AuthorDate: Tue Sep 22 17:06:40 2020 +0200
[SYSTEMDS-2546,2547] Fix Federated rbind/cbind
Adds FederatedLocalData so that we can use local data without the
necessity to send it to a worker. This allows reusing a lot of code, but
might lead to overhead. Other options to handle this scenario exist.
- Adds support for local data rbind and cbind.
- Fix federated rbind/cbind with support for local data
- Adds `FederatedLocalData` so that we can use local data without the
necessity to send it to a worker. This allows reusing a lot of code,
but might lead to overhead.
- Add return comment to `FederatedData.copyWithNewID()`
- Ignore failing privacy transfer tests
Closing #1062
---
.../controlprogram/federated/FederatedData.java | 25 +++---
.../federated/FederatedLocalData.java | 59 +++++++++++++
.../federated/FederatedWorkerHandler.java | 14 ++--
.../controlprogram/federated/FederationMap.java | 37 +++++---
.../controlprogram/federated/FederationUtils.java | 12 +++
.../runtime/instructions/InstructionUtils.java | 20 +++++
.../instructions/fed/AppendFEDInstruction.java | 98 ++++++++++++----------
.../instructions/fed/FEDInstructionUtils.java | 23 ++++-
.../federated/primitives/FederatedRCBindTest.java | 28 +++++--
.../privacy/FederatedWorkerHandlerTest.java | 6 +-
.../functions/federated/FederatedRCBindTest.dml | 24 ++++--
.../federated/FederatedRCBindTestReference.dml | 17 ++--
12 files changed, 268 insertions(+), 95 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
index 8a3fbd2..4a8387f 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
@@ -63,20 +63,10 @@ public class FederatedData {
_dataType = dataType;
_address = address;
_filepath = filepath;
- if( _address != null )
+ if(_address != null)
_allFedSites.add(_address);
}
-
- /**
- * Make a copy of the <code>FederatedData</code> metadata, but use another varID (refer to another object on worker)
- * @param other the <code>FederatedData</code> of which we want to copy the worker information from
- * @param varID the varID of the variable we refer to
- */
- public FederatedData(FederatedData other, long varID) {
- this(other._dataType, other._address, other._filepath);
- _varID = varID;
- }
-
+
public InetSocketAddress getAddress() {
return _address;
}
@@ -102,6 +92,17 @@ public class FederatedData {
&& _address.equals(that._address);
}
+ /**
+ * Make a copy of the <code>FederatedData</code> metadata, but use another varID (refer to another object on worker)
+ * @param varID the varID of the variable we refer to
+ * @return new <code>FederatedData</code> with different varID set
+ */
+ public FederatedData copyWithNewID(long varID) {
+ FederatedData copy = new FederatedData(_dataType, _address, _filepath);
+ copy.setVarID(varID);
+ return copy;
+ }
+
public synchronized Future<FederatedResponse> initFederatedData(long id) {
if(isInitialized())
throw new DMLRuntimeException("Tried to init already initialized data");
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java
new file mode 100644
index 0000000..1589dc3
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLocalData.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.federated;
+
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Future;
+
+import org.apache.log4j.Logger;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+
+public class FederatedLocalData extends FederatedData {
+ protected final static Logger log = Logger.getLogger(FederatedWorkerHandler.class);
+
+ private static final ExecutionContextMap ecm = new ExecutionContextMap();
+ private static final FederatedWorkerHandler fwh = new FederatedWorkerHandler(ecm);
+
+ private final CacheableData<?> _data;
+
+ public FederatedLocalData(long id, CacheableData<?> data) {
+ super(data.getDataType(), null, data.getFileName());
+ _data = data;
+ synchronized(ecm) {
+ ecm.get(-1).setVariable(Long.toString(id), _data);
+ }
+ setVarID(id);
+ }
+
+ @Override
+ boolean equalAddress(FederatedData that) {
+ return that.getClass().equals(this.getClass());
+ }
+
+ @Override
+ public FederatedData copyWithNewID(long varID) {
+ return new FederatedLocalData(varID, _data);
+ }
+
+ @Override
+ public synchronized Future<FederatedResponse> executeFederatedOperation(FederatedRequest... request) {
+ return CompletableFuture.completedFuture(fwh.createResponse(request));
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 323248e..6764f12 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -76,6 +76,10 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
+ ctx.writeAndFlush(createResponse(msg)).addListener(new CloseListener());
+ }
+
+ public FederatedResponse createResponse(Object msg) {
if( log.isDebugEnabled() ){
log.debug("Received: " + msg.getClass().getSimpleName());
}
@@ -94,7 +98,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
}
PrivacyMonitor.setCheckPrivacy(request.checkPrivacy());
PrivacyMonitor.clearCheckedConstraints();
-
+
//execute command and handle privacy constraints
FederatedResponse tmp = executeCommand(request);
conditionalAddCheckedConstraints(request, tmp);
@@ -102,9 +106,9 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
//select the response for the entire batch of requests
if (!tmp.isSuccessful()) {
log.error("Command " + request.getType() + " failed: "
- + tmp.getErrorMessage() + "full command: \n" + request.toString());
+ + tmp.getErrorMessage() + "full command: \n" + request.toString());
response = (response == null || response.isSuccessful())
- ? tmp : response; //return first error
+ ? tmp : response; //return first error
}
else if( request.getType() == RequestType.GET_VAR ) {
if( response != null && response.isSuccessful() )
@@ -114,13 +118,13 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
else if( response == null && i == requests.length-1 ) {
response = tmp; //return last
}
-
+
if (DMLScript.STATISTICS && request.getType() == RequestType.CLEAR && Statistics.allowWorkerStatistics){
System.out.println("Federated Worker " + Statistics.display());
Statistics.reset();
}
}
- ctx.writeAndFlush(response).addListener(new CloseListener());
+ return response;
}
private static void conditionalAddCheckedConstraints(FederatedRequest request, FederatedResponse response){
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 7d537c9..6d2e7c1 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -194,6 +194,25 @@ public class FederationMap
return ret;
}
+ public FederationMap identCopy(long tid, long id) {
+ Future<FederatedResponse>[] copyInstr = execute(tid,
+ new FederatedRequest(RequestType.EXEC_INST, _ID,
+ VariableCPInstruction.prepareCopyInstruction(Long.toString(_ID), Long.toString(id)).toString()));
+ for(Future<FederatedResponse> future : copyInstr) {
+ try {
+ FederatedResponse response = future.get();
+ if(!response.isSuccessful())
+ response.throwExceptionFromResponse();
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ }
+ FederationMap copyFederationMap = copyWithNewID(id);
+ copyFederationMap._type = _type;
+ return copyFederationMap;
+ }
+
public FederationMap copyWithNewID() {
return copyWithNewID(FederationUtils.getNextFedDataID());
}
@@ -202,7 +221,7 @@ public class FederationMap
Map<FederatedRange, FederatedData> map = new TreeMap<>();
//TODO handling of file path, but no danger as never written
for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() )
- map.put(new FederatedRange(e.getKey()), new FederatedData(e.getValue(), id));
+ map.put(new FederatedRange(e.getKey()), e.getValue().copyWithNewID(id));
return new FederationMap(id, map, _type);
}
@@ -210,26 +229,22 @@ public class FederationMap
Map<FederatedRange, FederatedData> map = new TreeMap<>();
//TODO handling of file path, but no danger as never written
for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() )
- map.put(new FederatedRange(e.getKey(), clen), new FederatedData(e.getValue(), id));
+ map.put(new FederatedRange(e.getKey(), clen), e.getValue().copyWithNewID(id));
return new FederationMap(id, map);
}
- public FederationMap rbind(long offset, FederationMap that) {
- for( Entry<FederatedRange, FederatedData> e : that._fedMap.entrySet() ) {
- _fedMap.put(
- new FederatedRange(e.getKey()).shift(offset, 0),
- new FederatedData(e.getValue(), _ID));
+ public FederationMap bind(long rOffset, long cOffset, FederationMap that) {
+ for(Entry<FederatedRange, FederatedData> e : that._fedMap.entrySet()) {
+ _fedMap.put(new FederatedRange(e.getKey()).shift(rOffset, cOffset), e.getValue().copyWithNewID(_ID));
}
return this;
}
-
+
public FederationMap transpose() {
Map<FederatedRange, FederatedData> tmp = new TreeMap<>(_fedMap);
_fedMap.clear();
for( Entry<FederatedRange, FederatedData> e : tmp.entrySet() ) {
- _fedMap.put(
- new FederatedRange(e.getKey()).transpose(),
- new FederatedData(e.getValue(), _ID));
+ _fedMap.put(new FederatedRange(e.getKey()).transpose(), e.getValue().copyWithNewID(_ID));
}
//derive output type
switch(_type) {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index bdec97f..0872c59 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -20,13 +20,16 @@
package org.apache.sysds.runtime.controlprogram.federated;
import java.util.Arrays;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.Future;
import org.apache.log4j.Logger;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.functionobjects.Builtin;
@@ -206,4 +209,13 @@ public class FederationUtils {
throw new DMLRuntimeException(ex);
}
}
+
+ public static FederationMap federateLocalData(CacheableData<?> data) {
+ long id = FederationUtils.getNextFedDataID();
+ FederatedLocalData federatedLocalData = new FederatedLocalData(id, data);
+ Map<FederatedRange, FederatedData> fedMap = new HashMap<>();
+ fedMap.put(new FederatedRange(new long[2], new long[] {data.getNumRows(), data.getNumColumns()}),
+ federatedLocalData);
+ return new FederationMap(id, fedMap);
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index 904f46d..a92f4dd 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -19,6 +19,7 @@
package org.apache.sysds.runtime.instructions;
+import java.util.Arrays;
import java.util.StringTokenizer;
import org.apache.sysds.common.Types;
@@ -140,6 +141,25 @@ public class InstructionUtils
return numFields;
}
+ public static int checkNumFields( String[] parts, int... expected ) {
+ int numParts = parts.length;
+ int numFields = numParts - 1; //account for opcode
+
+ if (Arrays.stream(expected).noneMatch((i) -> numFields == i)) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("checkNumFields() -- expected number (");
+ for (int i = 0; i < expected.length; i++) {
+ sb.append(expected[i]);
+ if (i != expected.length - 1)
+ sb.append(", ");
+ }
+ sb.append(") != is not equal to actual number (").append(numFields).append(").");
+ throw new DMLRuntimeException(sb.toString());
+ }
+
+ return numFields;
+ }
+
public static int checkNumFields( String str, int expected1, int expected2 ) {
//note: split required for empty tokens
int numParts = str.split(Instruction.OPERAND_DELIM).length;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
index d17b7b5..ee0d8aa 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
@@ -22,8 +22,7 @@ package org.apache.sysds.runtime.instructions.fed;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.OffsetColumnIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -33,71 +32,80 @@ import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
public class AppendFEDInstruction extends BinaryFEDInstruction {
- protected boolean _cbind; //otherwise rbind
-
- protected AppendFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out,
- boolean cbind, String opcode, String istr) {
+ protected boolean _cbind; // otherwise rbind
+
+ protected AppendFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, boolean cbind,
+ String opcode, String istr) {
super(FEDType.Append, op, in1, in2, out, opcode, istr);
_cbind = cbind;
}
-
+
public static AppendFEDInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
- InstructionUtils.checkNumFields(parts, 5, 4);
-
+ InstructionUtils.checkNumFields(parts, 6, 5, 4);
+
String opcode = parts[0];
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[parts.length - 2]);
boolean cbind = Boolean.parseBoolean(parts[parts.length - 1]);
-
+
Operator op = new ReorgOperator(OffsetColumnIndex.getOffsetColumnIndexFnObject(-1));
return new AppendFEDInstruction(op, in1, in2, out, cbind, opcode, str);
}
-
+
@Override
public void processInstruction(ExecutionContext ec) {
- //get inputs
+ // get inputs
MatrixObject mo1 = ec.getMatrixObject(input1.getName());
MatrixObject mo2 = ec.getMatrixObject(input2.getName());
DataCharacteristics dc1 = mo1.getDataCharacteristics();
DataCharacteristics dc2 = mo1.getDataCharacteristics();
-
- //check input dimensions
- if (_cbind && mo1.getNumRows() != mo2.getNumRows()) {
- throw new DMLRuntimeException(
- "Append-cbind is not possible for federated input matrices " + input1.getName() + " and "
- + input2.getName() + " with different number of rows: " + mo1.getNumRows() + " vs "
- + mo2.getNumRows());
- }
- else if (!_cbind && mo1.getNumColumns() != mo2.getNumColumns()) {
- throw new DMLRuntimeException(
- "Append-rbind is not possible for federated input matrices " + input1.getName() + " and "
- + input2.getName() + " with different number of columns: " + mo1.getNumColumns()
- + " vs " + mo2.getNumColumns());
+
+ // check input dimensions
+ if(_cbind && mo1.getNumRows() != mo2.getNumRows()) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("Append-cbind is not possible for federated input matrices ");
+ sb.append(input1.getName()).append(" and ").append(input2.getName());
+ sb.append(" with different number of rows: ");
+ sb.append(mo1.getNumRows()).append(" vs ").append(mo2.getNumRows());
+ throw new DMLRuntimeException(sb.toString());
}
-
- if( mo1.isFederated(FType.ROW) && _cbind ) {
- FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
- FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
- new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()});
- mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
- //derive new fed mapping for output
- MatrixObject out = ec.getMatrixObject(output);
- out.getDataCharacteristics().set(dc1.getRows(), dc1.getCols()+dc2.getCols(),
- dc1.getBlocksize(), dc1.getNonZeros()+dc2.getNonZeros());
- out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID()));
+ else if(!_cbind && mo1.getNumColumns() != mo2.getNumColumns()) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("Append-rbind is not possible for federated input matrices ");
+ sb.append(input1.getName()).append(" and ").append(input2.getName());
+ sb.append(" with different number of columns: ");
+ sb.append(mo1.getNumColumns()).append(" vs ").append(mo2.getNumColumns());
+ throw new DMLRuntimeException(sb.toString());
}
- else if( mo1.isFederated(FType.ROW) && mo2.isFederated(FType.ROW) && !_cbind ) {
- MatrixObject out = ec.getMatrixObject(output);
- out.getDataCharacteristics().set(dc1.getRows()+dc2.getRows(), dc1.getCols(),
- dc1.getBlocksize(), dc1.getNonZeros()+dc2.getNonZeros());
- long id = FederationUtils.getNextFedDataID();
- out.setFedMapping(mo1.getFedMapping().copyWithNewID(id).rbind(dc1.getRows(), mo2.getFedMapping()));
+
+ FederationMap fm1;
+ if(mo1.isFederated())
+ fm1 = mo1.getFedMapping();
+ else
+ fm1 = FederationUtils.federateLocalData(mo1);
+ FederationMap fm2;
+ if(mo2.isFederated())
+ fm2 = mo2.getFedMapping();
+ else
+ fm2 = FederationUtils.federateLocalData(mo2);
+
+ MatrixObject out = ec.getMatrixObject(output);
+ long id = FederationUtils.getNextFedDataID();
+ if(_cbind) {
+ out.getDataCharacteristics().set(dc1.getRows(),
+ dc1.getCols() + dc2.getCols(),
+ dc1.getBlocksize(),
+ dc1.getNonZeros() + dc2.getNonZeros());
+ out.setFedMapping(fm1.identCopy(getTID(), id).bind(0, dc1.getCols(), fm2.identCopy(getTID(), id)));
}
- else { //other combinations
- throw new DMLRuntimeException("Federated AggregateBinary not supported with the "
- + "following federated objects: "+mo1.isFederated()+" "+mo2.isFederated());
+ else {
+ out.getDataCharacteristics().set(dc1.getRows() + dc2.getRows(),
+ dc1.getCols(),
+ dc1.getBlocksize(),
+ dc1.getNonZeros() + dc2.getNonZeros());
+ out.setFedMapping(fm1.identCopy(getTID(), id).bind(dc1.getRows(), 0, fm2.identCopy(getTID(), id)));
}
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 2e41aa5..0101954 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -24,9 +24,19 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.instructions.Instruction;
-import org.apache.sysds.runtime.instructions.cp.*;
+import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.WriteSPInstruction;
@@ -70,7 +80,9 @@ public class FEDInstructionUtils {
BinaryCPInstruction instruction = (BinaryCPInstruction) inst;
if( (instruction.input1.isMatrix() && ec.getMatrixObject(instruction.input1).isFederated())
|| (instruction.input2.isMatrix() && ec.getMatrixObject(instruction.input2).isFederated()) ) {
- if(!instruction.getOpcode().equals("append")) //TODO support rbind/cbind
+ if(instruction.getOpcode().equals("append"))
+ fedinst = AppendFEDInstruction.parseInstruction(inst.getInstructionString());
+ else
fedinst = BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
}
}
@@ -145,6 +157,13 @@ public class FEDInstructionUtils {
fedinst = AppendFEDInstruction.parseInstruction(instruction.getInstructionString());
}
}
+ else if (inst instanceof AppendGSPInstruction) {
+ AppendGSPInstruction instruction = (AppendGSPInstruction) inst;
+ Data data = ec.getVariable(instruction.input1);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederated()) {
+ fedinst = AppendFEDInstruction.parseInstruction(instruction.getInstructionString());
+ }
+ }
//set thread id for federated context management
if( fedinst != null ) {
fedinst.setTID(ec.getTID());
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
index 712c041..ca745b9 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
@@ -54,7 +54,11 @@ public class FederatedRCBindTest extends AutomatedTestBase {
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R", "C"}));
+ // we generate 3 datasets, both with rbind and cbind (F...Federated, L...Local):
+ // F-F, F-L, L-F
+ addTestConfiguration(TEST_NAME,
+ new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,
+ new String[] {"R_FF", "R_FL", "R_LF", "C_FF", "C_FL", "C_LF"}));
}
@Test
@@ -76,15 +80,21 @@ public class FederatedRCBindTest extends AutomatedTestBase {
double[][] A = getRandomMatrix(rows, cols, -10, 10, 1, 1);
writeInputMatrixWithMTD("A", A, false, new MatrixCharacteristics(rows, cols, blocksize, rows * cols));
+ double[][] B = getRandomMatrix(rows, cols, -10, 10, 1, 2);
+ writeInputMatrixWithMTD("B", B, false, new MatrixCharacteristics(rows, cols, blocksize, rows * cols));
- int port = getRandomAvailablePort();
- Thread t = startLocalFedWorkerThread(port);
+ int port1 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1);
+ int port2 = getRandomAvailablePort();
+ Thread t2 = startLocalFedWorkerThread(port2);
// we need the reference file to not be written to hdfs, so we get the correct format
rtplatform = Types.ExecMode.SINGLE_NODE;
// Run reference dml script with normal matrix for Row/Col sum
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-args", input("A"), expected("R"), expected("C")};
+ programArgs = new String[] {"-nvargs", "in1=" + input("A"), "in2=" + input("B"), "out_R_FF=" + expected("R_FF"),
+ "out_R_FL=" + expected("R_FL"), "out_R_LF=" + expected("R_LF"), "out_C_FF=" + expected("C_FF"),
+ "out_C_FL=" + expected("C_FL"), "out_C_LF=" + expected("C_LF")};
runTest(true, false, null, -1);
// reference file should not be written to hdfs, so we set platform here
@@ -95,16 +105,18 @@ public class FederatedRCBindTest extends AutomatedTestBase {
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-nvargs",
- "in=" + TestUtils.federatedAddress(port, input("A")), "rows=" + rows,
- "cols=" + cols, "out_R=" + output("R"), "out_C=" + output("C")};
+ programArgs = new String[] {"-nvargs", "in1=" + TestUtils.federatedAddress(port1, input("A")),
+ "in2=" + TestUtils.federatedAddress(port2, input("B")), "in2_local=" + input("B"), "rows=" + rows,
+ "cols=" + cols, "out_R_FF=" + output("R_FF"), "out_R_FL=" + output("R_FL"),
+ "out_R_LF=" + output("R_LF"), "out_C_FF=" + output("C_FF"), "out_C_FL=" + output("C_FL"),
+ "out_C_LF=" + output("C_LF")};
runTest(true, false, null, -1);
// compare all sums via files
compareResults(1e-11);
- TestUtils.shutdownThread(t);
+ TestUtils.shutdownThreads(t1, t2);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
index 7b18293..c75e9a2 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
@@ -22,6 +22,7 @@ package org.apache.sysds.test.functions.privacy;
import java.util.Arrays;
import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.privacy.PrivacyConstraint;
@@ -29,8 +30,8 @@ import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
+import org.junit.Ignore;
import org.junit.Test;
-import org.apache.sysds.common.Types;
import static java.lang.Thread.sleep;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@@ -207,16 +208,19 @@ public class FederatedWorkerHandlerTest extends AutomatedTestBase {
}
@Test
+ @Ignore
public void transferPrivateTest() {
federatedRCBind(Types.ExecMode.SINGLE_NODE, PrivacyLevel.Private, DMLRuntimeException.class);
}
@Test
+ @Ignore
public void transferPrivateAggregationTest() {
federatedRCBind(Types.ExecMode.SINGLE_NODE, PrivacyLevel.PrivateAggregation, DMLRuntimeException.class);
}
@Test
+ @Ignore
public void transferNonePrivateTest() {
federatedRCBind(Types.ExecMode.SINGLE_NODE, PrivacyLevel.None, null);
}
diff --git a/src/test/scripts/functions/federated/FederatedRCBindTest.dml b/src/test/scripts/functions/federated/FederatedRCBindTest.dml
index 1084f8c..4447b95 100644
--- a/src/test/scripts/functions/federated/FederatedRCBindTest.dml
+++ b/src/test/scripts/functions/federated/FederatedRCBindTest.dml
@@ -19,9 +19,21 @@
#
#-------------------------------------------------------------
-A = federated(addresses=list($in), ranges=list(list(0, 0), list($rows, $cols)))
-B = federated(addresses=list($in), ranges=list(list(0, 0), list($rows, $cols)))
-R = rbind(A, B)
-C = cbind(A, B)
-write(R, $out_R)
-write(C, $out_C)
+A = federated(addresses=list($in1), ranges=list(list(0, 0), list($rows, $cols)))
+BF = federated(addresses=list($in2), ranges=list(list(0, 0), list($rows, $cols)))
+B = read($in2_local)
+
+R_FF = rbind(A, BF)
+C_FF = cbind(A, BF)
+R_FL = rbind(A, B)
+C_FL = cbind(A, B)
+R_LF = rbind(B, A)
+C_LF = cbind(B, A)
+
+write(R_FF, $out_R_FF)
+write(R_FL, $out_R_FL)
+write(R_LF, $out_R_LF)
+
+write(C_FF, $out_C_FF)
+write(C_FL, $out_C_FL)
+write(C_LF, $out_C_LF)
diff --git a/src/test/scripts/functions/federated/FederatedRCBindTestReference.dml b/src/test/scripts/functions/federated/FederatedRCBindTestReference.dml
index dd6d3cb..034a957 100644
--- a/src/test/scripts/functions/federated/FederatedRCBindTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedRCBindTestReference.dml
@@ -19,8 +19,15 @@
#
#-------------------------------------------------------------
-A = read($1)
-R = rbind(A, A)
-C = cbind(A, A)
-write(R, $2)
-write(C, $3)
+A = read($in1)
+B = read($in2)
+R = rbind(A, B)
+C = cbind(A, B)
+R_LF = rbind(B, A)
+C_LF = cbind(B, A)
+write(R, $out_R_FF)
+write(R, $out_R_FL)
+write(R_LF, $out_R_LF)
+write(C, $out_C_FF)
+write(C, $out_C_FL)
+write(C_LF, $out_C_LF)