You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2022/06/05 21:46:30 UTC
[systemds] branch main updated: [SYSTEMDS-3383] Extended federated monitoring tool (stats collection)
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new ffc2c36246 [SYSTEMDS-3383] Extended federated monitoring tool (stats collection)
ffc2c36246 is described below
commit ffc2c36246f16b8d2dc4d4a84145b3ce25084039
Author: Mito <mk...@arakt.com>
AuthorDate: Sun Jun 5 23:27:30 2022 +0200
[SYSTEMDS-3383] Extended federated monitoring tool (stats collection)
Closes #1624.
---
.../federated/FederatedStatistics.java | 104 ++++++++++-
.../federated/FederatedWorkerHandler.java | 17 ++
.../controllers/CoordinatorController.java | 37 +++-
.../monitoring/controllers/WorkerController.java | 28 ++-
.../monitoring/models/BaseEntityModel.java | 58 +------
.../{BaseEntityModel.java => NodeEntityModel.java} | 28 +--
.../monitoring/models/StatsEntityModel.java | 139 +++++++++++++++
.../{EntityEnum.java => Constants.java} | 15 +-
.../monitoring/repositories/DerbyRepository.java | 190 +++++++++++++++++----
.../monitoring/repositories/EntityEnum.java | 1 +
.../monitoring/repositories/IRepository.java | 7 +-
.../CoordinatorService.java} | 40 ++---
.../monitoring/services/MapperService.java | 92 ++++++++++
.../{WorkerService.java => StatsService.java} | 39 ++---
.../monitoring/services/WorkerService.java | 115 ++++++++-----
.../controlprogram/paramserv/NativeHEHelper.java | 186 ++++++++++----------
.../homomorphicEncryption/SEALClient.java | 100 +++++------
.../FederatedCoordinatorIntegrationCRUDTest.java | 97 +++++++++++
.../monitoring/FederatedMonitoringTestBase.java | 74 +++++++-
.../FederatedWorkerIntegrationCRUDTest.java | 38 ++++-
.../monitoring/FederatedWorkerStatisticsTest.java | 31 +++-
21 files changed, 1056 insertions(+), 380 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
index 5907776898..17b4012fec 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
@@ -20,8 +20,13 @@
package org.apache.sysds.runtime.controlprogram.federated;
import java.io.Serializable;
+import java.lang.management.ManagementFactory;
+import java.lang.management.MemoryMXBean;
+import java.lang.management.ThreadMXBean;
import java.net.InetSocketAddress;
import java.text.DecimalFormat;
+import java.time.LocalDateTime;
+import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
@@ -34,7 +39,9 @@ import java.util.concurrent.Future;
import java.util.concurrent.atomic.LongAdder;
import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.ImmutableTriple;
import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.lang3.tuple.Triple;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
@@ -89,6 +96,9 @@ public class FederatedStatistics {
private static final LongAdder fedPutLineageItems = new LongAdder();
private static final LongAdder fedSerializationReuseCount = new LongAdder();
private static final LongAdder fedSerializationReuseBytes = new LongAdder();
+ // Traffic between federated worker and a coordinator site
+ // in the form of [{ datetime, coordinatorAddress, transferredBytes }, { ... }] }
+ private static List<Triple<LocalDateTime, String, Long>> coordinatorsTrafficBytes = new ArrayList<>();
public static void logServerTraffic(long read, long written) {
bytesReceived.add(read);
@@ -131,13 +141,20 @@ public class FederatedStatistics {
}
private static void incFedTransfer(Object dataObj) {
+ incFedTransfer(dataObj, null);
+ }
+
+ public static void incFedTransfer(Object dataObj, String host) {
+ long byteAmount = 0;
if(dataObj instanceof MatrixBlock) {
transferredMatrixCount.increment();
- transferredMatrixBytes.add(((MatrixBlock)dataObj).getInMemorySize());
+ byteAmount = ((MatrixBlock)dataObj).getInMemorySize();
+ transferredMatrixBytes.add(byteAmount);
}
else if(dataObj instanceof FrameBlock) {
transferredFrameCount.increment();
- transferredFrameBytes.add(((FrameBlock)dataObj).getInMemorySize());
+ byteAmount = ((FrameBlock)dataObj).getInMemorySize();
+ transferredFrameBytes.add(byteAmount);
}
else if(dataObj instanceof ScalarObject)
transferredScalarCount.increment();
@@ -145,6 +162,10 @@ public class FederatedStatistics {
transferredListCount.increment();
else if(dataObj instanceof MatrixCharacteristics)
transferredMatCharCount.increment();
+
+ if (host != null && byteAmount > 0) {
+ coordinatorsTrafficBytes.add(new ImmutableTriple<>(LocalDateTime.now(), host, byteAmount));
+ }
}
public static void incAsyncPrefetchCount(long c) {
@@ -184,6 +205,8 @@ public class FederatedStatistics {
bytesReceived.reset();
fedBytesSent.reset();
fedBytesReceived.reset();
+ //TODO merge with existing
+ coordinatorsTrafficBytes.clear();
}
public static String displayFedIOExecStatistics() {
@@ -248,6 +271,9 @@ public class FederatedStatistics {
sb.append(displayFedReuseReadStats());
sb.append(displayFedPutLineageStats());
sb.append(displayFedSerializationReuseStats());
+ sb.append(displayFedTransfer());
+ sb.append(displayCPUUsage());
+ sb.append(displayMemoryUsage());
return sb.toString();
}
@@ -264,6 +290,9 @@ public class FederatedStatistics {
sb.append(displayGCStats(fedStats.gcStats));
sb.append(displayLinCacheStats(fedStats.linCacheStats));
sb.append(displayMultiTenantStats(fedStats.mtStats));
+ sb.append(displayCPUUsage());
+ sb.append(displayMemoryUsage());
+ sb.append(displayFedTransfer());
sb.append(displayHeavyHitters(fedStats.heavyHitters, numHeavyHitters));
sb.append(displayNetworkTrafficStatistics());
return sb.toString();
@@ -312,6 +341,38 @@ public class FederatedStatistics {
return displayHeavyHitters(heavyHitters, 10);
}
+ private static String displayFedTransfer() {
+ StringBuilder sb = new StringBuilder();
+ sb.append("Transferred bytes (Host/Datetime/ByteAmount):\n");
+
+ for (var entry: coordinatorsTrafficBytes) {
+ sb.append(String.format("%s/%s/%d.\n",
+ entry.getLeft().format(DateTimeFormatter.ISO_DATE_TIME), entry.getMiddle(), entry.getRight()));
+ }
+
+ return sb.toString();
+ }
+
+ private static String displayCPUUsage() {
+ StringBuilder sb = new StringBuilder();
+
+ double cpuUsage = getCPUUsage();
+
+ sb.append(String.format("CPU usage %%: %.2f\n", cpuUsage));
+
+ return sb.toString();
+ }
+
+ private static String displayMemoryUsage() {
+ StringBuilder sb = new StringBuilder();
+
+ double memoryUsage = getMemoryUsage();
+
+ sb.append(String.format("Memory usage %%: %.2f\n", memoryUsage));
+
+ return sb.toString();
+ }
+
private static String displayHeavyHitters(HashMap<String, Pair<Long, Double>> heavyHitters, int num) {
StringBuilder sb = new StringBuilder();
@SuppressWarnings("unchecked")
@@ -414,6 +475,32 @@ public class FederatedStatistics {
return fedLookupTableGetCount.longValue();
}
+ public static List<Triple<LocalDateTime, String, Long>> getCoordinatorsTrafficBytes() {
+ return coordinatorsTrafficBytes;
+ }
+
+ public static double getCPUUsage() {
+ ThreadMXBean threadMXBean = ManagementFactory.getThreadMXBean();
+ double cpuUsage = 0.0f;
+
+ for(Long threadID : threadMXBean.getAllThreadIds()) {
+ cpuUsage += threadMXBean.getThreadCpuTime(threadID);
+ }
+
+ cpuUsage /= 1000000000; // nanoseconds to seconds
+
+ return cpuUsage;
+ }
+
+ public static double getMemoryUsage() {
+ MemoryMXBean memoryMXBean = ManagementFactory.getMemoryMXBean();
+
+ double maxMemory = (double)memoryMXBean.getHeapMemoryUsage().getMax() / 1073741824;
+ double usedMemory = (double)memoryMXBean.getHeapMemoryUsage().getUsed() / 1073741824;
+
+ return (usedMemory / maxMemory) * 100;
+ }
+
public static long getFedLookupTableGetTime() {
return fedLookupTableGetTime.longValue();
}
@@ -563,15 +650,20 @@ public class FederatedStatistics {
private void collectStats() {
cacheStats.collectStats();
jitCompileTime = ((double)Statistics.getJITCompileTime()) / 1000; // in sec
+ cpuUsage = getCPUUsage();
+ memoryUsage = getMemoryUsage();
gcStats.collectStats();
linCacheStats.collectStats();
mtStats.collectStats();
heavyHitters = Statistics.getHeavyHittersHashMap();
+ coordinatorsTrafficBytes = getCoordinatorsTrafficBytes();
}
public void aggregate(FedStatsCollection that) {
cacheStats.aggregate(that.cacheStats);
jitCompileTime += that.jitCompileTime;
+ cpuUsage += that.cpuUsage;
+ memoryUsage += that.memoryUsage;
gcStats.aggregate(that.gcStats);
linCacheStats.aggregate(that.linCacheStats);
mtStats.aggregate(that.mtStats);
@@ -579,6 +671,7 @@ public class FederatedStatistics {
(key, value) -> heavyHitters.merge(key, value, (v1, v2) ->
new ImmutablePair<>(v1.getLeft() + v2.getLeft(), v1.getRight() + v2.getRight()))
);
+ that.coordinatorsTrafficBytes.addAll(coordinatorsTrafficBytes);
}
protected static class CacheStatsCollection implements Serializable {
@@ -725,10 +818,13 @@ public class FederatedStatistics {
}
private CacheStatsCollection cacheStats = new CacheStatsCollection();
- private double jitCompileTime = 0;
+ public double jitCompileTime = 0;
+ public double cpuUsage = 0;
+ public double memoryUsage = 0;
private GCStatsCollection gcStats = new GCStatsCollection();
private LineageCacheStatsCollection linCacheStats = new LineageCacheStatsCollection();
private MultiTenantStatsCollection mtStats = new MultiTenantStatsCollection();
- private HashMap<String, Pair<Long, Double>> heavyHitters = new HashMap<>();
+ public HashMap<String, Pair<Long, Double>> heavyHitters = new HashMap<>();
+ public List<Triple<LocalDateTime, String, Long>> coordinatorsTrafficBytes = new ArrayList<>();
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 47cedd739c..bfeb19cc16 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -98,6 +98,8 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
/** Federated workload analyzer */
private final FederatedWorkloadAnalyzer _fan;
+ private String _remoteAddress = FederatedLookupTable.NOHOST;
+
/**
* Create a Federated Worker Handler.
*
@@ -139,6 +141,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
}
String host;
+ _remoteAddress = remoteAddress.toString();
if(remoteAddress instanceof InetSocketAddress) {
host = ((InetSocketAddress) remoteAddress).getHostString();
}
@@ -216,6 +219,20 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
response = tmp; // return last
}
+
+ if(t == RequestType.PUT_VAR || t == RequestType.EXEC_UDF) {
+ for (int paramIndex = 0; paramIndex < request.getNumParams(); paramIndex++) {
+ FederatedStatistics.incFedTransfer(request.getParam(paramIndex), _remoteAddress);
+ }
+ }
+
+ if(t == RequestType.GET_VAR) {
+ var data = response.getData();
+ for (int dataObjIndex = 0; dataObjIndex < Arrays.stream(data).count(); dataObjIndex++) {
+ FederatedStatistics.incFedTransfer(data[dataObjIndex], _remoteAddress);
+ }
+ }
+
if(t == RequestType.CLEAR)
containsCLEAR = true;
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/CoordinatorController.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/CoordinatorController.java
index 8c81ffd24d..c6e4041542 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/CoordinatorController.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/CoordinatorController.java
@@ -22,30 +22,57 @@ package org.apache.sysds.runtime.controlprogram.federated.monitoring.controllers
import io.netty.handler.codec.http.FullHttpResponse;
import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.Request;
import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.Response;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.services.CoordinatorService;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.services.MapperService;
public class CoordinatorController implements IController {
+ private final CoordinatorService _coordinatorService = new CoordinatorService();
+
@Override
public FullHttpResponse create(Request request) {
- return null;
+
+ var model = MapperService.getModelFromBody(request);
+
+ _coordinatorService.create(model);
+
+ return Response.ok("Success");
}
@Override
public FullHttpResponse update(Request request, Long objectId) {
- return null;
+ var model = MapperService.getModelFromBody(request);
+
+ _coordinatorService.update(model);
+
+ return Response.ok("Success");
}
@Override
public FullHttpResponse delete(Request request, Long objectId) {
- return null;
+ _coordinatorService.remove(objectId);
+
+ return Response.ok("Success");
}
@Override
public FullHttpResponse get(Request request, Long objectId) {
- return Response.ok("Success");
+ var result = _coordinatorService.get(objectId);
+
+ if (result == null) {
+ return Response.notFound("No such coordinator can be found");
+ }
+
+ return Response.ok(result.toString());
}
@Override
public FullHttpResponse getAll(Request request) {
- return Response.ok("Success");
+ var coordinators = _coordinatorService.getAll();
+
+ if (coordinators.isEmpty()) {
+ return Response.notFound("No coordinators can be found");
+ }
+
+ return Response.ok(coordinators.toString());
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/WorkerController.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/WorkerController.java
index bdc46304f6..63f68a6e86 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/WorkerController.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/WorkerController.java
@@ -19,15 +19,12 @@
package org.apache.sysds.runtime.controlprogram.federated.monitoring.controllers;
-import com.fasterxml.jackson.databind.ObjectMapper;
import io.netty.handler.codec.http.FullHttpResponse;
import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.Request;
import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.Response;
-import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.BaseEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.services.MapperService;
import org.apache.sysds.runtime.controlprogram.federated.monitoring.services.WorkerService;
-import java.io.IOException;
-
public class WorkerController implements IController {
private final WorkerService _workerService = new WorkerService();
@@ -35,26 +32,27 @@ public class WorkerController implements IController {
@Override
public FullHttpResponse create(Request request) {
- ObjectMapper mapper = new ObjectMapper();
+ var model = MapperService.getModelFromBody(request);
- try {
- BaseEntityModel model = mapper.readValue(request.getBody(), BaseEntityModel.class);
- _workerService.create(model);
- return Response.ok("Success");
- }
- catch (IOException e) {
- throw new RuntimeException(e);
- }
+ _workerService.create(model);
+
+ return Response.ok("Success");
}
@Override
public FullHttpResponse update(Request request, Long objectId) {
- return null;
+ var model = MapperService.getModelFromBody(request);
+
+ _workerService.update(model);
+
+ return Response.ok("Success");
}
@Override
public FullHttpResponse delete(Request request, Long objectId) {
- return null;
+ _workerService.remove(objectId);
+
+ return Response.ok("Success");
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/BaseEntityModel.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/BaseEntityModel.java
index d42e76556f..41cf507696 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/BaseEntityModel.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/BaseEntityModel.java
@@ -19,60 +19,4 @@
package org.apache.sysds.runtime.controlprogram.federated.monitoring.models;
-public class BaseEntityModel {
- private Long _id;
- private String _name;
- private String _address;
-
- private String _data;
-
- public BaseEntityModel() { }
-
- public BaseEntityModel(final Long id, final String name, final String address) {
- _id = id;
- _name = name;
- _address = address;
- }
-
- public Long getId() {
- return _id;
- }
-
- public void setId(final Long id) {
- _id = id;
- }
-
- public String getName() {
- return _name;
- }
-
- public void setName(final String name) {
- _name = name;
- }
-
- public String getAddress() {
- return _address;
- }
-
- public void setAddress(final String address) {
- _address = address;
- }
-
- public String getData() {
- return _data;
- }
-
- public void setData(final String data) {
- _data = data;
- }
-
- @Override
- public String toString() {
- return String.format("{" +
- "\"id\": %d," +
- "\"name\": \"%s\"," +
- "\"address\": \"%s\"," +
- "\"data\": \"%s\"" +
- "}", _id, _name, _address, _data);
- }
-}
+public abstract class BaseEntityModel { }
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/BaseEntityModel.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/NodeEntityModel.java
similarity index 74%
copy from src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/BaseEntityModel.java
copy to src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/NodeEntityModel.java
index d42e76556f..725274509f 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/BaseEntityModel.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/NodeEntityModel.java
@@ -19,16 +19,18 @@
package org.apache.sysds.runtime.controlprogram.federated.monitoring.models;
-public class BaseEntityModel {
+import java.util.List;
+
+public class NodeEntityModel extends BaseEntityModel {
private Long _id;
private String _name;
private String _address;
- private String _data;
+ private List<BaseEntityModel> _stats;
- public BaseEntityModel() { }
+ public NodeEntityModel() { }
- public BaseEntityModel(final Long id, final String name, final String address) {
+ public NodeEntityModel(final Long id, final String name, final String address) {
_id = id;
_name = name;
_address = address;
@@ -58,21 +60,21 @@ public class BaseEntityModel {
_address = address;
}
- public String getData() {
- return _data;
+ public List<BaseEntityModel> getStats() {
+ return _stats;
}
- public void setData(final String data) {
- _data = data;
+ public void setStats(final List<BaseEntityModel> stats) {
+ _stats = stats;
}
@Override
public String toString() {
return String.format("{" +
- "\"id\": %d," +
- "\"name\": \"%s\"," +
- "\"address\": \"%s\"," +
- "\"data\": \"%s\"" +
- "}", _id, _name, _address, _data);
+ "\"id\": %d," +
+ "\"name\": \"%s\"," +
+ "\"address\": \"%s\"," +
+ "\"stats\": %s" +
+ "}", _id, _name, _address, _stats);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/StatsEntityModel.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/StatsEntityModel.java
new file mode 100644
index 0000000000..bfd9c9e840
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/StatsEntityModel.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.federated.monitoring.models;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.lang3.tuple.Triple;
+
+import java.time.LocalDateTime;
+import java.time.format.DateTimeFormatter;
+import java.util.List;
+import java.util.Map;
+
+public class StatsEntityModel extends BaseEntityModel {
+ private Long _workerId;
+ private double _cpuUsage;
+ private double _memoryUsage;
+ private Map<String, Pair<Long, Double>> _heavyHitterInstructionsObj;
+ private String _heavyHitterInstructions;
+ private List<Triple<LocalDateTime, String, Long>> _transferredBytesObj;
+ private String _transferredBytes;
+
+ public StatsEntityModel() { }
+
+ public StatsEntityModel(Long workerId, double cpuUsage, double memoryUsage,
+ Map<String, Pair<Long, Double>> heavyHitterInstructionsObj,
+ List<Triple<LocalDateTime, String, Long>> transferredBytesObj)
+ {
+ _workerId = workerId;
+ _cpuUsage = cpuUsage;
+ _memoryUsage = memoryUsage;
+ _heavyHitterInstructionsObj = heavyHitterInstructionsObj;
+ _transferredBytesObj = transferredBytesObj;
+ _heavyHitterInstructions = "";
+ _transferredBytes = "";
+ }
+
+ public Long getWorkerId() {
+ return _workerId;
+ }
+
+ public void setWorkerId(final Long workerId) {
+ _workerId = workerId;
+ }
+
+ public double getCPUUsage() {
+ return _cpuUsage;
+ }
+
+ public void setCPUUsage(final double cpuUsage) {
+ _cpuUsage = cpuUsage;
+ }
+
+ public double getMemoryUsage() {
+ return _memoryUsage;
+ }
+
+ public void setMemoryUsage(final double memoryUsage) {
+ _memoryUsage = memoryUsage;
+ }
+
+ public String getHeavyHitterInstructions() {
+ if (_heavyHitterInstructions.isEmpty() || _heavyHitterInstructions.isBlank()) {
+ StringBuilder sb = new StringBuilder();
+
+ sb.append("{");
+ for(Map.Entry<String, Pair<Long, Double>> entry : _heavyHitterInstructionsObj.entrySet()) {
+ String instruction = entry.getKey();
+ Long count = entry.getValue().getLeft();
+ double duration = entry.getValue().getRight();
+ sb.append(String.format("{" +
+ "\"instruction\": %s," +
+ "\"count\": \"%d\"," +
+ "\"duration\": \"%.2f\"," +
+ "},", instruction, count, duration));
+ }
+ sb.append("}");
+
+ _heavyHitterInstructions = sb.toString();
+ }
+
+ return _heavyHitterInstructions;
+ }
+
+ public void setHeavyHitterInstructions(final String heavyHitterInstructionsJsonString) {
+ _heavyHitterInstructions = heavyHitterInstructionsJsonString;
+ }
+
+ public String getTransferredBytes() {
+ if (_transferredBytes.isEmpty() || _transferredBytes.isBlank()) {
+ StringBuilder sb = new StringBuilder();
+
+ sb.append("{");
+ for (var entry: _transferredBytesObj) {
+ sb.append(String.format("{" +
+ "\"datetime\": %s," +
+ "\"coordinatorAddress\": \"%s\"," +
+ "\"byteAmount\": \"%d\"," +
+ "},", entry.getLeft().format(DateTimeFormatter.ISO_DATE_TIME),
+ entry.getMiddle(), entry.getRight()));
+ }
+ sb.append("}");
+
+ _transferredBytes = sb.toString();
+ }
+
+ return _transferredBytes;
+ }
+
+ public void setTransferredBytes(final String transferredBytesJsonString) {
+ _transferredBytes = transferredBytesJsonString;
+ }
+
+ @Override
+ public String toString() {
+ return String.format("{" +
+ "\"cpuUsage\": %.2f," +
+ "\"memoryUsage\": %.2f," +
+ "\"coordinatorTraffic\": %s," +
+ "\"heavyHitters\": %s" +
+ "}", _cpuUsage, _memoryUsage, getTransferredBytes(), getHeavyHitterInstructions());
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/EntityEnum.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/Constants.java
similarity index 56%
copy from src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/EntityEnum.java
copy to src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/Constants.java
index 7384257bf3..40ce052716 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/EntityEnum.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/Constants.java
@@ -19,7 +19,16 @@
package org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories;
-public enum EntityEnum {
- WORKER,
- COORDINATOR
+public class Constants {
+ public static final String WORKERS_TABLE_NAME= "workers";
+ public static final String COORDINATORS_TABLE_NAME= "coordinators";
+ public static final String STATS_TABLE_NAME= "statistics";
+ public static final String ENTITY_NAME_COL = "name";
+ public static final String ENTITY_ADDR_COL = "address";
+ public static final String ENTITY_CPU_COL = "cpuUsage";
+ public static final String ENTITY_MEM_COL = "memoryUsage";
+ public static final String ENTITY_TRAFFIC_COL = "coordinatorTraffic";
+ public static final String ENTITY_HEAVY_HITTERS_COL = "heavyHitters";
+ public static final String ENTITY_ID_COL = "id";
+ public static final String ENTITY_WORKER_ID_COL = "workerId";
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/DerbyRepository.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/DerbyRepository.java
index 9e94a41d61..02a948769f 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/DerbyRepository.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/DerbyRepository.java
@@ -19,32 +19,39 @@
package org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories;
+import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.BaseEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.NodeEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.StatsEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.services.MapperService;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
-import java.sql.Types;
import java.util.ArrayList;
import java.util.List;
public class DerbyRepository implements IRepository {
private final static String DB_CONNECTION = "jdbc:derby:memory:derbyDB";
private final Connection _db;
-
- private static final String WORKERS_TABLE_NAME= "workers";
- private static final String ENTITY_NAME_COL = "name";
- private static final String ENTITY_ADDR_COL = "address";
-
private static final String ENTITY_SCHEMA_CREATE_STMT = "CREATE TABLE %s " +
"(id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), " +
"%s VARCHAR(60), " +
"%s VARCHAR(120))";
+ private static final String ENTITY_SCHEMA_CREATE_STATS_STMT = "CREATE TABLE %s " +
+ "(id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), " +
+ "%s INTEGER, " +
+ "%s DOUBLE, " +
+ "%s DOUBLE," +
+ "%s VARCHAR(1000)," +
+ "%s VARCHAR(1000))";
private static final String ENTITY_INSERT_STMT = "INSERT INTO %s (%s, %s) VALUES (?, ?)";
-
- private static final String GET_ENTITY_WITH_ID_STMT = "SELECT * FROM %s WHERE id = ?";
+ private static final String ENTITY_STATS_INSERT_STMT = "INSERT INTO %s (%s, %s, %s, %s, %s) VALUES (?, ?, ?, ?, ?)";
+ private static final String GET_ENTITY_WITH_COL_STMT = "SELECT * FROM %s WHERE %s = ?";
+ private static final String DELETE_ENTITY_WITH_COL_STMT = "DELETE FROM %s WHERE %s = ?";
+ private static final String UPDATE_ENTITY_WITH_COL_STMT = "UPDATE %s SET %s = ?, %s = ? WHERE %s = ?";
private static final String GET_ALL_ENTITIES_STMT = "SELECT * FROM %s";
public DerbyRepository() {
@@ -68,15 +75,33 @@ public class DerbyRepository implements IRepository {
private void createMonitoringEntitiesInDB(Connection db) {
try {
var dbMetaData = db.getMetaData();
- var workersExist = dbMetaData.getTables(null, null, WORKERS_TABLE_NAME.toUpperCase(),null);
+ var workersExist = dbMetaData.getTables(null, null, Constants.WORKERS_TABLE_NAME.toUpperCase(),null);
+ var statsExist = dbMetaData.getTables(null, null, Constants.STATS_TABLE_NAME.toUpperCase(),null);
+ var coordinatorsExist = dbMetaData.getTables(null, null, Constants.COORDINATORS_TABLE_NAME.toUpperCase(),null);
// Check if table already exists and create if not
- if(!workersExist.next())
- {
+ if(!workersExist.next()) {
+ PreparedStatement st = db.prepareStatement(
+ String.format(ENTITY_SCHEMA_CREATE_STMT, Constants.WORKERS_TABLE_NAME, Constants.ENTITY_NAME_COL, Constants.ENTITY_ADDR_COL));
+ st.executeUpdate();
+ }
+
+ if(!statsExist.next()) {
PreparedStatement st = db.prepareStatement(
- String.format(ENTITY_SCHEMA_CREATE_STMT, WORKERS_TABLE_NAME, ENTITY_NAME_COL, ENTITY_ADDR_COL));
+ String.format(ENTITY_SCHEMA_CREATE_STATS_STMT, Constants.STATS_TABLE_NAME,
+ Constants.ENTITY_WORKER_ID_COL,
+ Constants.ENTITY_CPU_COL,
+ Constants.ENTITY_MEM_COL,
+ Constants.ENTITY_TRAFFIC_COL,
+ Constants.ENTITY_HEAVY_HITTERS_COL));
st.executeUpdate();
+ }
+ if(!coordinatorsExist.next()) {
+ PreparedStatement st = db.prepareStatement(String.format(
+ ENTITY_SCHEMA_CREATE_STMT, Constants.COORDINATORS_TABLE_NAME,
+ Constants.ENTITY_NAME_COL, Constants.ENTITY_ADDR_COL));
+ st.executeUpdate();
}
}
catch (SQLException e) {
@@ -84,23 +109,56 @@ public class DerbyRepository implements IRepository {
}
}
- public void createEntity(EntityEnum type, BaseEntityModel model) {
+ public Long createEntity(EntityEnum type, BaseEntityModel model) {
+
+ PreparedStatement st = null;
+ long id = -1L;
try {
- PreparedStatement st = _db.prepareStatement(
- String.format(ENTITY_INSERT_STMT, WORKERS_TABLE_NAME, ENTITY_NAME_COL, ENTITY_ADDR_COL));
+ if (type == EntityEnum.WORKER_STATS) {
+ st = _db.prepareStatement(
+ String.format(ENTITY_STATS_INSERT_STMT, Constants.STATS_TABLE_NAME,
+ Constants.ENTITY_WORKER_ID_COL,
+ Constants.ENTITY_CPU_COL,
+ Constants.ENTITY_MEM_COL,
+ Constants.ENTITY_TRAFFIC_COL,
+ Constants.ENTITY_HEAVY_HITTERS_COL), PreparedStatement.RETURN_GENERATED_KEYS);
+
+ StatsEntityModel newModel = (StatsEntityModel) model;
+
+ st.setLong(1, newModel.getWorkerId());
+ st.setDouble(2, newModel.getCPUUsage());
+ st.setDouble(3, newModel.getMemoryUsage());
+ st.setString(4, newModel.getTransferredBytes());
+ st.setString(5, newModel.getHeavyHitterInstructions());
+ } else {
+ st = _db.prepareStatement(
+ String.format(ENTITY_INSERT_STMT, Constants.WORKERS_TABLE_NAME, Constants.ENTITY_NAME_COL, Constants.ENTITY_ADDR_COL),
+ PreparedStatement.RETURN_GENERATED_KEYS);
+ NodeEntityModel newModel = (NodeEntityModel) model;
+
+ if (type == EntityEnum.COORDINATOR) {
+ st = _db.prepareStatement(
+ String.format(ENTITY_INSERT_STMT, Constants.COORDINATORS_TABLE_NAME, Constants.ENTITY_NAME_COL, Constants.ENTITY_ADDR_COL),
+ PreparedStatement.RETURN_GENERATED_KEYS);
+ }
- if (type == EntityEnum.COORDINATOR) {
- // Change statement
+ st.setString(1, newModel.getName());
+ st.setString(2, newModel.getAddress());
}
- st.setString(1, model.getName());
- st.setString(2, model.getAddress());
st.executeUpdate();
+ ResultSet rs = st.getGeneratedKeys();
+ if (rs.next()) {
+ id = rs.getLong(1); // this is the auto-generated id key
+ }
+
} catch (SQLException e) {
throw new RuntimeException(e);
}
+
+ return id;
}
public BaseEntityModel getEntity(EntityEnum type, Long id) {
@@ -108,17 +166,21 @@ public class DerbyRepository implements IRepository {
try {
PreparedStatement st = _db.prepareStatement(
- String.format(GET_ENTITY_WITH_ID_STMT, WORKERS_TABLE_NAME));
+ String.format(GET_ENTITY_WITH_COL_STMT, Constants.WORKERS_TABLE_NAME, Constants.ENTITY_ID_COL));
if (type == EntityEnum.COORDINATOR) {
- // Change statement
+ st = _db.prepareStatement(
+ String.format(GET_ENTITY_WITH_COL_STMT, Constants.COORDINATORS_TABLE_NAME, Constants.ENTITY_ID_COL));
+ } else if (type == EntityEnum.WORKER_STATS) {
+ st = _db.prepareStatement(
+ String.format(GET_ENTITY_WITH_COL_STMT, Constants.STATS_TABLE_NAME, Constants.ENTITY_WORKER_ID_COL));
}
st.setLong(1, id);
var resultSet = st.executeQuery();
if (resultSet.next()){
- resultModel = mapEntityToModel(resultSet);
+ resultModel = MapperService.mapEntityToModel(resultSet, type);
}
} catch (SQLException e) {
throw new RuntimeException(e);
@@ -132,16 +194,40 @@ public class DerbyRepository implements IRepository {
try {
PreparedStatement st = _db.prepareStatement(
- String.format(GET_ALL_ENTITIES_STMT, WORKERS_TABLE_NAME));
+ String.format(GET_ALL_ENTITIES_STMT, Constants.WORKERS_TABLE_NAME));
if (type == EntityEnum.COORDINATOR) {
- // Change statement
+ st = _db.prepareStatement(
+ String.format(GET_ALL_ENTITIES_STMT, Constants.COORDINATORS_TABLE_NAME));
}
var resultSet = st.executeQuery();
+ while (resultSet.next()){
+ resultModels.add(MapperService.mapEntityToModel(resultSet, type));
+ }
+ } catch (SQLException e) {
+ throw new RuntimeException(e);
+ }
+
+ return resultModels;
+ }
+
+ public List<BaseEntityModel> getAllEntitiesByField(EntityEnum type, Object fieldValue) {
+ List<BaseEntityModel> resultModels = new ArrayList<>();
+ PreparedStatement st = null;
+
+ try {
+ if (type == EntityEnum.WORKER_STATS) {
+ st = _db.prepareStatement(
+ String.format(GET_ENTITY_WITH_COL_STMT, Constants.STATS_TABLE_NAME, Constants.ENTITY_WORKER_ID_COL));
+ st.setLong(1, (Long) fieldValue);
+ } else {
+ throw new NotImplementedException();
+ }
+ var resultSet = st.executeQuery();
while (resultSet.next()){
- resultModels.add(mapEntityToModel(resultSet));
+ resultModels.add(MapperService.mapEntityToModel(resultSet, type));
}
} catch (SQLException e) {
throw new RuntimeException(e);
@@ -150,22 +236,52 @@ public class DerbyRepository implements IRepository {
return resultModels;
}
- private BaseEntityModel mapEntityToModel(ResultSet resultSet) throws SQLException {
- BaseEntityModel tmpModel = new BaseEntityModel();
+ @Override
+ public void updateEntity(EntityEnum type, BaseEntityModel model) {
+
+ try {
+ PreparedStatement st = _db.prepareStatement(
+ String.format(UPDATE_ENTITY_WITH_COL_STMT, Constants.WORKERS_TABLE_NAME,
+ Constants.ENTITY_NAME_COL,
+ Constants.ENTITY_ADDR_COL,
+ Constants.ENTITY_ID_COL));
+ NodeEntityModel editModel = (NodeEntityModel) model;
- for (int column = 1; column <= resultSet.getMetaData().getColumnCount(); column++) {
- if (resultSet.getMetaData().getColumnType(column) == Types.INTEGER) {
- tmpModel.setId(resultSet.getLong(column));
+ if (type == EntityEnum.COORDINATOR) {
+ st = _db.prepareStatement(
+ String.format(UPDATE_ENTITY_WITH_COL_STMT, Constants.COORDINATORS_TABLE_NAME,
+ Constants.ENTITY_NAME_COL,
+ Constants.ENTITY_ADDR_COL,
+ Constants.ENTITY_ID_COL));
}
- if (resultSet.getMetaData().getColumnType(column) == Types.VARCHAR) {
- if (resultSet.getMetaData().getColumnName(column).equalsIgnoreCase(ENTITY_NAME_COL)) {
- tmpModel.setName(resultSet.getString(column));
- } else if (resultSet.getMetaData().getColumnName(column).equalsIgnoreCase(ENTITY_ADDR_COL)) {
- tmpModel.setAddress(resultSet.getString(column));
- }
+ st.setString(1, editModel.getName());
+ st.setString(2, editModel.getAddress());
+ st.setLong(3, editModel.getId());
+
+ st.executeUpdate();
+
+ } catch (SQLException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public void removeEntity(EntityEnum type, Long id) {
+ PreparedStatement st = null;
+ try {
+ if (type == EntityEnum.WORKER) {
+ st = _db.prepareStatement(
+ String.format(DELETE_ENTITY_WITH_COL_STMT, Constants.WORKERS_TABLE_NAME, Constants.ENTITY_ID_COL));
+ st.setLong(1, id);
+ } else {
+ st = _db.prepareStatement(
+ String.format(DELETE_ENTITY_WITH_COL_STMT, Constants.COORDINATORS_TABLE_NAME, Constants.ENTITY_ID_COL));
+ st.setLong(1, id);
}
+ st.executeUpdate();
+ } catch (SQLException e) {
+ throw new RuntimeException(e);
}
- return tmpModel;
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/EntityEnum.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/EntityEnum.java
index 7384257bf3..18b17ea7fc 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/EntityEnum.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/EntityEnum.java
@@ -21,5 +21,6 @@ package org.apache.sysds.runtime.controlprogram.federated.monitoring.repositorie
public enum EntityEnum {
WORKER,
+ WORKER_STATS,
COORDINATOR
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/IRepository.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/IRepository.java
index d441693e59..dd683080e2 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/IRepository.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/IRepository.java
@@ -25,9 +25,14 @@ import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.BaseE
import java.util.List;
public interface IRepository {
- void createEntity(EntityEnum type, BaseEntityModel model);
+ Long createEntity(EntityEnum type, BaseEntityModel model);
BaseEntityModel getEntity(EntityEnum type, Long id);
List<BaseEntityModel> getAllEntities(EntityEnum type);
+
+ List<BaseEntityModel> getAllEntitiesByField(EntityEnum type, Object fieldValue);
+ void updateEntity(EntityEnum type, BaseEntityModel model);
+
+ void removeEntity(EntityEnum type, Long id);
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/CoordinatorController.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/CoordinatorService.java
similarity index 52%
copy from src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/CoordinatorController.java
copy to src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/CoordinatorService.java
index 8c81ffd24d..91137acaa2 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/CoordinatorController.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/CoordinatorService.java
@@ -17,35 +17,35 @@
* under the License.
*/
-package org.apache.sysds.runtime.controlprogram.federated.monitoring.controllers;
+package org.apache.sysds.runtime.controlprogram.federated.monitoring.services;
-import io.netty.handler.codec.http.FullHttpResponse;
-import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.Request;
-import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.Response;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.BaseEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.DerbyRepository;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.EntityEnum;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.IRepository;
-public class CoordinatorController implements IController {
- @Override
- public FullHttpResponse create(Request request) {
- return null;
+import java.util.List;
+
+public class CoordinatorService {
+ private static final IRepository _entityRepository = new DerbyRepository();
+
+ public void create(BaseEntityModel model) {
+ _entityRepository.createEntity(EntityEnum.COORDINATOR, model);
}
- @Override
- public FullHttpResponse update(Request request, Long objectId) {
- return null;
+ public void update(BaseEntityModel model) {
+ _entityRepository.updateEntity(EntityEnum.COORDINATOR, model);
}
- @Override
- public FullHttpResponse delete(Request request, Long objectId) {
- return null;
+ public void remove(Long id) {
+ _entityRepository.removeEntity(EntityEnum.COORDINATOR, id);
}
- @Override
- public FullHttpResponse get(Request request, Long objectId) {
- return Response.ok("Success");
+ public BaseEntityModel get(Long id) {
+ return _entityRepository.getEntity(EntityEnum.COORDINATOR, id);
}
- @Override
- public FullHttpResponse getAll(Request request) {
- return Response.ok("Success");
+ public List<BaseEntityModel> getAll() {
+ return _entityRepository.getAllEntities(EntityEnum.COORDINATOR);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/MapperService.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/MapperService.java
new file mode 100644
index 0000000000..cd0efcc732
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/MapperService.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.federated.monitoring.services;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.BaseEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.NodeEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.Request;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.StatsEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.Constants;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.EntityEnum;
+
+import java.io.IOException;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Types;
+
+public class MapperService {
+ public static BaseEntityModel getModelFromBody(Request request) {
+ ObjectMapper mapper = new ObjectMapper();
+
+ try {
+ return mapper.readValue(request.getBody(), NodeEntityModel.class);
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public static BaseEntityModel mapEntityToModel(ResultSet resultSet, EntityEnum targetModel) {
+ try {
+ if (targetModel != EntityEnum.WORKER_STATS) {
+ NodeEntityModel tmpModel = new NodeEntityModel();
+
+ for (int column = 1; column <= resultSet.getMetaData().getColumnCount(); column++) {
+ if (resultSet.getMetaData().getColumnType(column) == Types.INTEGER) {
+ tmpModel.setId(resultSet.getLong(column));
+ }
+
+ if (resultSet.getMetaData().getColumnType(column) == Types.VARCHAR) {
+ if (resultSet.getMetaData().getColumnName(column).equalsIgnoreCase(Constants.ENTITY_NAME_COL)) {
+ tmpModel.setName(resultSet.getString(column));
+ } else if (resultSet.getMetaData().getColumnName(column).equalsIgnoreCase(Constants.ENTITY_ADDR_COL)) {
+ tmpModel.setAddress(resultSet.getString(column));
+ }
+ }
+ }
+ return tmpModel;
+ } else {
+ StatsEntityModel tmpModel = new StatsEntityModel();
+
+ for (int column = 1; column <= resultSet.getMetaData().getColumnCount(); column++) {
+
+ if (resultSet.getMetaData().getColumnType(column) == Types.VARCHAR) {
+ if (resultSet.getMetaData().getColumnName(column).equalsIgnoreCase(Constants.ENTITY_TRAFFIC_COL)) {
+ tmpModel.setTransferredBytes(resultSet.getString(column));
+ } else if (resultSet.getMetaData().getColumnName(column).equalsIgnoreCase(Constants.ENTITY_HEAVY_HITTERS_COL)) {
+ tmpModel.setHeavyHitterInstructions(resultSet.getString(column));
+ }
+ } else {
+ if (resultSet.getMetaData().getColumnName(column).equalsIgnoreCase(Constants.ENTITY_CPU_COL)) {
+ tmpModel.setCPUUsage(resultSet.getDouble(column));
+ } else if (resultSet.getMetaData().getColumnName(column).equalsIgnoreCase(Constants.ENTITY_MEM_COL)) {
+ tmpModel.setMemoryUsage(resultSet.getDouble(column));
+ }
+ }
+ }
+
+ return tmpModel;
+ }
+ } catch (SQLException e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/WorkerService.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/StatsService.java
similarity index 68%
copy from src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/WorkerService.java
copy to src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/StatsService.java
index 98fdfed672..565f1b2712 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/WorkerService.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/StatsService.java
@@ -25,26 +25,17 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.BaseEntityModel;
-import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.DerbyRepository;
-import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.EntityEnum;
-import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.IRepository;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.StatsEntityModel;
import java.net.InetSocketAddress;
-import java.util.List;
import java.util.concurrent.Future;
-public class WorkerService {
- private static final IRepository _entityRepository = new DerbyRepository();
-
- public void create(BaseEntityModel model) {
- _entityRepository.createEntity(EntityEnum.WORKER, model);
- }
-
- public BaseEntityModel get(Long id) {
- var model = _entityRepository.getEntity(EntityEnum.WORKER, id);
+public class StatsService {
+ public static BaseEntityModel getWorkerStatistics(Long id, String address) {
+ StatsEntityModel parsedStats = null;
try {
- var statisticsResponse = getWorkerStatistics(model.getAddress()).get();
+ var statisticsResponse = sendStatisticsRequest(address).get();
if (statisticsResponse.isSuccessful()) {
FederatedStatistics.FedStatsCollection aggFedStats = new FederatedStatistics.FedStatsCollection();
@@ -53,33 +44,31 @@ public class WorkerService {
if(tmp[0] instanceof FederatedStatistics.FedStatsCollection)
aggFedStats.aggregate((FederatedStatistics.FedStatsCollection)tmp[0]);
- var statsStr = FederatedStatistics.displayStatistics(aggFedStats, 5);
- model.setData(statsStr);
+ parsedStats = new StatsEntityModel(
+ id, aggFedStats.cpuUsage, aggFedStats.memoryUsage,
+ aggFedStats.heavyHitters, aggFedStats.coordinatorsTrafficBytes);
}
+ } catch(DMLRuntimeException dre) {
+ // silently ignore -> caused by offline federated workers
} catch (Exception e) {
throw new RuntimeException(e);
}
- return model;
+ return parsedStats;
}
- public List<BaseEntityModel> getAll() {
- return _entityRepository.getAllEntities(EntityEnum.WORKER);
- }
-
- private Future<FederatedResponse> getWorkerStatistics(String address) {
+ private static Future<FederatedResponse> sendStatisticsRequest(String address) {
Future<FederatedResponse> result = null;
-
String host = address.split(":")[0];
int port = Integer.parseInt(address.split(":")[1]);
InetSocketAddress isa = new InetSocketAddress(host, port);
FederatedRequest frUDF = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
- new FederatedStatistics.FedStatsCollectFunction());
+ new FederatedStatistics.FedStatsCollectFunction());
try {
result = FederatedData.executeFederatedOperation(isa, frUDF);
} catch(DMLRuntimeException dre) {
- // silently ignore this exception --> caused by offline federated workers
+ throw dre; // caused by offline federated workers
} catch (Exception e) {
throw new RuntimeException(e);
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/WorkerService.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/WorkerService.java
index 98fdfed672..82845177c3 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/WorkerService.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/WorkerService.java
@@ -19,71 +19,106 @@
package org.apache.sysds.runtime.controlprogram.federated.monitoring.services;
-import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.BaseEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.NodeEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.StatsEntityModel;
import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.DerbyRepository;
import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.EntityEnum;
import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.IRepository;
-import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
-import java.util.concurrent.Future;
+import java.util.Map;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
public class WorkerService {
private static final IRepository _entityRepository = new DerbyRepository();
+ private static final Map<Long, String> _cachedWorkers = new HashMap<>();
+
+ public WorkerService() {
+ updateCachedWorkers(null);
+
+ ScheduledExecutorService executor = Executors.newScheduledThreadPool(1);
+ executor.scheduleAtFixedRate(syncWorkerStatisticsWithDB(), 0, 3, TimeUnit.SECONDS);
+ }
public void create(BaseEntityModel model) {
- _entityRepository.createEntity(EntityEnum.WORKER, model);
+ long id = _entityRepository.createEntity(EntityEnum.WORKER, model);
+
+ var modelEntity = (NodeEntityModel) model;
+
+ _cachedWorkers.putIfAbsent(id, modelEntity.getAddress());
}
- public BaseEntityModel get(Long id) {
- var model = _entityRepository.getEntity(EntityEnum.WORKER, id);
+ public void update(BaseEntityModel model) {
+ _entityRepository.updateEntity(EntityEnum.WORKER, model);
+ }
- try {
- var statisticsResponse = getWorkerStatistics(model.getAddress()).get();
+ public void remove(Long id) {
+ _entityRepository.removeEntity(EntityEnum.WORKER, id);
- if (statisticsResponse.isSuccessful()) {
- FederatedStatistics.FedStatsCollection aggFedStats = new FederatedStatistics.FedStatsCollection();
+ _cachedWorkers.remove(id);
+ }
- Object[] tmp = statisticsResponse.getData();
- if(tmp[0] instanceof FederatedStatistics.FedStatsCollection)
- aggFedStats.aggregate((FederatedStatistics.FedStatsCollection)tmp[0]);
+ public BaseEntityModel get(Long id) {
+ var model = (NodeEntityModel) _entityRepository.getEntity(EntityEnum.WORKER, id);
+ var stats = (List<BaseEntityModel>) _entityRepository.getAllEntitiesByField(EntityEnum.WORKER_STATS, id);
- var statsStr = FederatedStatistics.displayStatistics(aggFedStats, 5);
- model.setData(statsStr);
- }
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
+ updateCachedWorkers(null);
+
+ model.setStats(stats);
return model;
}
public List<BaseEntityModel> getAll() {
- return _entityRepository.getAllEntities(EntityEnum.WORKER);
+ var workersRaw = _entityRepository.getAllEntities(EntityEnum.WORKER);
+ var workersResult = new ArrayList<BaseEntityModel>();
+
+ updateCachedWorkers(workersRaw);
+
+ for (var worker: workersRaw) {
+ var workerModel = (NodeEntityModel) worker;
+ var stats = (List<BaseEntityModel>) _entityRepository.getAllEntitiesByField(EntityEnum.WORKER_STATS, workerModel.getId());
+
+ workerModel.setStats(stats);
+
+ workersResult.add(workerModel);
+ }
+
+ return workersResult;
}
- private Future<FederatedResponse> getWorkerStatistics(String address) {
- Future<FederatedResponse> result = null;
-
- String host = address.split(":")[0];
- int port = Integer.parseInt(address.split(":")[1]);
-
- InetSocketAddress isa = new InetSocketAddress(host, port);
- FederatedRequest frUDF = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
- new FederatedStatistics.FedStatsCollectFunction());
- try {
- result = FederatedData.executeFederatedOperation(isa, frUDF);
- } catch(DMLRuntimeException dre) {
- // silently ignore this exception --> caused by offline federated workers
- } catch (Exception e) {
- throw new RuntimeException(e);
+ private void updateCachedWorkers(List<BaseEntityModel> workersRaw) {
+ List<BaseEntityModel> workersBaseModel = workersRaw;
+
+ if (workersBaseModel == null) {
+ workersBaseModel = getAll();
}
- return result;
+ for(var workerBaseModel : workersBaseModel) {
+ var worker = (NodeEntityModel) workerBaseModel;
+
+ _cachedWorkers.putIfAbsent(worker.getId(), worker.getAddress());
+ }
+ }
+
+ private static Runnable syncWorkerStatisticsWithDB() {
+ return () -> {
+
+ for(Map.Entry<Long, String> entry : _cachedWorkers.entrySet()) {
+ Long id = entry.getKey();
+ String address = entry.getValue();
+
+ var stats = (StatsEntityModel) StatsService.getWorkerStatistics(id, address);
+
+ if (stats != null) {
+ _entityRepository.createEntity(EntityEnum.WORKER_STATS, stats);
+ }
+ }
+ };
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NativeHEHelper.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NativeHEHelper.java
index 38e4dec553..b2874fa908 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NativeHEHelper.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/NativeHEHelper.java
@@ -23,97 +23,97 @@ import org.apache.commons.lang.SystemUtils;
import org.apache.sysds.utils.NativeHelper;
public class NativeHEHelper {
- public static boolean initialize() {
- String platform_suffix = (SystemUtils.IS_OS_WINDOWS ? "-Windows-AMD64.dll" : "-Linux-x86_64.so");
- String library_name = "libhe" + platform_suffix;
- return NativeHelper.loadLibraryHelperFromResource(library_name);
- }
-
- // ----------------------------------------------------------------------------------------------------------------
- // SEAL integration
- // ----------------------------------------------------------------------------------------------------------------
-
- // these are called by SEALClient
-
- /**
- * generates a Client object
- * @param a a constant generated by generateA
- * @return a pointer to the native client object
- */
- public static native long initClient(byte[] a);
-
- /**
- * generates a partial public key
- * stores a partial private key corresponding to the partial public key in client
- * @param client A pointer to a Client, obtained from initClient
- * @return a serialized partial public key
- */
- public static native byte[] generatePartialPublicKey(long client);
-
- /**
- * sets the public key and stores it in client
- * @param client A pointer to a Client, obtained from initClient
- * @param public_key serialized public key obtained from generatePartialPublicKey
- */
- public static native void setPublicKey(long client, byte[] public_key);
-
- /**
- * encrypts data with public key stored in client
- * setPublicKey() must have been called before calling this
- * @param client A pointer to a Client, obtained from initClient
- * @param plaintexts array of double values to be encrypted
- * @return serialized ciphertext
- */
- public static native byte[] encrypt(long client, double[] plaintexts);
-
- /**
- * partially decrypts ciphertexts with the partial private key. generatePartialPublicKey() must
- * have been called before calling this function
- * @param client A pointer to a Client, obtained from initClient
- * @param ciphertext serialized ciphertext
- * @return serialized partial decryption
- */
- public static native byte[] partiallyDecrypt(long client, byte[] ciphertext);
-
- // ----------------------------------------------------------------------------------------------------------------
-
- // these are called by SEALServer
-
- /**
- * generates the Server object and returns a pointer to it
- * @return pointer to a native Server object
- */
- public static native long initServer();
-
- /**
- * this generates the a constant. in a future version we want to generate this together with the clients to prevent misuse
- * @param server A pointer to a Server, obtained from initServer
- * @return serialized a constant
- */
- public static native byte[] generateA(long server);
-
- /**
- * accumulates the given partial public keys into a public key, stores it in server and returns it
- * @param server A pointer to a Server, obtained from initServer
- * @param partial_public_keys array of serialized partial public keys
- * @return serialized partial public key
- */
- public static native byte[] aggregatePartialPublicKeys(long server, byte[][] partial_public_keys);
-
- /**
- * accumulates the given ciphertexts into a sum ciphertext and returns it
- * @param server A pointer to a Server, obtained from initServer
- * @param ciphertexts array of serialized ciphertexts
- * @return serialized accumulated ciphertext
- */
- public static native byte[] accumulateCiphertexts(long server, byte[][] ciphertexts);
-
- /**
- * averages the partial decryptions and returns the result
- * @param server A pointer to a Server, obtained from initServer
- * @param encrypted_sum the result of accumulateCiphertexts()
- * @param partial_plaintexts the result of partiallyDecrypt of each ciphertext fed into accumulateCiphertexts
- * @return average of original data
- */
- public static native double[] average(long server, byte[] encrypted_sum, byte[][] partial_plaintexts);
+ public static boolean initialize() {
+ String platform_suffix = (SystemUtils.IS_OS_WINDOWS ? "-Windows-AMD64.dll" : "-Linux-x86_64.so");
+ String library_name = "libhe" + platform_suffix;
+ return NativeHelper.loadLibraryHelperFromResource(library_name);
+ }
+
+ // ----------------------------------------------------------------------------------------------------------------
+ // SEAL integration
+ // ----------------------------------------------------------------------------------------------------------------
+
+ // these are called by SEALClient
+
+ /**
+ * generates a Client object
+ * @param a a constant generated by generateA
+ * @return a pointer to the native client object
+ */
+ public static native long initClient(byte[] a);
+
+ /**
+ * generates a partial public key
+ * stores a partial private key corresponding to the partial public key in client
+ * @param client A pointer to a Client, obtained from initClient
+ * @return a serialized partial public key
+ */
+ public static native byte[] generatePartialPublicKey(long client);
+
+ /**
+ * sets the public key and stores it in client
+ * @param client A pointer to a Client, obtained from initClient
+ * @param public_key serialized public key obtained from generatePartialPublicKey
+ */
+ public static native void setPublicKey(long client, byte[] public_key);
+
+ /**
+ * encrypts data with public key stored in client
+ * setPublicKey() must have been called before calling this
+ * @param client A pointer to a Client, obtained from initClient
+ * @param plaintexts array of double values to be encrypted
+ * @return serialized ciphertext
+ */
+ public static native byte[] encrypt(long client, double[] plaintexts);
+
+ /**
+ * partially decrypts ciphertexts with the partial private key. generatePartialPublicKey() must
+ * have been called before calling this function
+ * @param client A pointer to a Client, obtained from initClient
+ * @param ciphertext serialized ciphertext
+ * @return serialized partial decryption
+ */
+ public static native byte[] partiallyDecrypt(long client, byte[] ciphertext);
+
+ // ----------------------------------------------------------------------------------------------------------------
+
+ // these are called by SEALServer
+
+ /**
+ * generates the Server object and returns a pointer to it
+ * @return pointer to a native Server object
+ */
+ public static native long initServer();
+
+ /**
+ * this generates the a constant. in a future version we want to generate this together with the clients to prevent misuse
+ * @param server A pointer to a Server, obtained from initServer
+ * @return serialized a constant
+ */
+ public static native byte[] generateA(long server);
+
+ /**
+ * accumulates the given partial public keys into a public key, stores it in server and returns it
+ * @param server A pointer to a Server, obtained from initServer
+ * @param partial_public_keys array of serialized partial public keys
+ * @return serialized partial public key
+ */
+ public static native byte[] aggregatePartialPublicKeys(long server, byte[][] partial_public_keys);
+
+ /**
+ * accumulates the given ciphertexts into a sum ciphertext and returns it
+ * @param server A pointer to a Server, obtained from initServer
+ * @param ciphertexts array of serialized ciphertexts
+ * @return serialized accumulated ciphertext
+ */
+ public static native byte[] accumulateCiphertexts(long server, byte[][] ciphertexts);
+
+ /**
+ * averages the partial decryptions and returns the result
+ * @param server A pointer to a Server, obtained from initServer
+ * @param encrypted_sum the result of accumulateCiphertexts()
+ * @param partial_plaintexts the result of partiallyDecrypt of each ciphertext fed into accumulateCiphertexts
+ * @return average of original data
+ */
+ public static native double[] average(long server, byte[] encrypted_sum, byte[][] partial_plaintexts);
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALClient.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALClient.java
index 935f2808af..6be5f8c246 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALClient.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/homomorphicEncryption/SEALClient.java
@@ -29,60 +29,60 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import java.util.stream.IntStream;
public class SEALClient {
- public SEALClient(byte[] a) {
- // TODO take params here, like slot_count etc.
- ctx = NativeHEHelper.initClient(a);
- }
+ public SEALClient(byte[] a) {
+ // TODO take params here, like slot_count etc.
+ ctx = NativeHEHelper.initClient(a);
+ }
- // this is a pointer to the context used by all native methods of this class
- private final long ctx;
+ // this is a pointer to the context used by all native methods of this class
+ private final long ctx;
- /**
- * generates a partial public key
- * stores a partial private key corresponding to the partial public key in ctx
- *
- * @return the partial public key
- */
- public PublicKey generatePartialPublicKey() {
- return new PublicKey(NativeHEHelper.generatePartialPublicKey(ctx));
- }
+ /**
+ * generates a partial public key
+ * stores a partial private key corresponding to the partial public key in ctx
+ *
+ * @return the partial public key
+ */
+ public PublicKey generatePartialPublicKey() {
+ return new PublicKey(NativeHEHelper.generatePartialPublicKey(ctx));
+ }
- /**
- * sets the public key and stores it in ctx
- *
- * @param public_key the public key to set
- */
- public void setPublicKey(PublicKey public_key) {
- NativeHEHelper.setPublicKey(ctx, public_key.getData());
- }
+ /**
+ * sets the public key and stores it in ctx
+ *
+ * @param public_key the public key to set
+ */
+ public void setPublicKey(PublicKey public_key) {
+ NativeHEHelper.setPublicKey(ctx, public_key.getData());
+ }
- /**
- * encrypts one block of data with public key stored statically and returns it
- * setPublicKey() must have been called before calling this
- * @param plaintext the MatrixObject to encrypt
- * @return the encrypted matrix
- */
- public CiphertextMatrix encrypt(MatrixObject plaintext) {
- MatrixBlock mb = plaintext.acquireReadAndRelease();
- if (mb.isInSparseFormat()) {
- mb.allocateSparseRowsBlock();
- mb.sparseToDense();
- }
- DenseBlock db = mb.getDenseBlock();
- int[] dims = IntStream.range(0, db.numDims()).map(db::getDim).toArray();
- double[] raw_data = mb.getDenseBlockValues();
- return new CiphertextMatrix(dims, plaintext.getDataCharacteristics(), NativeHEHelper.encrypt(ctx, raw_data));
- }
+ /**
+ * encrypts one block of data with public key stored statically and returns it
+ * setPublicKey() must have been called before calling this
+ * @param plaintext the MatrixObject to encrypt
+ * @return the encrypted matrix
+ */
+ public CiphertextMatrix encrypt(MatrixObject plaintext) {
+ MatrixBlock mb = plaintext.acquireReadAndRelease();
+ if (mb.isInSparseFormat()) {
+ mb.allocateSparseRowsBlock();
+ mb.sparseToDense();
+ }
+ DenseBlock db = mb.getDenseBlock();
+ int[] dims = IntStream.range(0, db.numDims()).map(db::getDim).toArray();
+ double[] raw_data = mb.getDenseBlockValues();
+ return new CiphertextMatrix(dims, plaintext.getDataCharacteristics(), NativeHEHelper.encrypt(ctx, raw_data));
+ }
- /**
- * partially decrypts ciphertext with the partial private key. generatePartialPublicKey() must
- * have been called before calling this function
- *
- * @param ciphertext the ciphertext to partially decrypt
- * @return the partial decryption of ciphertext
- */
- public PlaintextMatrix partiallyDecrypt(CiphertextMatrix ciphertext) {
- return new PlaintextMatrix(ciphertext.getDims(), ciphertext.getDataCharacteristics(), NativeHEHelper.partiallyDecrypt(ctx, ciphertext.getData()));
- }
+ /**
+ * partially decrypts ciphertext with the partial private key. generatePartialPublicKey() must
+ * have been called before calling this function
+ *
+ * @param ciphertext the ciphertext to partially decrypt
+ * @return the partial decryption of ciphertext
+ */
+ public PlaintextMatrix partiallyDecrypt(CiphertextMatrix ciphertext) {
+ return new PlaintextMatrix(ciphertext.getDims(), ciphertext.getDataCharacteristics(), NativeHEHelper.partiallyDecrypt(ctx, ciphertext.getData()));
+ }
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedCoordinatorIntegrationCRUDTest.java b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedCoordinatorIntegrationCRUDTest.java
new file mode 100644
index 0000000000..c6612f2cb9
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedCoordinatorIntegrationCRUDTest.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.monitoring;
+
+import org.apache.commons.lang.StringUtils;
+import org.apache.http.HttpStatus;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.NodeEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.EntityEnum;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+
+public class FederatedCoordinatorIntegrationCRUDTest extends FederatedMonitoringTestBase {
+ private final static String TEST_NAME = "FederatedCoordinatorIntegrationCRUDTest";
+
+ private final static String TEST_DIR = "functions/federated/monitoring/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + FederatedCoordinatorIntegrationCRUDTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"}));
+ startFedMonitoring(null);
+ }
+
+ @Test
+ public void testCoordinatorAddedForMonitoring() {
+ var addedCoordinators = addEntities(EntityEnum.COORDINATOR,1);
+ var firstCoordinatorStatus = addedCoordinators.get(0).statusCode();
+
+ Assert.assertEquals("Added coordinator status code", HttpStatus.SC_OK, firstCoordinatorStatus);
+ }
+
+ @Test
+ @Ignore
+ public void testCoordinatorRemovedFromMonitoring() {
+ addEntities(EntityEnum.COORDINATOR,2);
+ var statusCode = removeEntity(EntityEnum.COORDINATOR,1L).statusCode();
+
+ var getAllCoordinatorsResponse = getEntities(EntityEnum.COORDINATOR);
+ var numReturnedCoordinators = StringUtils.countMatches(getAllCoordinatorsResponse.body().toString(), "id");
+
+ Assert.assertEquals("Removed coordinator status code", HttpStatus.SC_OK, statusCode);
+ Assert.assertEquals("Removed coordinators num", 1, numReturnedCoordinators);
+ }
+
+ @Test
+ @Ignore
+ public void testCoordinatorDataUpdated() {
+ addEntities(EntityEnum.COORDINATOR,3);
+ var newCoordinatorData = new NodeEntityModel(1L, "NonExistentName", "nonexistent.address");
+
+ var editedCoordinator = updateEntity(EntityEnum.COORDINATOR, newCoordinatorData);
+
+ var getAllCoordinatorsResponse = getEntities(EntityEnum.COORDINATOR);
+ var numCoordinatorsNewData = StringUtils.countMatches(getAllCoordinatorsResponse.body().toString(), newCoordinatorData.getName());
+
+ Assert.assertEquals("Updated coordinator status code", HttpStatus.SC_OK, editedCoordinator.statusCode());
+ Assert.assertEquals("Updated coordinators num", 1, numCoordinatorsNewData);
+ }
+
+ @Test
+ @Ignore
+ public void testCorrectAmountAddedCoordinatorsForMonitoring() {
+ int numCoordinators = 3;
+ var addedCoordinators = addEntities(EntityEnum.COORDINATOR, numCoordinators);
+
+ for (int i = 0; i < numCoordinators; i++) {
+ var coordinatorStatus = addedCoordinators.get(i).statusCode();
+ Assert.assertEquals("Added coordinator status code", HttpStatus.SC_OK, coordinatorStatus);
+ }
+
+ var getAllCoordinatorsResponse = getEntities(EntityEnum.COORDINATOR);
+ var numReturnedCoordinators = StringUtils.countMatches(getAllCoordinatorsResponse.body().toString(), "id");
+
+ Assert.assertEquals("Amount of coordinators to get", numCoordinators, numReturnedCoordinators);
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedMonitoringTestBase.java b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedMonitoringTestBase.java
index 483e3eee9e..4206151686 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedMonitoringTestBase.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedMonitoringTestBase.java
@@ -20,7 +20,8 @@
package org.apache.sysds.test.functions.federated.monitoring;
import com.fasterxml.jackson.databind.ObjectMapper;
-import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.BaseEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.NodeEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.EntityEnum;
import org.apache.sysds.test.functions.federated.multitenant.MultiTenantTestBase;
import org.junit.After;
@@ -36,7 +37,10 @@ public abstract class FederatedMonitoringTestBase extends MultiTenantTestBase {
protected Process monitoringProcess;
private int monitoringPort;
+ private static final String MAIN_URI = "http://localhost";
+
private static final String WORKER_MAIN_PATH = "/workers";
+ private static final String COORDINATOR_MAIN_PATH = "/coordinators";
@Override
public abstract void setUp();
@@ -59,16 +63,22 @@ public abstract class FederatedMonitoringTestBase extends MultiTenantTestBase {
monitoringProcess = startLocalFedMonitoring(monitoringPort, addArgs);
}
- protected List<HttpResponse<?>> addWorkers(int numWorkers) {
- String uriStr = String.format("http://localhost:%d%s", monitoringPort, WORKER_MAIN_PATH);
+ protected List<HttpResponse<?>> addEntities(EntityEnum type, int count) {
+ String uriStr = MAIN_URI + ":" + monitoringPort + WORKER_MAIN_PATH;
+ String name = "Worker";
+
+ if (type == EntityEnum.COORDINATOR) {
+ uriStr = MAIN_URI + ":" + monitoringPort + COORDINATOR_MAIN_PATH;
+ name = "Coordinator";
+ }
List<HttpResponse<?>> responses = new ArrayList<>();
try {
ObjectMapper objectMapper = new ObjectMapper();
- for (int i = 0; i < numWorkers; i++) {
+ for (int i = 0; i < count; i++) {
String requestBody = objectMapper
.writerWithDefaultPrettyPrinter()
- .writeValueAsString(new BaseEntityModel((i + 1L), "Worker", "localhost"));
+ .writeValueAsString(new NodeEntityModel((i + 1L), name, "localhost"));
var client = HttpClient.newHttpClient();
var request = HttpRequest.newBuilder(URI.create(uriStr))
.header("accept", "application/json")
@@ -84,8 +94,58 @@ public abstract class FederatedMonitoringTestBase extends MultiTenantTestBase {
}
}
- protected HttpResponse<?> getWorkers() {
- String uriStr = String.format("http://localhost:%d%s", monitoringPort, WORKER_MAIN_PATH);
+ protected HttpResponse<?> updateEntity(EntityEnum type, NodeEntityModel editModel) {
+ String uriStr = MAIN_URI + ":" + monitoringPort + WORKER_MAIN_PATH;
+
+ if (type == EntityEnum.COORDINATOR) {
+ uriStr = MAIN_URI + ":" + monitoringPort + COORDINATOR_MAIN_PATH;
+ }
+
+ try {
+ ObjectMapper objectMapper = new ObjectMapper();
+ String requestBody = objectMapper
+ .writerWithDefaultPrettyPrinter()
+ .writeValueAsString(new NodeEntityModel(editModel.getId(), editModel.getName(), editModel.getAddress()));
+ var client = HttpClient.newHttpClient();
+ var request = HttpRequest.newBuilder(URI.create(uriStr))
+ .header("accept", "application/json")
+ .PUT(HttpRequest.BodyPublishers.ofString(requestBody))
+ .build();
+
+ return client.send(request, HttpResponse.BodyHandlers.ofString());
+ }
+ catch (IOException | InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ protected HttpResponse<?> removeEntity(EntityEnum type, Long id) {
+ String uriStr = MAIN_URI + ":" + monitoringPort + WORKER_MAIN_PATH + "/" + id;
+
+ if (type == EntityEnum.COORDINATOR) {
+ uriStr = MAIN_URI + ":" + monitoringPort + COORDINATOR_MAIN_PATH + "/" + id;
+ }
+
+ try {
+ var client = HttpClient.newHttpClient();
+ var request = HttpRequest.newBuilder(URI.create(uriStr))
+ .header("accept", "application/json")
+ .DELETE()
+ .build();
+
+ return client.send(request, HttpResponse.BodyHandlers.ofString());
+ }
+ catch (IOException | InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ protected HttpResponse<?> getEntities(EntityEnum type) {
+ String uriStr = MAIN_URI + ":" + monitoringPort + WORKER_MAIN_PATH;
+
+ if (type == EntityEnum.COORDINATOR) {
+ uriStr = MAIN_URI + ":" + monitoringPort + COORDINATOR_MAIN_PATH;
+ }
try {
var client = HttpClient.newHttpClient();
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerIntegrationCRUDTest.java b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerIntegrationCRUDTest.java
index d9fd9d5d8e..2282c06871 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerIntegrationCRUDTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerIntegrationCRUDTest.java
@@ -21,9 +21,12 @@ package org.apache.sysds.test.functions.federated.monitoring;
import org.apache.commons.lang.StringUtils;
import org.apache.http.HttpStatus;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.NodeEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.EntityEnum;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
+import org.junit.Ignore;
import org.junit.Test;
public class FederatedWorkerIntegrationCRUDTest extends FederatedMonitoringTestBase {
@@ -41,23 +44,52 @@ public class FederatedWorkerIntegrationCRUDTest extends FederatedMonitoringTestB
@Test
public void testWorkerAddedForMonitoring() {
- var addedWorkers = addWorkers(1);
+ var addedWorkers = addEntities(EntityEnum.WORKER,1);
var firstWorkerStatus = addedWorkers.get(0).statusCode();
Assert.assertEquals("Added worker status code", HttpStatus.SC_OK, firstWorkerStatus);
}
@Test
+ @Ignore
+ public void testWorkerRemovedFromMonitoring() {
+ addEntities(EntityEnum.WORKER,2);
+ var statusCode = removeEntity(EntityEnum.WORKER,1L).statusCode();
+
+ var getAllWorkersResponse = getEntities(EntityEnum.WORKER);
+ var numReturnedWorkers = StringUtils.countMatches(getAllWorkersResponse.body().toString(), "id");
+
+ Assert.assertEquals("Removed worker status code", HttpStatus.SC_OK, statusCode);
+ Assert.assertEquals("Removed workers num", 1, numReturnedWorkers);
+ }
+
+ @Test
+ @Ignore
+ public void testWorkerDataUpdated() {
+ addEntities(EntityEnum.WORKER,3);
+ var newWorkerData = new NodeEntityModel(1L, "NonExistentName", "nonexistent.address");
+
+ var editedWorker = updateEntity(EntityEnum.WORKER, newWorkerData);
+
+ var getAllWorkersResponse = getEntities(EntityEnum.WORKER);
+ var numWorkersNewData = StringUtils.countMatches(getAllWorkersResponse.body().toString(), newWorkerData.getName());
+
+ Assert.assertEquals("Updated worker status code", HttpStatus.SC_OK, editedWorker.statusCode());
+ Assert.assertEquals("Updated workers num", 1, numWorkersNewData);
+ }
+
+ @Test
+ @Ignore
public void testCorrectAmountAddedWorkersForMonitoring() {
int numWorkers = 3;
- var addedWorkers = addWorkers(numWorkers);
+ var addedWorkers = addEntities(EntityEnum.WORKER, numWorkers);
for (int i = 0; i < numWorkers; i++) {
var workerStatus = addedWorkers.get(i).statusCode();
Assert.assertEquals("Added worker status code", HttpStatus.SC_OK, workerStatus);
}
- var getAllWorkersResponse = getWorkers();
+ var getAllWorkersResponse = getEntities(EntityEnum.WORKER);
var numReturnedWorkers = StringUtils.countMatches(getAllWorkersResponse.body().toString(), "id");
Assert.assertEquals("Amount of workers to get", numWorkers, numReturnedWorkers);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerStatisticsTest.java b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerStatisticsTest.java
index dc8fca39b6..2eaa3a6232 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerStatisticsTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerStatisticsTest.java
@@ -19,7 +19,9 @@
package org.apache.sysds.test.functions.federated.monitoring;
-import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.BaseEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.NodeEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.StatsEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.services.StatsService;
import org.apache.sysds.runtime.controlprogram.federated.monitoring.services.WorkerService;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
@@ -42,15 +44,30 @@ public class FederatedWorkerStatisticsTest extends FederatedMonitoringTestBase {
workerPorts = startFedWorkers(3);
}
+ @Test
+ public void testWorkerStatisticsParsedCorrectly() {
+
+ var model = (StatsEntityModel) StatsService.getWorkerStatistics(1L, "localhost:" + workerPorts[0]);
+
+ Assert.assertNotNull("Stats parsed correctly", model);
+ Assert.assertNotEquals("CPU stats parsed correctly", 0, model.getCPUUsage());
+ Assert.assertNotEquals("Memory Stats parsed correctly", 0, model.getMemoryUsage());
+ }
+
@Test
public void testWorkerStatisticsReturnedForMonitoring() {
- workerMonitoringService.create(new BaseEntityModel(1L, "Worker", "localhost:" + workerPorts[0]));
+ workerMonitoringService.create(new NodeEntityModel(1L, "Worker", "localhost:" + workerPorts[0]));
- var model = workerMonitoringService.get(1L);
- var modelData = model.getData();
+ var model = (NodeEntityModel) workerMonitoringService.get(1L);
+
+ Assert.assertNotNull("Stats field of model contains worker statistics", model.getStats());
+ }
+
+ @Test
+ public void testNonExistentWorkerStatistics() {
+ workerMonitoringService.create(new NodeEntityModel(1L, "Worker", "not-running.address"));
+ var model = (NodeEntityModel) workerMonitoringService.get(1L);
- Assert.assertNotNull("Data field of model contains worker statistics", model.getData());
- Assert.assertNotEquals("Data field of model contains worker statistics",0, modelData.length());
- Assert.assertTrue("Data field of model contains worker statistics", modelData.contains("JVM"));
+ Assert.assertEquals("Stats field of model contains worker statistics", 0, model.getStats().size());
}
}