You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by se...@apache.org on 2022/07/05 10:20:50 UTC
[systemds] branch main updated: [SYSTEMDS-3018] Federated Coordinator Privacy Constraint Retrieval
This is an automated email from the ASF dual-hosted git repository.
sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 7b36729701 [SYSTEMDS-3018] Federated Coordinator Privacy Constraint Retrieval
7b36729701 is described below
commit 7b367297012ea3bb9639d47783baeda99f2f7057
Author: sebwrede <sw...@know-center.at>
AuthorDate: Tue Jun 28 14:14:22 2022 +0200
[SYSTEMDS-3018] Federated Coordinator Privacy Constraint Retrieval
This commit will:
- Include all privacy constraints in remote retrieval
- Add privacy constraint propagation to all compiled federated planners
- Add PrivacyConstraintLoader which handles loading of privacy constraints from federated workers and propagation of the constraints at the coordinator
- Add privacy constraint to Explain output
- Add FederatedPlannerUtil class
- Edit hop propagation to throw exception when hop type is unknown and hop has privacy constraint on input
Closes #1651.
---
.../hops/fedplanner/FederatedPlannerCostbased.java | 38 ++-
.../hops/fedplanner/FederatedPlannerUtils.java | 67 +++++
.../hops/fedplanner/PrivacyConstraintLoader.java | 281 +++++++++++++++++++++
.../hops/ipa/IPAPassRewriteFederatedPlan.java | 17 +-
.../apache/sysds/hops/rewrite/ProgramRewriter.java | 7 -
.../hops/rewrite/RewriteFederatedExecution.java | 187 --------------
.../sysds/runtime/privacy/PrivacyConstraint.java | 7 +-
.../privacy/propagation/PrivacyPropagator.java | 44 +++-
src/main/java/org/apache/sysds/utils/Explain.java | 3 +
.../fedplanning/FederatedMultiplyPlanningTest.java | 8 +
.../FederatedMultiplyPlanningTest11.dml | 34 +++
.../FederatedMultiplyPlanningTest11Reference.dml | 32 +++
12 files changed, 497 insertions(+), 228 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
index 1f9abb4c18..3c33d783ab 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
@@ -170,7 +170,7 @@ public class FederatedPlannerCostbased extends AFederatedPlanner {
selectFederatedExecutionPlan(sbHop, paramMap);
if(sbHop instanceof FunctionOp) {
String funcName = ((FunctionOp) sbHop).getFunctionName();
- Map<String, Hop> funcParamMap = getParamMap((FunctionOp) sbHop);
+ Map<String, Hop> funcParamMap = FederatedPlannerUtils.getParamMap((FunctionOp) sbHop);
if ( paramMap != null && funcParamMap != null)
funcParamMap.putAll(paramMap);
paramMap = funcParamMap;
@@ -182,22 +182,6 @@ public class FederatedPlannerCostbased extends AFederatedPlanner {
return new ArrayList<>(Collections.singletonList(sb));
}
- /**
- * Return parameter map containing the mapping from parameter name to input hop
- * for all parameters of the function hop.
- * @param funcOp hop for which the mapping of parameter names to input hops are made
- * @return parameter map or empty map if function has no parameters
- */
- private Map<String,Hop> getParamMap(FunctionOp funcOp){
- String[] inputNames = funcOp.getInputVariableNames();
- Map<String,Hop> paramMap = new HashMap<>();
- if ( inputNames != null ){
- for ( int i = 0; i < funcOp.getInput().size(); i++ )
- paramMap.put(inputNames[i],funcOp.getInput(i));
- }
- return paramMap;
- }
-
/**
* Set final fedouts of all hops starting from terminal hops.
*/
@@ -327,13 +311,21 @@ public class FederatedPlannerCostbased extends AFederatedPlanner {
ArrayList<HopRel> hopRels = getFedPlans(currentHop, paramMap);
// Put NONE HopRel into memo table if no FOUT or LOUT HopRels were added
if(hopRels.isEmpty())
- hopRels.add(getNONEHopRel(currentHop));
+ hopRels.add(getNONEHopRel(currentHop, paramMap));
addTrace(hopRels);
hopRelMemo.put(currentHop, hopRels);
}
- private HopRel getNONEHopRel(Hop currentHop){
- HopRel noneHopRel = new HopRel(currentHop, FederatedOutput.NONE, hopRelMemo);
+ private ArrayList<Hop> getHopInputs(Hop currentHop, Map<String, Hop> paramMap){
+ if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.TRANSIENTREAD) )
+ return FederatedPlannerUtils.getTransientInputs(currentHop, paramMap, transientWrites);
+ else
+ return currentHop.getInput();
+ }
+
+ private HopRel getNONEHopRel(Hop currentHop, Map<String, Hop> paramMap){
+ ArrayList<Hop> inputs = getHopInputs(currentHop, paramMap);
+ HopRel noneHopRel = new HopRel(currentHop, FederatedOutput.NONE, hopRelMemo, inputs);
FType[] inputFType = noneHopRel.getInputDependency().stream().map(HopRel::getFType).toArray(FType[]::new);
FType outputFType = getFederatedOut(currentHop, inputFType);
noneHopRel.setFType(outputFType);
@@ -348,9 +340,7 @@ public class FederatedPlannerCostbased extends AFederatedPlanner {
*/
private ArrayList<HopRel> getFedPlans(Hop currentHop, Map<String, Hop> paramMap){
ArrayList<HopRel> hopRels = new ArrayList<>();
- ArrayList<Hop> inputHops = currentHop.getInput();
- if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.TRANSIENTREAD) )
- inputHops = getTransientInputs(currentHop, paramMap);
+ ArrayList<Hop> inputHops = getHopInputs(currentHop, paramMap);
if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.TRANSIENTWRITE) )
transientWrites.put(currentHop.getName(), currentHop);
if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.FEDERATED) )
@@ -453,6 +443,8 @@ public class FederatedPlannerCostbased extends AFederatedPlanner {
private void debugLog(Hop currentHop){
if ( LOG.isDebugEnabled() ){
LOG.debug("Visiting HOP: " + currentHop + " Input size: " + currentHop.getInput().size());
+ if (currentHop.getPrivacy() != null)
+ LOG.debug(currentHop.getPrivacy());
int index = 0;
for ( Hop hop : currentHop.getInput()){
if ( hop == null )
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java
new file mode 100644
index 0000000000..45b711a41d
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.fedplanner;
+
+import org.apache.sysds.hops.FunctionOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.runtime.DMLRuntimeException;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+public class FederatedPlannerUtils {
+ /**
+ * Get transient inputs from either paramMap or transientWrites.
+ * Inputs from paramMap has higher priority than inputs from transientWrites.
+ * @param currentHop hop for which inputs are read from maps
+ * @param paramMap of local parameters
+ * @param transientWrites map of transient writes
+ * @return inputs of currentHop
+ */
+ public static ArrayList<Hop> getTransientInputs(Hop currentHop, Map<String, Hop> paramMap, Map<String,Hop> transientWrites){
+ Hop tWriteHop = null;
+ if ( paramMap != null)
+ tWriteHop = paramMap.get(currentHop.getName());
+ if ( tWriteHop == null )
+ tWriteHop = transientWrites.get(currentHop.getName());
+ if ( tWriteHop == null )
+ throw new DMLRuntimeException("Transient write not found for " + currentHop);
+ else
+ return new ArrayList<>(Collections.singletonList(tWriteHop));
+ }
+
+ /**
+ * Return parameter map containing the mapping from parameter name to input hop
+ * for all parameters of the function hop.
+ * @param funcOp hop for which the mapping of parameter names to input hops are made
+ * @return parameter map or empty map if function has no parameters
+ */
+ public static Map<String,Hop> getParamMap(FunctionOp funcOp){
+ String[] inputNames = funcOp.getInputVariableNames();
+ Map<String,Hop> paramMap = new HashMap<>();
+ if ( inputNames != null ){
+ for ( int i = 0; i < funcOp.getInput().size(); i++ )
+ paramMap.put(inputNames[i],funcOp.getInput(i));
+ }
+ return paramMap;
+ }
+}
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java b/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java
new file mode 100644
index 0000000000..82e4316988
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java
@@ -0,0 +1,281 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.fedplanner;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.FunctionOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.DataExpression;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
+import org.apache.sysds.runtime.io.IOUtilFunctions;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.privacy.DMLPrivacyException;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator;
+import org.apache.sysds.utils.JSONHelper;
+import org.apache.wink.json4j.JSONObject;
+
+import java.io.BufferedReader;
+import java.io.InputStreamReader;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.UnknownHostException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.Future;
+
+public class PrivacyConstraintLoader {
+
+ private final Map<Long, Hop> memo = new HashMap<>();
+ private final Map<String, Hop> transientWrites = new HashMap<>();
+
+ public void loadConstraints(DMLProgram prog){
+ rewriteStatementBlocks(prog, prog.getStatementBlocks(), null);
+ }
+
+ private void rewriteStatementBlocks(DMLProgram prog, List<StatementBlock> sbs, Map<String, Hop> paramMap) {
+ sbs.forEach(block -> rewriteStatementBlock(prog, block, paramMap));
+ }
+
+ private void rewriteStatementBlock(DMLProgram prog, StatementBlock block, Map<String, Hop> paramMap){
+ if(block instanceof WhileStatementBlock)
+ rewriteWhileStatementBlock(prog, (WhileStatementBlock) block, paramMap);
+ else if(block instanceof IfStatementBlock)
+ rewriteIfStatementBlock(prog, (IfStatementBlock) block, paramMap);
+ else if(block instanceof ForStatementBlock) {
+ // This also includes ParForStatementBlocks
+ rewriteForStatementBlock(prog, (ForStatementBlock) block, paramMap);
+ }
+ else if(block instanceof FunctionStatementBlock)
+ rewriteFunctionStatementBlock(prog, (FunctionStatementBlock) block, paramMap);
+ else {
+ // StatementBlock type (no subclass)
+ rewriteDefaultStatementBlock(prog, block, paramMap);
+ }
+ }
+
+ private void rewriteWhileStatementBlock(DMLProgram prog, WhileStatementBlock whileSB, Map<String, Hop> paramMap) {
+ Hop whilePredicateHop = whileSB.getPredicateHops();
+ loadPrivacyConstraint(whilePredicateHop, paramMap);
+ for(Statement stm : whileSB.getStatements()) {
+ WhileStatement whileStm = (WhileStatement) stm;
+ rewriteStatementBlocks(prog, whileStm.getBody(), paramMap);
+ }
+ }
+
+ private void rewriteIfStatementBlock(DMLProgram prog, IfStatementBlock ifSB, Map<String, Hop> paramMap) {
+ loadPrivacyConstraint(ifSB.getPredicateHops(), paramMap);
+ for(Statement statement : ifSB.getStatements()) {
+ IfStatement ifStatement = (IfStatement) statement;
+ rewriteStatementBlocks(prog, ifStatement.getIfBody(), paramMap);
+ rewriteStatementBlocks(prog, ifStatement.getElseBody(), paramMap);
+ }
+ }
+
+ private void rewriteForStatementBlock(DMLProgram prog, ForStatementBlock forSB, Map<String, Hop> paramMap) {
+ loadPrivacyConstraint(forSB.getFromHops(), paramMap);
+ loadPrivacyConstraint(forSB.getToHops(), paramMap);
+ loadPrivacyConstraint(forSB.getIncrementHops(), paramMap);
+ for(Statement statement : forSB.getStatements()) {
+ ForStatement forStatement = ((ForStatement) statement);
+ rewriteStatementBlocks(prog, forStatement.getBody(), paramMap);
+ }
+ }
+
+ private void rewriteFunctionStatementBlock(DMLProgram prog, FunctionStatementBlock funcSB, Map<String, Hop> paramMap) {
+ for(Statement statement : funcSB.getStatements()) {
+ FunctionStatement funcStm = (FunctionStatement) statement;
+ rewriteStatementBlocks(prog, funcStm.getBody(), paramMap);
+ }
+ }
+
+ private void rewriteDefaultStatementBlock(DMLProgram prog, StatementBlock sb, Map<String, Hop> paramMap) {
+ if(sb.hasHops()) {
+ for(Hop sbHop : sb.getHops()) {
+ loadPrivacyConstraint(sbHop, paramMap);
+ if(sbHop instanceof FunctionOp) {
+ String funcName = ((FunctionOp) sbHop).getFunctionName();
+ Map<String, Hop> funcParamMap = FederatedPlannerUtils.getParamMap((FunctionOp) sbHop);
+ if ( paramMap != null && funcParamMap != null)
+ funcParamMap.putAll(paramMap);
+ paramMap = funcParamMap;
+ FunctionStatementBlock sbFuncBlock = prog.getBuiltinFunctionDictionary().getFunction(funcName);
+ rewriteStatementBlock(prog, sbFuncBlock, paramMap);
+ }
+ }
+ }
+ }
+
+ private void loadPrivacyConstraint(Hop root, Map<String, Hop> paramMap){
+ if ( root != null && !memo.containsKey(root.getHopID()) ){
+ for ( Hop input : root.getInput() ){
+ loadPrivacyConstraint(input, paramMap);
+ }
+ propagatePrivConstraintsLocal(root, paramMap);
+ memo.put(root.getHopID(), root);
+ }
+ }
+
+ private void propagatePrivConstraintsLocal(Hop currentHop, Map<String, Hop> paramMap){
+ if ( currentHop.isFederatedDataOp() )
+ loadFederatedPrivacyConstraints(currentHop);
+ else if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.TRANSIENTWRITE) ){
+ currentHop.setPrivacy(currentHop.getInput(0).getPrivacy());
+ transientWrites.put(currentHop.getName(), currentHop);
+ }
+ else if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.TRANSIENTREAD) ){
+ currentHop.setPrivacy(FederatedPlannerUtils.getTransientInputs(currentHop, paramMap, transientWrites).get(0).getPrivacy());
+ } else {
+ PrivacyPropagator.hopPropagation(currentHop);
+ }
+ }
+
+ /**
+ * Get privacy constraints from federated workers for DataOps.
+ * @hop hop for which privacy constraints are loaded
+ */
+ private static void loadFederatedPrivacyConstraints(Hop hop){
+ try {
+ PrivacyConstraint.PrivacyLevel constraintLevel = hop.getInput(0).getInput().stream().parallel()
+ .map( in -> ((LiteralOp)in).getStringValue() )
+ .map(PrivacyConstraintLoader::sendPrivConstraintRequest)
+ .map(PrivacyConstraintLoader::unwrapPrivConstraint)
+ .map(constraint -> (constraint != null) ? constraint.getPrivacyLevel() : PrivacyConstraint.PrivacyLevel.None)
+ .reduce(PrivacyConstraint.PrivacyLevel.None, (out,in) -> {
+ if ( out == PrivacyConstraint.PrivacyLevel.Private || in == PrivacyConstraint.PrivacyLevel.Private )
+ return PrivacyConstraint.PrivacyLevel.Private;
+ else if ( out == PrivacyConstraint.PrivacyLevel.PrivateAggregation || in == PrivacyConstraint.PrivacyLevel.PrivateAggregation )
+ return PrivacyConstraint.PrivacyLevel.PrivateAggregation;
+ else
+ return out;
+ });
+ PrivacyConstraint fedDataPrivConstraint = (constraintLevel != PrivacyConstraint.PrivacyLevel.None) ?
+ new PrivacyConstraint(constraintLevel) : null;
+
+ hop.setPrivacy(fedDataPrivConstraint);
+ }
+ catch(Exception ex) {
+ throw new DMLException(ex);
+ }
+ }
+
+ private static Future<FederatedResponse> sendPrivConstraintRequest(String address)
+ {
+ try{
+ String[] parsedAddress = InitFEDInstruction.parseURL(address);
+ String host = parsedAddress[0];
+ int port = Integer.parseInt(parsedAddress[1]);
+ PrivacyConstraintRetriever retriever = new PrivacyConstraintRetriever(parsedAddress[2]);
+ FederatedRequest privacyRetrieval =
+ new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, retriever);
+ InetSocketAddress inetAddress = new InetSocketAddress(InetAddress.getByName(host), port);
+ return FederatedData.executeFederatedOperation(inetAddress, privacyRetrieval);
+ } catch(UnknownHostException ex){
+ throw new DMLException(ex);
+ }
+ }
+
+ private static PrivacyConstraint unwrapPrivConstraint(Future<FederatedResponse> privConstraintFuture)
+ {
+ try {
+ FederatedResponse privConstraintResponse = privConstraintFuture.get();
+ return (PrivacyConstraint) privConstraintResponse.getData()[0];
+ } catch(Exception ex){
+ throw new DMLException(ex);
+ }
+ }
+
+ /**
+ * FederatedUDF for retrieving privacy constraint of data stored in file name.
+ */
+ public static class PrivacyConstraintRetriever extends FederatedUDF {
+ private static final long serialVersionUID = 3551741240135587183L;
+ private final String filename;
+
+ public PrivacyConstraintRetriever(String filename){
+ super(new long[]{});
+ this.filename = filename;
+ }
+
+ /**
+ * Reads metadata JSON object, parses privacy constraint and returns the constraint in FederatedResponse.
+ * @param ec execution context
+ * @param data one or many data objects
+ * @return FederatedResponse with privacy constraint object
+ */
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ PrivacyConstraint privacyConstraint;
+ FileSystem fs = null;
+ try {
+ String mtdname = DataExpression.getMTDFileName(filename);
+ Path path = new Path(mtdname);
+ fs = IOUtilFunctions.getFileSystem(mtdname);
+ try(BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(path)))) {
+ JSONObject metadataObject = JSONHelper.parse(br);
+ privacyConstraint = PrivacyPropagator.parseAndReturnPrivacyConstraint(metadataObject);
+ }
+ }
+ catch (DMLPrivacyException | FederatedWorkerHandlerException ex){
+ throw ex;
+ }
+ catch (Exception ex) {
+ String msg = "Exception in reading metadata of: " + filename;
+ throw new DMLRuntimeException(msg);
+ }
+ finally {
+ IOUtilFunctions.closeSilently(fs);
+ }
+ return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, privacyConstraint);
+ }
+
+ @Override
+ public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+ return null;
+ }
+ }
+
+}
diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
index 6be3b9c8ec..e6c683eb38 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
@@ -23,6 +23,7 @@ import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.fedplanner.FTypes.FederatedPlanner;
+import org.apache.sysds.hops.fedplanner.PrivacyConstraintLoader;
import org.apache.sysds.parser.DMLProgram;
/**
@@ -58,16 +59,24 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
*/
@Override
public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
- // obtain planner instance according to config
String splanner = ConfigurationManager.getDMLConfig()
.getTextValue(DMLConfig.FEDERATED_PLANNER);
+ loadPrivacyConstraints(prog, splanner);
+ generatePlan(prog, fgraph, fcallSizes, splanner);
+ return false;
+ }
+
+ private void loadPrivacyConstraints(DMLProgram prog, String splanner){
+ if (FederatedPlanner.isCompiled(splanner))
+ new PrivacyConstraintLoader().loadConstraints(prog);
+ }
+
+ private void generatePlan(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes, String splanner){
FederatedPlanner planner = FederatedPlanner.isCompiled(splanner) ?
FederatedPlanner.valueOf(splanner.toUpperCase()) :
FederatedPlanner.COMPILE_COST_BASED;
-
+
// run planner rewrite with forced federated exec types
planner.getPlanner().rewriteProgram(prog, fgraph, fcallSizes);
-
- return false;
}
}
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index faec3504e9..db20ada280 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -27,10 +27,8 @@ import org.apache.log4j.Logger;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.CompilerConfig.ConfigType;
-import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
-import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
@@ -141,11 +139,6 @@ public class ProgramRewriter
_dagRuleSet.add( new RewriteAlgebraicSimplificationDynamic() ); //dependencies: cse
_dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse
}
- String planner = ConfigurationManager.getDMLConfig()
- .getTextValue(DMLConfig.FEDERATED_PLANNER);
- if ( OptimizerUtils.FEDERATED_COMPILATION || FTypes.FederatedPlanner.isCompiled(planner) ) {
- _dagRuleSet.add( new RewriteFederatedExecution() );
- }
}
// cleanup after all rewrites applied
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
deleted file mode 100644
index 822b4b5d95..0000000000
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
+++ /dev/null
@@ -1,187 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.hops.rewrite;
-
-import org.apache.commons.lang3.tuple.Pair;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.Path;
-import org.apache.log4j.Logger;
-import org.apache.sysds.api.DMLException;
-import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.LiteralOp;
-import org.apache.sysds.parser.DataExpression;
-import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException;
-import org.apache.sysds.runtime.instructions.cp.Data;
-import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
-import org.apache.sysds.runtime.io.IOUtilFunctions;
-import org.apache.sysds.runtime.lineage.LineageItem;
-import org.apache.sysds.runtime.privacy.DMLPrivacyException;
-import org.apache.sysds.runtime.privacy.PrivacyConstraint;
-import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator;
-import org.apache.sysds.utils.JSONHelper;
-import org.apache.wink.json4j.JSONObject;
-
-import javax.net.ssl.SSLException;
-import java.io.BufferedReader;
-import java.io.InputStreamReader;
-import java.net.InetAddress;
-import java.net.InetSocketAddress;
-import java.net.UnknownHostException;
-import java.util.ArrayList;
-import java.util.concurrent.Future;
-
-public class RewriteFederatedExecution extends HopRewriteRule {
- private static final Logger LOG = Logger.getLogger(RewriteFederatedExecution.class);
-
- @Override
- public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
- if ( roots != null )
- for ( Hop root : roots )
- rewriteHopDAG(root, state);
- return roots;
- }
-
- @Override
- public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
- if ( root != null )
- visitHop(root);
- return root;
- }
-
- private void visitHop(Hop hop){
- if (hop.isVisited())
- return;
-
- LOG.debug("RewriteFederatedExecution visitHop + " + hop);
-
- // Depth first to get to the input
- for ( Hop input : hop.getInput() )
- visitHop(input);
-
- privacyBasedHopDecisionWithFedCall(hop);
- hop.setVisited();
- }
-
- /**
- * Get privacy constraints of DataOps from federated worker,
- * propagate privacy constraints from input to current hop,
- * and set federated output flag.
- * @param hop current hop
- */
- private static void privacyBasedHopDecisionWithFedCall(Hop hop){
- loadFederatedPrivacyConstraints(hop);
- PrivacyPropagator.hopPropagation(hop);
- }
-
- /**
- * Get privacy constraints from federated workers for DataOps.
- * @hop hop for which privacy constraints are loaded
- */
- private static void loadFederatedPrivacyConstraints(Hop hop){
- if ( hop.isFederatedDataOp() && hop.getPrivacy() == null){
- try {
- LOG.debug("Load privacy constraints of " + hop);
- PrivacyConstraint privConstraint = unwrapPrivConstraint(sendPrivConstraintRequest(hop));
- LOG.debug("PrivacyConstraint retrieved: " + privConstraint);
- hop.setPrivacy(privConstraint);
- }
- catch(Exception e) {
- throw new DMLException(e);
- }
- }
- }
-
- private static Future<FederatedResponse> sendPrivConstraintRequest(Hop hop)
- throws UnknownHostException, SSLException
- {
- String address = ((LiteralOp) hop.getInput(0).getInput(0)).getStringValue();
- String[] parsedAddress = InitFEDInstruction.parseURL(address);
- String host = parsedAddress[0];
- int port = Integer.parseInt(parsedAddress[1]);
- PrivacyConstraintRetriever retriever = new PrivacyConstraintRetriever(parsedAddress[2]);
- FederatedRequest privacyRetrieval =
- new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, retriever);
- InetSocketAddress inetAddress = new InetSocketAddress(InetAddress.getByName(host), port);
- return FederatedData.executeFederatedOperation(inetAddress, privacyRetrieval);
- }
-
- private static PrivacyConstraint unwrapPrivConstraint(Future<FederatedResponse> privConstraintFuture)
- throws Exception
- {
- FederatedResponse privConstraintResponse = privConstraintFuture.get();
- return (PrivacyConstraint) privConstraintResponse.getData()[0];
- }
-
- /**
- * FederatedUDF for retrieving privacy constraint of data stored in file name.
- */
- public static class PrivacyConstraintRetriever extends FederatedUDF {
- private static final long serialVersionUID = 3551741240135587183L;
- private final String filename;
-
- public PrivacyConstraintRetriever(String filename){
- super(new long[]{});
- this.filename = filename;
- }
-
- /**
- * Reads metadata JSON object, parses privacy constraint and returns the constraint in FederatedResponse.
- * @param ec execution context
- * @param data one or many data objects
- * @return FederatedResponse with privacy constraint object
- */
- @Override
- public FederatedResponse execute(ExecutionContext ec, Data... data) {
- PrivacyConstraint privacyConstraint;
- FileSystem fs = null;
- try {
- String mtdname = DataExpression.getMTDFileName(filename);
- Path path = new Path(mtdname);
- fs = IOUtilFunctions.getFileSystem(mtdname);
- try(BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(path)))) {
- JSONObject metadataObject = JSONHelper.parse(br);
- privacyConstraint = PrivacyPropagator.parseAndReturnPrivacyConstraint(metadataObject);
- }
- }
- catch (DMLPrivacyException | FederatedWorkerHandlerException ex){
- throw ex;
- }
- catch (Exception ex) {
- String msg = "Exception in reading metadata of: " + filename;
- throw new DMLRuntimeException(msg);
- }
- finally {
- IOUtilFunctions.closeSilently(fs);
- }
- return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, privacyConstraint);
- }
-
- @Override
- public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
- return null;
- }
- }
-}
diff --git a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java
index 8ea061844a..fc9ba440c8 100644
--- a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java
+++ b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java
@@ -262,8 +262,11 @@ public class PrivacyConstraint implements Externalizable
@Override
public String toString(){
- return "General privacy level: " + privacyLevel + System.getProperty("line.separator")
- + "Fine-grained privacy level: " + fineGrainedPrivacy.toString();
+ String constraintString = "General privacy level: " + privacyLevel;
+ if ( fineGrainedPrivacy != null && fineGrainedPrivacy.hasConstraints() )
+ constraintString = constraintString + System.getProperty("line.separator")
+ + "Fine-grained privacy level: " + fineGrainedPrivacy.toString();
+ return constraintString;
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java b/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
index 7e6c0127e5..94834ebc6e 100644
--- a/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
+++ b/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
@@ -23,11 +23,18 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
+import org.apache.sysds.api.DMLException;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.DataGenOp;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.NaryOp;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
@@ -168,12 +175,39 @@ public class PrivacyPropagator
* @param hop which the privacy constraints are propagated to
*/
public static void hopPropagation(Hop hop){
- PrivacyConstraint[] inputConstraints = hop.getInput().stream()
+ hopPropagation(hop, hop.getInput());
+ }
+
+ /**
+ * Propagate privacy constraints from input hops to given hop.
+ * @param hop which the privacy constraints are propagated to
+ * @param inputHops inputs to given hop
+ */
+ public static void hopPropagation(Hop hop, ArrayList<Hop> inputHops){
+ PrivacyConstraint[] inputConstraints = inputHops.stream()
.map(Hop::getPrivacy).toArray(PrivacyConstraint[]::new);
- if ( hop instanceof TernaryOp || hop instanceof BinaryOp || hop instanceof ReorgOp )
- hop.setPrivacy(mergeNary(inputConstraints, OperatorType.NonAggregate));
+ OperatorType opType = getOpType(hop);
+ hop.setPrivacy(mergeNary(inputConstraints, opType));
+ if (opType == null && Arrays.stream(inputConstraints).anyMatch(Objects::nonNull))
+ throw new DMLException("Input has constraint but hop type not recognized by PrivacyPropagator. " +
+ "Hop is " + hop + " " + hop.getClass());
+ }
+
+ /**
+ * Get operator type of given hop.
+ * Returns null if hop type is not known.
+ * @param hop for which operator type is returned
+ * @return operator type of hop or null if hop type is unknown
+ */
+ private static OperatorType getOpType(Hop hop){
+ if ( hop instanceof TernaryOp || hop instanceof BinaryOp || hop instanceof ReorgOp
+ || hop instanceof DataOp || hop instanceof LiteralOp || hop instanceof NaryOp
+ || hop instanceof DataGenOp || hop instanceof FunctionOp )
+ return OperatorType.NonAggregate;
else if ( hop instanceof AggBinaryOp || hop instanceof AggUnaryOp || hop instanceof UnaryOp )
- hop.setPrivacy(mergeNary(inputConstraints, OperatorType.Aggregate));
+ return OperatorType.Aggregate;
+ else
+ return null;
}
/**
@@ -406,7 +440,7 @@ public class PrivacyPropagator
if (inputOperands != null){
for ( CPOperand input : inputOperands ){
PrivacyConstraint privacyConstraint = getInputPrivacyConstraint(ec, input);
- if ( privacyConstraint != null){
+ if ( privacyConstraint != null && privacyConstraint.hasConstraints()){
throw new DMLPrivacyException("Input of instruction " + inst + " has privacy constraints activated, but the constraints are not propagated during preprocessing of instruction.");
}
}
diff --git a/src/main/java/org/apache/sysds/utils/Explain.java b/src/main/java/org/apache/sysds/utils/Explain.java
index 589f23a845..ded46c039a 100644
--- a/src/main/java/org/apache/sysds/utils/Explain.java
+++ b/src/main/java/org/apache/sysds/utils/Explain.java
@@ -626,6 +626,9 @@ public class Explain
}
}
+ if ( hop.getPrivacy() != null )
+ sb.append(" ").append(hop.getPrivacy().getPrivacyLevel().name());
+
sb.append('\n');
hop.setVisited();
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index 14c093ebe8..2477bdef85 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -55,6 +55,7 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
private final static String TEST_NAME_8 = "FederatedMultiplyPlanningTest8";
private final static String TEST_NAME_9 = "FederatedMultiplyPlanningTest9";
private final static String TEST_NAME_10 = "FederatedMultiplyPlanningTest10";
+ private final static String TEST_NAME_11 = "FederatedMultiplyPlanningTest11";
private final static String TEST_CLASS_DIR = TEST_DIR + FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
private static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, "SystemDS-config-cost-based.xml");
@@ -77,6 +78,7 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
addTestConfiguration(TEST_NAME_8, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_8, new String[] {"Z.scalar"}));
addTestConfiguration(TEST_NAME_9, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_9, new String[] {"Z.scalar"}));
addTestConfiguration(TEST_NAME_10, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_10, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME_11, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_11, new String[] {"Z"}));
}
@Parameterized.Parameters
@@ -153,6 +155,12 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
federatedTwoMatricesSingleNodeTest(TEST_NAME_10, expectedHeavyHitters);
}
+ @Test
+ public void federatedMultiplyPlanningTest11(){
+ String[] expectedHeavyHitters = new String[]{"fed_fedinit"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_11, expectedHeavyHitters);
+ }
+
private void writeStandardMatrix(String matrixName, long seed){
writeStandardMatrix(matrixName, seed, new PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation));
}
diff --git a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11.dml
new file mode 100644
index 0000000000..147bf2cd13
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11.dml
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+ ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
+Y = federated(addresses=list($Y1, $Y2),
+ ranges=list(list(0, 0), list($r/2, $c), list($r / 2, 0), list($r, $c)))
+
+i = 0
+while(i < 10){
+ Z0 = X * Y
+ Z = t(Z0) %*% X
+ i=i+1
+}
+
+write(Z, $Z)
diff --git a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11Reference.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11Reference.dml
new file mode 100644
index 0000000000..187623bbfe
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11Reference.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2))
+Y = rbind(read($Y1), read($Y2))
+
+i = 0
+while(i < 10){
+ Z0 = X * Y
+ Z = t(Z0) %*% X
+ i=i+1
+}
+
+write(Z, $Z)