You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by yw...@apache.org on 2022/07/28 15:33:46 UTC

[systemds] branch main updated: [MINOR] FederatedLookupTable Eviction Fix

This is an automated email from the ASF dual-hosted git repository.

ywcb00 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new f45ae975b9 [MINOR] FederatedLookupTable Eviction Fix
f45ae975b9 is described below

commit f45ae975b904008ac4c5fd6ebbc587dbfa865b10
Author: ywcb00 <yw...@ywcb.org>
AuthorDate: Thu Jul 28 17:31:10 2022 +0200

    [MINOR] FederatedLookupTable Eviction Fix
    
    - Remove the coordinator-specific entry from the FederatedLookupTable
      when receiving a CLEAR request.
    - Fix scalar broadcasting of federated left indexing instruction.
    - Avoid ConcurrentModificationException by changing the ArrayList for
      the coordinator's traffic bytes in the federated statistics to a
      CopyOnWriteArrayList
    - Avoid race condition while obtaining the heavy hitters for statistics
    
    Closes #1663.
---
 .../federated/FederatedLookupTable.java            | 22 +++++++++++++++++-----
 .../federated/FederatedStatistics.java             |  3 ++-
 .../federated/FederatedWorkerHandler.java          |  7 +++++--
 .../instructions/fed/IndexingFEDInstruction.java   |  2 +-
 .../java/org/apache/sysds/utils/Statistics.java    | 12 ++++++------
 5 files changed, 31 insertions(+), 15 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
index afba8ac42a..188c57ed95 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
@@ -47,10 +47,6 @@ public class FederatedLookupTable {
 		_lookup_table = new ConcurrentHashMap<>();
 	}
 
