You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by se...@apache.org on 2022/04/20 12:29:18 UTC

[systemds] branch main updated: [SYSTEMDS-3018] Federated Planner Extended 3

This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 46a30eaef2 [SYSTEMDS-3018] Federated Planner Extended 3
46a30eaef2 is described below

commit 46a30eaef2fb9e25f41c1b46405e60228783b230
Author: sebwrede <sw...@know-center.at>
AuthorDate: Tue Apr 19 17:34:28 2022 +0200

    [SYSTEMDS-3018] Federated Planner Extended 3
    
    This commit adds DataOps to allowsFederated and getFederatedOut methods to ensure that transient reads and writes are allowed to be FOUT.
    It also changes tests to load configuration files and remove OptimizerUtils calls.
    
    Closes #1586.
---
 .../sysds/hops/fedplanner/AFederatedPlanner.java   |  7 +++++++
 .../hops/fedplanner/FederatedPlannerCostbased.java | 14 ++++++-------
 .../fedplanning/FederatedL2SVMPlanningTest.java    |  9 ++-------
 .../fedplanning/FederatedMultiplyPlanningTest.java | 23 ++++++++++++++++------
 4 files changed, 32 insertions(+), 21 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java b/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
index b5adb09780..3403cc4bbe 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
@@ -78,6 +78,10 @@ public abstract class AFederatedPlanner {
 		else if ( HopRewriteUtils.isReorg(hop, ReOrgOp.TRANS) ){
 			return ft[0] == FType.COL || ft[0] == FType.ROW;
 		}
+		else if (HopRewriteUtils.isData(hop, Types.OpOpData.FEDERATED)
+			|| HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTWRITE)
+			|| HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTREAD))
+			return true;
 		else if(ft.length==1 && ft[0] != null) {
 			return HopRewriteUtils.isReorg(hop, ReOrgOp.TRANS)
 				|| HopRewriteUtils.isAggUnaryOp(hop, AggOp.SUM, AggOp.MIN, AggOp.MAX);
@@ -135,6 +139,9 @@ public abstract class AFederatedPlanner {
 		}
 		else if ( HopRewriteUtils.isData(hop, Types.OpOpData.FEDERATED) )
 			return deriveFType((DataOp)hop);
+		else if ( HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTWRITE)
+			|| HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTREAD) )
+			return ft[0];
 		return null;
 	}
 	
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
index a4c0bb8760..ee39e468bd 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
@@ -327,14 +327,12 @@ public class FederatedPlannerCostbased extends AFederatedPlanner {
 		}
 		if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.TRANSIENTWRITE) )
 			transientWrites.put(currentHop.getName(), currentHop);
-		else {
-			if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.FEDERATED) )
-				hopRels.add(new HopRel(currentHop, FederatedOutput.FOUT, deriveFType((DataOp)currentHop), hopRelMemo, inputHops));
-			else
-				hopRels.addAll(generateHopRels(currentHop, inputHops));
-			if ( isLOUTSupported(currentHop) )
-				hopRels.add(new HopRel(currentHop, FederatedOutput.LOUT, hopRelMemo, inputHops));
-		}
+		if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.FEDERATED) )
+			hopRels.add(new HopRel(currentHop, FederatedOutput.FOUT, deriveFType((DataOp)currentHop), hopRelMemo, inputHops));
+		else
+			hopRels.addAll(generateHopRels(currentHop, inputHops));
+		if ( isLOUTSupported(currentHop) )
+			hopRels.add(new HopRel(currentHop, FederatedOutput.LOUT, hopRelMemo, inputHops));
 		return hopRels;
 	}
 
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
index 3b0ab91f49..e9ab6b6ad0 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
@@ -23,7 +23,6 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
-import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint;
 import org.apache.sysds.test.AutomatedTestBase;
