You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/07/04 21:20:49 UTC
[systemds] branch master updated: [SYSTEMDS-3018] Compilation of
federated ops under privacy constraints
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new d522183 [SYSTEMDS-3018] Compilation of federated ops under privacy constraints
d522183 is described below
commit d522183249bd792fffa7e00b833213eb172f1fbf
Author: sebwrede <sw...@know-center.at>
AuthorDate: Sun Jul 4 23:06:54 2021 +0200
[SYSTEMDS-3018] Compilation of federated ops under privacy constraints
Closes #1313.
---
.../java/org/apache/sysds/hops/AggBinaryOp.java | 1 -
.../java/org/apache/sysds/hops/AggUnaryOp.java | 1 -
src/main/java/org/apache/sysds/hops/BinaryOp.java | 1 -
src/main/java/org/apache/sysds/hops/DataOp.java | 4 +
src/main/java/org/apache/sysds/hops/Hop.java | 35 +-
src/main/java/org/apache/sysds/hops/ReorgOp.java | 5 +-
src/main/java/org/apache/sysds/hops/TernaryOp.java | 2 -
.../apache/sysds/hops/rewrite/ProgramRewriter.java | 3 +
.../RewriteAlgebraicSimplificationDynamic.java | 2 +-
.../hops/rewrite/RewriteFederatedExecution.java | 197 ++++++++
.../runtime/instructions/FEDInstructionParser.java | 6 +
.../runtime/instructions/cp/SqlCPInstruction.java | 9 +
.../fed/AggregateBinaryFEDInstruction.java | 74 ++-
.../runtime/privacy/propagation/OperatorType.java | 47 ++
.../privacy/propagation/PrivacyPropagator.java | 503 +++++++++------------
.../fedplanning/FederatedMultiplyPlanningTest.java | 8 +-
16 files changed, 552 insertions(+), 346 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index a17430e..9b3356f 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -102,7 +102,6 @@ public class AggBinaryOp extends MultiThreadedHop
outerOp = outOp;
getInput().add(0, in1);
getInput().add(1, in2);
- updateETFed();
in1.getParent().add(this);
in2.getParent().add(this);
diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index 4840f97..9c18f49 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -59,7 +59,6 @@ public class AggUnaryOp extends MultiThreadedHop
_direction = idx;
getInput().add(0, inp);
inp.getParent().add(this);
- updateETFed();
}
@Override
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index f114bc0..8826f92 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -97,7 +97,6 @@ public class BinaryOp extends MultiThreadedHop
op = o;
getInput().add(0, inp1);
getInput().add(1, inp2);
- updateETFed();
inp1.getParent().add(this);
inp2.getParent().add(this);
diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java b/src/main/java/org/apache/sysds/hops/DataOp.java
index 52d424e..03dfd08 100644
--- a/src/main/java/org/apache/sysds/hops/DataOp.java
+++ b/src/main/java/org/apache/sysds/hops/DataOp.java
@@ -355,6 +355,10 @@ public class DataOp extends Hop {
return( _op == OpOpData.PERSISTENTREAD || _op == OpOpData.PERSISTENTWRITE );
}
+ public boolean isFederatedData(){
+ return _op == OpOpData.FEDERATED;
+ }
+
@Override
public String getOpString() {
String s = new String("");
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index e01ffa1..0ef6d96 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -86,9 +86,9 @@ public abstract class Hop implements ParseInfo {
protected ExecType _etypeForced = null; //exec type forced via platform or external optimizer
/**
- * Boolean defining if the output of the operation should be federated.
- * If it is true, the output should be kept at federated sites.
- * If it is false, the output should be retrieved by the coordinator.
+ * Field defining if the output of the operation should be federated.
+ * If it is fout, the output should be kept at federated sites.
+ * If it is lout, the output should be retrieved by the coordinator.
*/
protected FederatedOutput _federatedOutput = FederatedOutput.NONE;
@@ -173,6 +173,14 @@ public abstract class Hop implements ParseInfo {
{
return _etype;
}
+
+ public void setExecType(ExecType execType){
+ _etype = execType;
+ }
+
+ public void setFederatedOutput(FederatedOutput federatedOutput){
+ _federatedOutput = federatedOutput;
+ }
public void resetExecType()
{
@@ -770,25 +778,12 @@ public abstract class Hop implements ParseInfo {
/**
* Update the execution type if input is federated and federated compilation is activated.
* Federated compilation is activated in OptimizerUtils.
+ * This method only has an effect if FEDERATED_COMPILATION is activated.
*/
protected void updateETFed(){
- if ( inputIsFED() )
+ if ( _federatedOutput == FederatedOutput.FOUT || _federatedOutput == FederatedOutput.LOUT )
_etype = ExecType.FED;
}
-
- /**
- * Returns true if any input has federated ExecType.
- * This method can only return true if FedDecision is activated.
- * @return true if any input has federated ExecType
- */
- protected boolean inputIsFED(){
- if ( !OptimizerUtils.FEDERATED_COMPILATION )
- return false;
- for ( Hop input : _input )
- if ( input.isFederated() || input.isFederatedOutput() )
- return true;
- return false;
- }
public boolean isFederated(){
return getExecType() == ExecType.FED;
@@ -798,6 +793,10 @@ public abstract class Hop implements ParseInfo {
return _federatedOutput == FederatedOutput.FOUT;
}
+ public boolean someInputFederated(){
+ return getInput().stream().anyMatch(Hop::hasFederatedOutput);
+ }
+
public ArrayList<Hop> getParent() {
return _parent;
}
diff --git a/src/main/java/org/apache/sysds/hops/ReorgOp.java b/src/main/java/org/apache/sysds/hops/ReorgOp.java
index 89aeb03..d2dcebf 100644
--- a/src/main/java/org/apache/sysds/hops/ReorgOp.java
+++ b/src/main/java/org/apache/sysds/hops/ReorgOp.java
@@ -61,8 +61,7 @@ public class ReorgOp extends MultiThreadedHop
_op = o;
getInput().add(0, inp);
inp.getParent().add(this);
- updateETFed();
-
+
//compute unknown dims and nnz
refreshSizeInformation();
}
@@ -78,8 +77,6 @@ public class ReorgOp extends MultiThreadedHop
in.getParent().add(this);
}
- updateETFed();
-
//compute unknown dims and nnz
refreshSizeInformation();
}
diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java b/src/main/java/org/apache/sysds/hops/TernaryOp.java
index f254d0d..7d9fca6 100644
--- a/src/main/java/org/apache/sysds/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java
@@ -80,7 +80,6 @@ public class TernaryOp extends MultiThreadedHop
getInput().add(0, inp1);
getInput().add(1, inp2);
getInput().add(2, inp3);
- updateETFed();
inp1.getParent().add(this);
inp2.getParent().add(this);
inp3.getParent().add(this);
@@ -98,7 +97,6 @@ public class TernaryOp extends MultiThreadedHop
getInput().add(3, inp4);
getInput().add(4, inp5);
getInput().add(5, inp6);
- updateETFed();
inp1.getParent().add(this);
inp2.getParent().add(this);
inp3.getParent().add(this);
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 467d476..2e3edb0 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -138,6 +138,9 @@ public class ProgramRewriter
_dagRuleSet.add( new RewriteAlgebraicSimplificationDynamic() ); //dependencies: cse
_dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse
}
+ if ( OptimizerUtils.FEDERATED_COMPILATION ) {
+ _dagRuleSet.add( new RewriteFederatedExecution() );
+ }
}
// cleanup after all rewrites applied
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 71a7240..269050c 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -139,7 +139,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//recursively process children
for( int i=0; i<hop.getInput().size(); i++)
{
- Hop hi = hop.getInput().get(i);
+ Hop hi = hop.getInput(i);
//process childs recursively first (to allow roll-up)
if( descendFirst )
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
new file mode 100644
index 0000000..29cda4a
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.rewrite;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.parser.DataExpression;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
+import org.apache.sysds.runtime.io.IOUtilFunctions;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.privacy.DMLPrivacyException;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator;
+import org.apache.sysds.utils.JSONHelper;
+import org.apache.wink.json4j.JSONObject;
+
+import javax.net.ssl.SSLException;
+import java.io.BufferedReader;
+import java.io.InputStreamReader;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.UnknownHostException;
+import java.util.ArrayList;
+import java.util.concurrent.Future;
+
+public class RewriteFederatedExecution extends HopRewriteRule {
+ @Override
+ public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
+ if ( roots == null )
+ return null;
+ for ( Hop root : roots )
+ visitHop(root);
+ return roots;
+ }
+
+ @Override
+ public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
+ if( root == null )
+ return null;
+ visitHop(root);
+ return root;
+ }
+
+ private void visitHop(Hop hop){
+ if (hop.isVisited())
+ return;
+
+ // Depth first to get to the input
+ for ( Hop input : hop.getInput() )
+ visitHop(input);
+
+ privacyBasedHopDecisionWithFedCall(hop);
+ hop.setVisited();
+ }
+
+ private static void privacyBasedHopDecision(Hop hop){
+ PrivacyPropagator.hopPropagation(hop);
+ PrivacyConstraint privacyConstraint = hop.getPrivacy();
+ if ( privacyConstraint != null && privacyConstraint.hasConstraints() )
+ hop.setFederatedOutput(FEDInstruction.FederatedOutput.FOUT);
+ else if ( hop.someInputFederated() )
+ hop.setFederatedOutput(FEDInstruction.FederatedOutput.LOUT);
+ }
+
+ /**
+ * Get privacy constraints of DataOps from federated worker,
+ * propagate privacy constraints from input to current hop,
+ * and set federated output flag.
+ * @param hop current hop
+ */
+ private static void privacyBasedHopDecisionWithFedCall(Hop hop){
+ loadFederatedPrivacyConstraints(hop);
+ privacyBasedHopDecision(hop);
+ }
+
+ /**
+ * Get privacy constraints from federated workers for DataOps.
+ * @hop hop for which privacy constraints are loaded
+ */
+ private static void loadFederatedPrivacyConstraints(Hop hop){
+ if ( isFederatedDataOp(hop) && hop.getPrivacy() == null){
+ try {
+ PrivacyConstraint privConstraint = unwrapPrivConstraint(sendPrivConstraintRequest(hop));
+ hop.setPrivacy(privConstraint);
+ }
+ catch(Exception e) {
+ throw new DMLException(e.getMessage());
+ }
+ }
+ }
+
+ private static Future<FederatedResponse> sendPrivConstraintRequest(Hop hop)
+ throws UnknownHostException, SSLException
+ {
+ String address = ((LiteralOp) hop.getInput(0).getInput(0)).getStringValue();
+ String[] parsedAddress = InitFEDInstruction.parseURL(address);
+ String host = parsedAddress[0];
+ int port = Integer.parseInt(parsedAddress[1]);
+ PrivacyConstraintRetriever retriever = new PrivacyConstraintRetriever(parsedAddress[2]);
+ FederatedRequest privacyRetrieval =
+ new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, retriever);
+ InetSocketAddress inetAddress = new InetSocketAddress(InetAddress.getByName(host), port);
+ return FederatedData.executeFederatedOperation(inetAddress, privacyRetrieval);
+ }
+
+ private static PrivacyConstraint unwrapPrivConstraint(Future<FederatedResponse> privConstraintFuture)
+ throws Exception
+ {
+ FederatedResponse privConstraintResponse = privConstraintFuture.get();
+ return (PrivacyConstraint) privConstraintResponse.getData()[0];
+ }
+
+ private static boolean isFederatedDataOp(Hop hop){
+ return hop instanceof DataOp && ((DataOp) hop).isFederatedData();
+ }
+
+ /**
+ * FederatedUDF for retrieving privacy constraint of data stored in file name.
+ */
+ public static class PrivacyConstraintRetriever extends FederatedUDF {
+ private static final long serialVersionUID = 3551741240135587183L;
+ private final String filename;
+
+ public PrivacyConstraintRetriever(String filename){
+ super(new long[]{});
+ this.filename = filename;
+ }
+
+ /**
+ * Reads metadata JSON object, parses privacy constraint and returns the constraint in FederatedResponse.
+ * @param ec execution context
+ * @param data one or many data objects
+ * @return FederatedResponse with privacy constraint object
+ */
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ PrivacyConstraint privacyConstraint;
+ FileSystem fs = null;
+ try {
+ String mtdname = DataExpression.getMTDFileName(filename);
+ Path path = new Path(mtdname);
+ fs = IOUtilFunctions.getFileSystem(mtdname);
+ try(BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(path)))) {
+ JSONObject metadataObject = JSONHelper.parse(br);
+ privacyConstraint = PrivacyPropagator.parseAndReturnPrivacyConstraint(metadataObject);
+ }
+ }
+ catch (DMLPrivacyException | FederatedWorkerHandlerException ex){
+ throw ex;
+ }
+ catch (Exception ex) {
+ String msg = "Exception in reading metadata of: " + filename;
+ throw new DMLRuntimeException(msg);
+ }
+ finally {
+ IOUtilFunctions.closeSilently(fs);
+ }
+ return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, privacyConstraint);
+ }
+
+ @Override
+ public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+ return null;
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 34db155..bae38a2 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -19,9 +19,11 @@
package org.apache.sysds.runtime.instructions;
+import org.apache.sysds.lops.Append;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.fed.AggregateBinaryFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.AggregateUnaryFEDInstruction;
+import org.apache.sysds.runtime.instructions.fed.AppendFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.BinaryFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FEDType;
@@ -66,6 +68,8 @@ public class FEDInstructionParser extends InstructionParser
// Ternary Instruction Opcodes
String2FEDInstructionType.put( "+*" , FEDType.Ternary);
String2FEDInstructionType.put( "-*" , FEDType.Ternary);
+
+ String2FEDInstructionType.put(Append.OPCODE, FEDType.Append);
}
public static FEDInstruction parseSingleInstruction (String str ) {
@@ -98,6 +102,8 @@ public class FEDInstructionParser extends InstructionParser
return TernaryFEDInstruction.parseInstruction(str);
case Reorg:
return ReorgFEDInstruction.parseInstruction(str);
+ case Append:
+ return AppendFEDInstruction.parseInstruction(str);
default:
throw new DMLRuntimeException("Invalid FEDERATED Instruction Type: " + fedtype );
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/SqlCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/SqlCPInstruction.java
index add9bb7..e4453bc 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/SqlCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/SqlCPInstruction.java
@@ -136,4 +136,13 @@ public class SqlCPInstruction extends CPInstruction {
public CPOperand getOutput(){
return _output;
}
+
+ /**
+ * Returns the inputs of the instruction.
+ * Inputs are conn, user, pass, and query.
+ * @return inputs of the instruction
+ */
+ public CPOperand[] getInputs(){
+ return new CPOperand[]{_conn, _user, _pass, _query};
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 10dd7c6..535e12d 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -24,6 +24,7 @@ import java.util.concurrent.Future;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
@@ -145,24 +146,43 @@ public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
}
//#2 vector - federated matrix multiplication
else if (mo2.isFederated(FType.ROW)) {// VM + MM
- //construct commands: broadcast rhs, fed mv, retrieve results
- FederatedRequest[] fr1 = mo2.getFedMapping().broadcastSliced(mo1, true);
- FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
- new CPOperand[]{input1, input2},
- new long[]{fr1[0].getID(), mo2.getFedMapping().getID()}, true);
- if ( _fedOut.isForcedFederated() ){
- // Partial aggregates (set fedmapping to the partial aggs)
- FederatedRequest fr3 = mo2.getFedMapping().cleanup(getTID(), fr1[0].getID());
- mo2.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
- setPartialOutput(mo2.getFedMapping(), mo1, mo2, fr2.getID(), ec);
+ if ( mo1.isFederated(FType.COL) && isAggBinaryFedAligned(mo1,mo2) ){
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2},
+ new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()}, true);
+ if ( _fedOut.isForcedFederated() ){
+ // Partial aggregates (set fedmapping to the partial aggs)
+ mo2.getFedMapping().execute(getTID(), true, fr2);
+ setPartialOutput(mo2.getFedMapping(), mo1, mo2, fr2.getID(), ec);
+ }
+ else {
+ FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo2.getFedMapping().execute(getTID(), fr2, fr3);
+ MatrixBlock ret = FederationUtils.aggAdd(tmp);
+ ec.setMatrixOutput(output.getName(), ret);
+ }
}
else {
- FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
- FederatedRequest fr4 = mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
- //execute federated operations and aggregate
- Future<FederatedResponse>[] tmp = mo2.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
- MatrixBlock ret = FederationUtils.aggAdd(tmp);
- ec.setMatrixOutput(output.getName(), ret);
+ //construct commands: broadcast rhs, fed mv, retrieve results
+ FederatedRequest[] fr1 = mo2.getFedMapping().broadcastSliced(mo1, true);
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2},
+ new long[]{fr1[0].getID(), mo2.getFedMapping().getID()}, true);
+ if ( _fedOut.isForcedFederated() ){
+ // Partial aggregates (set fedmapping to the partial aggs)
+ FederatedRequest fr3 = mo2.getFedMapping().cleanup(getTID(), fr1[0].getID());
+ mo2.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+ setPartialOutput(mo2.getFedMapping(), mo1, mo2, fr2.getID(), ec);
+ }
+ else {
+ FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 = mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo2.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
+ MatrixBlock ret = FederationUtils.aggAdd(tmp);
+ ec.setMatrixOutput(output.getName(), ret);
+ }
}
}
//#3 col-federated matrix vector multiplication
@@ -195,6 +215,28 @@ public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
}
/**
+ * Checks alignment of dimensions for the federated aggregate binary processing without broadcast.
+ * If the begin and end ranges of mo1 has cols equal to the rows of the begin and end ranges of mo2,
+ * the two inputs are aligned for the processing of the federated aggregate binary instruction without broadcasting.
+ * @param mo1 input matrix object 1
+ * @param mo2 input matrix object 2
+ * @return true if the two inputs are aligned for aggregate binary processing without broadcasting
+ */
+ private static boolean isAggBinaryFedAligned(MatrixObject mo1, MatrixObject mo2){
+ FederatedRange[] mo1FederatedRanges = mo1.getFedMapping().getFederatedRanges();
+ FederatedRange[] mo2FederatedRanges = mo2.getFedMapping().getFederatedRanges();
+ for ( int i = 0; i < mo1FederatedRanges.length; i++ ){
+ FederatedRange mo1FedRange = mo1FederatedRanges[i];
+ FederatedRange mo2FedRange = mo2FederatedRanges[i];
+
+ if ( mo1FedRange.getBeginDims()[1] != mo2FedRange.getBeginDims()[0]
+ || mo1FedRange.getEndDims()[1] != mo2FedRange.getEndDims()[0])
+ return false;
+ }
+ return true;
+ }
+
+ /**
* Sets the output with a federated mapping of overlapping partial aggregates.
* @param federationMap federated map from which the federated metadata is retrieved
* @param mo1 matrix object with number of rows used to set the number of rows of the output
diff --git a/src/main/java/org/apache/sysds/runtime/privacy/propagation/OperatorType.java b/src/main/java/org/apache/sysds/runtime/privacy/propagation/OperatorType.java
index 18a94b1..ebd9adf 100644
--- a/src/main/java/org/apache/sysds/runtime/privacy/propagation/OperatorType.java
+++ b/src/main/java/org/apache/sysds/runtime/privacy/propagation/OperatorType.java
@@ -19,7 +19,54 @@
package org.apache.sysds.runtime.privacy.propagation;
+import org.apache.sysds.lops.MMTSJ;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+
public enum OperatorType {
Aggregate,
NonAggregate;
+
+ /**
+ * Returns the operator type of MMChainCPInstruction based on the input data characteristics.
+ * @param inst MMChainCPInstruction for which operator type is returned
+ * @param ec execution context
+ * @return operator type of instruction
+ */
+ public static OperatorType getAggregationType(MMChainCPInstruction inst, ExecutionContext ec){
+ DataCharacteristics inputDataCharacteristics = ec.getDataCharacteristics(inst.getInputs()[0].getName());
+ if ( inputDataCharacteristics.getRows() == 1 && inputDataCharacteristics.getCols() == 1)
+ return NonAggregate;
+ else return Aggregate;
+ }
+
+ /**
+ * Returns the operator type of MMTSJCPInstruction based on the input data characteristics and the MMTSJType.
+ * @param inst MMTSJCPInstruction for which operator type is returned
+ * @param ec execution context
+ * @return operator type of instruction
+ */
+ public static OperatorType getAggregationType(MMTSJCPInstruction inst, ExecutionContext ec){
+ DataCharacteristics inputDataCharacteristics = ec.getDataCharacteristics(inst.getInputs()[0].getName());
+ if ( (inputDataCharacteristics.getRows() == 1 && inst.getMMTSJType() == MMTSJ.MMTSJType.LEFT)
+ || (inputDataCharacteristics.getCols() == 1 && inst.getMMTSJType() != MMTSJ.MMTSJType.LEFT) )
+ return OperatorType.NonAggregate;
+ else return OperatorType.Aggregate;
+ }
+
+ /**
+ * Returns the operator type of AggregateBinaryCPInstruction based on the input data characteristics and the transpose.
+ * @param inst AggregateBinaryCPInstruction for which operator type is returned
+ * @param ec execution context
+ * @return operator type of instruction
+ */
+ public static OperatorType getAggregationType(AggregateBinaryCPInstruction inst, ExecutionContext ec){
+ DataCharacteristics inputDC = ec.getDataCharacteristics(inst.input1.getName());
+ if ((inputDC.getCols() == 1 && !inst.transposeLeft) || (inputDC.getRows() == 1 && inst.transposeLeft) )
+ return OperatorType.NonAggregate;
+ else return OperatorType.Aggregate;
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java b/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
index 71e1d46..945df73 100644
--- a/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
+++ b/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
@@ -19,8 +19,18 @@
package org.apache.sysds.runtime.privacy.propagation;
-import java.util.*;
-
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.AggUnaryOp;
+import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.TernaryOp;
+import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
@@ -39,15 +49,39 @@ import org.apache.wink.json4j.JSONObject;
*/
public class PrivacyPropagator
{
+ /**
+ * Parses the privacy constraint of the given metadata object
+ * and sets the field of the given Data if the privacy constraint is not null.
+ * @param cd data for which privacy constraint is set
+ * @param mtd metadata object
+ * @return data object with privacy constraint set
+ * @throws JSONException during parsing of metadata
+ */
public static Data parseAndSetPrivacyConstraint(Data cd, JSONObject mtd)
throws JSONException
{
+ PrivacyConstraint mtdPrivConstraint = parseAndReturnPrivacyConstraint(mtd);
+ if ( mtdPrivConstraint != null )
+ cd.setPrivacyConstraints(mtdPrivConstraint);
+ return cd;
+ }
+
+ /**
+ * Parses the privacy constraint of the given metadata object
+ * or returns null if no privacy constraint is set in the metadata.
+ * @param mtd metadata
+ * @return privacy constraint parsed from metadata object
+ * @throws JSONException during parsing of metadata
+ */
+ public static PrivacyConstraint parseAndReturnPrivacyConstraint(JSONObject mtd)
+ throws JSONException
+ {
if ( mtd.containsKey(DataExpression.PRIVACY) ) {
String privacyLevel = mtd.getString(DataExpression.PRIVACY);
if ( privacyLevel != null )
- cd.setPrivacyConstraints(new PrivacyConstraint(PrivacyLevel.valueOf(privacyLevel)));
+ return new PrivacyConstraint(PrivacyLevel.valueOf(privacyLevel));
}
- return cd;
+ return null;
}
private static boolean anyInputHasLevel(PrivacyLevel[] inputLevels, PrivacyLevel targetLevel){
@@ -92,7 +126,13 @@ public class PrivacyPropagator
return PrivacyLevel.None;
}
- public static PrivacyConstraint mergeNary(PrivacyConstraint[] privacyConstraints, OperatorType operatorType){
+ /**
+ * Merges the given privacy constraints with the core propagation using the given operator type.
+ * @param privacyConstraints array of privacy constraints to merge
+ * @param operatorType type of operation to use when merging with the core propagation
+ * @return merged privacy constraint
+ */
+ private static PrivacyConstraint mergeNary(PrivacyConstraint[] privacyConstraints, OperatorType operatorType){
PrivacyLevel[] privacyLevels = Arrays.stream(privacyConstraints)
.map(constraint -> {
if (constraint != null)
@@ -104,21 +144,17 @@ public class PrivacyPropagator
return new PrivacyConstraint(outputPrivacyLevel);
}
+ /**
+ * Merges the input privacy constraints using the core propagation with NonAggregate operator type.
+ * @param privacyConstraint1 first privacy constraint
+ * @param privacyConstraint2 second privacy constraint
+ * @return merged privacy constraint
+ */
public static PrivacyConstraint mergeBinary(PrivacyConstraint privacyConstraint1, PrivacyConstraint privacyConstraint2) {
if (privacyConstraint1 != null && privacyConstraint2 != null){
- PrivacyLevel privacyLevel1 = privacyConstraint1.getPrivacyLevel();
- PrivacyLevel privacyLevel2 = privacyConstraint2.getPrivacyLevel();
-
- // One of the inputs are private, hence the output must be private.
- if (privacyLevel1 == PrivacyLevel.Private || privacyLevel2 == PrivacyLevel.Private)
- return new PrivacyConstraint(PrivacyLevel.Private);
- // One of the inputs are private with aggregation allowed, but none of the inputs are completely private,
- // hence the output must be private with aggregation.
- else if (privacyLevel1 == PrivacyLevel.PrivateAggregation || privacyLevel2 == PrivacyLevel.PrivateAggregation)
- return new PrivacyConstraint(PrivacyLevel.PrivateAggregation);
- // Both inputs have privacy level "None", hence the privacy constraint can be removed.
- else
- return null;
+ PrivacyLevel[] privacyLevels = new PrivacyLevel[]{
+ privacyConstraint1.getPrivacyLevel(),privacyConstraint2.getPrivacyLevel()};
+ return new PrivacyConstraint(corePropagation(privacyLevels, OperatorType.NonAggregate));
}
else if (privacyConstraint1 != null)
return privacyConstraint1;
@@ -127,12 +163,36 @@ public class PrivacyPropagator
return null;
}
- public static PrivacyConstraint mergeNary(PrivacyConstraint[] privacyConstraints){
- PrivacyConstraint mergedPrivacyConstraint = privacyConstraints[0];
- for ( int i = 1; i < privacyConstraints.length; i++ ){
- mergedPrivacyConstraint = mergeBinary(mergedPrivacyConstraint, privacyConstraints[i]);
+ /**
+ * Propagate privacy constraints from input hops to given hop.
+ * @param hop which the privacy constraints are propagated to
+ */
+ public static void hopPropagation(Hop hop){
+ PrivacyConstraint[] inputConstraints = hop.getInput().stream()
+ .map(Hop::getPrivacy).toArray(PrivacyConstraint[]::new);
+ if ( hop instanceof TernaryOp || hop instanceof BinaryOp || hop instanceof ReorgOp )
+ hop.setPrivacy(mergeNary(inputConstraints, OperatorType.NonAggregate));
+ else if ( hop instanceof AggBinaryOp || hop instanceof AggUnaryOp || hop instanceof UnaryOp )
+ hop.setPrivacy(mergeNary(inputConstraints, OperatorType.Aggregate));
+ }
+
+ /**
+ * Propagate privacy constraints to output variables
+ * based on privacy constraint of CPOperand output in instruction
+ * which has been set during privacy propagation preprocessing.
+ * @param inst instruction for which privacy constraints are propagated
+ * @param ec execution context
+ */
+ public static void postProcessInstruction(Instruction inst, ExecutionContext ec){
+ // if inst has output
+ List<CPOperand> instOutputs = getOutputOperands(inst);
+ if (!instOutputs.isEmpty()){
+ for ( CPOperand output : instOutputs ){
+ PrivacyConstraint outputPrivacyConstraint = output.getPrivacyConstraint();
+ if ( PrivacyUtils.someConstraintSetUnary(outputPrivacyConstraint) )
+ setOutputPrivacyConstraint(ec, outputPrivacyConstraint, output.getName());
+ }
}
- return mergedPrivacyConstraint;
}
/**
@@ -145,127 +205,88 @@ public class PrivacyPropagator
public static Instruction preprocessInstruction(Instruction inst, ExecutionContext ec){
switch ( inst.getType() ){
case CONTROL_PROGRAM:
- return preprocessCPInstructionFineGrained( (CPInstruction) inst, ec );
+ return preprocessCPInstruction( (CPInstruction) inst, ec );
case BREAKPOINT:
case SPARK:
case GPU:
case FEDERATED:
return inst;
default:
- throwExceptionIfPrivacyActivated(inst);
- return inst;
+ return throwExceptionIfInputOrInstPrivacy(inst, ec);
}
}
- public static Instruction preprocessCPInstructionFineGrained(CPInstruction inst, ExecutionContext ec){
- switch ( inst.getCPInstructionType() ){
- case AggregateBinary:
- if ( inst instanceof AggregateBinaryCPInstruction ){
- // This can only be a matrix multiplication and it does not count as an aggregation in terms of privacy.
- return preprocessAggregateBinaryCPInstruction((AggregateBinaryCPInstruction)inst, ec);
- } else if ( inst instanceof CovarianceCPInstruction ){
- return preprocessCovarianceCPInstruction((CovarianceCPInstruction)inst, ec);
- } else preprocessInstructionSimple(inst, ec);
- case AggregateTernary:
- //TODO: Support propagation of fine-grained privacy constraints
- return preprocessTernaryCPInstruction((ComputationCPInstruction) inst, ec);
- case AggregateUnary:
- // Assumption: aggregates in one or several dimensions, number of dimensions may change, only certain slices of the data may be aggregated upon, elements do not change position
- return preprocessAggregateUnaryCPInstruction((AggregateUnaryCPInstruction)inst, ec);
- case Append:
- return preprocessAppendCPInstruction((AppendCPInstruction) inst, ec);
+ private static Instruction preprocessCPInstruction(CPInstruction inst, ExecutionContext ec){
+ switch(inst.getCPInstructionType()){
case Binary:
- // TODO: Support propagation of fine-grained privacy constraints
- return preprocessBinaryCPInstruction((BinaryCPInstruction) inst, ec);
case Builtin:
case BuiltinNary:
- //TODO: Support propagation of fine-grained privacy constraints
- return preprocessBuiltinNary((BuiltinNaryCPInstruction) inst, ec);
- /*case CentralMoment:
- break;
- case Compression:
- break;
- case Covariance:
- break;
- case Ctable:
- break;
- case Dnn:
- break;
- */
case FCall:
- //TODO: Support propagation of fine-grained privacy constraints
- return preprocessExternal((FunctionCallCPInstruction) inst, ec);
- /*
- case MMChain:
- break;
- case MMTSJ:
- break;
- case MatrixIndexing:
- break;*/
- case MultiReturnBuiltin:
- case MultiReturnParameterizedBuiltin:
- // TODO: Support propagation of fine-grained privacy constraints
- return preprocessMultiReturn((ComputationCPInstruction)inst, ec);
- /*case PMMJ:
- break;*/
case ParameterizedBuiltin:
- // TODO: Support propagation of fine-grained privacy constraints
- return preprocessParameterizedBuiltin((ParameterizedBuiltinCPInstruction) inst, ec);
- /*case Partition:
- break;
- case QPick:
- break;
- case QSort:
- break;*/
case Quaternary:
- // TODO: Support propagation of fine-grained privacy constraints
- return preprocessQuaternary((QuaternaryCPInstruction) inst, ec);
- /*case Rand:
- break;*/
case Reorg:
- // TODO: Support propagation of fine-grained privacy constraints
- return preprocessUnaryCPInstruction((UnaryCPInstruction) inst, ec);
- /*case Reshape:
- break;
- case SpoofFused:
- break;
- case Sql:
- break;
- case StringInit:
- break;*/
case Ternary:
- // TODO: Support propagation of fine-grained privacy constraints
- return preprocessTernaryCPInstruction((ComputationCPInstruction) inst, ec);
- /*case UaggOuterChain:
- break;*/
case Unary:
- // Assumption: No aggregation, elements do not change position, no change of dimensions
- return preprocessUnaryCPInstruction((UnaryCPInstruction) inst, ec);
+ case MultiReturnBuiltin:
+ case MultiReturnParameterizedBuiltin:
+ case MatrixIndexing:
+ return mergePrivacyConstraintsFromInput( inst, ec, OperatorType.NonAggregate );
+ case AggregateTernary:
+ case AggregateUnary:
+ return mergePrivacyConstraintsFromInput(inst, ec, OperatorType.Aggregate);
+ case Append:
+ return preprocessAppendCPInstruction((AppendCPInstruction) inst, ec);
+ case AggregateBinary:
+ if ( inst instanceof AggregateBinaryCPInstruction )
+ return preprocessAggregateBinaryCPInstruction((AggregateBinaryCPInstruction)inst, ec);
+ else return throwExceptionIfInputOrInstPrivacy(inst, ec);
+ case MMTSJ:
+ OperatorType mmtsjOpType = OperatorType.getAggregationType((MMTSJCPInstruction) inst, ec);
+ return mergePrivacyConstraintsFromInput(inst, ec, mmtsjOpType);
+ case MMChain:
+ OperatorType mmChainOpType = OperatorType.getAggregationType((MMChainCPInstruction) inst, ec);
+ return mergePrivacyConstraintsFromInput(inst, ec, mmChainOpType);
case Variable:
return preprocessVariableCPInstruction((VariableCPInstruction) inst, ec);
default:
- return preprocessInstructionSimple(inst, ec);
-
+ return throwExceptionIfInputOrInstPrivacy(inst, ec);
}
}
- /**
- * Throw exception if privacy constraint activated for instruction or for input to instruction.
- * @param inst covariance instruction
- * @param ec execution context
- * @return input instruction if privacy constraints are not activated
- */
- private static Instruction preprocessCovarianceCPInstruction(CovarianceCPInstruction inst, ExecutionContext ec){
- throwExceptionIfPrivacyActivated(inst);
- for ( CPOperand input : inst.getInputs() ){
- PrivacyConstraint privacyConstraint = getInputPrivacyConstraint(ec, input);
- if ( privacyConstraint != null){
- throw new DMLPrivacyException("Input of instruction " + inst + " has privacy constraints activated, but the constraints are not propagated during preprocessing of instruction.");
- }
+ private static Instruction preprocessVariableCPInstruction(VariableCPInstruction inst, ExecutionContext ec){
+ switch ( inst.getVariableOpcode() ) {
+ case CopyVariable:
+ case MoveVariable:
+ case RemoveVariableAndFile:
+ case CastAsMatrixVariable:
+ case CastAsFrameVariable:
+ case Write:
+ case SetFileName:
+ case CastAsScalarVariable:
+ case CastAsDoubleVariable:
+ case CastAsIntegerVariable:
+ case CastAsBooleanVariable:
+ return propagateFirstInputPrivacy(inst, ec);
+ case CreateVariable:
+ return propagateSecondInputPrivacy(inst, ec);
+ case AssignVariable:
+ case RemoveVariable:
+ return mergePrivacyConstraintsFromInput( inst, ec, OperatorType.NonAggregate );
+ case Read:
+ // Adds scalar object to variable map, hence input (data type and filename) privacy should not be propagated
+ return inst;
+ default:
+ return throwExceptionIfInputOrInstPrivacy(inst, ec);
}
- return inst;
}
+ /**
+ * Propagates fine-grained constraints if input has fine-grained constraints,
+ * otherwise it propagates general constraints.
+ * @param inst aggregate binary instruction for which constraints are propagated
+ * @param ec execution context
+ * @return instruction with merged privacy constraints propagated to it and output CPOperand
+ */
private static Instruction preprocessAggregateBinaryCPInstruction(AggregateBinaryCPInstruction inst, ExecutionContext ec){
PrivacyConstraint[] privacyConstraints = getInputPrivacyConstraints(ec, inst.getInputs());
if ( PrivacyUtils.someConstraintSetBinary(privacyConstraints) ){
@@ -279,7 +300,7 @@ public class PrivacyPropagator
ec.releaseMatrixInput(inst.input1.getName(), inst.input2.getName());
}
else {
- mergedPrivacyConstraint = mergeNary(privacyConstraints, OperatorType.Aggregate);
+ mergedPrivacyConstraint = mergeNary(privacyConstraints, OperatorType.getAggregationType(inst, ec));
inst.setPrivacyConstraint(mergedPrivacyConstraint);
}
inst.output.setPrivacyConstraint(mergedPrivacyConstraint);
@@ -287,7 +308,13 @@ public class PrivacyPropagator
return inst;
}
- public static Instruction preprocessAppendCPInstruction(AppendCPInstruction inst, ExecutionContext ec){
+ /**
+ * Propagates input privacy constraints using general and fine-grained constraints depending on the AppendType.
+ * @param inst append instruction for which constraints are propagated
+ * @param ec execution context
+ * @return instruction with merged privacy constraints propagated to it and output CPOperand
+ */
+ private static Instruction preprocessAppendCPInstruction(AppendCPInstruction inst, ExecutionContext ec){
PrivacyConstraint[] privacyConstraints = getInputPrivacyConstraints(ec, inst.getInputs());
if ( PrivacyUtils.someConstraintSetBinary(privacyConstraints) ){
if ( inst.getAppendType() == AppendCPInstruction.AppendType.STRING ){
@@ -327,92 +354,35 @@ public class PrivacyPropagator
return inst;
}
- public static Instruction preprocessBinaryCPInstruction(BinaryCPInstruction inst, ExecutionContext ec){
- PrivacyConstraint privacyConstraint1 = getInputPrivacyConstraint(ec, inst.input1);
- PrivacyConstraint privacyConstraint2 = getInputPrivacyConstraint(ec, inst.input2);
- if ( privacyConstraint1 != null || privacyConstraint2 != null) {
- PrivacyConstraint mergedPrivacyConstraint = mergeBinary(privacyConstraint1, privacyConstraint2);
- inst.setPrivacyConstraint(mergedPrivacyConstraint);
- inst.output.setPrivacyConstraint(mergedPrivacyConstraint);
- }
- return inst;
- }
-
/**
- * Propagate privacy constraint to output if any of the elements are private.
- * Privacy constraint is always propagated to instruction.
- * @param inst aggregate instruction
+ * Propagates privacy constraints from input to instruction and output CPOperand based on given operator type.
+ * The propagation is done through the core propagation.
+ * @param inst instruction for which privacy is propagated
* @param ec execution context
- * @return updated instruction with propagated privacy constraints
+ * @param operatorType defining whether the instruction is aggregating the input
+ * @return instruction with the merged privacy constraint propagated to it and output CPOperand
*/
- private static Instruction preprocessAggregateUnaryCPInstruction(AggregateUnaryCPInstruction inst, ExecutionContext ec){
- PrivacyConstraint privacyConstraint = getInputPrivacyConstraint(ec, inst.input1);
- if ( privacyConstraint != null ) {
- inst.setPrivacyConstraint(privacyConstraint);
- if ( inst.output != null){
- //Only propagate to output if any of the elements are private.
- //It is an aggregation, hence the constraint can be removed in case of any other privacy level.
- if(privacyConstraint.hasPrivateElements())
- inst.output.setPrivacyConstraint(new PrivacyConstraint(PrivacyLevel.Private));
- }
- }
- return inst;
+ private static Instruction mergePrivacyConstraintsFromInput(Instruction inst, ExecutionContext ec,
+ OperatorType operatorType){
+ return mergePrivacyConstraintsFromInput(inst, ec, getInputOperands(inst), getOutputOperands(inst), operatorType);
}
/**
- * Throw exception if privacy constraints are activated or return instruction if privacy is not activated
- * @param inst instruction
+ * Propagates privacy constraints from input to instruction and output CPOperand based on given operator type.
+ * The propagation is done through the core propagation.
+ * @param inst instruction for which privacy is propagated
* @param ec execution context
- * @return instruction
+ * @param inputs to instruction
+ * @param outputs of instruction
+ * @param operatorType defining whether the instruction is aggregating the input
+ * @return instruction with the merged privacy constraint propagated to it and output CPOperand
*/
- public static Instruction preprocessInstructionSimple(Instruction inst, ExecutionContext ec){
- throwExceptionIfPrivacyActivated(inst);
- return inst;
- }
-
-
- public static Instruction preprocessExternal(FunctionCallCPInstruction inst, ExecutionContext ec){
- return mergePrivacyConstraintsFromInput(
- inst,
- ec,
- inst.getInputs(),
- inst.getBoundOutputParamNames().toArray(new String[0])
- );
- }
-
- public static Instruction preprocessMultiReturn(ComputationCPInstruction inst, ExecutionContext ec){
- List<CPOperand> outputs = getOutputOperands(inst);
- return mergePrivacyConstraintsFromInput(inst, ec, inst.getInputs(), outputs);
- }
-
- public static Instruction preprocessParameterizedBuiltin(ParameterizedBuiltinCPInstruction inst, ExecutionContext ec){
- return mergePrivacyConstraintsFromInput(inst, ec, inst.getInputs(), inst.getOutput() );
- }
-
- private static Instruction mergePrivacyConstraintsFromInput(Instruction inst, ExecutionContext ec, CPOperand[] inputs, String[] outputNames){
+ private static Instruction mergePrivacyConstraintsFromInput(Instruction inst, ExecutionContext ec,
+ CPOperand[] inputs, List<CPOperand> outputs, OperatorType operatorType){
if ( inputs != null && inputs.length > 0 ){
PrivacyConstraint[] privacyConstraints = getInputPrivacyConstraints(ec, inputs);
if ( privacyConstraints != null ){
- PrivacyConstraint mergedPrivacyConstraint = mergeNary(privacyConstraints);
- inst.setPrivacyConstraint(mergedPrivacyConstraint);
- if ( outputNames != null ){
- for (String outputName : outputNames)
- setOutputPrivacyConstraint(ec, mergedPrivacyConstraint, outputName);
- }
- }
- }
- return inst;
- }
-
- private static Instruction mergePrivacyConstraintsFromInput(Instruction inst, ExecutionContext ec, CPOperand[] inputs, CPOperand output){
- return mergePrivacyConstraintsFromInput(inst, ec, inputs, getSingletonList(output));
- }
-
- private static Instruction mergePrivacyConstraintsFromInput(Instruction inst, ExecutionContext ec, CPOperand[] inputs, List<CPOperand> outputs){
- if ( inputs != null && inputs.length > 0 ){
- PrivacyConstraint[] privacyConstraints = getInputPrivacyConstraints(ec, inputs);
- if ( privacyConstraints != null ){
- PrivacyConstraint mergedPrivacyConstraint = mergeNary(privacyConstraints);
+ PrivacyConstraint mergedPrivacyConstraint = mergeNary(privacyConstraints, operatorType);
inst.setPrivacyConstraint(mergedPrivacyConstraint);
for ( CPOperand output : outputs ){
if ( output != null ) {
@@ -424,54 +394,24 @@ public class PrivacyPropagator
return inst;
}
- public static Instruction preprocessBuiltinNary(BuiltinNaryCPInstruction inst, ExecutionContext ec){
- return mergePrivacyConstraintsFromInput(inst, ec, inst.getInputs(), inst.getOutput() );
- }
-
- public static Instruction preprocessQuaternary(QuaternaryCPInstruction inst, ExecutionContext ec){
- return mergePrivacyConstraintsFromInput(
- inst,
- ec,
- new CPOperand[] {inst.input1,inst.input2,inst.input3,inst.getInput4()},
- inst.output
- );
- }
-
- public static Instruction preprocessTernaryCPInstruction(ComputationCPInstruction inst, ExecutionContext ec){
- return mergePrivacyConstraintsFromInput(inst, ec, inst.getInputs(), inst.output);
- }
-
- public static Instruction preprocessUnaryCPInstruction(UnaryCPInstruction inst, ExecutionContext ec){
- return propagateInputPrivacy(inst, ec, inst.input1, inst.output);
- }
-
- public static Instruction preprocessVariableCPInstruction(VariableCPInstruction inst, ExecutionContext ec){
- switch ( inst.getVariableOpcode() ) {
- case CreateVariable:
- return propagateSecondInputPrivacy(inst, ec);
- case AssignVariable:
- return propagateInputPrivacy(inst, ec, inst.getInput1(), inst.getInput2());
- case CopyVariable:
- case MoveVariable:
- case RemoveVariableAndFile:
- case CastAsMatrixVariable:
- case CastAsFrameVariable:
- case Write:
- case SetFileName:
- return propagateFirstInputPrivacy(inst, ec);
- case RemoveVariable:
- return propagateAllInputPrivacy(inst, ec);
- case CastAsScalarVariable:
- case CastAsDoubleVariable:
- case CastAsIntegerVariable:
- case CastAsBooleanVariable:
- return propagateCastAsScalarVariablePrivacy(inst, ec);
- case Read:
- return inst;
- default:
- throwExceptionIfPrivacyActivated(inst);
- return inst;
+ /**
+ * Throw exception if privacy constraint activated for instruction or for input to instruction.
+ * @param inst covariance instruction
+ * @param ec execution context
+ * @return input instruction if privacy constraints are not activated
+ */
+ private static Instruction throwExceptionIfInputOrInstPrivacy(Instruction inst, ExecutionContext ec){
+ throwExceptionIfPrivacyActivated(inst);
+ CPOperand[] inputOperands = getInputOperands(inst);
+ if (inputOperands != null){
+ for ( CPOperand input : inputOperands ){
+ PrivacyConstraint privacyConstraint = getInputPrivacyConstraint(ec, input);
+ if ( privacyConstraint != null){
+ throw new DMLPrivacyException("Input of instruction " + inst + " has privacy constraints activated, but the constraints are not propagated during preprocessing of instruction.");
+ }
+ }
}
+ return inst;
}
private static void throwExceptionIfPrivacyActivated(Instruction inst){
@@ -481,28 +421,6 @@ public class PrivacyPropagator
}
/**
- * Propagate privacy from first input.
- * @param inst Instruction
- * @param ec execution context
- * @return instruction with or without privacy constraints
- */
- private static Instruction propagateCastAsScalarVariablePrivacy(VariableCPInstruction inst, ExecutionContext ec){
- inst = (VariableCPInstruction) propagateFirstInputPrivacy(inst, ec);
- return inst;
- }
-
- /**
- * Propagate privacy constraints from all inputs if privacy constraints are set.
- * @param inst instruction
- * @param ec execution context
- * @return instruction with or without privacy constraints
- */
- private static Instruction propagateAllInputPrivacy(VariableCPInstruction inst, ExecutionContext ec){
- return mergePrivacyConstraintsFromInput(
- inst, ec, inst.getInputs().toArray(new CPOperand[0]), inst.getOutput());
- }
-
- /**
* Propagate privacy constraint to instruction and output of instruction
* if data of first input is CacheableData and
* privacy constraint is activated.
@@ -561,7 +479,12 @@ public class PrivacyPropagator
return null;
}
-
+ /**
+ * Returns input privacy constraints as array or returns null if no privacy constraints are found in the inputs.
+ * @param ec execution context
+ * @param inputs from which privacy constraints are retrieved
+ * @return array of privacy constraints from inputs
+ */
private static PrivacyConstraint[] getInputPrivacyConstraints(ExecutionContext ec, CPOperand[] inputs){
if ( inputs != null && inputs.length > 0){
boolean privacyFound = false;
@@ -595,41 +518,29 @@ public class PrivacyPropagator
}
/**
- * Propagate privacy constraints to output variables
- * based on privacy constraint of CPOperand output in instruction
- * which has been set during privacy propagation preprocessing.
- * @param inst instruction for which privacy constraints are propagated
- * @param ec execution context
+ * Returns input CPOperands of instruction or returns null if instruction type is not supported by this method.
+ * @param inst instruction from which the inputs are retrieved
+ * @return array of input CPOperands or null
*/
- public static void postProcessInstruction(Instruction inst, ExecutionContext ec){
- // if inst has output
- List<CPOperand> instOutputs = getOutputOperands(inst);
- if (!instOutputs.isEmpty()){
- for ( CPOperand output : instOutputs ){
- PrivacyConstraint outputPrivacyConstraint = output.getPrivacyConstraint();
- if ( PrivacyUtils.someConstraintSetUnary(outputPrivacyConstraint) )
- setOutputPrivacyConstraint(ec, outputPrivacyConstraint, output.getName());
- }
- }
- }
-
- @SuppressWarnings("unused")
- private static String[] getOutputVariableName(Instruction inst){
- String[] instructionOutputNames = null;
- // The order of the following statements is important
- if ( inst instanceof MultiReturnParameterizedBuiltinCPInstruction )
- instructionOutputNames = ((MultiReturnParameterizedBuiltinCPInstruction) inst).getOutputNames();
- else if ( inst instanceof MultiReturnBuiltinCPInstruction )
- instructionOutputNames = ((MultiReturnBuiltinCPInstruction) inst).getOutputNames();
- else if ( inst instanceof ComputationCPInstruction )
- instructionOutputNames = new String[]{((ComputationCPInstruction) inst).getOutputVariableName()};
- else if ( inst instanceof VariableCPInstruction )
- instructionOutputNames = new String[]{((VariableCPInstruction) inst).getOutputVariableName()};
- else if ( inst instanceof SqlCPInstruction )
- instructionOutputNames = new String[]{((SqlCPInstruction) inst).getOutputVariableName()};
- return instructionOutputNames;
+ private static CPOperand[] getInputOperands(Instruction inst){
+ if ( inst instanceof ComputationCPInstruction )
+ return ((ComputationCPInstruction)inst).getInputs();
+ if ( inst instanceof BuiltinNaryCPInstruction )
+ return ((BuiltinNaryCPInstruction)inst).getInputs();
+ if ( inst instanceof FunctionCallCPInstruction )
+ return ((FunctionCallCPInstruction)inst).getInputs();
+ if ( inst instanceof SqlCPInstruction )
+ return ((SqlCPInstruction)inst).getInputs();
+ else return null;
}
+ /**
+ * Returns a list of output CPOperands of instruction or an empty list if the instruction has no outputs.
+ * Note that this method needs to be extended as new instruction types are added, otherwise it will
+ * return an empty list for instructions that may have outputs.
+ * @param inst instruction from which the outputs are retrieved
+ * @return list of outputs
+ */
private static List<CPOperand> getOutputOperands(Instruction inst){
// The order of the following statements is important
if ( inst instanceof MultiReturnParameterizedBuiltinCPInstruction )
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index 79fe54f..342a26b 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -22,7 +22,6 @@ package org.apache.sysds.test.functions.privacy.fedplanning;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.privacy.PrivacyConstraint;
import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
-import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -81,7 +80,6 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
}
@Test
- @Ignore
public void federatedRowSum(){
federatedTwoMatricesSingleNodeTest(TEST_NAME_2);
}
@@ -98,14 +96,12 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
}
@Test
- @Ignore
public void federatedAggregateBinaryColFedSequence(){
cols = rows;
federatedTwoMatricesSingleNodeTest(TEST_NAME_5);
}
@Test
- @Ignore
public void federatedAggregateBinarySequence2(){
federatedTwoMatricesSingleNodeTest(TEST_NAME_6);
}
@@ -147,8 +143,8 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
if ( testName.equals(TEST_NAME_5) ){
writeColStandardMatrix("X1", 42);
writeColStandardMatrix("X2", 1340);
- writeColStandardMatrix("Y1", 44);
- writeColStandardMatrix("Y2", 21);
+ writeColStandardMatrix("Y1", 44, null);
+ writeColStandardMatrix("Y2", 21, null);
}
else if ( testName.equals(TEST_NAME_6) ){
writeColStandardMatrix("X1", 42);