-	public void clear() {
-		_lookup_table.clear();
-	}
-	
 	/**
 	 * Get the ExecutionContextMap corresponding to the given host and pid of the
 	 * requesting coordinator from the lookup table. Create a new
@@ -61,9 +57,9 @@ public class FederatedLookupTable {
 	 * @return ExecutionContextMap the ECM corresponding to the requesting coordinator
 	 */
 	public ExecutionContextMap getECM(String host, long pid) {
-		LOG.trace("Getting the ExecutionContextMap for coordinator " + pid + "@" + host);
 		long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
 		FedUniqueCoordID funCID = new FedUniqueCoordID(host, pid);
+		LOG.trace("Getting the ExecutionContextMap for coordinator " + funCID.toString());
 		ExecutionContextMap ecm = _lookup_table.computeIfAbsent(funCID,
 			k -> createNewECM());
 		if(ecm == null) {
@@ -79,6 +75,22 @@ public class FederatedLookupTable {
 		return ecm;
 	}
 
+	/**
+	 * Remove the ExecutionContextMap corresponding to the given host and pid of the
+	 * requesting coordinator from the lookup table. Do nothing if no entry
+	 * is associated to the host and pid.
+	 *
+	 * @param host the host string of the requesting coordinator (usually IP address)
+	 * @param pid the process id of the requesting coordinator
+	 */
+	public void removeECM(String host, long pid) {
+		FedUniqueCoordID funCID = new FedUniqueCoordID(host, pid);
+		LOG.trace("Removing the ExecutionContextMap of coordinator " + funCID.toString());
+		if(_lookup_table.remove(funCID) == null)
+			LOG.warn("Removing federated execution context map failed. "
+				+ "No valid resolution for " + funCID.toString() + " found.");
+	}
+
 	/**
 	 * Check if there is a mapped ExecutionContextMap for the coordinator
 	 * with the given host and pid.
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
index 17b4012fec..b53ef801a5 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
@@ -30,6 +30,7 @@ import java.time.format.DateTimeFormatter;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Comparator;
+import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -98,7 +99,7 @@ public class FederatedStatistics {
 	private static final LongAdder fedSerializationReuseBytes = new LongAdder();
 	// Traffic between federated worker and a coordinator site
 	// in the form of [{ datetime, coordinatorAddress, transferredBytes }, { ... }] }
-	private static List<Triple<LocalDateTime, String, Long>> coordinatorsTrafficBytes = new ArrayList<>();
+	private static CopyOnWriteArrayList<Triple<LocalDateTime, String, Long>> coordinatorsTrafficBytes = new CopyOnWriteArrayList<>();
 
 	public static void logServerTraffic(long read, long written) {
 		bytesReceived.add(read);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index bfeb19cc16..509e0998eb 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -188,6 +188,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 			
 		FederatedResponse response = null; // last response
 		boolean containsCLEAR = false;
+		long clearReqPid = -1;
 		for(int i = 0; i < requests.length; i++) {
 			final FederatedRequest request = requests[i];
 			final RequestType t = request.getType();
@@ -233,12 +234,14 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 				}
 			}
 
-			if(t == RequestType.CLEAR)
+			if(t == RequestType.CLEAR) {
 				containsCLEAR = true;
+				clearReqPid = request.getPID();
+			}
 		}
 
 		if(containsCLEAR) {
-			_flt.clear();
+			_flt.removeECM(remoteHost, clearReqPid);
 			printStatistics();
 		}
 
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
index 128fc1d4a6..6fc8c24a7e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
@@ -331,7 +331,7 @@ public final class IndexingFEDInstruction extends UnaryFEDInstruction {
 			FederatedRequest fr3 = fedMap.cleanup(getTID(), fr1.getID());
 
 			if(fr2.length == 1)
-				fedMap.execute(getTID(), true, fr2, fr1, fr3);
+				fedMap.execute(getTID(), true, fr1, fr2[0], fr3);
 			else
 				fedMap.execute(getTID(), true, ranges, fr2[cpVarInstIx], fr2[from], fr1, fr3);
 		}
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java
index aece9b655a..454ecac6e1 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -392,12 +392,11 @@ public class Statistics
 	 */
 	@SuppressWarnings("unchecked")
 	public static String getHeavyHitters(int num) {
-		int len = _instStats.size();
-		if (num <= 0 || len <= 0)
+		if (num <= 0 || _instStats.size() <= 0)
 			return "-";
 
 		// get top k via sort
-		Entry<String, InstStats>[] tmp = _instStats.entrySet().toArray(new Entry[len]);
+		Entry<String, InstStats>[] tmp = _instStats.entrySet().toArray(Entry[]::new);
 		Arrays.sort(tmp, new Comparator<Entry<String, InstStats>>() {
 			@Override
 			public int compare(Entry<String, InstStats> e1, Entry<String, InstStats> e2) {
@@ -410,6 +409,7 @@ public class Statistics
 		final String timeSCol = "Time(s)";
 		final String countCol = "Count";
 		StringBuilder sb = new StringBuilder();
+		int len = tmp.length;
 		int numHittersToDisplay = Math.min(num, len);
 		int maxNumLen = String.valueOf(numHittersToDisplay).length();
 		int maxInstLen = instCol.length();
@@ -466,11 +466,10 @@ public class Statistics
 
 	@SuppressWarnings("unchecked")
 	public static String getCPHeavyHittersMem(int num) {
-		int n = _cpMemObjs.size();
-		if ((n <= 0) || (num <= 0))
+		if ((_cpMemObjs.size() <= 0) || (num <= 0))
 			return "-";
 
-		Entry<String,Double>[] entries = _cpMemObjs.entrySet().toArray(new Entry[_cpMemObjs.size()]);
+		Entry<String,Double>[] entries = _cpMemObjs.entrySet().toArray(Entry[]::new);
 		Arrays.sort(entries, new Comparator<Entry<String, Double>>() {
 			@Override
 			public int compare(Entry<String, Double> a, Entry<String, Double> b) {
@@ -478,6 +477,7 @@ public class Statistics
 			}
 		});
 
+		int n = entries.length;
 		int numHittersToDisplay = Math.min(num, n);
 		int numPadLen = String.format("%d", numHittersToDisplay).length();
 		int maxNameLength = 0;