You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2022/05/15 16:43:39 UTC
[systemds] branch main updated: [SYSTEMDS-3355] Federated monitoring backend worker communication
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new f33b516d10 [SYSTEMDS-3355] Federated monitoring backend worker communication
f33b516d10 is described below
commit f33b516d102115433ad101d0f76136cab92d01ae
Author: Mito <mk...@arakt.com>
AuthorDate: Sun May 15 18:41:51 2022 +0200
[SYSTEMDS-3355] Federated monitoring backend worker communication
Closes #1608.
---
bin/systemds | 38 ++++-
src/main/java/org/apache/sysds/api/DMLOptions.java | 31 +++-
src/main/java/org/apache/sysds/api/DMLScript.java | 6 +
.../federated/FederatedStatistics.java | 12 +-
.../monitoring/FederatedMonitoringServer.java | 70 +++++---
.../FederatedMonitoringServerHandler.java | 180 +++++++++++----------
.../controllers/CoordinatorController.java | 46 +++---
.../{BaseController.java => IController.java} | 14 +-
.../monitoring/controllers/WorkerController.java | 81 ++++++++++
.../monitoring/models/BaseEntityModel.java | 78 +++++++++
.../federated/monitoring/{ => models}/Request.java | 30 ++--
.../monitoring/{ => models}/Response.java | 38 ++---
.../monitoring/repositories/DerbyRepository.java | 171 ++++++++++++++++++++
.../{Request.java => repositories/EntityEnum.java} | 26 +--
.../IRepository.java} | 17 +-
.../monitoring/services/WorkerService.java | 89 ++++++++++
.../org/apache/sysds/test/AutomatedTestBase.java | 25 +++
.../monitoring/FederatedMonitoringTestBase.java | 101 ++++++++++++
.../FederatedWorkerIntegrationCRUDTest.java | 65 ++++++++
.../monitoring/FederatedWorkerStatisticsTest.java | 56 +++++++
20 files changed, 952 insertions(+), 222 deletions(-)
diff --git a/bin/systemds b/bin/systemds
index 707f7e58ab..ffff838991 100755
--- a/bin/systemds
+++ b/bin/systemds
@@ -170,6 +170,10 @@ Worker Usage: $0 [-r] WORKER [SystemDS.jar] <portnumber> [arguments] [-help]
port : The port to open for the federated worker.
+Federated Monitoring Usage: $0 [-r] FEDMONITOR [SystemDS.jar] <portnumber> [arguments] [-help]
+
+ port : The port to open for the federated monitoring tool.
+
Set custom launch configuration by setting/editing SYSTEMDS_STANDALONE_OPTS
and/or SYSTEMDS_DISTRIBUTED_OPTS.
@@ -256,6 +260,20 @@ elif echo "$1" | grep -q "WORKER"; then
printUsage
fi
shift
+elif echo "$1" | grep -q "FEDMONITOR"; then
+ FEDMONITOR=1
+ shift
+ if echo "$1" | grep -q "jar"; then
+ SYSTEMDS_JAR_FILE=$1
+ shift
+ fi
+ PORT=$1
+ re='^[0-9]+$'
+ if ! [[ $PORT =~ $re ]] ; then
+ echo "error: Port is not a number"
+ printUsage
+ fi
+ shift
else
# handle optional '-f' before DML file (for consistency)
if echo "$1" | grep -q "\-f"; then
@@ -272,6 +290,9 @@ if [ -z "$WORKER" ] ; then
WORKER=0
fi
+if [ -z "$FEDMONITOR" ] ; then
+ FEDMONITOR=0
+fi
# find me a SystemDS jar file to run
if [ -z "$SYSTEMDS_JAR_FILE" ];then
@@ -409,7 +430,7 @@ print_out "# HADOOP_HOME= $HADOOP_HOME"
#build the command to run
if [ $WORKER == 1 ]; then
print_out "#"
- print_out "# starting Fedederated worker on port $PORT"
+ print_out "# starting Federated worker on port $PORT"
print_out "###############################################################################"
CMD=" \
java $SYSTEMDS_STANDALONE_OPTS \
@@ -422,6 +443,21 @@ if [ $WORKER == 1 ]; then
print_out "Executing command: $CMD"
print_out ""
+if [ $FEDMONITORING == 1 ]; then
+ print_out "#"
+ print_out "# starting Federated backend monitoring on port $PORT"
+ print_out "###############################################################################"
+ CMD=" \
+ java $SYSTEMDS_STANDALONE_OPTS \
+ -cp $CLASSPATH \
+ $LOG4JPROP \
+ org.apache.sysds.api.DMLScript \
+ -fedMonitor $PORT \
+ $CONFIG_FILE \
+ $*"
+ print_out "Executing command: $CMD"
+ print_out ""
+
elif [ $SYSDS_DISTRIBUTED == 0 ]; then
print_out "#"
print_out "# Running script $SCRIPT_FILE locally with opts: $*"
diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java b/src/main/java/org/apache/sysds/api/DMLOptions.java
index f2c3f5664c..11bcea1604 100644
--- a/src/main/java/org/apache/sysds/api/DMLOptions.java
+++ b/src/main/java/org/apache/sysds/api/DMLOptions.java
@@ -73,7 +73,9 @@ public class DMLOptions {
public boolean lineage_debugger = false; // whether enable lineage debugger
public boolean fedWorker = false;
public int fedWorkerPort = -1;
- public int pythonPort = -1;
+ public boolean fedMonitoring = false;
+ public int fedMonitoringPort = -1;
+ public int pythonPort = -1;
public boolean checkPrivacy = false; // Check which privacy constraints are loaded and checked during federated execution
public boolean federatedCompilation = false; // Compile federated instructions based on input federation state and privacy constraints.
public boolean noFedRuntimeConversion = false; // If activated, no runtime conversion of CP instructions to FED instructions will be performed.
@@ -95,6 +97,7 @@ public class DMLOptions {
", statsCount=" + statsCount +
", fedStats=" + fedStats +
", fedStatsCount=" + fedStatsCount +
+ ", fedMonitor=" + fedMonitoring +
", memStats=" + memStats +
", explainType=" + explainType +
", execMode=" + execMode +
@@ -217,6 +220,7 @@ public class DMLOptions {
}
}
}
+
dmlOptions.memStats = line.hasOption("mem");
dmlOptions.clean = line.hasOption("clean");
@@ -230,6 +234,11 @@ public class DMLOptions {
dmlOptions.fedWorkerPort = Integer.parseInt(line.getOptionValue("w"));
}
+ if (line.hasOption("fedMonitor")) {
+ dmlOptions.fedMonitoring= true;
+ dmlOptions.fedMonitoringPort = Integer.parseInt(line.getOptionValue("fedMonitor"));
+ }
+
if (line.hasOption("f")){
dmlOptions.filePath = line.getOptionValue("f");
}
@@ -314,7 +323,8 @@ public class DMLOptions {
Option configOpt = OptionBuilder.withArgName("filename")
.withDescription("uses a given configuration file (can be on local/hdfs/gpfs; default values in SystemDS-config.xml")
.hasArg().create("config");
- Option cleanOpt = OptionBuilder.withDescription("cleans up all SystemDS working directories (FS, DFS); all other flags are ignored in this mode.")
+ Option cleanOpt = OptionBuilder
+ .withDescription("cleans up all SystemDS working directories (FS, DFS); all other flags are ignored in this mode.")
.create("clean");
Option statsOpt = OptionBuilder.withArgName("count")
.withDescription("monitors and reports summary execution statistics; heavy hitter <count> is 10 unless overridden; default off")
@@ -335,7 +345,8 @@ public class DMLOptions {
.hasOptionalArg().create("gpu");
Option debugOpt = OptionBuilder.withDescription("runs in debug mode; default off")
.create("debug");
- Option pythonOpt = OptionBuilder.withDescription("Python Context start with port argument for communication to python")
+ Option pythonOpt = OptionBuilder
+ .withDescription("Python Context start with port argument for communication to python")
.isRequired().hasArg().create("python");
Option fileOpt = OptionBuilder.withArgName("filename")
.withDescription("specifies dml/pydml file to execute; path can be local/hdfs/gpfs (prefixed with appropriate URI)")
@@ -343,12 +354,18 @@ public class DMLOptions {
Option scriptOpt = OptionBuilder.withArgName("script_contents")
.withDescription("specified script string to execute directly")
.isRequired().hasArg().create("s");
- Option helpOpt = OptionBuilder.withDescription("shows usage message")
+ Option helpOpt = OptionBuilder
+ .withDescription("shows usage message")
.create("help");
- Option lineageOpt = OptionBuilder.withDescription("computes lineage traces")
+ Option lineageOpt = OptionBuilder
+ .withDescription("computes lineage traces")
.hasOptionalArgs().create("lineage");
- Option fedOpt = OptionBuilder.withDescription("starts a federated worker with the given argument as the port.")
+ Option fedOpt = OptionBuilder
+ .withDescription("starts a federated worker with the given argument as the port.")
.hasOptionalArg().create("w");
+ Option monitorOpt = OptionBuilder
+ .withDescription("Starts a federated monitoring backend with the given argument as the port.")
+ .hasOptionalArg().create("fedMonitor");
Option checkPrivacy = OptionBuilder
.withDescription("Check which privacy constraints are loaded and checked during federated execution")
.create("checkPrivacy");
@@ -375,6 +392,7 @@ public class DMLOptions {
options.addOption(debugOpt);
options.addOption(lineageOpt);
options.addOption(fedOpt);
+ options.addOption(monitorOpt);
options.addOption(checkPrivacy);
options.addOption(federatedCompilation);
options.addOption(noFedRuntimeConversion);
@@ -387,6 +405,7 @@ public class DMLOptions {
.addOption(cleanOpt)
.addOption(helpOpt)
.addOption(fedOpt)
+ .addOption(monitorOpt)
.addOption(pythonOpt);
fileOrScriptOpt.setRequired(true);
options.addOptionGroup(fileOrScriptOpt);
diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java
index d74cf59bf7..9f6eb656e0 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -64,6 +64,7 @@ import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedWorker;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.FederatedMonitoringServer;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool;
@@ -284,6 +285,11 @@ public class DMLScript
return true;
}
+ if(dmlOptions.fedMonitoring) {
+ new FederatedMonitoringServer(dmlOptions.fedMonitoringPort, dmlOptions.debug);
+ return true;
+ }
+
LineageCacheConfig.setConfig(LINEAGE_REUSE);
LineageCacheConfig.setCachePolicy(LINEAGE_POLICY);
LineageCacheConfig.setEstimator(LINEAGE_ESTIMATE);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
index 58a9480266..d95b02afd2 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
@@ -219,8 +219,12 @@ public class FederatedStatistics {
}
public static String displayStatistics(int numHeavyHitters) {
- StringBuilder sb = new StringBuilder();
FedStatsCollection fedStats = collectFedStats();
+ return displayStatistics(fedStats, numHeavyHitters);
+ }
+
+ public static String displayStatistics(FedStatsCollection fedStats, int numHeavyHitters) {
+ StringBuilder sb = new StringBuilder();
sb.append("SystemDS Federated Statistics:\n");
sb.append(displayCacheStats(fedStats.cacheStats));
sb.append(String.format("Total JIT compile time:\t\t%.3f sec.\n", fedStats.jitCompileTime));
@@ -499,7 +503,7 @@ public class FederatedStatistics {
return "";
}
- private static class FedStatsCollectFunction extends FederatedUDF {
+ public static class FedStatsCollectFunction extends FederatedUDF {
private static final long serialVersionUID = 1L;
public FedStatsCollectFunction() {
@@ -519,7 +523,7 @@ public class FederatedStatistics {
}
}
- protected static class FedStatsCollection implements Serializable {
+ public static class FedStatsCollection implements Serializable {
private static final long serialVersionUID = 1L;
private void collectStats() {
@@ -531,7 +535,7 @@ public class FederatedStatistics {
heavyHitters = Statistics.getHeavyHittersHashMap();
}
- private void aggregate(FedStatsCollection that) {
+ public void aggregate(FedStatsCollection that) {
cacheStats.aggregate(that.cacheStats);
jitCompileTime += that.jitCompileTime;
gcStats.aggregate(that.gcStats);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/FederatedMonitoringServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/FederatedMonitoringServer.java
index 61bc6e5dc3..8976d65194 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/FederatedMonitoringServer.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/FederatedMonitoringServer.java
@@ -28,37 +28,55 @@ import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.HttpServerCodec;
+import org.apache.log4j.Logger;
public class FederatedMonitoringServer {
- private final int _port;
+ protected static Logger log = Logger.getLogger(FederatedMonitoringServer.class);
+ private final int _port;
- public FederatedMonitoringServer(int port) {
- _port = (port == -1) ? 4201 : port;
- }
+ private final boolean _debug;
- public void run() throws Exception {
- EventLoopGroup bossGroup = new NioEventLoopGroup();
- EventLoopGroup workerGroup = new NioEventLoopGroup();
+ public FederatedMonitoringServer(int port, boolean debug) {
+ _port = (port == -1) ? 4201 : port;
- try {
- ServerBootstrap server = new ServerBootstrap();
- server.group(bossGroup, workerGroup)
- .channel(NioServerSocketChannel.class)
- .childHandler(new ChannelInitializer<>() {
- @Override
- protected void initChannel(Channel ch) {
- ChannelPipeline pipeline = ch.pipeline();
+ _debug = debug;
- pipeline.addLast(new HttpServerCodec());
- pipeline.addLast(new FederatedMonitoringServerHandler());
- }
- });
+ run();
+ }
- ChannelFuture f = server.bind(_port).sync();
- f.channel().closeFuture().sync();
- } finally {
- workerGroup.shutdownGracefully();
- bossGroup.shutdownGracefully();
- }
- }
+ public void run() {
+ log.info("Setting up Federated Monitoring Backend on port " + _port);
+ EventLoopGroup bossGroup = new NioEventLoopGroup();
+ EventLoopGroup workerGroup = new NioEventLoopGroup();
+
+ try {
+ ServerBootstrap server = new ServerBootstrap();
+ server.group(bossGroup, workerGroup)
+ .channel(NioServerSocketChannel.class)
+ .childHandler(new ChannelInitializer<>() {
+ @Override
+ protected void initChannel(Channel ch) {
+ ChannelPipeline pipeline = ch.pipeline();
+
+ pipeline.addLast(new HttpServerCodec());
+ pipeline.addLast(new FederatedMonitoringServerHandler());
+ }
+ });
+
+ log.info("Starting Federated Monitoring Backend server at port: " + _port);
+ ChannelFuture f = server.bind(_port).sync();
+ log.info("Started Federated Monitoring Backend at port: " + _port);
+ f.channel().closeFuture().sync();
+ } catch(Exception e) {
+ log.info("Federated Monitoring Backend Interrupted");
+ if (_debug) {
+ log.error(e.getMessage());
+ e.printStackTrace();
+ }
+ } finally{
+ log.info("Federated Monitoring Backend Shutting down.");
+ workerGroup.shutdownGracefully();
+ bossGroup.shutdownGracefully();
+ }
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/FederatedMonitoringServerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/FederatedMonitoringServerHandler.java
index ac392b5000..2e7006055b 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/FederatedMonitoringServerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/FederatedMonitoringServerHandler.java
@@ -27,8 +27,10 @@ import io.netty.handler.codec.http.HttpObject;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.util.CharsetUtil;
-import org.apache.sysds.runtime.controlprogram.federated.monitoring.controllers.BaseController;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.controllers.IController;
import org.apache.sysds.runtime.controlprogram.federated.monitoring.controllers.CoordinatorController;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.controllers.WorkerController;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.Request;
import java.util.HashMap;
import java.util.Map;
@@ -38,92 +40,92 @@ import java.util.regex.Pattern;
public class FederatedMonitoringServerHandler extends SimpleChannelInboundHandler<HttpObject> {
- private final Map<String, BaseController> _allControllers = new HashMap<>();
- {
- _allControllers.put("/coordinators", new CoordinatorController());
- }
-
- private final static ThreadLocal<Request> _currentRequest = new ThreadLocal<>();
-
- @Override
- protected void channelRead0(ChannelHandlerContext ctx, HttpObject msg) {
-
- if (msg instanceof LastHttpContent) {
- final ByteBuf jsonBuf = ((LastHttpContent) msg).content();
- final Request request = _currentRequest.get();
- request.setBody(jsonBuf.toString(CharsetUtil.UTF_8));
-
- _currentRequest.remove();
-
- final FullHttpResponse response = processRequest(request);
- ctx.write(response);
-
- } else if (msg instanceof HttpRequest) {
- final HttpRequest httpRequest = (HttpRequest) msg;
- final Request request = new Request();
- request.setContext(httpRequest);
-
- _currentRequest.set(request);
- }
-
- }
-
- @Override
- public void channelReadComplete(ChannelHandlerContext ctx) {
- ctx.flush();
- }
-
- @Override
- public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
- cause.printStackTrace();
- ctx.close();
- }
-
- private FullHttpResponse processRequest(final Request request) {
- try {
- final BaseController controller = parseController(request.getContext().uri());
- final String method = request.getContext().method().name();
-
- switch (method) {
- case "GET":
- final Long id = parseId(request.getContext().uri());
-
- if (id != null) {
- return controller.get(request, id);
- }
-
- return controller.getAll(request);
- case "PUT":
- return controller.create(request);
- case "POST":
- return controller.update(request, parseId(request.getContext().uri()));
- case "DELETE":
- return controller.delete(request, parseId(request.getContext().uri()));
- default:
- throw new IllegalArgumentException("Method is not supported!");
- }
- } catch (RuntimeException ex) {
- ex.printStackTrace();
- return null;
- }
- }
-
- private BaseController parseController(final String currentPath) {
- final Optional<String> controller = _allControllers.keySet().stream()
- .filter(currentPath::startsWith)
- .findFirst();
-
- return controller.map(_allControllers::get).orElseThrow(() ->
- new IllegalArgumentException("Such controller does not exist!"));
- }
-
- private Long parseId(final String uri) {
- final Pattern pattern = Pattern.compile("^[/][a-z]+[/]");
- final Matcher matcher = pattern.matcher(uri);
-
- if (matcher.find()) {
- return Long.valueOf(uri.substring(matcher.end()));
- }
- return null;
- }
+ private final Map<String, IController> _allControllers = new HashMap<>();
+ {
+ _allControllers.put("/coordinators", new CoordinatorController());
+ _allControllers.put("/workers", new WorkerController());
+ }
+
+ private final static ThreadLocal<Request> _currentRequest = new ThreadLocal<>();
+
+ @Override
+ protected void channelRead0(ChannelHandlerContext ctx, HttpObject msg) {
+
+ if (msg instanceof LastHttpContent) {
+ ByteBuf jsonBuf = ((LastHttpContent) msg).content();
+ Request request = _currentRequest.get();
+ request.setBody(jsonBuf.toString(CharsetUtil.UTF_8));
+
+ _currentRequest.remove();
+
+ final FullHttpResponse response = processRequest(request);
+ ctx.write(response);
+
+ } else if (msg instanceof HttpRequest) {
+ HttpRequest httpRequest = (HttpRequest) msg;
+ Request request = new Request();
+ request.setContext(httpRequest);
+
+ _currentRequest.set(request);
+ }
+
+ }
+
+ @Override
+ public void channelReadComplete(ChannelHandlerContext ctx) {
+ ctx.flush();
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+ cause.printStackTrace();
+ ctx.close();
+ }
+
+ private FullHttpResponse processRequest(final Request request) {
+ try {
+ final IController controller = parseController(request.getContext().uri());
+ final String method = request.getContext().method().name();
+
+ switch (method) {
+ case "GET":
+ final Long id = parseId(request.getContext().uri());
+
+ if (id != null) {
+ return controller.get(request, id);
+ }
+
+ return controller.getAll(request);
+ case "PUT":
+ return controller.update(request, parseId(request.getContext().uri()));
+ case "POST":
+ return controller.create(request);
+ case "DELETE":
+ return controller.delete(request, parseId(request.getContext().uri()));
+ default:
+ throw new IllegalArgumentException("Method is not supported!");
+ }
+ } catch (RuntimeException ex) {
+ throw ex;
+ }
+ }
+
+ private IController parseController(final String currentPath) {
+ final Optional<String> controller = _allControllers.keySet().stream()
+ .filter(currentPath::startsWith)
+ .findFirst();
+
+ return controller.map(_allControllers::get).orElseThrow(() ->
+ new IllegalArgumentException("Such controller does not exist!"));
+ }
+
+ private Long parseId(final String uri) {
+ final Pattern pattern = Pattern.compile("^[/][a-z]+[/]");
+ final Matcher matcher = pattern.matcher(uri);
+
+ if (matcher.find()) {
+ return Long.valueOf(uri.substring(matcher.end()));
+ }
+ return null;
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/CoordinatorController.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/CoordinatorController.java
index be807721b9..8c81ffd24d 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/CoordinatorController.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/CoordinatorController.java
@@ -20,32 +20,32 @@
package org.apache.sysds.runtime.controlprogram.federated.monitoring.controllers;
import io.netty.handler.codec.http.FullHttpResponse;
-import org.apache.sysds.runtime.controlprogram.federated.monitoring.Request;
-import org.apache.sysds.runtime.controlprogram.federated.monitoring.Response;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.Request;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.Response;
-public class CoordinatorController implements BaseController {
- @Override
- public FullHttpResponse create(Request request) {
- return null;
- }
+public class CoordinatorController implements IController {
+ @Override
+ public FullHttpResponse create(Request request) {
+ return null;
+ }
- @Override
- public FullHttpResponse update(Request request, Long objectId) {
- return null;
- }
+ @Override
+ public FullHttpResponse update(Request request, Long objectId) {
+ return null;
+ }
- @Override
- public FullHttpResponse delete(Request request, Long objectId) {
- return null;
- }
+ @Override
+ public FullHttpResponse delete(Request request, Long objectId) {
+ return null;
+ }
- @Override
- public FullHttpResponse get(Request request, Long objectId) {
- return Response.ok("Success");
- }
+ @Override
+ public FullHttpResponse get(Request request, Long objectId) {
+ return Response.ok("Success");
+ }
- @Override
- public FullHttpResponse getAll(Request request) {
- return Response.ok("Success");
- }
+ @Override
+ public FullHttpResponse getAll(Request request) {
+ return Response.ok("Success");
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/BaseController.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/IController.java
similarity index 73%
copy from src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/BaseController.java
copy to src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/IController.java
index 34a415b6e3..6016748bc8 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/BaseController.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/IController.java
@@ -20,17 +20,17 @@
package org.apache.sysds.runtime.controlprogram.federated.monitoring.controllers;
import io.netty.handler.codec.http.FullHttpResponse;
-import org.apache.sysds.runtime.controlprogram.federated.monitoring.Request;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.Request;
-public interface BaseController {
+public interface IController {
- FullHttpResponse create(final Request request);
+ FullHttpResponse create(final Request request);
- FullHttpResponse update(final Request request, final Long objectId);
+ FullHttpResponse update(final Request request, final Long objectId);
- FullHttpResponse delete(final Request request, final Long objectId);
+ FullHttpResponse delete(final Request request, final Long objectId);
- FullHttpResponse get(final Request request, final Long objectId);
+ FullHttpResponse get(final Request request, final Long objectId);
- FullHttpResponse getAll(final Request request);
+ FullHttpResponse getAll(final Request request);
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/WorkerController.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/WorkerController.java
new file mode 100644
index 0000000000..bdc46304f6
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/WorkerController.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.federated.monitoring.controllers;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.netty.handler.codec.http.FullHttpResponse;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.Request;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.Response;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.BaseEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.services.WorkerService;
+
+import java.io.IOException;
+
+public class WorkerController implements IController {
+
+ private final WorkerService _workerService = new WorkerService();
+
+ @Override
+ public FullHttpResponse create(Request request) {
+
+ ObjectMapper mapper = new ObjectMapper();
+
+ try {
+ BaseEntityModel model = mapper.readValue(request.getBody(), BaseEntityModel.class);
+ _workerService.create(model);
+ return Response.ok("Success");
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public FullHttpResponse update(Request request, Long objectId) {
+ return null;
+ }
+
+ @Override
+ public FullHttpResponse delete(Request request, Long objectId) {
+ return null;
+ }
+
+ @Override
+ public FullHttpResponse get(Request request, Long objectId) {
+ var result = _workerService.get(objectId);
+
+ if (result == null) {
+ return Response.notFound("No such worker can be found");
+ }
+
+ return Response.ok(result.toString());
+ }
+
+ @Override
+ public FullHttpResponse getAll(Request request) {
+ var workers = _workerService.getAll();
+
+ if (workers.isEmpty()) {
+ return Response.notFound("No workers can be found");
+ }
+
+ return Response.ok(workers.toString());
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/BaseEntityModel.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/BaseEntityModel.java
new file mode 100644
index 0000000000..d42e76556f
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/BaseEntityModel.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.federated.monitoring.models;
+
+public class BaseEntityModel {
+ private Long _id;
+ private String _name;
+ private String _address;
+
+ private String _data;
+
+ public BaseEntityModel() { }
+
+ public BaseEntityModel(final Long id, final String name, final String address) {
+ _id = id;
+ _name = name;
+ _address = address;
+ }
+
+ public Long getId() {
+ return _id;
+ }
+
+ public void setId(final Long id) {
+ _id = id;
+ }
+
+ public String getName() {
+ return _name;
+ }
+
+ public void setName(final String name) {
+ _name = name;
+ }
+
+ public String getAddress() {
+ return _address;
+ }
+
+ public void setAddress(final String address) {
+ _address = address;
+ }
+
+ public String getData() {
+ return _data;
+ }
+
+ public void setData(final String data) {
+ _data = data;
+ }
+
+ @Override
+ public String toString() {
+ return String.format("{" +
+ "\"id\": %d," +
+ "\"name\": \"%s\"," +
+ "\"address\": \"%s\"," +
+ "\"data\": \"%s\"" +
+ "}", _id, _name, _address, _data);
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/Request.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/Request.java
similarity index 71%
copy from src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/Request.java
copy to src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/Request.java
index b9d71fe428..21d6812644 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/Request.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/Request.java
@@ -17,27 +17,27 @@
* under the License.
*/
-package org.apache.sysds.runtime.controlprogram.federated.monitoring;
+package org.apache.sysds.runtime.controlprogram.federated.monitoring.models;
import io.netty.handler.codec.http.HttpRequest;
public class Request {
- private HttpRequest _context;
- private String _body;
+ private HttpRequest _context;
+ private String _body;
- public HttpRequest getContext() {
- return _context;
- }
+ public HttpRequest getContext() {
+ return _context;
+ }
- public void setContext(final HttpRequest requestContext) {
- this._context = requestContext;
- }
+ public void setContext(final HttpRequest requestContext) {
+ this._context = requestContext;
+ }
- public String getBody() {
- return _body;
- }
+ public String getBody() {
+ return _body;
+ }
- public void setBody(final String content) {
- this._body = content;
- }
+ public void setBody(final String content) {
+ this._body = content;
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/Response.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/Response.java
similarity index 55%
rename from src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/Response.java
rename to src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/Response.java
index 7a3814835b..9693af6060 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/Response.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/models/Response.java
@@ -17,7 +17,7 @@
* under the License.
*/
-package org.apache.sysds.runtime.controlprogram.federated.monitoring;
+package org.apache.sysds.runtime.controlprogram.federated.monitoring.models;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
@@ -27,27 +27,27 @@ import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
public class Response {
- public static FullHttpResponse ok(final String result) {
- FullHttpResponse response = new DefaultFullHttpResponse(
- HttpVersion.HTTP_1_1,
- HttpResponseStatus.OK,
- Unpooled.wrappedBuffer(result.getBytes()));
+ public static FullHttpResponse ok(final String result) {
+ FullHttpResponse response = new DefaultFullHttpResponse(
+ HttpVersion.HTTP_1_1,
+ HttpResponseStatus.OK,
+ Unpooled.wrappedBuffer(result.getBytes()));
- response.headers().set(HttpHeaderNames.CONTENT_TYPE, "application/json");
- response.headers().set(HttpHeaderNames.CONTENT_LENGTH, response.content().readableBytes());
+ response.headers().set(HttpHeaderNames.CONTENT_TYPE, "application/json");
+ response.headers().set(HttpHeaderNames.CONTENT_LENGTH, response.content().readableBytes());
- return response;
- }
+ return response;
+ }
- public static FullHttpResponse notFound(final String exception) {
- FullHttpResponse response = new DefaultFullHttpResponse(
- HttpVersion.HTTP_1_1,
- HttpResponseStatus.NOT_FOUND,
- Unpooled.wrappedBuffer(exception.getBytes()));
+ public static FullHttpResponse notFound(final String exception) {
+ FullHttpResponse response = new DefaultFullHttpResponse(
+ HttpVersion.HTTP_1_1,
+ HttpResponseStatus.NOT_FOUND,
+ Unpooled.wrappedBuffer(exception.getBytes()));
- response.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/plain");
- response.headers().set(HttpHeaderNames.CONTENT_LENGTH, response.content().readableBytes());
+ response.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/plain");
+ response.headers().set(HttpHeaderNames.CONTENT_LENGTH, response.content().readableBytes());
- return response;
- }
+ return response;
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/DerbyRepository.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/DerbyRepository.java
new file mode 100644
index 0000000000..9e94a41d61
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/DerbyRepository.java
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories;
+
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.BaseEntityModel;
+
+import java.sql.Connection;
+import java.sql.DriverManager;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Types;
+import java.util.ArrayList;
+import java.util.List;
+
+public class DerbyRepository implements IRepository {
+ private final static String DB_CONNECTION = "jdbc:derby:memory:derbyDB";
+ private final Connection _db;
+
+ private static final String WORKERS_TABLE_NAME= "workers";
+ private static final String ENTITY_NAME_COL = "name";
+ private static final String ENTITY_ADDR_COL = "address";
+
+ private static final String ENTITY_SCHEMA_CREATE_STMT = "CREATE TABLE %s " +
+ "(id INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), " +
+ "%s VARCHAR(60), " +
+ "%s VARCHAR(120))";
+ private static final String ENTITY_INSERT_STMT = "INSERT INTO %s (%s, %s) VALUES (?, ?)";
+
+ private static final String GET_ENTITY_WITH_ID_STMT = "SELECT * FROM %s WHERE id = ?";
+ private static final String GET_ALL_ENTITIES_STMT = "SELECT * FROM %s";
+
+ public DerbyRepository() {
+ _db = createMonitoringDatabase();
+ }
+
+ private Connection createMonitoringDatabase() {
+ Connection db = null;
+ try {
+ // Creates only if DB doesn't exist
+ db = DriverManager.getConnection(DB_CONNECTION + ";create=true");
+ createMonitoringEntitiesInDB(db);
+
+ return db;
+ }
+ catch (SQLException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private void createMonitoringEntitiesInDB(Connection db) {
+ try {
+ var dbMetaData = db.getMetaData();
+ var workersExist = dbMetaData.getTables(null, null, WORKERS_TABLE_NAME.toUpperCase(),null);
+
+ // Check if table already exists and create if not
+ if(!workersExist.next())
+ {
+ PreparedStatement st = db.prepareStatement(
+ String.format(ENTITY_SCHEMA_CREATE_STMT, WORKERS_TABLE_NAME, ENTITY_NAME_COL, ENTITY_ADDR_COL));
+ st.executeUpdate();
+
+ }
+ }
+ catch (SQLException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public void createEntity(EntityEnum type, BaseEntityModel model) {
+
+ try {
+ PreparedStatement st = _db.prepareStatement(
+ String.format(ENTITY_INSERT_STMT, WORKERS_TABLE_NAME, ENTITY_NAME_COL, ENTITY_ADDR_COL));
+
+ if (type == EntityEnum.COORDINATOR) {
+ // Change statement
+ }
+
+ st.setString(1, model.getName());
+ st.setString(2, model.getAddress());
+ st.executeUpdate();
+
+ } catch (SQLException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public BaseEntityModel getEntity(EntityEnum type, Long id) {
+ BaseEntityModel resultModel = null;
+
+ try {
+ PreparedStatement st = _db.prepareStatement(
+ String.format(GET_ENTITY_WITH_ID_STMT, WORKERS_TABLE_NAME));
+
+ if (type == EntityEnum.COORDINATOR) {
+ // Change statement
+ }
+
+ st.setLong(1, id);
+ var resultSet = st.executeQuery();
+
+ if (resultSet.next()){
+ resultModel = mapEntityToModel(resultSet);
+ }
+ } catch (SQLException e) {
+ throw new RuntimeException(e);
+ }
+
+ return resultModel;
+ }
+
+ public List<BaseEntityModel> getAllEntities(EntityEnum type) {
+ List<BaseEntityModel> resultModels = new ArrayList<>();
+
+ try {
+ PreparedStatement st = _db.prepareStatement(
+ String.format(GET_ALL_ENTITIES_STMT, WORKERS_TABLE_NAME));
+
+ if (type == EntityEnum.COORDINATOR) {
+ // Change statement
+ }
+
+ var resultSet = st.executeQuery();
+
+ while (resultSet.next()){
+ resultModels.add(mapEntityToModel(resultSet));
+ }
+ } catch (SQLException e) {
+ throw new RuntimeException(e);
+ }
+
+ return resultModels;
+ }
+
+ private BaseEntityModel mapEntityToModel(ResultSet resultSet) throws SQLException {
+ BaseEntityModel tmpModel = new BaseEntityModel();
+
+ for (int column = 1; column <= resultSet.getMetaData().getColumnCount(); column++) {
+ if (resultSet.getMetaData().getColumnType(column) == Types.INTEGER) {
+ tmpModel.setId(resultSet.getLong(column));
+ }
+
+ if (resultSet.getMetaData().getColumnType(column) == Types.VARCHAR) {
+ if (resultSet.getMetaData().getColumnName(column).equalsIgnoreCase(ENTITY_NAME_COL)) {
+ tmpModel.setName(resultSet.getString(column));
+ } else if (resultSet.getMetaData().getColumnName(column).equalsIgnoreCase(ENTITY_ADDR_COL)) {
+ tmpModel.setAddress(resultSet.getString(column));
+ }
+ }
+ }
+ return tmpModel;
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/Request.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/EntityEnum.java
similarity index 65%
rename from src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/Request.java
rename to src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/EntityEnum.java
index b9d71fe428..7384257bf3 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/Request.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/EntityEnum.java
@@ -17,27 +17,9 @@
* under the License.
*/
-package org.apache.sysds.runtime.controlprogram.federated.monitoring;
+package org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories;
-import io.netty.handler.codec.http.HttpRequest;
-
-public class Request {
- private HttpRequest _context;
- private String _body;
-
- public HttpRequest getContext() {
- return _context;
- }
-
- public void setContext(final HttpRequest requestContext) {
- this._context = requestContext;
- }
-
- public String getBody() {
- return _body;
- }
-
- public void setBody(final String content) {
- this._body = content;
- }
+public enum EntityEnum {
+ WORKER,
+ COORDINATOR
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/BaseController.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/IRepository.java
similarity index 68%
rename from src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/BaseController.java
rename to src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/IRepository.java
index 34a415b6e3..d441693e59 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/BaseController.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/IRepository.java
@@ -17,20 +17,17 @@
* under the License.
*/
-package org.apache.sysds.runtime.controlprogram.federated.monitoring.controllers;
-import io.netty.handler.codec.http.FullHttpResponse;
-import org.apache.sysds.runtime.controlprogram.federated.monitoring.Request;
+package org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories;
-public interface BaseController {
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.BaseEntityModel;
- FullHttpResponse create(final Request request);
+import java.util.List;
- FullHttpResponse update(final Request request, final Long objectId);
+public interface IRepository {
+ void createEntity(EntityEnum type, BaseEntityModel model);
- FullHttpResponse delete(final Request request, final Long objectId);
+ BaseEntityModel getEntity(EntityEnum type, Long id);
- FullHttpResponse get(final Request request, final Long objectId);
-
- FullHttpResponse getAll(final Request request);
+ List<BaseEntityModel> getAllEntities(EntityEnum type);
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/WorkerService.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/WorkerService.java
new file mode 100644
index 0000000000..98fdfed672
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/WorkerService.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.federated.monitoring.services;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.BaseEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.DerbyRepository;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.EntityEnum;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.IRepository;
+
+import java.net.InetSocketAddress;
+import java.util.List;
+import java.util.concurrent.Future;
+
+public class WorkerService {
+ private static final IRepository _entityRepository = new DerbyRepository();
+
+ public void create(BaseEntityModel model) {
+ _entityRepository.createEntity(EntityEnum.WORKER, model);
+ }
+
+ public BaseEntityModel get(Long id) {
+ var model = _entityRepository.getEntity(EntityEnum.WORKER, id);
+
+ try {
+ var statisticsResponse = getWorkerStatistics(model.getAddress()).get();
+
+ if (statisticsResponse.isSuccessful()) {
+ FederatedStatistics.FedStatsCollection aggFedStats = new FederatedStatistics.FedStatsCollection();
+
+ Object[] tmp = statisticsResponse.getData();
+ if(tmp[0] instanceof FederatedStatistics.FedStatsCollection)
+ aggFedStats.aggregate((FederatedStatistics.FedStatsCollection)tmp[0]);
+
+ var statsStr = FederatedStatistics.displayStatistics(aggFedStats, 5);
+ model.setData(statsStr);
+ }
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+
+ return model;
+ }
+
+ public List<BaseEntityModel> getAll() {
+ return _entityRepository.getAllEntities(EntityEnum.WORKER);
+ }
+
+ private Future<FederatedResponse> getWorkerStatistics(String address) {
+ Future<FederatedResponse> result = null;
+
+ String host = address.split(":")[0];
+ int port = Integer.parseInt(address.split(":")[1]);
+
+ InetSocketAddress isa = new InetSocketAddress(host, port);
+ FederatedRequest frUDF = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+ new FederatedStatistics.FedStatsCollectFunction());
+ try {
+ result = FederatedData.executeFederatedOperation(isa, frUDF);
+ } catch(DMLRuntimeException dre) {
+ // silently ignore this exception --> caused by offline federated workers
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+
+ return result;
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index d960dac327..6ebff8eacd 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -1602,6 +1602,31 @@ public abstract class AutomatedTestBase {
return process;
}
+ /**
+ * Start new JVM for a federated monitoring backend at the port.
+ *
+ * @param port Port to use for the JVM
+ * @return the process associated with the worker.
+ */
+ protected Process startLocalFedMonitoring(int port, String[] addArgs) {
+ Process process = null;
+ String separator = System.getProperty("file.separator");
+ String classpath = System.getProperty("java.class.path");
+ String path = System.getProperty("java.home") + separator + "bin" + separator + "java";
+ String[] args = ArrayUtils.addAll(new String[]{path, "-cp", classpath, DMLScript.class.getName(),
+ "-fedMonitor", Integer.toString(port)}, addArgs);
+ ProcessBuilder processBuilder = new ProcessBuilder(args);
+
+ try {
+ process = processBuilder.start();
+ sleep(1000);
+ }
+ catch(IOException | InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ return process;
+ }
+
/**
* Start a thread for a worker. This will share the same JVM, so all static variables will be shared.!
*
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedMonitoringTestBase.java b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedMonitoringTestBase.java
new file mode 100644
index 0000000000..483e3eee9e
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedMonitoringTestBase.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.monitoring;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.BaseEntityModel;
+import org.apache.sysds.test.functions.federated.multitenant.MultiTenantTestBase;
+import org.junit.After;
+
+import java.io.IOException;
+import java.net.URI;
+import java.net.http.HttpClient;
+import java.net.http.HttpRequest;
+import java.net.http.HttpResponse;
+import java.util.ArrayList;
+import java.util.List;
+
+public abstract class FederatedMonitoringTestBase extends MultiTenantTestBase {
+ protected Process monitoringProcess;
+ private int monitoringPort;
+
+ private static final String WORKER_MAIN_PATH = "/workers";
+
+ @Override
+ public abstract void setUp();
+
+ // ensure that the processes are killed - even if the test throws an exception
+ @After
+ public void stopMonitoringProcesses() {
+ if (monitoringProcess != null) {
+ monitoringProcess.destroyForcibly();
+ }
+ }
+
+ /**
+ * Start federated backend monitoring processes on available port
+ *
+ * @return
+ */
+ protected void startFedMonitoring(String[] addArgs) {
+ monitoringPort = getRandomAvailablePort();
+ monitoringProcess = startLocalFedMonitoring(monitoringPort, addArgs);
+ }
+
+ protected List<HttpResponse<?>> addWorkers(int numWorkers) {
+ String uriStr = String.format("http://localhost:%d%s", monitoringPort, WORKER_MAIN_PATH);
+
+ List<HttpResponse<?>> responses = new ArrayList<>();
+ try {
+ ObjectMapper objectMapper = new ObjectMapper();
+ for (int i = 0; i < numWorkers; i++) {
+ String requestBody = objectMapper
+ .writerWithDefaultPrettyPrinter()
+ .writeValueAsString(new BaseEntityModel((i + 1L), "Worker", "localhost"));
+ var client = HttpClient.newHttpClient();
+ var request = HttpRequest.newBuilder(URI.create(uriStr))
+ .header("accept", "application/json")
+ .POST(HttpRequest.BodyPublishers.ofString(requestBody))
+ .build();
+ responses.add(client.send(request, HttpResponse.BodyHandlers.ofString()));
+ }
+
+ return responses;
+ }
+ catch (IOException | InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ protected HttpResponse<?> getWorkers() {
+ String uriStr = String.format("http://localhost:%d%s", monitoringPort, WORKER_MAIN_PATH);
+
+ try {
+ var client = HttpClient.newHttpClient();
+ var request = HttpRequest.newBuilder(URI.create(uriStr))
+ .header("accept", "application/json")
+ .GET().build();
+ return client.send(request, HttpResponse.BodyHandlers.ofString());
+ }
+ catch (IOException | InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerIntegrationCRUDTest.java b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerIntegrationCRUDTest.java
new file mode 100644
index 0000000000..d9fd9d5d8e
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerIntegrationCRUDTest.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.monitoring;
+
+import org.apache.commons.lang.StringUtils;
+import org.apache.http.HttpStatus;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class FederatedWorkerIntegrationCRUDTest extends FederatedMonitoringTestBase {
+ private final static String TEST_NAME = "FederatedWorkerIntegrationCRUDTest";
+
+ private final static String TEST_DIR = "functions/federated/monitoring/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + FederatedWorkerIntegrationCRUDTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"}));
+ startFedMonitoring(null);
+ }
+
+ @Test
+ public void testWorkerAddedForMonitoring() {
+ var addedWorkers = addWorkers(1);
+ var firstWorkerStatus = addedWorkers.get(0).statusCode();
+
+ Assert.assertEquals("Added worker status code", HttpStatus.SC_OK, firstWorkerStatus);
+ }
+
+ @Test
+ public void testCorrectAmountAddedWorkersForMonitoring() {
+ int numWorkers = 3;
+ var addedWorkers = addWorkers(numWorkers);
+
+ for (int i = 0; i < numWorkers; i++) {
+ var workerStatus = addedWorkers.get(i).statusCode();
+ Assert.assertEquals("Added worker status code", HttpStatus.SC_OK, workerStatus);
+ }
+
+ var getAllWorkersResponse = getWorkers();
+ var numReturnedWorkers = StringUtils.countMatches(getAllWorkersResponse.body().toString(), "id");
+
+ Assert.assertEquals("Amount of workers to get", numWorkers, numReturnedWorkers);
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerStatisticsTest.java b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerStatisticsTest.java
new file mode 100644
index 0000000000..dc8fca39b6
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerStatisticsTest.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.monitoring;
+
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.BaseEntityModel;
+import org.apache.sysds.runtime.controlprogram.federated.monitoring.services.WorkerService;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class FederatedWorkerStatisticsTest extends FederatedMonitoringTestBase {
+ private final static String TEST_NAME = "FederatedWorkerStatisticsTest";
+
+ private final static String TEST_DIR = "functions/federated/monitoring/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + FederatedWorkerStatisticsTest.class.getSimpleName() + "/";
+
+ private static int[] workerPorts;
+ private final WorkerService workerMonitoringService = new WorkerService();
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"}));
+ workerPorts = startFedWorkers(3);
+ }
+
+ @Test
+ public void testWorkerStatisticsReturnedForMonitoring() {
+ workerMonitoringService.create(new BaseEntityModel(1L, "Worker", "localhost:" + workerPorts[0]));
+
+ var model = workerMonitoringService.get(1L);
+ var modelData = model.getData();
+
+ Assert.assertNotNull("Data field of model contains worker statistics", model.getData());
+ Assert.assertNotEquals("Data field of model contains worker statistics",0, modelData.length());
+ Assert.assertTrue("Data field of model contains worker statistics", modelData.contains("JVM"));
+ }
+}