@@ -74,7 +73,8 @@ public class FederatedL2SVMPlanningTest extends AutomatedTestBase {
 	public void runL2SVMCostBasedTest(){
 		//String[] expectedHeavyHitters = new String[]{ "fed_fedinit", "fed_ba+*", "fed_tak+*", "fed_+*",
 		//	"fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
-		String[] expectedHeavyHitters = new String[]{ "fed_fedinit"};
+		String[] expectedHeavyHitters = new String[]{ "fed_fedinit", "fed_ba+*", "fed_tak+*", "fed_+*",
+			"fed_max", "fed_1-*", "fed_>"};
 		setTestConf("SystemDS-config-cost-based.xml");
 		loadAndRunTest(expectedHeavyHitters);
 	}
@@ -126,8 +126,6 @@ public class FederatedL2SVMPlanningTest extends AutomatedTestBase {
 		Thread t1 = null, t2 = null;
 
 		try {
-			OptimizerUtils.FEDERATED_COMPILATION = true;
-
 			getAndLoadTestConfiguration(TEST_NAME);
 			String HOME = SCRIPT_DIR + TEST_DIR;
 
@@ -145,8 +143,6 @@ public class FederatedL2SVMPlanningTest extends AutomatedTestBase {
 				"Y=" + input("Y"), "r=" + rows, "c=" + cols, "Z=" + output("Z")};
 			runTest(true, false, null, -1);
 
-			OptimizerUtils.FEDERATED_COMPILATION = false;
-
 			// Run reference dml script with normal matrix
 			fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
 			programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"),
@@ -160,7 +156,6 @@ public class FederatedL2SVMPlanningTest extends AutomatedTestBase {
 					+ Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
 		}
 		finally {
-			OptimizerUtils.FEDERATED_COMPILATION = false;
 			TestUtils.shutdownThreads(t1, t2);
 			rtplatform = platformOld;
 			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index 56a7dae1f6..6ec10232dd 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -19,7 +19,8 @@
 
 package org.apache.sysds.test.functions.privacy.fedplanning;
 
-import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
 import org.junit.Ignore;
@@ -33,6 +34,7 @@ import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 
+import java.io.File;
 import java.util.Arrays;
 import java.util.Collection;
 
@@ -41,6 +43,8 @@ import static org.junit.Assert.fail;
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
 public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
+	private static final Log LOG = LogFactory.getLog(FederatedMultiplyPlanningTest.class.getName());
+
 	private final static String TEST_DIR = "functions/privacy/fedplanning/";
 	private final static String TEST_NAME = "FederatedMultiplyPlanningTest";
 	private final static String TEST_NAME_2 = "FederatedMultiplyPlanningTest2";
@@ -52,6 +56,7 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
 	private final static String TEST_NAME_8 = "FederatedMultiplyPlanningTest8";
 	private final static String TEST_NAME_9 = "FederatedMultiplyPlanningTest9";
 	private final static String TEST_CLASS_DIR = TEST_DIR + FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
+	private final static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, "SystemDS-config-cost-based.xml");
 
 	private final static int blocksize = 1024;
 	@Parameterized.Parameter()
@@ -223,8 +228,6 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
 		Thread t1 = null, t2 = null;
 
 		try{
-			OptimizerUtils.FEDERATED_COMPILATION = true;
-
 			getAndLoadTestConfiguration(testName);
 			String HOME = SCRIPT_DIR + TEST_DIR;
 
@@ -244,8 +247,6 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
 			rewriteRealProgramArgs(testName, port1, port2);
 			runTest(true, false, null, -1);
 
-			OptimizerUtils.FEDERATED_COMPILATION = false;
-
 			// Run reference dml script with normal matrix
 			fullDMLScriptName = HOME + testName + "Reference.dml";
 			programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
@@ -259,7 +260,6 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
 				fail("The following expected heavy hitters are missing: "
 					+ Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
 		} finally {
-			OptimizerUtils.FEDERATED_COMPILATION = false;
 			TestUtils.shutdownThreads(t1, t2);
 			rtplatform = platformOld;
 			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
@@ -289,5 +289,16 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
 				"Y2=" + input("Y2"), "W1=" + input("W1"), "W2=" + input("W2"), "Z=" + expected("Z")};
 		}
 	}
+
+	/**
+	 * Override default configuration with custom test configuration to ensure
+	 * scratch space and local temporary directory locations are also updated.
+	 */
+	@Override
+	protected File getConfigTemplateFile() {
+		// Instrumentation in this test's output log to show custom configuration file used for template.
+		LOG.info("This test case overrides default configuration with " + TEST_CONF_FILE.getPath());
+		return TEST_CONF_FILE;
+	}
 }