You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/07/04 17:55:22 UTC
[systemds] branch master updated: [SYSTEMDS-3039] Tracking and
consolidation of federated statistics
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 582b9c3 [SYSTEMDS-3039] Tracking and consolidation of federated statistics
582b9c3 is described below
commit 582b9c3f622d87cd9d11b9dd01abcb0a6f179309
Author: ywcb00 <yw...@ywcb.org>
AuthorDate: Sun Jul 4 19:52:56 2021 +0200
[SYSTEMDS-3039] Tracking and consolidation of federated statistics
Closes #1321.
---
src/main/java/org/apache/sysds/api/DMLOptions.java | 19 ++
src/main/java/org/apache/sysds/api/DMLScript.java | 72 ++---
.../federated/FederatedStatistics.java | 311 +++++++++++++++++++++
.../instructions/fed/InitFEDInstruction.java | 5 +
.../java/org/apache/sysds/utils/Statistics.java | 22 +-
.../primitives/FederatedStatisticsTest.java | 5 +-
6 files changed, 395 insertions(+), 39 deletions(-)
diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java b/src/main/java/org/apache/sysds/api/DMLOptions.java
index 0949c06..fbdaa90 100644
--- a/src/main/java/org/apache/sysds/api/DMLOptions.java
+++ b/src/main/java/org/apache/sysds/api/DMLOptions.java
@@ -52,6 +52,8 @@ public class DMLOptions {
public boolean clean = false; // Whether to clean up all SystemDS working directories (FS, DFS)
public boolean stats = false; // Whether to record and print the statistics
public int statsCount = 10; // Default statistics count
+ public boolean fedStats = false; // Whether to record and print the federated statistics
+ public int fedStatsCount = 10; // Default federated statistics count
public boolean memStats = false; // max memory statistics
public Explain.ExplainType explainType = Explain.ExplainType.NONE; // Whether to print the "Explain" and if so, what type
public ExecMode execMode = OptimizerUtils.getDefaultExecutionMode(); // Execution mode standalone, MR, Spark or a hybrid
@@ -85,6 +87,8 @@ public class DMLOptions {
", clean=" + clean +
", stats=" + stats +
", statsCount=" + statsCount +
+ ", fedStats=" + fedStats +
+ ", fedStatsCount=" + fedStatsCount +
", memStats=" + memStats +
", explainType=" + explainType +
", execMode=" + execMode +
@@ -193,6 +197,17 @@ public class DMLOptions {
}
}
}
+ dmlOptions.fedStats = line.hasOption("fedStats");
+ if (dmlOptions.fedStats) {
+ String fedStatsCount = line.getOptionValue("fedStats");
+ if(fedStatsCount != null) {
+ try {
+ dmlOptions.fedStatsCount = Integer.parseInt(fedStatsCount);
+ } catch (NumberFormatException e) {
+ throw new org.apache.commons.cli.ParseException("Invalid argument specified for -fedStats option, must be a valid integer");
+ }
+ }
+ }
dmlOptions.memStats = line.hasOption("mem");
dmlOptions.clean = line.hasOption("clean");
@@ -265,6 +280,9 @@ public class DMLOptions {
Option statsOpt = OptionBuilder.withArgName("count")
.withDescription("monitors and reports summary execution statistics; heavy hitter <count> is 10 unless overridden; default off")
.hasOptionalArg().create("stats");
+ Option fedStatsOpt = OptionBuilder.withArgName("count")
+ .withDescription("monitors and reports summary execution statistics of federated workers; heavy hitter <count> is 10 unless overridden; default off")
+ .hasOptionalArg().create("fedStats");
Option memOpt = OptionBuilder.withDescription("monitors and reports max memory consumption in CP; default off")
.create("mem");
Option explainOpt = OptionBuilder.withArgName("level")
@@ -299,6 +317,7 @@ public class DMLOptions {
options.addOption(configOpt);
options.addOption(cleanOpt);
options.addOption(statsOpt);
+ options.addOption(fedStatsOpt);
options.addOption(memOpt);
options.addOption(explainOpt);
options.addOption(execOpt);
diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java
index 7d2bf16..e2e67a5 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -82,26 +82,28 @@ import org.apache.sysds.utils.Explain.ExplainType;
public class DMLScript
{
- private static ExecMode EXEC_MODE = DMLOptions.defaultOptions.execMode; // the execution mode
- public static boolean STATISTICS = DMLOptions.defaultOptions.stats; // whether to print statistics
- public static boolean JMLC_MEM_STATISTICS = false; // whether to gather memory use stats in JMLC
- public static int STATISTICS_COUNT = DMLOptions.defaultOptions.statsCount; // statistics maximum heavy hitter count
- public static int STATISTICS_MAX_WRAP_LEN = 30; // statistics maximum wrap length
- public static ExplainType EXPLAIN = DMLOptions.defaultOptions.explainType; // explain type
- public static String DML_FILE_PATH_ANTLR_PARSER = DMLOptions.defaultOptions.filePath; // filename of dml/pydml script
- public static String FLOATING_POINT_PRECISION = "double"; // data type to use internally
- public static boolean PRINT_GPU_MEMORY_INFO = false; // whether to print GPU memory-related information
- public static long EVICTION_SHADOW_BUFFER_MAX_BYTES = 0; // maximum number of bytes to use for shadow buffer
- public static long EVICTION_SHADOW_BUFFER_CURR_BYTES = 0; // number of bytes to use for shadow buffer
- public static double GPU_MEMORY_UTILIZATION_FACTOR = 0.9; // fraction of available GPU memory to use
- public static String GPU_MEMORY_ALLOCATOR = "cuda"; // GPU memory allocator to use
- public static boolean LINEAGE = DMLOptions.defaultOptions.lineage; // whether compute lineage trace
- public static boolean LINEAGE_DEDUP = DMLOptions.defaultOptions.lineage_dedup; // whether deduplicate lineage items
- public static ReuseCacheType LINEAGE_REUSE = DMLOptions.defaultOptions.linReuseType; // whether lineage-based reuse
- public static LineageCachePolicy LINEAGE_POLICY = DMLOptions.defaultOptions.linCachePolicy; // lineage cache eviction policy
- public static boolean LINEAGE_ESTIMATE = DMLOptions.defaultOptions.lineage_estimate; // whether estimate reuse benefits
- public static boolean LINEAGE_DEBUGGER = DMLOptions.defaultOptions.lineage_debugger; // whether enable lineage debugger
- public static boolean CHECK_PRIVACY = DMLOptions.defaultOptions.checkPrivacy; // Check which privacy constraints are loaded and checked during federated execution
+ private static ExecMode EXEC_MODE = DMLOptions.defaultOptions.execMode; // the execution mode
+ public static boolean STATISTICS = DMLOptions.defaultOptions.stats; // whether to print statistics
+ public static boolean JMLC_MEM_STATISTICS = false; // whether to gather memory use stats in JMLC
+ public static int STATISTICS_COUNT = DMLOptions.defaultOptions.statsCount; // statistics maximum heavy hitter count
+ public static int STATISTICS_MAX_WRAP_LEN = 30; // statistics maximum wrap length
+ public static boolean FED_STATISTICS = DMLOptions.defaultOptions.fedStats; // whether to print federated statistics
+ public static int FED_STATISTICS_COUNT = DMLOptions.defaultOptions.fedStatsCount; // federated statistics maximum heavy hitter count
+ public static ExplainType EXPLAIN = DMLOptions.defaultOptions.explainType; // explain type
+ public static String DML_FILE_PATH_ANTLR_PARSER = DMLOptions.defaultOptions.filePath; // filename of dml/pydml script
+ public static String FLOATING_POINT_PRECISION = "double"; // data type to use internally
+ public static boolean PRINT_GPU_MEMORY_INFO = false; // whether to print GPU memory-related information
+ public static long EVICTION_SHADOW_BUFFER_MAX_BYTES = 0; // maximum number of bytes to use for shadow buffer
+ public static long EVICTION_SHADOW_BUFFER_CURR_BYTES = 0; // number of bytes to use for shadow buffer
+ public static double GPU_MEMORY_UTILIZATION_FACTOR = 0.9; // fraction of available GPU memory to use
+ public static String GPU_MEMORY_ALLOCATOR = "cuda"; // GPU memory allocator to use
+ public static boolean LINEAGE = DMLOptions.defaultOptions.lineage; // whether compute lineage trace
+ public static boolean LINEAGE_DEDUP = DMLOptions.defaultOptions.lineage_dedup; // whether deduplicate lineage items
+ public static ReuseCacheType LINEAGE_REUSE = DMLOptions.defaultOptions.linReuseType; // whether lineage-based reuse
+ public static LineageCachePolicy LINEAGE_POLICY = DMLOptions.defaultOptions.linCachePolicy; // lineage cache eviction policy
+ public static boolean LINEAGE_ESTIMATE = DMLOptions.defaultOptions.lineage_estimate; // whether estimate reuse benefits
+ public static boolean LINEAGE_DEBUGGER = DMLOptions.defaultOptions.lineage_debugger; // whether enable lineage debugger
+ public static boolean CHECK_PRIVACY = DMLOptions.defaultOptions.checkPrivacy; // Check which privacy constraints are loaded and checked during federated execution
public static boolean USE_ACCELERATOR = DMLOptions.defaultOptions.gpu;
public static boolean FORCE_ACCELERATOR = DMLOptions.defaultOptions.forceGPU;
@@ -212,20 +214,22 @@ public class DMLScript
try
{
- STATISTICS = dmlOptions.stats;
- STATISTICS_COUNT = dmlOptions.statsCount;
- JMLC_MEM_STATISTICS = dmlOptions.memStats;
- USE_ACCELERATOR = dmlOptions.gpu;
- FORCE_ACCELERATOR = dmlOptions.forceGPU;
- EXPLAIN = dmlOptions.explainType;
- EXEC_MODE = dmlOptions.execMode;
- LINEAGE = dmlOptions.lineage;
- LINEAGE_DEDUP = dmlOptions.lineage_dedup;
- LINEAGE_REUSE = dmlOptions.linReuseType;
- LINEAGE_POLICY = dmlOptions.linCachePolicy;
- LINEAGE_ESTIMATE = dmlOptions.lineage_estimate;
- CHECK_PRIVACY = dmlOptions.checkPrivacy;
- LINEAGE_DEBUGGER = dmlOptions.lineage_debugger;
+ STATISTICS = dmlOptions.stats;
+ STATISTICS_COUNT = dmlOptions.statsCount;
+ FED_STATISTICS = dmlOptions.fedStats;
+ FED_STATISTICS_COUNT = dmlOptions.fedStatsCount;
+ JMLC_MEM_STATISTICS = dmlOptions.memStats;
+ USE_ACCELERATOR = dmlOptions.gpu;
+ FORCE_ACCELERATOR = dmlOptions.forceGPU;
+ EXPLAIN = dmlOptions.explainType;
+ EXEC_MODE = dmlOptions.execMode;
+ LINEAGE = dmlOptions.lineage;
+ LINEAGE_DEDUP = dmlOptions.lineage_dedup;
+ LINEAGE_REUSE = dmlOptions.linReuseType;
+ LINEAGE_POLICY = dmlOptions.linCachePolicy;
+ LINEAGE_ESTIMATE = dmlOptions.lineage_estimate;
+ CHECK_PRIVACY = dmlOptions.checkPrivacy;
+ LINEAGE_DEBUGGER = dmlOptions.lineage_debugger;
String fnameOptConfig = dmlOptions.configFile;
boolean isFile = dmlOptions.filePath != null;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
new file mode 100644
index 0000000..14f29d9
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.federated;
+
+import java.io.Serializable;
+import java.net.InetSocketAddress;
+import java.text.DecimalFormat;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.concurrent.Future;
+import javax.net.ssl.SSLException;
+
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.runtime.controlprogram.caching.CacheStatistics;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.CacheStatsCollection;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.GCStatsCollection;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.utils.Statistics;
+
+public class FederatedStatistics {
+ private static Set<Pair<String, Integer>> _fedWorkerAddresses = new HashSet<>();
+
+ public static void registerFedWorker(String host, int port) {
+ _fedWorkerAddresses.add(new ImmutablePair<>(host, new Integer(port)));
+ }
+
+ public static String displayFedWorkers() {
+ StringBuilder sb = new StringBuilder();
+ sb.append("Federated Worker Addresses:\n");
+ for(Pair<String, Integer> fedAddr : _fedWorkerAddresses) {
+ sb.append(String.format(" %s:%d", fedAddr.getLeft(), fedAddr.getRight().intValue()));
+ sb.append("\n");
+ }
+ return sb.toString();
+ }
+
+ public static String displayFedStatistics(int numHeavyHitters) {
+ StringBuilder sb = new StringBuilder();
+ FedStatsCollection fedStats = collectFedStats();
+ sb.append("SystemDS Federated Statistics:\n");
+ sb.append(displayCacheStats(fedStats.cacheStats));
+ sb.append(String.format("Total JIT compile time:\t\t%.3f sec.\n", fedStats.jitCompileTime));
+ sb.append(displayGCStats(fedStats.gcStats));
+ sb.append(displayHeavyHitters(fedStats.heavyHitters, numHeavyHitters));
+ return sb.toString();
+ }
+
+ public static String displayCacheStats(CacheStatsCollection csc) {
+ StringBuilder sb = new StringBuilder();
+ sb.append(String.format("Cache hits (Mem/Li/WB/FS/HDFS):\t%d/%d/%d/%d/%d.\n",
+ csc.memHits, csc.linHits, csc.fsBuffHits, csc.fsHits, csc.hdfsHits));
+ sb.append(String.format("Cache writes (Li/WB/FS/HDFS):\t%d/%d/%d/%d.\n",
+ csc.linWrites, csc.fsBuffWrites, csc.fsWrites, csc.hdfsWrites));
+ sb.append(String.format("Cache times (ACQr/m, RLS, EXP):\t%.3f/%.3f/%.3f/%.3f sec.\n",
+ csc.acqRTime, csc.acqMTime, csc.rlsTime, csc.expTime));
+ return sb.toString();
+ }
+
+ public static String displayGCStats(GCStatsCollection gcsc) {
+ StringBuilder sb = new StringBuilder();
+ sb.append(String.format("Total JVM GC count:\t\t%d.\n", gcsc.gcCount));
+ sb.append(String.format("Total JVM GC time:\t\t%.3f sec.\n", gcsc.gcTime));
+ return sb.toString();
+ }
+
+ public static String displayHeavyHitters(HashMap<String, Pair<Long, Double>> heavyHitters) {
+ return displayHeavyHitters(heavyHitters, 10);
+ }
+
+ public static String displayHeavyHitters(HashMap<String, Pair<Long, Double>> heavyHitters, int num) {
+ StringBuilder sb = new StringBuilder();
+ @SuppressWarnings("unchecked")
+ Entry<String, Pair<Long, Double>>[] hhArr = heavyHitters.entrySet().toArray(new Entry[0]);
+ Arrays.sort(hhArr, new Comparator<Entry<String, Pair<Long, Double>>>() {
+ public int compare(Entry<String, Pair<Long, Double>> e1, Entry<String, Pair<Long, Double>> e2) {
+ return e1.getValue().getRight().compareTo(e2.getValue().getRight());
+ }
+ });
+
+ sb.append("Heavy hitter instructions:\n");
+ final String numCol = "#";
+ final String instCol = "Instruction";
+ final String timeSCol = "Time(s)";
+ final String countCol = "Count";
+ int numHittersToDisplay = Math.min(num, hhArr.length);
+ int maxNumLen = String.valueOf(numHittersToDisplay).length();
+ int maxInstLen = instCol.length();
+ int maxTimeSLen = timeSCol.length();
+ int maxCountLen = countCol.length();
+ DecimalFormat sFormat = new DecimalFormat("#,##0.000");
+ for (int counter = 0; counter < numHittersToDisplay; counter++) {
+ Entry<String, Pair<Long, Double>> hh = hhArr[hhArr.length - 1 - counter];
+ String instruction = hh.getKey();
+ maxInstLen = Math.max(maxInstLen, instruction.length());
+ String timeString = sFormat.format(hh.getValue().getRight());
+ maxTimeSLen = Math.max(maxTimeSLen, timeString.length());
+ maxCountLen = Math.max(maxCountLen, String.valueOf(hh.getValue().getLeft()).length());
+ }
+ maxInstLen = Math.min(maxInstLen, DMLScript.STATISTICS_MAX_WRAP_LEN);
+ sb.append(String.format( " %" + maxNumLen + "s %-" + maxInstLen + "s %"
+ + maxTimeSLen + "s %" + maxCountLen + "s", numCol, instCol, timeSCol, countCol));
+ sb.append("\n");
+
+ for (int counter = 0; counter < numHittersToDisplay; counter++) {
+ String instruction = hhArr[hhArr.length - 1 - counter].getKey();
+ String [] wrappedInstruction = Statistics.wrap(instruction, maxInstLen);
+
+ String timeSString = sFormat.format(hhArr[hhArr.length - 1 - counter].getValue().getRight());
+
+ long count = hhArr[hhArr.length - 1 - counter].getValue().getLeft();
+ int numLines = wrappedInstruction.length;
+
+ for(int wrapIter = 0; wrapIter < numLines; wrapIter++) {
+ String instStr = (wrapIter < wrappedInstruction.length) ? wrappedInstruction[wrapIter] : "";
+ if(wrapIter == 0) {
+ sb.append(String.format(
+ " %" + maxNumLen + "d %-" + maxInstLen + "s %" + maxTimeSLen + "s %"
+ + maxCountLen + "d", (counter + 1), instStr, timeSString, count));
+ }
+ else {
+ sb.append(String.format(
+ " %" + maxNumLen + "s %-" + maxInstLen + "s %" + maxTimeSLen + "s %"
+ + maxCountLen + "s", "", instStr, "", ""));
+ }
+ sb.append("\n");
+ }
+ }
+
+ return sb.toString();
+ }
+
+ private static FedStatsCollection collectFedStats() {
+ Future<FederatedResponse>[] responses = getFederatedResponses();
+ FedStatsCollection aggFedStats = new FedStatsCollection();
+ for(Future<FederatedResponse> res : responses) {
+ try {
+ Object[] tmp = res.get().getData();
+ if(tmp[0] instanceof FedStatsCollection)
+ aggFedStats.aggregate((FedStatsCollection)tmp[0]);
+ } catch(Exception e) {
+ throw new DMLRuntimeException("Exception of type " + e.getClass().toString()
+ + " thrown while " + "getting the federated stats of the federated response: ", e);
+ }
+ }
+ return aggFedStats;
+ }
+
+ private static Future<FederatedResponse>[] getFederatedResponses() {
+ List<Future<FederatedResponse>> ret = new ArrayList<>();
+ for(Pair<String, Integer> fedAddr : _fedWorkerAddresses) {
+ InetSocketAddress isa = new InetSocketAddress(fedAddr.getLeft(), fedAddr.getRight());
+ FederatedRequest frUDF = new FederatedRequest(RequestType.EXEC_UDF, -1,
+ new FedStatsCollectFunction());
+ try {
+ ret.add(FederatedData.executeFederatedOperation(isa, frUDF));
+ } catch(SSLException ssle) {
+ throw new DMLRuntimeException("SSLException while getting the federated stats from "
+ + isa.toString() + ": ", ssle);
+ } catch (Exception e) {
+ throw new DMLRuntimeException("Exeption of type " + e.getClass().getName()
+ + " thrown while getting stats from federated worker: ", e);
+ }
+ }
+ @SuppressWarnings("unchecked")
+ Future<FederatedResponse>[] retArr = ret.toArray(new Future[0]);
+ return retArr;
+ }
+
+ private static class FedStatsCollectFunction extends FederatedUDF {
+ private static final long serialVersionUID = 1L;
+
+ public FedStatsCollectFunction() {
+ super(new long[] { });
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ FedStatsCollection fedStats = new FedStatsCollection();
+ fedStats.collectStats();
+ return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, fedStats);
+ }
+
+ @Override
+ public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+ return null;
+ }
+ }
+
+ protected static class FedStatsCollection implements Serializable {
+ private static final long serialVersionUID = 1L;
+
+ private void collectStats() {
+ cacheStats.collectStats();
+ jitCompileTime = ((double)Statistics.getJITCompileTime()) / 1000; // in sec
+ gcStats.collectStats();
+ heavyHitters = Statistics.getHeavyHittersHashMap();
+ }
+
+ private void aggregate(FedStatsCollection that) {
+ cacheStats.aggregate(that.cacheStats);
+ jitCompileTime += that.jitCompileTime;
+ gcStats.aggregate(that.gcStats);
+ that.heavyHitters.forEach(
+ (key, value) -> heavyHitters.merge(key, value, (v1, v2) ->
+ new ImmutablePair<>(v1.getLeft() + v2.getLeft(), v1.getRight() + v2.getRight()))
+ );
+ }
+
+ protected static class CacheStatsCollection implements Serializable {
+ private static final long serialVersionUID = 1L;
+
+ private void collectStats() {
+ memHits = CacheStatistics.getMemHits();
+ linHits = CacheStatistics.getLinHits();
+ fsBuffHits = CacheStatistics.getFSBuffHits();
+ fsHits = CacheStatistics.getFSHits();
+ hdfsHits = CacheStatistics.getHDFSHits();
+ linWrites = CacheStatistics.getLinWrites();
+ fsBuffWrites = CacheStatistics.getFSBuffWrites();
+ fsWrites = CacheStatistics.getFSWrites();
+ hdfsWrites = CacheStatistics.getHDFSWrites();
+ acqRTime = ((double)CacheStatistics.getAcquireRTime()) / 1000000000; // in sec
+ acqMTime = ((double)CacheStatistics.getAcquireMTime()) / 1000000000; // in sec
+ rlsTime = ((double)CacheStatistics.getReleaseTime()) / 1000000000; // in sec
+ expTime = ((double)CacheStatistics.getExportTime()) / 1000000000; // in sec
+ }
+
+ private void aggregate(CacheStatsCollection that) {
+ memHits += that.memHits;
+ linHits += that.linHits;
+ fsBuffHits += that.fsBuffHits;
+ fsHits += that.fsHits;
+ hdfsHits += that.hdfsHits;
+ linWrites += that.linWrites;
+ fsBuffWrites += that.fsBuffWrites;
+ fsWrites += that.fsWrites;
+ hdfsWrites += that.hdfsWrites;
+ acqRTime += that.acqRTime;
+ acqMTime += that.acqMTime;
+ rlsTime += that.rlsTime;
+ expTime += that.expTime;
+ }
+
+ private long memHits = 0;
+ private long linHits = 0;
+ private long fsBuffHits = 0;
+ private long fsHits = 0;
+ private long hdfsHits = 0;
+ private long linWrites = 0;
+ private long fsBuffWrites = 0;
+ private long fsWrites = 0;
+ private long hdfsWrites = 0;
+ private double acqRTime = 0;
+ private double acqMTime = 0;
+ private double rlsTime = 0;
+ private double expTime = 0;
+ }
+
+ protected static class GCStatsCollection implements Serializable {
+ private static final long serialVersionUID = 1L;
+
+ private void collectStats() {
+ gcCount = Statistics.getJVMgcCount();
+ gcTime = ((double)Statistics.getJVMgcTime()) / 1000; // in sec
+ }
+
+ private void aggregate(GCStatsCollection that) {
+ gcCount += that.gcCount;
+ gcTime += that.gcTime;
+ }
+
+ private long gcCount = 0;
+ private double gcTime = 0;
+ }
+
+ private CacheStatsCollection cacheStats = new CacheStatsCollection();
+ private double jitCompileTime = 0;
+ private GCStatsCollection gcStats = new GCStatsCollection();
+ private HashMap<String, Pair<Long, Double>> heavyHitters = new HashMap<>();
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 9b6d3f0..b4d2e04 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -47,6 +47,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -120,6 +121,10 @@ public class InitFEDInstruction extends FEDInstruction implements LineageTraceab
String host = parsedValues[0];
int port = Integer.parseInt(parsedValues[1]);
String filePath = parsedValues[2];
+
+ // register the federated worker for federated statistics creation
+ FederatedStatistics.registerFedWorker(host, port);
+
// get beginning and end of data ranges
List<Data> rangesData = ranges.getData();
Data beginData = rangesData.get(i * 2);
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java
index d4247a7..dd8ddce 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -25,6 +25,7 @@ import java.lang.management.ManagementFactory;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.Comparator;
+import java.util.HashMap;
import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
@@ -32,12 +33,15 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.DoubleAdder;
import java.util.concurrent.atomic.LongAdder;
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.controlprogram.caching.CacheStatistics;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -732,6 +736,17 @@ public class Statistics
return (tmp != null) ? tmp.count.longValue() : 0;
}
+ public static HashMap<String, Pair<Long, Double>> getHeavyHittersHashMap() {
+ HashMap<String, Pair<Long, Double>> heavyHitters = new HashMap<>();
+ for(String opcode : _instStats.keySet()) {
+ InstStats val = _instStats.get(opcode);
+ long count = val.count.longValue();
+ double time = val.time.longValue() / 1000000000d; // in sec
+ heavyHitters.put(opcode, new ImmutablePair<Long, Double>(new Long(count), new Double(time)));
+ }
+ return heavyHitters;
+ }
+
/**
* Obtain a string tabular representation of the heavy hitter instructions
* that displays the time, instruction count, and optionally GPU stats about
@@ -956,7 +971,7 @@ public class Statistics
}
- private static String [] wrap(String str, int wrapLength) {
+ public static String [] wrap(String str, int wrapLength) {
int numLines = (int) Math.ceil( ((double)str.length()) / wrapLength);
int len = str.length();
String [] ret = new String[numLines];
@@ -1105,6 +1120,11 @@ public class Statistics
if (DMLScript.CHECK_PRIVACY)
sb.append(CheckedConstraintsLog.display());
+ if(DMLScript.FED_STATISTICS) {
+ sb.append("\n");
+ sb.append(FederatedStatistics.displayFedStatistics(DMLScript.FED_STATISTICS_COUNT));
+ }
+
return sb.toString();
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
index 09ca19e..54d89e6 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
@@ -30,14 +30,12 @@ import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
-import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
-@Ignore
public class FederatedStatisticsTest extends AutomatedTestBase {
private final static String TEST_DIR = "functions/federated/";
@@ -105,7 +103,6 @@ public class FederatedStatisticsTest extends AutomatedTestBase {
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
-
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {"-args", input("X1"), input("X2"), input("Y"), expected("Z")};
@@ -113,7 +110,7 @@ public class FederatedStatisticsTest extends AutomatedTestBase {
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-stats", "30", "-nvargs",
+ programArgs = new String[] {"-stats", "30", "-fedStats", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + cols,
"in_Y=" + input("Y"), "out=" + output("Z")};