You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/07/04 17:55:22 UTC

[systemds] branch master updated: [SYSTEMDS-3039] Tracking and consolidation of federated statistics

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 582b9c3  [SYSTEMDS-3039] Tracking and consolidation of federated statistics
582b9c3 is described below

commit 582b9c3f622d87cd9d11b9dd01abcb0a6f179309
Author: ywcb00 <yw...@ywcb.org>
AuthorDate: Sun Jul 4 19:52:56 2021 +0200

    [SYSTEMDS-3039] Tracking and consolidation of federated statistics
    
    Closes #1321.
---
 src/main/java/org/apache/sysds/api/DMLOptions.java |  19 ++
 src/main/java/org/apache/sysds/api/DMLScript.java  |  72 ++---
 .../federated/FederatedStatistics.java             | 311 +++++++++++++++++++++
 .../instructions/fed/InitFEDInstruction.java       |   5 +
 .../java/org/apache/sysds/utils/Statistics.java    |  22 +-
 .../primitives/FederatedStatisticsTest.java        |   5 +-
 6 files changed, 395 insertions(+), 39 deletions(-)

diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java b/src/main/java/org/apache/sysds/api/DMLOptions.java
index 0949c06..fbdaa90 100644
--- a/src/main/java/org/apache/sysds/api/DMLOptions.java
+++ b/src/main/java/org/apache/sysds/api/DMLOptions.java
@@ -52,6 +52,8 @@ public class DMLOptions {
 	public boolean              clean         = false;            // Whether to clean up all SystemDS working directories (FS, DFS)
 	public boolean              stats         = false;            // Whether to record and print the statistics
 	public int                  statsCount    = 10;               // Default statistics count
+	public boolean              fedStats      = false;            // Whether to record and print the federated statistics
+	public int                  fedStatsCount = 10;               // Default federated statistics count
 	public boolean              memStats      = false;            // max memory statistics
 	public Explain.ExplainType  explainType   = Explain.ExplainType.NONE;  // Whether to print the "Explain" and if so, what type
 	public ExecMode             execMode      = OptimizerUtils.getDefaultExecutionMode();  // Execution mode standalone, MR, Spark or a hybrid
@@ -85,6 +87,8 @@ public class DMLOptions {
 			", clean=" + clean +
 			", stats=" + stats +
 			", statsCount=" + statsCount +
+			", fedStats=" + fedStats +
+			", fedStatsCount=" + fedStatsCount +
 			", memStats=" + memStats +
 			", explainType=" + explainType +
 			", execMode=" + execMode +
@@ -193,6 +197,17 @@ public class DMLOptions {
 				}
 			}
 		}
+		dmlOptions.fedStats = line.hasOption("fedStats");
+		if (dmlOptions.fedStats) {
+			String fedStatsCount = line.getOptionValue("fedStats");
+			if(fedStatsCount != null) {
+				try {
+					dmlOptions.fedStatsCount = Integer.parseInt(fedStatsCount);
+				} catch (NumberFormatException e) {
+					throw new org.apache.commons.cli.ParseException("Invalid argument specified for -fedStats option, must be a valid integer");
+				}
+			}
+		}
 		dmlOptions.memStats = line.hasOption("mem");
 
 		dmlOptions.clean = line.hasOption("clean");
@@ -265,6 +280,9 @@ public class DMLOptions {
 		Option statsOpt = OptionBuilder.withArgName("count")
 			.withDescription("monitors and reports summary execution statistics; heavy hitter <count> is 10 unless overridden; default off")
 			.hasOptionalArg().create("stats");
+		Option fedStatsOpt = OptionBuilder.withArgName("count")
+			.withDescription("monitors and reports summary execution statistics of federated workers; heavy hitter <count> is 10 unless overridden; default off")
+			.hasOptionalArg().create("fedStats");
 		Option memOpt = OptionBuilder.withDescription("monitors and reports max memory consumption in CP; default off")
 			.create("mem");
 		Option explainOpt = OptionBuilder.withArgName("level")
@@ -299,6 +317,7 @@ public class DMLOptions {
 		options.addOption(configOpt);
 		options.addOption(cleanOpt);
 		options.addOption(statsOpt);
+		options.addOption(fedStatsOpt);
 		options.addOption(memOpt);
 		options.addOption(explainOpt);
 		options.addOption(execOpt);
diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java
index 7d2bf16..e2e67a5 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -82,26 +82,28 @@ import org.apache.sysds.utils.Explain.ExplainType;
 
 public class DMLScript 
 {
-	private static ExecMode   EXEC_MODE          = DMLOptions.defaultOptions.execMode;     // the execution mode
-	public static boolean     STATISTICS          = DMLOptions.defaultOptions.stats;       // whether to print statistics
-	public static boolean     JMLC_MEM_STATISTICS = false;                                 // whether to gather memory use stats in JMLC
-	public static int         STATISTICS_COUNT    = DMLOptions.defaultOptions.statsCount;  // statistics maximum heavy hitter count
-	public static int         STATISTICS_MAX_WRAP_LEN = 30;                                // statistics maximum wrap length
-	public static ExplainType EXPLAIN             = DMLOptions.defaultOptions.explainType; // explain type
-	public static String      DML_FILE_PATH_ANTLR_PARSER = DMLOptions.defaultOptions.filePath; // filename of dml/pydml script
-	public static String      FLOATING_POINT_PRECISION = "double";                         // data type to use internally
-	public static boolean     PRINT_GPU_MEMORY_INFO = false;                               // whether to print GPU memory-related information
-	public static long        EVICTION_SHADOW_BUFFER_MAX_BYTES = 0;                        // maximum number of bytes to use for shadow buffer
-	public static long        EVICTION_SHADOW_BUFFER_CURR_BYTES = 0;                       // number of bytes to use for shadow buffer
-	public static double      GPU_MEMORY_UTILIZATION_FACTOR = 0.9;                         // fraction of available GPU memory to use
-	public static String      GPU_MEMORY_ALLOCATOR = "cuda";                               // GPU memory allocator to use
-	public static boolean     LINEAGE = DMLOptions.defaultOptions.lineage;                 // whether compute lineage trace
-	public static boolean     LINEAGE_DEDUP = DMLOptions.defaultOptions.lineage_dedup;     // whether deduplicate lineage items
-	public static ReuseCacheType LINEAGE_REUSE = DMLOptions.defaultOptions.linReuseType;   // whether lineage-based reuse
-	public static LineageCachePolicy LINEAGE_POLICY = DMLOptions.defaultOptions.linCachePolicy; // lineage cache eviction policy
-	public static boolean     LINEAGE_ESTIMATE = DMLOptions.defaultOptions.lineage_estimate; // whether estimate reuse benefits
-	public static boolean     LINEAGE_DEBUGGER = DMLOptions.defaultOptions.lineage_debugger; // whether enable lineage debugger
-	public static boolean     CHECK_PRIVACY = DMLOptions.defaultOptions.checkPrivacy;      // Check which privacy constraints are loaded and checked during federated execution
+	private static ExecMode   EXEC_MODE          = DMLOptions.defaultOptions.execMode;           // the execution mode
+	public static boolean     STATISTICS          = DMLOptions.defaultOptions.stats;             // whether to print statistics
+	public static boolean     JMLC_MEM_STATISTICS = false;                                       // whether to gather memory use stats in JMLC
+	public static int         STATISTICS_COUNT    = DMLOptions.defaultOptions.statsCount;        // statistics maximum heavy hitter count
+	public static int         STATISTICS_MAX_WRAP_LEN = 30;                                      // statistics maximum wrap length
+	public static boolean     FED_STATISTICS        = DMLOptions.defaultOptions.fedStats;        // whether to print federated statistics
+	public static int         FED_STATISTICS_COUNT  = DMLOptions.defaultOptions.fedStatsCount;   // federated statistics maximum heavy hitter count
+	public static ExplainType EXPLAIN             = DMLOptions.defaultOptions.explainType;       // explain type
+	public static String      DML_FILE_PATH_ANTLR_PARSER = DMLOptions.defaultOptions.filePath;   // filename of dml/pydml script
+	public static String      FLOATING_POINT_PRECISION = "double";                               // data type to use internally
+	public static boolean     PRINT_GPU_MEMORY_INFO = false;                                     // whether to print GPU memory-related information
+	public static long        EVICTION_SHADOW_BUFFER_MAX_BYTES = 0;                              // maximum number of bytes to use for shadow buffer
+	public static long        EVICTION_SHADOW_BUFFER_CURR_BYTES = 0;                             // number of bytes to use for shadow buffer
+	public static double      GPU_MEMORY_UTILIZATION_FACTOR = 0.9;                               // fraction of available GPU memory to use
+	public static String      GPU_MEMORY_ALLOCATOR = "cuda";                                     // GPU memory allocator to use
+	public static boolean     LINEAGE = DMLOptions.defaultOptions.lineage;                       // whether compute lineage trace
+	public static boolean     LINEAGE_DEDUP = DMLOptions.defaultOptions.lineage_dedup;           // whether deduplicate lineage items
+	public static ReuseCacheType LINEAGE_REUSE = DMLOptions.defaultOptions.linReuseType;         // whether lineage-based reuse
+	public static LineageCachePolicy LINEAGE_POLICY = DMLOptions.defaultOptions.linCachePolicy;  // lineage cache eviction policy
+	public static boolean     LINEAGE_ESTIMATE = DMLOptions.defaultOptions.lineage_estimate;     // whether estimate reuse benefits
+	public static boolean     LINEAGE_DEBUGGER = DMLOptions.defaultOptions.lineage_debugger;     // whether enable lineage debugger
+	public static boolean     CHECK_PRIVACY = DMLOptions.defaultOptions.checkPrivacy;            // Check which privacy constraints are loaded and checked during federated execution
 
 	public static boolean           USE_ACCELERATOR     = DMLOptions.defaultOptions.gpu;
 	public static boolean           FORCE_ACCELERATOR   = DMLOptions.defaultOptions.forceGPU;
@@ -212,20 +214,22 @@ public class DMLScript
 
 		try
 		{
-			STATISTICS          = dmlOptions.stats;
-			STATISTICS_COUNT    = dmlOptions.statsCount;
-			JMLC_MEM_STATISTICS = dmlOptions.memStats;
-			USE_ACCELERATOR     = dmlOptions.gpu;
-			FORCE_ACCELERATOR   = dmlOptions.forceGPU;
-			EXPLAIN             = dmlOptions.explainType;
-			EXEC_MODE           = dmlOptions.execMode;
-			LINEAGE             = dmlOptions.lineage;
-			LINEAGE_DEDUP       = dmlOptions.lineage_dedup;
-			LINEAGE_REUSE       = dmlOptions.linReuseType;
-			LINEAGE_POLICY      = dmlOptions.linCachePolicy;
-			LINEAGE_ESTIMATE    = dmlOptions.lineage_estimate;
-			CHECK_PRIVACY       = dmlOptions.checkPrivacy;
-			LINEAGE_DEBUGGER	= dmlOptions.lineage_debugger;
+			STATISTICS            = dmlOptions.stats;
+			STATISTICS_COUNT      = dmlOptions.statsCount;
+			FED_STATISTICS        = dmlOptions.fedStats;
+			FED_STATISTICS_COUNT  = dmlOptions.fedStatsCount;
+			JMLC_MEM_STATISTICS   = dmlOptions.memStats;
+			USE_ACCELERATOR       = dmlOptions.gpu;
+			FORCE_ACCELERATOR     = dmlOptions.forceGPU;
+			EXPLAIN               = dmlOptions.explainType;
+			EXEC_MODE             = dmlOptions.execMode;
+			LINEAGE               = dmlOptions.lineage;
+			LINEAGE_DEDUP         = dmlOptions.lineage_dedup;
+			LINEAGE_REUSE         = dmlOptions.linReuseType;
+			LINEAGE_POLICY        = dmlOptions.linCachePolicy;
+			LINEAGE_ESTIMATE      = dmlOptions.lineage_estimate;
+			CHECK_PRIVACY         = dmlOptions.checkPrivacy;
+			LINEAGE_DEBUGGER      = dmlOptions.lineage_debugger;
 
 			String fnameOptConfig = dmlOptions.configFile;
 			boolean isFile = dmlOptions.filePath != null;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
new file mode 100644
index 0000000..14f29d9
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.federated;
+
+import java.io.Serializable;
+import java.net.InetSocketAddress;
+import java.text.DecimalFormat;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.concurrent.Future;
+import javax.net.ssl.SSLException;
+
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.runtime.controlprogram.caching.CacheStatistics;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.CacheStatsCollection;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.GCStatsCollection;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.utils.Statistics;
+
+public class FederatedStatistics {
+	private static Set<Pair<String, Integer>> _fedWorkerAddresses = new HashSet<>();
+
+	public static void registerFedWorker(String host, int port) {
+		_fedWorkerAddresses.add(new ImmutablePair<>(host, new Integer(port)));
+	}
+
+	public static String displayFedWorkers() {
+		StringBuilder sb = new StringBuilder();
+		sb.append("Federated Worker Addresses:\n");
+		for(Pair<String, Integer> fedAddr : _fedWorkerAddresses) {
+			sb.append(String.format("  %s:%d", fedAddr.getLeft(), fedAddr.getRight().intValue()));
+			sb.append("\n");
+		}
+		return sb.toString();
+	}
+
+	public static String displayFedStatistics(int numHeavyHitters) {
+		StringBuilder sb = new StringBuilder();
+		FedStatsCollection fedStats = collectFedStats();
+		sb.append("SystemDS Federated Statistics:\n");
+		sb.append(displayCacheStats(fedStats.cacheStats));
+		sb.append(String.format("Total JIT compile time:\t\t%.3f sec.\n", fedStats.jitCompileTime));
+		sb.append(displayGCStats(fedStats.gcStats));
+		sb.append(displayHeavyHitters(fedStats.heavyHitters, numHeavyHitters));
+		return sb.toString();
+	}
+
+	public static String displayCacheStats(CacheStatsCollection csc) {
+		StringBuilder sb = new StringBuilder();
+		sb.append(String.format("Cache hits (Mem/Li/WB/FS/HDFS):\t%d/%d/%d/%d/%d.\n",
+			csc.memHits, csc.linHits, csc.fsBuffHits, csc.fsHits, csc.hdfsHits));
+		sb.append(String.format("Cache writes (Li/WB/FS/HDFS):\t%d/%d/%d/%d.\n",
+			csc.linWrites, csc.fsBuffWrites, csc.fsWrites, csc.hdfsWrites));
+		sb.append(String.format("Cache times (ACQr/m, RLS, EXP):\t%.3f/%.3f/%.3f/%.3f sec.\n",
+			csc.acqRTime, csc.acqMTime, csc.rlsTime, csc.expTime));
+		return sb.toString();
+	}
+
+	public static String displayGCStats(GCStatsCollection gcsc) {
+		StringBuilder sb = new StringBuilder();
+		sb.append(String.format("Total JVM GC count:\t\t%d.\n", gcsc.gcCount));
+		sb.append(String.format("Total JVM GC time:\t\t%.3f sec.\n", gcsc.gcTime));
+		return sb.toString();
+	}
+
+	public static String displayHeavyHitters(HashMap<String, Pair<Long, Double>> heavyHitters) {
+		return displayHeavyHitters(heavyHitters, 10);
+	}
+
+	public static String displayHeavyHitters(HashMap<String, Pair<Long, Double>> heavyHitters, int num) {
+		StringBuilder sb = new StringBuilder();
+		@SuppressWarnings("unchecked")
+		Entry<String, Pair<Long, Double>>[] hhArr = heavyHitters.entrySet().toArray(new Entry[0]);
+		Arrays.sort(hhArr, new Comparator<Entry<String, Pair<Long, Double>>>() {
+			public int compare(Entry<String, Pair<Long, Double>> e1, Entry<String, Pair<Long, Double>> e2) {
+				return e1.getValue().getRight().compareTo(e2.getValue().getRight());
+			}
+		});
+		
+		sb.append("Heavy hitter instructions:\n");
+		final String numCol = "#";
+		final String instCol = "Instruction";
+		final String timeSCol = "Time(s)";
+		final String countCol = "Count";
+		int numHittersToDisplay = Math.min(num, hhArr.length);
+		int maxNumLen = String.valueOf(numHittersToDisplay).length();
+		int maxInstLen = instCol.length();
+		int maxTimeSLen = timeSCol.length();
+		int maxCountLen = countCol.length();
+		DecimalFormat sFormat = new DecimalFormat("#,##0.000");
+		for (int counter = 0; counter < numHittersToDisplay; counter++) {
+			Entry<String, Pair<Long, Double>> hh = hhArr[hhArr.length - 1 - counter];
+			String instruction = hh.getKey();
+			maxInstLen = Math.max(maxInstLen, instruction.length());
+			String timeString = sFormat.format(hh.getValue().getRight());
+			maxTimeSLen = Math.max(maxTimeSLen, timeString.length());
+			maxCountLen = Math.max(maxCountLen, String.valueOf(hh.getValue().getLeft()).length());
+		}
+		maxInstLen = Math.min(maxInstLen, DMLScript.STATISTICS_MAX_WRAP_LEN);
+		sb.append(String.format( " %" + maxNumLen + "s  %-" + maxInstLen + "s  %"
+			+ maxTimeSLen + "s  %" + maxCountLen + "s", numCol, instCol, timeSCol, countCol));
+		sb.append("\n");
+
+		for (int counter = 0; counter < numHittersToDisplay; counter++) {
+			String instruction = hhArr[hhArr.length - 1 - counter].getKey();
+			String [] wrappedInstruction = Statistics.wrap(instruction, maxInstLen);
+
+			String timeSString = sFormat.format(hhArr[hhArr.length - 1 - counter].getValue().getRight());
+
+			long count = hhArr[hhArr.length - 1 - counter].getValue().getLeft();
+			int numLines = wrappedInstruction.length;
+			
+			for(int wrapIter = 0; wrapIter < numLines; wrapIter++) {
+				String instStr = (wrapIter < wrappedInstruction.length) ? wrappedInstruction[wrapIter] : "";
+				if(wrapIter == 0) {
+					sb.append(String.format(
+						" %" + maxNumLen + "d  %-" + maxInstLen + "s  %" + maxTimeSLen + "s  %" 
+						+ maxCountLen + "d", (counter + 1), instStr, timeSString, count));
+				}
+				else {
+					sb.append(String.format(
+						" %" + maxNumLen + "s  %-" + maxInstLen + "s  %" + maxTimeSLen + "s  %" 
+						+ maxCountLen + "s", "", instStr, "", ""));
+				}
+				sb.append("\n");
+			}
+		}
+
+		return sb.toString();
+	}
+
+	private static FedStatsCollection collectFedStats() {
+		Future<FederatedResponse>[] responses = getFederatedResponses();
+		FedStatsCollection aggFedStats = new FedStatsCollection();
+		for(Future<FederatedResponse> res : responses) {
+			try {
+				Object[] tmp = res.get().getData();
+				if(tmp[0] instanceof FedStatsCollection)
+					aggFedStats.aggregate((FedStatsCollection)tmp[0]);
+			} catch(Exception e) {
+				throw new DMLRuntimeException("Exception of type " + e.getClass().toString() 
+					+ " thrown while " + "getting the federated stats of the federated response: ", e);
+			}
+		}
+		return aggFedStats;
+	}
+
+	private static Future<FederatedResponse>[] getFederatedResponses() {
+		List<Future<FederatedResponse>> ret = new ArrayList<>();
+		for(Pair<String, Integer> fedAddr : _fedWorkerAddresses) {
+			InetSocketAddress isa = new InetSocketAddress(fedAddr.getLeft(), fedAddr.getRight());
+			FederatedRequest frUDF = new FederatedRequest(RequestType.EXEC_UDF, -1, 
+				new FedStatsCollectFunction());
+			try {
+				ret.add(FederatedData.executeFederatedOperation(isa, frUDF));
+			} catch(SSLException ssle) {
+				throw new DMLRuntimeException("SSLException while getting the federated stats from "
+					+ isa.toString() + ": ", ssle);
+			} catch (Exception e) {
+				throw new DMLRuntimeException("Exeption of type " + e.getClass().getName() 
+					+ " thrown while getting stats from federated worker: ", e);
+			}
+		}
+		@SuppressWarnings("unchecked")
+		Future<FederatedResponse>[] retArr = ret.toArray(new Future[0]);
+		return retArr;
+	}
+
+	private static class FedStatsCollectFunction extends FederatedUDF {
+		private static final long serialVersionUID = 1L;
+
+		public FedStatsCollectFunction() {
+			super(new long[] { });
+		}
+
+		@Override
+		public FederatedResponse execute(ExecutionContext ec, Data... data) {
+			FedStatsCollection fedStats = new FedStatsCollection();
+			fedStats.collectStats();
+			return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, fedStats);
+		}
+
+		@Override
+		public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+			return null;
+		}
+	}
+
+	protected static class FedStatsCollection implements Serializable {
+		private static final long serialVersionUID = 1L;
+
+		private void collectStats() {
+			cacheStats.collectStats();
+			jitCompileTime = ((double)Statistics.getJITCompileTime()) / 1000; // in sec
+			gcStats.collectStats();
+			heavyHitters = Statistics.getHeavyHittersHashMap();
+		}
+		
+		private void aggregate(FedStatsCollection that) {
+			cacheStats.aggregate(that.cacheStats);
+			jitCompileTime += that.jitCompileTime;
+			gcStats.aggregate(that.gcStats);
+			that.heavyHitters.forEach(
+				(key, value) -> heavyHitters.merge(key, value, (v1, v2) ->
+					new ImmutablePair<>(v1.getLeft() + v2.getLeft(), v1.getRight() + v2.getRight()))
+			);
+		}
+
+		protected static class CacheStatsCollection implements Serializable {
+			private static final long serialVersionUID = 1L;
+
+			private void collectStats() {
+				memHits = CacheStatistics.getMemHits();
+				linHits = CacheStatistics.getLinHits();
+				fsBuffHits = CacheStatistics.getFSBuffHits();
+				fsHits = CacheStatistics.getFSHits();
+				hdfsHits = CacheStatistics.getHDFSHits();
+				linWrites = CacheStatistics.getLinWrites();
+				fsBuffWrites = CacheStatistics.getFSBuffWrites();
+				fsWrites = CacheStatistics.getFSWrites();
+				hdfsWrites = CacheStatistics.getHDFSWrites();
+				acqRTime = ((double)CacheStatistics.getAcquireRTime()) / 1000000000; // in sec
+				acqMTime = ((double)CacheStatistics.getAcquireMTime()) / 1000000000; // in sec
+				rlsTime = ((double)CacheStatistics.getReleaseTime()) / 1000000000; // in sec
+				expTime = ((double)CacheStatistics.getExportTime()) / 1000000000; // in sec
+			}
+
+			private void aggregate(CacheStatsCollection that) {
+				memHits += that.memHits;
+				linHits += that.linHits;
+				fsBuffHits += that.fsBuffHits;
+				fsHits += that.fsHits;
+				hdfsHits += that.hdfsHits;
+				linWrites += that.linWrites;
+				fsBuffWrites += that.fsBuffWrites;
+				fsWrites += that.fsWrites;
+				hdfsWrites += that.hdfsWrites;
+				acqRTime += that.acqRTime;
+				acqMTime += that.acqMTime;
+				rlsTime += that.rlsTime;
+				expTime += that.expTime;
+			}
+
+			private long memHits = 0;
+			private long linHits = 0;
+			private long fsBuffHits = 0;
+			private long fsHits = 0;
+			private long hdfsHits = 0;
+			private long linWrites = 0;
+			private long fsBuffWrites = 0;
+			private long fsWrites = 0;
+			private long hdfsWrites = 0;
+			private double acqRTime = 0;
+			private double acqMTime = 0;
+			private double rlsTime = 0;
+			private double expTime = 0;
+		}
+
+		protected static class GCStatsCollection implements Serializable {
+			private static final long serialVersionUID = 1L;
+
+			private void collectStats() {
+				gcCount = Statistics.getJVMgcCount();
+				gcTime = ((double)Statistics.getJVMgcTime()) / 1000; // in sec
+			}
+
+			private void aggregate(GCStatsCollection that) {
+				gcCount += that.gcCount;
+				gcTime += that.gcTime;
+			}
+
+			private long gcCount = 0;
+			private double gcTime = 0;
+		}
+
+		private CacheStatsCollection cacheStats = new CacheStatsCollection();
+		private double jitCompileTime = 0;
+		private GCStatsCollection gcStats = new GCStatsCollection();
+		private HashMap<String, Pair<Long, Double>> heavyHitters = new HashMap<>();
+	}
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 9b6d3f0..b4d2e04 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -47,6 +47,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -120,6 +121,10 @@ public class InitFEDInstruction extends FEDInstruction implements LineageTraceab
 				String host = parsedValues[0];
 				int port = Integer.parseInt(parsedValues[1]);
 				String filePath = parsedValues[2];
+
+				// register the federated worker for federated statistics creation
+				FederatedStatistics.registerFedWorker(host, port);
+
 				// get beginning and end of data ranges
 				List<Data> rangesData = ranges.getData();
 				Data beginData = rangesData.get(i * 2);
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java
index d4247a7..dd8ddce 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -25,6 +25,7 @@ import java.lang.management.ManagementFactory;
 import java.text.DecimalFormat;
 import java.util.Arrays;
 import java.util.Comparator;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map.Entry;
 import java.util.Set;
@@ -32,12 +33,15 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.DoubleAdder;
 import java.util.concurrent.atomic.LongAdder;
 
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.controlprogram.caching.CacheStatistics;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -732,6 +736,17 @@ public class Statistics
 		return (tmp != null) ? tmp.count.longValue() : 0;
 	}
 
+	public static HashMap<String, Pair<Long, Double>> getHeavyHittersHashMap() {
+		HashMap<String, Pair<Long, Double>> heavyHitters = new HashMap<>();
+		for(String opcode : _instStats.keySet()) {
+			InstStats val = _instStats.get(opcode);
+			long count = val.count.longValue();
+			double time = val.time.longValue() / 1000000000d; // in sec
+			heavyHitters.put(opcode, new ImmutablePair<Long, Double>(new Long(count), new Double(time)));
+		}
+		return heavyHitters;
+	}
+
 	/**
 	 * Obtain a string tabular representation of the heavy hitter instructions
 	 * that displays the time, instruction count, and optionally GPU stats about
@@ -956,7 +971,7 @@ public class Statistics
 	}
 	
 	
-	private static String [] wrap(String str, int wrapLength) {
+	public static String [] wrap(String str, int wrapLength) {
 		int numLines = (int) Math.ceil( ((double)str.length()) / wrapLength);
 		int len = str.length();
 		String [] ret = new String[numLines];
@@ -1105,6 +1120,11 @@ public class Statistics
 		if (DMLScript.CHECK_PRIVACY)
 			sb.append(CheckedConstraintsLog.display());
 
+		if(DMLScript.FED_STATISTICS) {
+			sb.append("\n");
+			sb.append(FederatedStatistics.displayFedStatistics(DMLScript.FED_STATISTICS_COUNT));
+		}
+
 		return sb.toString();
 	}
 }
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
index 09ca19e..54d89e6 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
@@ -30,14 +30,12 @@ import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.junit.Assert;
-import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
-@Ignore
 public class FederatedStatisticsTest extends AutomatedTestBase {
 
 	private final static String TEST_DIR = "functions/federated/";
@@ -105,7 +103,6 @@ public class FederatedStatisticsTest extends AutomatedTestBase {
 		TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
 		loadTestConfiguration(config);
 		
-
 		// Run reference dml script with normal matrix
 		fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
 		programArgs = new String[] {"-args", input("X1"), input("X2"), input("Y"), expected("Z")};
@@ -113,7 +110,7 @@ public class FederatedStatisticsTest extends AutomatedTestBase {
 
 		// Run actual dml script with federated matrix
 		fullDMLScriptName = HOME + TEST_NAME + ".dml";
-		programArgs = new String[] {"-stats", "30", "-nvargs",
+		programArgs = new String[] {"-stats", "30", "-fedStats", "-nvargs",
 			"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
 			"in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + cols,
 			"in_Y=" + input("Y"), "out=" + output("Z")};