This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit a6d3131ca5d479ea81337bc9aab9c696a75bf24e Author: mito <[email protected]> AuthorDate: Fri Sep 16 14:29:54 2022 +0300 [SYSTEMDS-3422] Federated monitoring tool extensions This commit adds extensions and fixes to the federated monitoring tool. And makes it easier to run via documentation and examples. Closes #1698 --- bin/systemds | 12 +- conf/SystemDS-config.xml.template | 3 + docs/img/monitoring-arch-overview.png | Bin 0 -> 35194 bytes docs/site/federated-monitoring.md | 147 ++++++++++++++++++++ .../dashboard/main/dashboard.component.html | 4 + .../dashboard/main/dashboard.component.scss | 6 + .../modules/dashboard/main/dashboard.component.ts | 38 ++++- .../modules/dashboard/worker/worker.component.scss | 4 +- .../app/modules/events/view/view.component.html | 1 - .../src/app/modules/events/view/view.component.ts | 135 +++++++++--------- scripts/monitoring/src/app/utils.ts | 6 +- scripts/staging/SIMD-double-vectors/systemds | 14 +- src/main/java/org/apache/sysds/api/DMLOptions.java | 18 ++- src/main/java/org/apache/sysds/api/DMLScript.java | 50 +++++++ src/main/java/org/apache/sysds/conf/DMLConfig.java | 5 +- .../federated/FederatedWorkerHandler.java | 10 +- .../controlprogram/federated/monitoring/README.md | 4 +- .../controllers/CoordinatorController.java | 23 ++- .../monitoring/controllers/IController.java | 3 +- .../controllers/StatisticsController.java | 3 +- .../monitoring/controllers/WorkerController.java | 13 ++ .../monitoring/repositories/DerbyRepository.java | 55 ++++---- .../monitoring/services/StatisticsService.java | 52 +++++-- .../monitoring/services/WorkerService.java | 124 ++++++++++------- .../org/apache/sysds/test/AutomatedTestBase.java | 2 +- .../FederatedBackendPerformanceTest.java | 111 +++++++++++++++ .../FederatedCoordinatorIntegrationCRUDTest.java | 22 +-- .../monitoring/FederatedMonitoringTestBase.java | 51 ++++++- .../FederatedWorkerIntegrationCRUDTest.java | 18 +-- .../monitoring/FederatedWorkerStatisticsTest.java | 154 ++++++++++++++++++++- 30 files changed, 873 insertions(+), 215 deletions(-) diff --git a/bin/systemds b/bin/systemds index b88da6e00d..0855807754 100755 --- a/bin/systemds +++ b/bin/systemds @@ -167,7 +167,7 @@ Worker Usage: $0 [-r] WORKER [SystemDS.jar] <portnumber> [arguments] [-help] port : The port to open for the federated worker. -Federated Monitoring Usage: $0 [-r] FEDMONITOR [SystemDS.jar] <portnumber> [arguments] [-help] +Federated Monitoring Usage: $0 [-r] FEDMONITORING [SystemDS.jar] <portnumber> [arguments] [-help] port : The port to open for the federated monitoring tool. @@ -257,8 +257,8 @@ elif echo "$1" | grep -q "WORKER"; then printUsage fi shift -elif echo "$1" | grep -q "FEDMONITOR"; then - FEDMONITOR=1 +elif echo "$1" | grep -q "FEDMONITORING"; then + FEDMONITORING=1 shift if echo "$1" | grep -q "jar"; then SYSTEMDS_JAR_FILE=$1 @@ -287,8 +287,8 @@ if [ -z "$WORKER" ] ; then WORKER=0 fi -if [ -z "$FEDMONITOR" ] ; then - FEDMONITOR=0 +if [ -z "$FEDMONITORING" ] ; then + FEDMONITORING=0 fi # find me a SystemDS jar file to run @@ -449,7 +449,7 @@ elif [ "$FEDMONITORING" == 1 ]; then -cp $CLASSPATH \ $LOG4JPROP \ org.apache.sysds.api.DMLScript \ - -fedMonitor $PORT \ + -fedMonitoring $PORT \ $CONFIG_FILE \ $*" print_out "Executing command: $CMD" diff --git a/conf/SystemDS-config.xml.template b/conf/SystemDS-config.xml.template index 6e51cb047f..2dc98d1e4f 100644 --- a/conf/SystemDS-config.xml.template +++ b/conf/SystemDS-config.xml.template @@ -127,6 +127,9 @@ <!-- set the degree of parallelism of the federated worker event loop (<=0 means number of virtual cores) --> <sysds.federated.par_conn>0</sysds.federated.par_conn> + <!-- Set worker polling frequency for the monitoring backend in seconds --> + <sysds.federated.monitorFreq>3</sysds.federated.monitorFreq> + <!-- set the degree of parallelism of the federated worker instructions (<=0 means number of virtual cores) --> <sysds.federated.par_inst>0</sysds.federated.par_inst> diff --git a/docs/img/monitoring-arch-overview.png b/docs/img/monitoring-arch-overview.png new file mode 100644 index 0000000000..04fde48e99 Binary files /dev/null and b/docs/img/monitoring-arch-overview.png differ diff --git a/docs/site/federated-monitoring.md b/docs/site/federated-monitoring.md new file mode 100644 index 0000000000..753c492697 --- /dev/null +++ b/docs/site/federated-monitoring.md @@ -0,0 +1,147 @@ +--- +layout: site +title: Use SystemDS Federated Monitoring Software +--- +<!-- +{% comment %} +Licensed to the Apache Software Foundation (ASF) under one or more +contributor license agreements. See the NOTICE file distributed with +this work for additional information regarding copyright ownership. +The ASF licenses this file to you under the Apache License, Version 2.0 +(the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +{% endcomment %} +--> + +## SystemDS Federated Monitoring Software + +### Introduction + +To monitor the federated infrastructure of SystemDS, a monitoring tool was developed for this purpose. +A general overview of the architecture can be seen in [**Figure 1**](figure-1). +The monitoring tool consists of two separate decoupled modules, the Java-based **monitoring backend** and +the **monitoring frontend** developed in [Angular](https://angular.io/). + +**NOTE:** To work with the monitoring tool both the back- and frontend services must be running! + + + +### Installation & Build + +#### 1. Monitoring Backend + +To compile the project, run the following code, more information can be found [here](./install.md): + +```bash +mvn package -P distribution +``` + +```bash +[INFO] ------------------------------------------------------------------------ +[INFO] BUILD SUCCESS +[INFO] ------------------------------------------------------------------------ +[INFO] Total time: 31.730 s +[INFO] Finished at: 2020-06-18T11:00:29+02:00 +[INFO] ------------------------------------------------------------------------ +``` + +The following example works if you open an terminal at the root of the downloaded release, +or a cloned repository. (You can also change the `$(pwd)` with the full path to the folder.), +more information can be found [here](./run.md): + +```bash +export SYSTEMDS_ROOT=$(pwd) +export PATH=$SYSTEMDS_ROOT/bin:$PATH +``` + +#### 2. Monitoring Frontend + +Since the frontend is in **Angular v13**, a **node version 12/14/16** or later minor version is required. +To install `nodejs` and `npm` go to [https://nodejs.org/en/](https://nodejs.org/en/) and install version either **12.x**, +**14.x** or **16.x**: + +```bash +# Verify installation ------- +node --version +# Output +# v14.2.0 + +npm --version +# Output +# 6.14.4 +# --------------------------- +``` + +To install the npm packages required for the Angular app to run, open the directory with +the SystemDS code and run: + +```bash +# 1. Go into the directory with the frontend app +cd scripts/monitoring +# 2. Install all npm packages +npm install +``` +After those steps all the packages needed for running the monitoring tool should be installed. + +### Running + +Both back- and frontend applications are separate modules of the same tool, they can be independently started and stopped. +Since they are designed with loose decoupling in mind, the frontend can integrate with different backends, and +the backend can work with different frontends, provided that the format of the data and the communication protocol is +preserved. + +#### 1. Monitoring Backend + +To run the backend, use the `-fedMonitoring` flag followed by a `port` and can be executed using the systemds binary like this: + +```bash +# Start the backend with the binary +systemds FEDMONITORING 8080 + +# You should see something like this +#[ INFO] Setting up Federated Monitoring Backend on port 8080 +#[ INFO] Starting Federated Monitoring Backend server at port: 8080 +#[ INFO] Started Federated Monitoring Backend at port: 8080 +``` +This will start the backend server which will be listening for REST requests on `http://localhost:8080`. + +**NOTE:** The backend is polling all registered workers with a given frequency, it can be changed by including +the `<sysds.federated.monitorFreq>3</sysds.federated.monitorFreq>` in the `SystemDS-config.xml` file, accepting +**doubles**, representing seconds (0.5 can be used for setting the frequency to be half a second). The example shown +here will start the backend with polling with frequency of **3 seconds**, which is also the default value. + +#### 2. Monitoring Frontend + +To run the Angular app: + +```bash +# 1. While in the systemds directory go to the folder holding the frontend app +cd scripts/monitoring +# 2. Start the angular app +npm start +``` +After this step the Angular UI should be started on [http://localhost:4200](http://localhost:4200) and can be viewed by opening the +browser on the same address. + +**NOTE:** The address of the backend is hardcoded in the frontend application and can be changed by changing the `BASE_URI` in the `systemds/scripts/monitoring/app/constants.ts` file. **DO NOT** include a trailing slash `/`, at the end of the address. + +#### 3. Coordinator self-registration for monitoring + +In addition to the manual registration of coordinators for monitoring, the self-registration feature can be used by +setting the `-fedMonitoringAddress` flag followed by the address of the backend: + +```bash +# Start the coordinator process with the -fedMonitoringAddress flag and the address of the backend +systemds -f testFederated.dml -exec singlenode -explain -debug -stats 20 -fedMonitoringAddress http://localhost:8080 +``` + +**NOTE:** The backend service should already be running, otherwise the coordinator will not start. + diff --git a/scripts/monitoring/src/app/modules/dashboard/main/dashboard.component.html b/scripts/monitoring/src/app/modules/dashboard/main/dashboard.component.html index 69587234ff..fee5399cbf 100644 --- a/scripts/monitoring/src/app/modules/dashboard/main/dashboard.component.html +++ b/scripts/monitoring/src/app/modules/dashboard/main/dashboard.component.html @@ -17,6 +17,10 @@ under the License. --> +<button (click)="zoom('in')" class="md-btn-zoom" color="primary" mat-raised-button>+</button> +<button (click)="zoom('zero')" class="md-btn-zoom" color="primary" mat-raised-button>Reset</button> +<button (click)="zoom('out')" class="md-btn-zoom" color="primary" mat-raised-button>-</button> + <button (click)="openConfigDialog()" class="md-btn-right" color="primary" mat-raised-button>Config</button> <div id="dashboard-content"> diff --git a/scripts/monitoring/src/app/modules/dashboard/main/dashboard.component.scss b/scripts/monitoring/src/app/modules/dashboard/main/dashboard.component.scss index 7f043c1a0a..e0716551cf 100644 --- a/scripts/monitoring/src/app/modules/dashboard/main/dashboard.component.scss +++ b/scripts/monitoring/src/app/modules/dashboard/main/dashboard.component.scss @@ -34,3 +34,9 @@ top: 1em; margin: 0; } + +.md-btn-zoom { + right: 1em; + top: 1em; + margin: 0 0.2em 0 0; +} diff --git a/scripts/monitoring/src/app/modules/dashboard/main/dashboard.component.ts b/scripts/monitoring/src/app/modules/dashboard/main/dashboard.component.ts index b0336cdbba..cc551e4778 100644 --- a/scripts/monitoring/src/app/modules/dashboard/main/dashboard.component.ts +++ b/scripts/monitoring/src/app/modules/dashboard/main/dashboard.component.ts @@ -44,10 +44,12 @@ export class DashboardComponent implements OnInit { @ViewChild(DashboardDirective, {static: true}) fedSiteHost!: DashboardDirective; private jsPlumbInstance: jsPlumbInstance; + private scale: number = 1; + + private dashboardId: string = 'dashboard-content'; constructor(public dialog: MatDialog, - private fedSiteService: FederatedSiteService) { - } + private fedSiteService: FederatedSiteService) { } ngOnInit(): void { @@ -57,11 +59,27 @@ export class DashboardComponent implements OnInit { }; this.jsPlumbInstance = jsPlumb.getInstance(); - this.jsPlumbInstance.setContainer('dashboard-content'); + this.jsPlumbInstance.setContainer(this.dashboardId); this.openConfigDialog(); } + zoom(type: string): void { + let element = document.getElementById(this.dashboardId) + + if (type === 'in') { + this.scale += 0.1 + } else if (type === 'out') { + this.scale -= 0.1 + } else if (type === 'zero') { + this.scale = 1; + } + + // @ts-ignore + element.style.transform = `scale(${this.scale})`; + this.jsPlumbInstance.setZoom(this.scale); + } + openConfigDialog(): void { this.fedSiteService.getAllCoordinators().subscribe(coordinators => this.fedSiteData.coordinators = coordinators); @@ -78,9 +96,19 @@ export class DashboardComponent implements OnInit { let selectedWorkers = this.fedSiteData.workers.filter(w => result['selectedWorkerIds'].includes(w.id)); this.fedSiteHost.viewContainerRef.clear(); - this.jsPlumbInstance.removeAllEndpoints('dashboard-content'); + this.jsPlumbInstance.reset(); + + let mainDashboard = document.getElementById(this.dashboardId); + + // @ts-ignore + for (const child of mainDashboard.children) { + if (child.tagName === 'DIV' || child.tagName === 'svg') { + mainDashboard!.removeChild(child); + } + } - this.redrawDiagram(selectedCoordinators, selectedWorkers); + // Wait for previous components to be destroyed + setTimeout(() => this.redrawDiagram(selectedCoordinators, selectedWorkers), 500) } }); } diff --git a/scripts/monitoring/src/app/modules/dashboard/worker/worker.component.scss b/scripts/monitoring/src/app/modules/dashboard/worker/worker.component.scss index 12eb692580..1f71cf1422 100644 --- a/scripts/monitoring/src/app/modules/dashboard/worker/worker.component.scss +++ b/scripts/monitoring/src/app/modules/dashboard/worker/worker.component.scss @@ -18,8 +18,8 @@ */ .worker-card { - width: 300px; - height: 465px; + width: 31em; + height: 38em; } .worker-chart { diff --git a/scripts/monitoring/src/app/modules/events/view/view.component.html b/scripts/monitoring/src/app/modules/events/view/view.component.html index f29c316336..ba91232822 100644 --- a/scripts/monitoring/src/app/modules/events/view/view.component.html +++ b/scripts/monitoring/src/app/modules/events/view/view.component.html @@ -19,7 +19,6 @@ <div class="metrics-cards"> <mat-card class="worker-metrics-card" id="events-metric-card"> - <canvas id="event-timeline"></canvas> </mat-card> </div> diff --git a/scripts/monitoring/src/app/modules/events/view/view.component.ts b/scripts/monitoring/src/app/modules/events/view/view.component.ts index b7b0ebce83..6329f62370 100644 --- a/scripts/monitoring/src/app/modules/events/view/view.component.ts +++ b/scripts/monitoring/src/app/modules/events/view/view.component.ts @@ -28,6 +28,7 @@ import { constants } from "../../../constants"; import 'chartjs-adapter-moment'; import { Subject } from "rxjs"; import { EventStage } from "../../../models/eventStage.model"; +import { Utils } from "../../../utils"; @Component({ selector: 'app-view-worker-events', @@ -41,7 +42,7 @@ export class ViewWorkerEventsComponent { @ViewChild(MatPaginator) paginator: MatPaginator; @ViewChild(MatSort) sort: MatSort; - private eventTimelineChart: Chart; + private eventTimelineChart: any = {}; private stopPollingStatistics = new Subject<any>(); @@ -56,7 +57,7 @@ export class ViewWorkerEventsComponent { this.statistics = new Statistics(); - const eventCanvasEle: any = document.getElementById('event-timeline'); + const eventSectionEle: any = document.getElementById('events-metric-card'); this.fedSiteService.getStatisticsPolling(id, this.stopPollingStatistics).subscribe(stats => { this.statistics = stats; @@ -64,66 +65,75 @@ export class ViewWorkerEventsComponent { const timeframe = this.getTimeframe(); const minVal = this.getLastSeconds(timeframe[1], 3); - if (!this.eventTimelineChart) { - this.eventTimelineChart = new Chart(eventCanvasEle.getContext('2d'), { - type: 'bar', - data: { - labels: [], - datasets: [] - }, - options: { - indexAxis: 'y', - responsive: true, - plugins: { - legend: { - position: 'top', - onClick: () => null, - onHover: () => null, - onLeave: () => null, - labels: { - generateLabels(chart: Chart): LegendItem[] { - let legendItemsTmp: LegendItem[] = []; - - for (const dataset of chart.data.datasets) { - const label = dataset.label! - if (!legendItemsTmp.find(i => i.text === label)) { - let li: LegendItem = { - text: label, - //@ts-ignore - fillStyle: dataset.backgroundColor, - //@ts-ignore - strokeStyle: dataset.borderColor, + const coordinatorNames = this.getCoordinatorNames(); + + for (const coordinatorName of coordinatorNames) { + + if (!this.eventTimelineChart[coordinatorName]) { + const canvas: any = document.createElement("canvas"); + canvas.width = 400; + eventSectionEle.appendChild(canvas); + + this.eventTimelineChart[coordinatorName] = new Chart(canvas.getContext('2d'), { + type: 'bar', + data: { + labels: [], + datasets: [] + }, + options: { + indexAxis: 'y', + responsive: true, + plugins: { + legend: { + position: 'top', + onClick: () => null, + onHover: () => null, + onLeave: () => null, + labels: { + generateLabels(chart: Chart): LegendItem[] { + let legendItemsTmp: LegendItem[] = []; + + for (const dataset of chart.data.datasets) { + const label = dataset.label! + if (!legendItemsTmp.find(i => i.text === label)) { + let li: LegendItem = { + text: label, + //@ts-ignore + fillStyle: dataset.backgroundColor, + //@ts-ignore + strokeStyle: dataset.borderColor, + } + legendItemsTmp.push(li); } - legendItemsTmp.push(li); } - } - return legendItemsTmp; + return legendItemsTmp; + } } + }, + title: { + display: true, + text: `Event timeline of worker with respect to coordinator ${coordinatorName}` } }, - title: { - display: true, - text: 'Event timeline of worker with respect to coordinators' - } - }, - scales: { - x: { - min: 0, - ticks: { - callback: function(value, index, ticks) { - // @ts-ignore - return new Date(minVal + value).toLocaleTimeString(); - } + scales: { + x: { + min: 0, + ticks: { + callback: function(value, index, ticks) { + // @ts-ignore + return new Date(minVal + value).toLocaleTimeString(); + } + }, + stacked: true }, - stacked: true + y: { + stacked: true + } }, - y: { - stacked: true - } }, - }, - }) + }) + } } this.updateEventTimeline(); @@ -198,10 +208,11 @@ export class ViewWorkerEventsComponent { const coordinatorNames = this.getCoordinatorNames(); coordinatorNames.forEach(c => { - this.eventTimelineChart.data.datasets = []; - this.eventTimelineChart.data.labels = [coordinatorNames]; + this.eventTimelineChart[c].data.datasets = []; + this.eventTimelineChart[c].data.labels = [c]; let coordinatorEvents = this.statistics.events.filter(e => e.coordinatorName === c); + coordinatorEvents.sort(Utils.sortEventsStartDate); let stageStack: EventStage[] = []; @@ -217,14 +228,14 @@ export class ViewWorkerEventsComponent { let nextStage = event.stages[stageIndex]; stageStack.push(nextStage); - this.eventTimelineChart.data.datasets.push({ + this.eventTimelineChart[c].data.datasets.push({ type: 'bar', label: currentStage.operation, backgroundColor: this.getColor(currentStage.operation), data: [new Date(currentStage.endTime).getTime() - new Date(currentStage.startTime).getTime()] }); - this.placeIntermediateBars(currentStage, nextStage); + this.placeIntermediateBars(currentStage, nextStage, c); } } else { stageStack.push(event.stages[0]); @@ -232,7 +243,7 @@ export class ViewWorkerEventsComponent { const lastStage = stageStack.pop()!; - this.eventTimelineChart.data.datasets.push({ + this.eventTimelineChart[c].data.datasets.push({ type: 'bar', label: lastStage.operation, borderColor: constants.chartColors.red, @@ -247,25 +258,25 @@ export class ViewWorkerEventsComponent { }); } - this.eventTimelineChart.update('none'); + this.eventTimelineChart[c].update('none'); }) } - private placeIntermediateBars(first: EventStage, second: EventStage) { + private placeIntermediateBars(first: EventStage, second: EventStage, coordinatorName: any) { let firstEnd = new Date(first.endTime).getTime(); let secondStart = new Date(second.startTime).getTime(); let diff = secondStart - firstEnd; if (diff > 0) { - this.eventTimelineChart.data.datasets.push({ + this.eventTimelineChart[coordinatorName].data.datasets.push({ type: 'bar', label: 'Idle', backgroundColor: constants.chartColors.white, data: [diff] }); } else if (diff < 0) { - this.eventTimelineChart.data.datasets.push({ + this.eventTimelineChart[coordinatorName].data.datasets.push({ type: 'bar', label: 'Overlap', backgroundColor: constants.chartColors.grey, diff --git a/scripts/monitoring/src/app/utils.ts b/scripts/monitoring/src/app/utils.ts index d16a6ac684..093564cf4f 100644 --- a/scripts/monitoring/src/app/utils.ts +++ b/scripts/monitoring/src/app/utils.ts @@ -22,7 +22,9 @@ export class Utils { return a.x < b.x ? -1 : (a.x > b.x ? 1 : 0); } - public static sortStartDate(a, b) { - return a.startTime < b.startTime ? -1 : (a.startTime > b.startTime ? 1 : 0); + public static sortEventsStartDate(a, b) { + let aFirstStage = a.stages[0]; + let bFirstStage = b.stages[0]; + return aFirstStage.startTime < bFirstStage.startTime ? -1 : (aFirstStage.startTime > bFirstStage.startTime ? 1 : 0); } } diff --git a/scripts/staging/SIMD-double-vectors/systemds b/scripts/staging/SIMD-double-vectors/systemds index 61b73726b4..6c17a8e0fa 100755 --- a/scripts/staging/SIMD-double-vectors/systemds +++ b/scripts/staging/SIMD-double-vectors/systemds @@ -167,10 +167,6 @@ Worker Usage: $0 [-r] WORKER [SystemDS.jar] <portnumber> [arguments] [-help] port : The port to open for the federated worker. -Federated Monitoring Usage: $0 [-r] FEDMONITOR [SystemDS.jar] <portnumber> [arguments] [-help] - - port : The port to open for the federated monitoring tool. - Set custom launch configuration by setting/editing SYSTEMDS_STANDALONE_OPTS and/or SYSTEMDS_DISTRIBUTED_OPTS. @@ -257,8 +253,8 @@ elif echo "$1" | grep -q "WORKER"; then printUsage fi shift -elif echo "$1" | grep -q "FEDMONITOR"; then - FEDMONITOR=1 +elif echo "$1" | grep -q "FEDMONITORING"; then + FEDMONITORING=1 shift if echo "$1" | grep -q "jar"; then SYSTEMDS_JAR_FILE=$1 @@ -287,8 +283,8 @@ if [ -z "$WORKER" ] ; then WORKER=0 fi -if [ -z "$FEDMONITOR" ] ; then - FEDMONITOR=0 +if [ -z "$FEDMONITORING" ] ; then + FEDMONITORING=0 fi # find me a SystemDS jar file to run @@ -449,7 +445,7 @@ elif [ "$FEDMONITORING" == 1 ]; then -cp $CLASSPATH \ $LOG4JPROP \ org.apache.sysds.api.DMLScript \ - -fedMonitor $PORT \ + -fedMonitoring $PORT \ $CONFIG_FILE \ $*" print_out "Executing command: $CMD" diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java b/src/main/java/org/apache/sysds/api/DMLOptions.java index 7686d84d52..70af5ba9e8 100644 --- a/src/main/java/org/apache/sysds/api/DMLOptions.java +++ b/src/main/java/org/apache/sysds/api/DMLOptions.java @@ -75,6 +75,7 @@ public class DMLOptions { public int fedWorkerPort = -1; public boolean fedMonitoring = false; public int fedMonitoringPort = -1; + public String fedMonitoringAddress = null; public int pythonPort = -1; public boolean checkPrivacy = false; // Check which privacy constraints are loaded and checked during federated execution public boolean federatedCompilation = false; // Compile federated instructions based on input federation state and privacy constraints. @@ -97,7 +98,8 @@ public class DMLOptions { ", statsCount=" + statsCount + ", fedStats=" + fedStats + ", fedStatsCount=" + fedStatsCount + - ", fedMonitor=" + fedMonitoring + + ", fedMonitoring=" + fedMonitoring + + ", fedMonitoringAddress" + fedMonitoringAddress + ", memStats=" + memStats + ", explainType=" + explainType + ", execMode=" + execMode + @@ -235,9 +237,13 @@ public class DMLOptions { dmlOptions.fedWorkerPort = Integer.parseInt(line.getOptionValue("w")); } - if (line.hasOption("fedMonitor")) { + if (line.hasOption("fedMonitoring")) { dmlOptions.fedMonitoring= true; - dmlOptions.fedMonitoringPort = Integer.parseInt(line.getOptionValue("fedMonitor")); + dmlOptions.fedMonitoringPort = Integer.parseInt(line.getOptionValue("fedMonitoring")); + } + + if (line.hasOption("fedMonitoringAddress")) { + dmlOptions.fedMonitoringAddress = line.getOptionValue("fedMonitoringAddress"); } if (line.hasOption("f")){ @@ -368,7 +374,10 @@ public class DMLOptions { .hasOptionalArg().create("w"); Option monitorOpt = OptionBuilder .withDescription("Starts a federated monitoring backend with the given argument as the port.") - .hasOptionalArg().create("fedMonitor"); + .hasOptionalArg().create("fedMonitoring"); + Option registerMonitorOpt = OptionBuilder + .withDescription("Registers the coordinator for monitoring with the specified address of the monitoring tool.") + .hasOptionalArg().create("fedMonitoringAddress"); Option checkPrivacy = OptionBuilder .withDescription("Check which privacy constraints are loaded and checked during federated execution") .create("checkPrivacy"); @@ -396,6 +405,7 @@ public class DMLOptions { options.addOption(lineageOpt); options.addOption(fedOpt); options.addOption(monitorOpt); + options.addOption(registerMonitorOpt); options.addOption(monitorIdOpt); options.addOption(checkPrivacy); options.addOption(federatedCompilation); diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index 881cd47040..671f2bbcde 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -25,12 +25,18 @@ import java.io.FileReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.net.InetAddress; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; import java.text.DateFormat; import java.text.SimpleDateFormat; import java.util.Date; import java.util.Map; import java.util.Scanner; +import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.commons.cli.AlreadySelectedException; import org.apache.commons.cli.HelpFormatter; import org.apache.commons.lang.StringUtils; @@ -65,6 +71,8 @@ import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysds.runtime.controlprogram.federated.FederatedData; import org.apache.sysds.runtime.controlprogram.federated.FederatedWorker; import org.apache.sysds.runtime.controlprogram.federated.monitoring.FederatedMonitoringServer; +import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.CoordinatorModel; +import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.WorkerModel; import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler; import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool; @@ -141,6 +149,8 @@ public class DMLScript // Global seed public static int SEED = -1; + public static String MONITORING_ADDRESS = null; + // flag that indicates whether or not to suppress any prints to stdout public static boolean _suppressPrint2Stdout = false; //set default local spark configuration - used for local testing @@ -262,6 +272,7 @@ public class DMLScript LINEAGE_DEBUGGER = dmlOptions.lineage_debugger; SEED = dmlOptions.seed; + String fnameOptConfig = dmlOptions.configFile; boolean isFile = dmlOptions.filePath != null; String fileOrScript = isFile ? dmlOptions.filePath : dmlOptions.script; @@ -286,10 +297,15 @@ public class DMLScript } if(dmlOptions.fedMonitoring) { + loadConfiguration(fnameOptConfig); new FederatedMonitoringServer(dmlOptions.fedMonitoringPort, dmlOptions.debug); return true; } + if(dmlOptions.fedMonitoringAddress != null) { + MONITORING_ADDRESS = dmlOptions.fedMonitoringAddress; + } + LineageCacheConfig.setConfig(LINEAGE_REUSE); LineageCacheConfig.setCachePolicy(LINEAGE_POLICY); LineageCacheConfig.setEstimator(LINEAGE_ESTIMATE); @@ -410,6 +426,9 @@ public class DMLScript { // print basic time, environment info, and process id printStartExecInfo(dmlScriptStr); + + // optionally register for monitoring + registerForMonitoring(); //Step 1: parse configuration files & write any configuration specific global variables loadConfiguration(fnameOptConfig); @@ -585,6 +604,37 @@ public class DMLScript if(info) LOG.info("Process id: " + IDHandler.obtainProcessID()); } + + private static void registerForMonitoring() { + + if (MONITORING_ADDRESS != null && !MONITORING_ADDRESS.isBlank() && !MONITORING_ADDRESS.isEmpty()) { + try { + + String uriString = MONITORING_ADDRESS + "/coordinators"; + + ObjectMapper objectMapper = new ObjectMapper(); + + var model = new CoordinatorModel(); + model.name = InetAddress.getLocalHost().getHostName(); + model.host = InetAddress.getLocalHost().getHostName(); + model.processId = Long.parseLong(IDHandler.obtainProcessID()); + + String requestBody = objectMapper + .writerWithDefaultPrettyPrinter() + .writeValueAsString(model); + + var client = HttpClient.newHttpClient(); + var request = HttpRequest.newBuilder(URI.create(uriString)) + .header("accept", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(requestBody)) + .build(); + + client.send(request, HttpResponse.BodyHandlers.ofString()); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + } + } private static String getDateTime() { DateFormat dateFormat = new SimpleDateFormat("MM/dd/yyyy HH:mm:ss"); diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java b/src/main/java/org/apache/sysds/conf/DMLConfig.java index dd2d3a15ba..58a78f25e9 100644 --- a/src/main/java/org/apache/sysds/conf/DMLConfig.java +++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java @@ -123,6 +123,8 @@ public class DMLConfig public static final String FEDERATED_PAR_CONN = "sysds.federated.par_conn"; public static final String FEDERATED_READCACHE = "sysds.federated.readcache"; public static final String PRIVACY_CONSTRAINT_MOCK = "sysds.federated.priv_mock"; + /** Trigger frequency of the collecting and parsing statistics process on registered workers for monitoring in seconds */ + public static final String FEDERATED_MONITOR_FREQUENCY = "sysds.federated.monitorFreq"; public static final int DEFAULT_FEDERATED_PORT = 4040; // borrowed default Spark Port public static final int DEFAULT_NUMBER_OF_FEDERATED_WORKER_THREADS = 8; @@ -194,6 +196,7 @@ public class DMLConfig _defaultVals.put(FEDERATED_PAR_CONN, "-1"); // vcores _defaultVals.put(FEDERATED_PAR_INST, "-1"); // vcores _defaultVals.put(FEDERATED_READCACHE, "true"); // vcores + _defaultVals.put(FEDERATED_MONITOR_FREQUENCY, "3"); _defaultVals.put(PRIVACY_CONSTRAINT_MOCK, null); } @@ -447,7 +450,7 @@ public class DMLConfig PRINT_GPU_MEMORY_INFO, AVAILABLE_GPUS, SYNCHRONIZE_GPU, EAGER_CUDA_FREE, FLOATING_POINT_PRECISION, GPU_EVICTION_POLICY, LOCAL_SPARK_NUM_THREADS, EVICTION_SHADOW_BUFFERSIZE, GPU_MEMORY_ALLOCATOR, GPU_MEMORY_UTILIZATION_FACTOR, USE_SSL_FEDERATED_COMMUNICATION, DEFAULT_FEDERATED_INITIALIZATION_TIMEOUT, - FEDERATED_TIMEOUT + FEDERATED_TIMEOUT, FEDERATED_MONITOR_FREQUENCY }; StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java index 574245da20..b748fe8740 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java @@ -303,6 +303,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { switch(method) { case READ_VAR: + eventStage.operation = method.name(); result = readData(request, ecm); // matrix/frame break; case PUT_VAR: @@ -317,12 +318,14 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { result = execInstruction(request, ecm, eventStage); break; case EXEC_UDF: - result = execUDF(request, ecm); + result = execUDF(request, ecm, eventStage); break; case CLEAR: + eventStage.operation = method.name(); result = execClear(ecm); break; case NOOP: + eventStage.operation = method.name(); result = execNoop(); break; default: @@ -623,13 +626,16 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { } } - private FederatedResponse execUDF(FederatedRequest request, ExecutionContextMap ecm) { + private FederatedResponse execUDF(FederatedRequest request, ExecutionContextMap ecm, EventStageModel eventStage) { checkNumParams(request.getNumParams(), 1); ExecutionContext ec = ecm.get(request.getTID()); // get function and input parameters try { FederatedUDF udf = (FederatedUDF) request.getParam(0); + + eventStage.operation = udf.getClass().getSimpleName(); + Data[] inputs = Arrays.stream(udf.getInputIDs()).mapToObj(id -> ec.getVariable(String.valueOf(id))) .map(PrivacyMonitor::handlePrivacy).toArray(Data[]::new); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/README.md b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/README.md index 65377f675f..ff89363612 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/README.md +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/README.md @@ -28,10 +28,10 @@ The backend process can be started in a similar manner with how a worker is star ```bash cd systemds mvn package - ./bin/systemds [-r] FEDMONITOR [SystemDS.jar] <portnumber> [arguments] + ./bin/systemds [-r] FEDMONITORING [SystemDS.jar] <portnumber> [arguments] ``` -Or with the specified **-fedMonitor 8080** flag indicating the start of the backend process on the specified port, in our case **8080**. +Or with the specified **-fedMonitoring 8080** flag indicating the start of the backend process on the specified port, in our case **8080**. ## Main components diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/CoordinatorController.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/CoordinatorController.java index fde6f1b713..cce0eacf00 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/CoordinatorController.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/CoordinatorController.java @@ -19,13 +19,14 @@ package org.apache.sysds.runtime.controlprogram.federated.monitoring.controllers; -import io.netty.handler.codec.http.FullHttpResponse; -import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.CoordinatorModel; import org.apache.sysds.runtime.controlprogram.federated.monitoring.Request; import org.apache.sysds.runtime.controlprogram.federated.monitoring.Response; +import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.CoordinatorModel; import org.apache.sysds.runtime.controlprogram.federated.monitoring.services.CoordinatorService; import org.apache.sysds.runtime.controlprogram.federated.monitoring.services.MapperService; +import io.netty.handler.codec.http.FullHttpResponse; + public class CoordinatorController implements IController { private final CoordinatorService coordinatorService = new CoordinatorService(); @@ -42,9 +43,25 @@ public class CoordinatorController implements IController { @Override public FullHttpResponse update(Request request, Long objectId) { + var result = coordinatorService.get(objectId); + + if (result == null) { + return Response.notFound(Constants.NOT_FOUND_MSG); + } + var model = MapperService.getModelFromBody(request, CoordinatorModel.class); - model.generateMonitoringKey(); + model.id = objectId; + // Setting host + model.host = model.host == null ? result.host : model.host; + + // Setting processId + model.processId = model.processId == null ? result.processId : model.processId; + + // Setting name + model.name = model.name == null ? result.name : model.name; + + model.generateMonitoringKey(); coordinatorService.update(model); return Response.ok(model.toString()); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/IController.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/IController.java index 17a6df58be..2f6882fcdf 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/IController.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/IController.java @@ -19,9 +19,10 @@ package org.apache.sysds.runtime.controlprogram.federated.monitoring.controllers; -import io.netty.handler.codec.http.FullHttpResponse; import org.apache.sysds.runtime.controlprogram.federated.monitoring.Request; +import io.netty.handler.codec.http.FullHttpResponse; + public interface IController { FullHttpResponse create(final Request request); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/StatisticsController.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/StatisticsController.java index 5fdeba1b09..a2ec3c5de4 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/StatisticsController.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/StatisticsController.java @@ -19,12 +19,13 @@ package org.apache.sysds.runtime.controlprogram.federated.monitoring.controllers; -import io.netty.handler.codec.http.FullHttpResponse; import org.apache.sysds.runtime.controlprogram.federated.monitoring.Request; import org.apache.sysds.runtime.controlprogram.federated.monitoring.Response; import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.StatisticsOptions; import org.apache.sysds.runtime.controlprogram.federated.monitoring.services.StatisticsService; +import io.netty.handler.codec.http.FullHttpResponse; + public class StatisticsController implements IController { private final StatisticsService statisticsService = new StatisticsService(); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/WorkerController.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/WorkerController.java index 95e47c54b1..e81834e035 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/WorkerController.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/controllers/WorkerController.java @@ -42,7 +42,20 @@ public class WorkerController implements IController { @Override public FullHttpResponse update(Request request, Long objectId) { + var result = workerService.get(objectId); + + if (result == null) { + return Response.notFound(Constants.NOT_FOUND_MSG); + } + var model = MapperService.getModelFromBody(request, WorkerModel.class); + model.id = objectId; + + // Setting address + model.address = model.address == null ? result.address : model.address; + + // Setting name + model.name = model.name == null ? result.name : model.name; workerService.update(model); model.setOnlineStatus(workerService.getWorkerOnlineStatus(model.id)); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/DerbyRepository.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/DerbyRepository.java index f19bde7b4f..9c23478b40 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/DerbyRepository.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/repositories/DerbyRepository.java @@ -36,7 +36,6 @@ import java.util.List; public class DerbyRepository implements IRepository { private final static String DB_CONNECTION = "jdbc:derby:memory:derbyDB"; - private final Connection _db; private final List<BaseModel> _allEntities = new ArrayList<>(List.of( new WorkerModel(), new CoordinatorModel(), @@ -60,17 +59,15 @@ public class DerbyRepository implements IRepository { private static final String GET_ALL_ENTITIES_STMT = "SELECT * FROM %s"; public DerbyRepository() { - _db = createMonitoringDatabase(); + createMonitoringDatabase(); } - private Connection createMonitoringDatabase() { + private void createMonitoringDatabase() { Connection db = null; try { // Creates only if DB doesn't exist db = DriverManager.getConnection(DB_CONNECTION + ";create=true"); createMonitoringEntitiesInDB(db); - - return db; } catch (SQLException e) { throw new RuntimeException(e); @@ -127,7 +124,7 @@ public class DerbyRepository implements IRepository { PreparedStatement st = null; long id = -1L; - try { + try (var db = DriverManager.getConnection(DB_CONNECTION)) { StringBuilder sb = new StringBuilder(); @@ -156,7 +153,7 @@ public class DerbyRepository implements IRepository { sb.replace(sb.length() - 1, sb.length(), ")"); String bindVarsStr = String.format("(%s)", String.join(",", Collections.nCopies(dbFieldCount, "?"))); - st = _db.prepareStatement(String.format(ENTITY_INSERT_STMT, sb, bindVarsStr), PreparedStatement.RETURN_GENERATED_KEYS); + st = db.prepareStatement(String.format(ENTITY_INSERT_STMT, sb, bindVarsStr), PreparedStatement.RETURN_GENERATED_KEYS); int bindVarIndex = 1; for (var field: fields) { @@ -198,13 +195,15 @@ public class DerbyRepository implements IRepository { public <T extends BaseModel> T getEntity(Long id, Class<T> type) { T resultModel = null; - try { + PreparedStatement st = null; + + try (var db = DriverManager.getConnection(DB_CONNECTION)) { var entityName = type.getSimpleName().replace(Constants.ENTITY_CLASS_SUFFIX, ""); - PreparedStatement st = _db.prepareStatement( - String.format(GET_ENTITY_WITH_COL_STMT, entityName, Constants.ENTITY_ID_COL)); + st = db.prepareStatement(String.format(GET_ENTITY_WITH_COL_STMT, entityName, Constants.ENTITY_ID_COL)); st.setLong(1, id); + var resultSet = st.executeQuery(); if (resultSet.next()){ @@ -219,12 +218,12 @@ public class DerbyRepository implements IRepository { public <T extends BaseModel> List<T> getAllEntities(Class<T> type) { List<T> resultModels = new ArrayList<>(); + PreparedStatement st = null; - try { + try (var db = DriverManager.getConnection(DB_CONNECTION)) { var entityName = type.getSimpleName().replace(Constants.ENTITY_CLASS_SUFFIX, ""); - PreparedStatement st = _db.prepareStatement( - String.format(GET_ALL_ENTITIES_STMT, entityName)); + st = db.prepareStatement(String.format(GET_ALL_ENTITIES_STMT, entityName)); var resultSet = st.executeQuery(); while (resultSet.next()){ @@ -244,15 +243,13 @@ public class DerbyRepository implements IRepository { List<T> resultModels = new ArrayList<>(); PreparedStatement st = null; - try { + try (var db = DriverManager.getConnection(DB_CONNECTION)) { var entityName = type.getSimpleName().replace(Constants.ENTITY_CLASS_SUFFIX, ""); if (rowCount < 0) { - st = _db.prepareStatement( - String.format(GET_ENTITY_WITH_COL_STMT, entityName, fieldName)); + st = db.prepareStatement(String.format(GET_ENTITY_WITH_COL_STMT, entityName, fieldName)); } else { - st = _db.prepareStatement( - String.format(GET_ENTITY_WITH_COL_LIMIT_STMT, entityName, fieldName, rowCount)); + st = db.prepareStatement(String.format(GET_ENTITY_WITH_COL_LIMIT_STMT, entityName, fieldName, rowCount)); } if (value.getClass().isAssignableFrom(String.class)) { @@ -274,12 +271,13 @@ public class DerbyRepository implements IRepository { public <T extends BaseModel> void removeAllEntitiesByField(String fieldName, Object value, Class<T> type) { - try { + PreparedStatement st = null; + + try (var db = DriverManager.getConnection(DB_CONNECTION)) { var entityName = type.getSimpleName().replace(Constants.ENTITY_CLASS_SUFFIX, ""); - PreparedStatement st = _db.prepareStatement( - String.format(DELETE_ENTITY_WITH_COL_STMT, entityName, fieldName)); + st = db.prepareStatement(String.format(DELETE_ENTITY_WITH_COL_STMT, entityName, fieldName)); if (value.getClass().isAssignableFrom(String.class)) { st.setString(1, String.valueOf(value)); @@ -296,7 +294,9 @@ public class DerbyRepository implements IRepository { @Override public <T extends BaseModel> void updateEntity(T model) { - try { + PreparedStatement st = null; + + try (var db = DriverManager.getConnection(DB_CONNECTION)) { StringBuilder sb = new StringBuilder(); var entityName = model.getClass().getSimpleName().replace(Constants.ENTITY_CLASS_SUFFIX, ""); @@ -324,7 +324,7 @@ public class DerbyRepository implements IRepository { sb.replace(sb.length() - 1, sb.length(), ""); - PreparedStatement st = _db.prepareStatement(String.format(UPDATE_ENTITY_WITH_COL_STMT, entityName, sb, Constants.ENTITY_ID_COL)); + st = db.prepareStatement(String.format(UPDATE_ENTITY_WITH_COL_STMT, entityName, sb, Constants.ENTITY_ID_COL)); for (int i = 0; i < fieldsToChange.size(); i++) { var field = fieldsToChange.get(i); @@ -352,13 +352,16 @@ public class DerbyRepository implements IRepository { @Override public <T extends BaseModel> void removeEntity(Long id, Class<T> type) { - try { + + PreparedStatement st = null; + + try (var db = DriverManager.getConnection(DB_CONNECTION)) { var entityName = type.getSimpleName().replace(Constants.ENTITY_CLASS_SUFFIX, ""); - PreparedStatement st = _db.prepareStatement( - String.format(DELETE_ENTITY_WITH_COL_STMT, entityName, Constants.ENTITY_ID_COL)); + st = db.prepareStatement(String.format(DELETE_ENTITY_WITH_COL_STMT, entityName, Constants.ENTITY_ID_COL)); st.setLong(1, id); + st.executeUpdate(); } catch (SQLException e) { throw new RuntimeException(e); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/StatisticsService.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/StatisticsService.java index d60f7935a4..d99a13ea95 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/StatisticsService.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/StatisticsService.java @@ -21,7 +21,10 @@ package org.apache.sysds.runtime.controlprogram.federated.monitoring.services; import java.net.InetSocketAddress; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -50,34 +53,64 @@ public class StatisticsService { private static final IRepository entityRepository = new DerbyRepository(); public StatisticsModel getAll(Long workerId, StatisticsOptions options) { + CompletableFuture<Void> utilizationFuture = null; + CompletableFuture<Void> trafficFuture = null; + CompletableFuture<Void> eventsFuture = null; + CompletableFuture<Void> dataObjFuture = null; + CompletableFuture<Void> requestsFuture = null; + var stats = new StatisticsModel(); if (options.utilization) { - stats.utilization = entityRepository.getAllEntitiesByField(Constants.ENTITY_WORKER_ID_COL, workerId, UtilizationModel.class, options.rowCount); + utilizationFuture = CompletableFuture + .supplyAsync(() -> entityRepository.getAllEntitiesByField(Constants.ENTITY_WORKER_ID_COL, workerId, UtilizationModel.class, options.rowCount)) + .thenAcceptAsync(result -> stats.utilization = result); } if (options.traffic) { - stats.traffic = entityRepository.getAllEntitiesByField(Constants.ENTITY_WORKER_ID_COL, workerId, TrafficModel.class, options.rowCount); + trafficFuture = CompletableFuture + .supplyAsync(() -> entityRepository.getAllEntitiesByField(Constants.ENTITY_WORKER_ID_COL, workerId, TrafficModel.class, options.rowCount)) + .thenAcceptAsync(result -> stats.traffic = result); } if (options.events) { - stats.events = entityRepository.getAllEntitiesByField(Constants.ENTITY_WORKER_ID_COL, workerId, EventModel.class, options.rowCount); + eventsFuture = CompletableFuture + .supplyAsync(() -> { + var events = entityRepository.getAllEntitiesByField(Constants.ENTITY_WORKER_ID_COL, workerId, EventModel.class, options.rowCount); - for (var event: stats.events) { - event.setCoordinatorName(entityRepository.getEntity(event.coordinatorId, CoordinatorModel.class).name); + for (var event : events) { + event.setCoordinatorName(entityRepository.getEntity(event.coordinatorId, CoordinatorModel.class).name); - event.stages = entityRepository.getAllEntitiesByField(Constants.ENTITY_EVENT_ID_COL, event.id, EventStageModel.class); - } + event.stages = entityRepository.getAllEntitiesByField(Constants.ENTITY_EVENT_ID_COL, event.id, EventStageModel.class); + } + + return events; + }) + .thenAcceptAsync(result -> stats.events = result); } if (options.dataObjects) { - stats.dataObjects = entityRepository.getAllEntitiesByField(Constants.ENTITY_WORKER_ID_COL, workerId, DataObjectModel.class); + dataObjFuture = CompletableFuture + .supplyAsync(() -> entityRepository.getAllEntitiesByField(Constants.ENTITY_WORKER_ID_COL, workerId, DataObjectModel.class)) + .thenAcceptAsync(result -> stats.dataObjects = result); } if (options.requests) { - stats.requests = entityRepository.getAllEntitiesByField(Constants.ENTITY_WORKER_ID_COL, workerId, RequestModel.class); + requestsFuture = CompletableFuture + .supplyAsync(() -> entityRepository.getAllEntitiesByField(Constants.ENTITY_WORKER_ID_COL, workerId, RequestModel.class)) + .thenAcceptAsync(result -> stats.requests = result); } + List<CompletableFuture<Void>> completableFutures = Arrays.asList(utilizationFuture, trafficFuture, eventsFuture, dataObjFuture, requestsFuture); + + completableFutures.forEach(cf -> { + try { + cf.get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + }); + return stats; } @@ -122,7 +155,6 @@ public class StatisticsService { traffic.forEach(t -> t.workerId = workerId); dataObjects.forEach(o -> o.workerId = workerId); - for (var event: events) { event.workerId = workerId; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/WorkerService.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/WorkerService.java index 854b804c14..2e8c663b79 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/WorkerService.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/monitoring/services/WorkerService.java @@ -19,30 +19,35 @@ package org.apache.sysds.runtime.controlprogram.federated.monitoring.services; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + import org.apache.commons.lang3.tuple.MutablePair; import org.apache.commons.lang3.tuple.Pair; +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.DataObjectModel; import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.RequestModel; +import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.StatisticsModel; import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.WorkerModel; import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.Constants; import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.DerbyRepository; import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.IRepository; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.Executors; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; - public class WorkerService { private static final IRepository entityRepository = new DerbyRepository(); // { workerId, { workerAddress, workerStatus } } private static final Map<Long, Pair<String, Boolean>> cachedWorkers = new HashMap<>(); + private static ScheduledExecutorService executorService; public WorkerService() { - ScheduledExecutorService executor = Executors.newScheduledThreadPool(1); - executor.scheduleAtFixedRate(syncWorkerStatisticsWithDB(), 0, 3, TimeUnit.SECONDS); + var freq = ConfigurationManager.getDMLConfig().getDoubleValue(DMLConfig.FEDERATED_MONITOR_FREQUENCY); + startStatsCollectionProcess(1, freq); } public Long create(WorkerModel model) { @@ -105,61 +110,84 @@ public class WorkerService { } } - private static Runnable syncWorkerStatisticsWithDB() { - return () -> { + private static synchronized void startStatsCollectionProcess(int threadCount, double frequencySeconds) { + if (executorService == null) { + executorService = Executors.newScheduledThreadPool(threadCount); + executorService.scheduleAtFixedRate(syncWorkerStatisticsRunnable(), 0, Math.round(frequencySeconds * 1000), TimeUnit.MILLISECONDS); + } + } - for(Map.Entry<Long, Pair<String, Boolean>> entry : cachedWorkers.entrySet()) { - Long id = entry.getKey(); - String address = entry.getValue().getLeft(); + public static void syncWorkerStatisticsWithDB(StatisticsModel stats, Long id) { - var stats = StatisticsService.getWorkerStatistics(id, address); + // NOTE: This part of the code is not directly connected to requests coming from the frontend + // and runs in the background. There is no need to handle the result data from the futures since + // it is directly saved in the database, and it will be returned in the next frontend request. - if (stats != null) { + if (stats != null) { - cachedWorkers.get(id).setValue(true); + cachedWorkers.get(id).setValue(true); - if (stats.utilization != null) { - entityRepository.createEntity(stats.utilization.get(0)); - } - if (stats.traffic != null) { - for (var trafficEntity: stats.traffic) { - if (trafficEntity.coordinatorId > 0) { - entityRepository.createEntity(trafficEntity); - } + if (stats.utilization != null) { + CompletableFuture.runAsync(() -> entityRepository.createEntity(stats.utilization.get(0))); + } + if (stats.traffic != null) { + CompletableFuture.runAsync(() -> { + for (var trafficEntity : stats.traffic) { + if (trafficEntity.coordinatorId > 0) { + entityRepository.createEntity(trafficEntity); } } - if (stats.events != null) { - for (var eventEntity: stats.events) { - if (eventEntity.coordinatorId > 0) { - var eventId = entityRepository.createEntity(eventEntity); + }); + } + if (stats.events != null) { + for (var eventEntity: stats.events) { + if (eventEntity.coordinatorId > 0) { + CompletableFuture.runAsync(() -> { + var eventId = entityRepository.createEntity(eventEntity); - for (var stageEntity: eventEntity.stages) { - stageEntity.eventId = eventId; + for (var stageEntity : eventEntity.stages) { + stageEntity.eventId = eventId; - entityRepository.createEntity(stageEntity); - } + entityRepository.createEntity(stageEntity); } - } + }); } - if (stats.dataObjects != null) { - entityRepository.removeAllEntitiesByField(Constants.ENTITY_WORKER_ID_COL, id, DataObjectModel.class); + } + } + if (stats.dataObjects != null) { + CompletableFuture.runAsync(() -> { + entityRepository.removeAllEntitiesByField(Constants.ENTITY_WORKER_ID_COL, id, DataObjectModel.class); - for (var dataObjectEntity: stats.dataObjects) { - entityRepository.createEntity(dataObjectEntity); - } + for (var dataObjectEntity : stats.dataObjects) { + entityRepository.createEntity(dataObjectEntity); } - if (stats.requests != null) { - entityRepository.removeAllEntitiesByField(Constants.ENTITY_WORKER_ID_COL, id, RequestModel.class); + }); + } + if (stats.requests != null) { + CompletableFuture.runAsync(() -> { + entityRepository.removeAllEntitiesByField(Constants.ENTITY_WORKER_ID_COL, id, RequestModel.class); - for (var requestEntity: stats.requests) { - if (requestEntity.coordinatorId > 0) { - entityRepository.createEntity(requestEntity); - } + for (var requestEntity : stats.requests) { + if (requestEntity.coordinatorId > 0) { + entityRepository.createEntity(requestEntity); } } - } else { - cachedWorkers.get(id).setValue(false); - } + }); + } + } else { + cachedWorkers.get(id).setValue(false); + } + } + + private static Runnable syncWorkerStatisticsRunnable() { + return () -> { + for(Map.Entry<Long, Pair<String, Boolean>> entry : cachedWorkers.entrySet()) { + Long id = entry.getKey(); + String address = entry.getValue().getLeft(); + + CompletableFuture + .supplyAsync(() -> StatisticsService.getWorkerStatistics(id, address)) + .thenAcceptAsync(stats -> syncWorkerStatisticsWithDB(stats, id)); } }; } diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java index c5f7d1a54b..f849111b72 100644 --- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java @@ -1620,7 +1620,7 @@ public abstract class AutomatedTestBase { String classpath = System.getProperty("java.class.path"); String path = System.getProperty("java.home") + separator + "bin" + separator + "java"; String[] args = ArrayUtils.addAll(new String[]{path, "-cp", classpath, DMLScript.class.getName(), - "-fedMonitor", Integer.toString(port)}, addArgs); + "-fedMonitoring", Integer.toString(port)}, addArgs); ProcessBuilder processBuilder = new ProcessBuilder(args); try { diff --git a/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedBackendPerformanceTest.java b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedBackendPerformanceTest.java new file mode 100644 index 0000000000..34f3d386a9 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedBackendPerformanceTest.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.monitoring; + +import static java.lang.Thread.sleep; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.WorkerModel; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; + +import java.net.http.HttpResponse; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +public class FederatedBackendPerformanceTest extends FederatedMonitoringTestBase { + private static final Log LOG = LogFactory.getLog(FederatedBackendPerformanceTest.class.getName()); + private final static String TEST_NAME = "FederatedBackendPerformanceTest"; + private final static String TEST_DIR = "functions/federated/monitoring/"; + private static final String TEST_CLASS_DIR = TEST_DIR + FederatedBackendPerformanceTest.class.getSimpleName() + "/"; + private static final String PERFORMANCE_FORMAT = "For %d number of requests, milliseconds elapsed %d."; + + private static int[] workerPort; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"})); + startFedMonitoring(null); + workerPort = startFedWorkers(1); + } + + @Test + @Ignore + public void testBackendPerformance() throws InterruptedException { + int numRequests = 20; + + double meanExecTime = 0.f; + double numRepetitionsExperiment = 100.f; + + addEntities(1, Entity.WORKER); + updateEntity(new WorkerModel(1L, "Worker", "localhost:" + workerPort[0]), Entity.WORKER); + // Give time for statistics to be collected (70s) + sleep(70000); + + ExecutorService executor = Executors.newFixedThreadPool(numRequests); + + for (int j = -10; j < numRepetitionsExperiment; j++) { + + long start = System.currentTimeMillis(); + + // Returns a list of Futures holding their status and results when all complete. + // Future.isDone() is true for each element of the returned list + // https://docs.oracle.com/javase/7/docs/api/java/util/concurrent/ExecutorService.html#invokeAll(java.util.Collection) + List<Future<HttpResponse<?>>> taskFutures = executor.invokeAll(Collections.nCopies(numRequests, + () -> getEntities(Entity.STATISTICS))); + + long finish = System.currentTimeMillis(); + long elapsedTime = (finish - start); + + if (j >= 0) { + meanExecTime += elapsedTime; + } + + taskFutures.forEach(res -> { + try { + Assert.assertEquals("Stats parsed correctly", res.get().statusCode(), 200); + } catch (InterruptedException | ExecutionException e) { + e.printStackTrace(); + } + }); + + // Wait for a second at the end of each iteration + sleep(500); + } + + executor.shutdown(); + + // Wait until all threads are finished + // Returns true if all tasks have completed following shut down. + // Note that isTerminated is never true unless either shutdown or shutdownNow was called first. + while (!executor.isTerminated()); + + LOG.info(String.format(PERFORMANCE_FORMAT, numRequests, Math.round(meanExecTime / numRepetitionsExperiment))); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedCoordinatorIntegrationCRUDTest.java b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedCoordinatorIntegrationCRUDTest.java index d3cc095034..4e85da6266 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedCoordinatorIntegrationCRUDTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedCoordinatorIntegrationCRUDTest.java @@ -21,7 +21,7 @@ package org.apache.sysds.test.functions.federated.monitoring; import org.apache.commons.lang.StringUtils; import org.apache.http.HttpStatus; -import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.WorkerModel; +import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.CoordinatorModel; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; import org.junit.Assert; @@ -43,7 +43,7 @@ public class FederatedCoordinatorIntegrationCRUDTest extends FederatedMonitoring @Test public void testCoordinatorAddedForMonitoring() { - var addedCoordinators = addEntities(1); + var addedCoordinators = addEntities(1, Entity.COORDINATOR); var firstCoordinatorStatus = addedCoordinators.get(0).statusCode(); Assert.assertEquals("Added coordinator status code", HttpStatus.SC_OK, firstCoordinatorStatus); @@ -52,10 +52,10 @@ public class FederatedCoordinatorIntegrationCRUDTest extends FederatedMonitoring @Test @Ignore public void testCoordinatorRemovedFromMonitoring() { - addEntities(2); - var statusCode = removeEntity(1L).statusCode(); + addEntities(2, Entity.COORDINATOR); + var statusCode = removeEntity(1L, Entity.COORDINATOR).statusCode(); - var getAllCoordinatorsResponse = getEntities(); + var getAllCoordinatorsResponse = getEntities(Entity.COORDINATOR); var numReturnedCoordinators = StringUtils.countMatches(getAllCoordinatorsResponse.body().toString(), "id"); Assert.assertEquals("Removed coordinator status code", HttpStatus.SC_OK, statusCode); @@ -65,12 +65,12 @@ public class FederatedCoordinatorIntegrationCRUDTest extends FederatedMonitoring @Test @Ignore public void testCoordinatorDataUpdated() { - addEntities(3); - var newCoordinatorData = new WorkerModel(1L, "NonExistentName", "nonexistent.address"); + addEntities(3, Entity.COORDINATOR); + var newCoordinatorData = new CoordinatorModel(1L); - var editedCoordinator = updateEntity(newCoordinatorData); + var editedCoordinator = updateEntity(newCoordinatorData, Entity.COORDINATOR); - var getAllCoordinatorsResponse = getEntities(); + var getAllCoordinatorsResponse = getEntities(Entity.COORDINATOR); var numCoordinatorsNewData = StringUtils.countMatches(getAllCoordinatorsResponse.body().toString(), newCoordinatorData.name); Assert.assertEquals("Updated coordinator status code", HttpStatus.SC_OK, editedCoordinator.statusCode()); @@ -81,14 +81,14 @@ public class FederatedCoordinatorIntegrationCRUDTest extends FederatedMonitoring @Ignore public void testCorrectAmountAddedCoordinatorsForMonitoring() { int numCoordinators = 3; - var addedCoordinators = addEntities(numCoordinators); + var addedCoordinators = addEntities(numCoordinators, Entity.COORDINATOR); for (int i = 0; i < numCoordinators; i++) { var coordinatorStatus = addedCoordinators.get(i).statusCode(); Assert.assertEquals("Added coordinator status code", HttpStatus.SC_OK, coordinatorStatus); } - var getAllCoordinatorsResponse = getEntities(); + var getAllCoordinatorsResponse = getEntities(Entity.COORDINATOR); var numReturnedCoordinators = StringUtils.countMatches(getAllCoordinatorsResponse.body().toString(), "id"); Assert.assertEquals("Amount of coordinators to get", numCoordinators, numReturnedCoordinators); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedMonitoringTestBase.java b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedMonitoringTestBase.java index 2a0901b15a..a3eb95abfc 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedMonitoringTestBase.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedMonitoringTestBase.java @@ -27,6 +27,8 @@ import java.net.http.HttpResponse; import java.util.ArrayList; import java.util.List; +import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.BaseModel; +import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.CoordinatorModel; import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.WorkerModel; import org.apache.sysds.test.functions.federated.multitenant.MultiTenantTestBase; import org.junit.After; @@ -40,7 +42,14 @@ public abstract class FederatedMonitoringTestBase extends MultiTenantTestBase { private static final String MAIN_URI = "http://localhost"; private static final String WORKER_MAIN_PATH = "/workers"; - // private static final String COORDINATOR_MAIN_PATH = "/coordinators"; + private static final String COORDINATOR_MAIN_PATH = "/coordinators"; + private static final String STATISTICS_MAIN_PATH = "/statistics"; + + public enum Entity { + WORKER, + COORDINATOR, + STATISTICS + } @Override public abstract void setUp(); @@ -63,10 +72,15 @@ public abstract class FederatedMonitoringTestBase extends MultiTenantTestBase { monitoringProcess = startLocalFedMonitoring(monitoringPort, addArgs); } - protected List<HttpResponse<?>> addEntities(int count) { + protected List<HttpResponse<?>> addEntities(int count, Entity entity) { String uriStr = MAIN_URI + ":" + monitoringPort + WORKER_MAIN_PATH; String name = "Worker"; + if (entity == Entity.COORDINATOR) { + uriStr = MAIN_URI + ":" + monitoringPort + COORDINATOR_MAIN_PATH; + name = "Coordinator"; + } + List<HttpResponse<?>> responses = new ArrayList<>(); try { ObjectMapper objectMapper = new ObjectMapper(); @@ -89,14 +103,25 @@ public abstract class FederatedMonitoringTestBase extends MultiTenantTestBase { } } - protected HttpResponse<?> updateEntity(WorkerModel editModel) { - String uriStr = MAIN_URI + ":" + monitoringPort + WORKER_MAIN_PATH; + protected HttpResponse<?> updateEntity(BaseModel editModel, Entity entity) { + String uriStr = MAIN_URI + ":" + monitoringPort + WORKER_MAIN_PATH + "/" + editModel.id; + + if (entity == Entity.COORDINATOR) { + uriStr = MAIN_URI + ":" + monitoringPort + COORDINATOR_MAIN_PATH + "/" + editModel.id; + } try { ObjectMapper objectMapper = new ObjectMapper(); + String requestBody = objectMapper .writerWithDefaultPrettyPrinter() - .writeValueAsString(new WorkerModel(editModel.id, editModel.name, editModel.address)); + .writeValueAsString(new WorkerModel(editModel.id, ((WorkerModel)editModel).name, ((WorkerModel)editModel).address)); + + if (entity == Entity.COORDINATOR) { + requestBody = objectMapper + .writerWithDefaultPrettyPrinter() + .writeValueAsString(new CoordinatorModel()); + } var client = HttpClient.newHttpClient(); var request = HttpRequest.newBuilder(URI.create(uriStr)) .header("accept", "application/json") @@ -110,9 +135,13 @@ public abstract class FederatedMonitoringTestBase extends MultiTenantTestBase { } } - protected HttpResponse<?> removeEntity(Long id) { + protected HttpResponse<?> removeEntity(Long id, Entity entity) { String uriStr = MAIN_URI + ":" + monitoringPort + WORKER_MAIN_PATH + "/" + id; + if (entity == Entity.COORDINATOR) { + uriStr = MAIN_URI + ":" + monitoringPort + COORDINATOR_MAIN_PATH + "/" + id; + } + try { var client = HttpClient.newHttpClient(); var request = HttpRequest.newBuilder(URI.create(uriStr)) @@ -127,9 +156,17 @@ public abstract class FederatedMonitoringTestBase extends MultiTenantTestBase { } } - protected HttpResponse<?> getEntities() { + protected HttpResponse<?> getEntities(Entity entity) { String uriStr = MAIN_URI + ":" + monitoringPort + WORKER_MAIN_PATH; + if (entity == Entity.COORDINATOR) { + uriStr = MAIN_URI + ":" + monitoringPort + COORDINATOR_MAIN_PATH; + } + + if (entity == Entity.STATISTICS) { + uriStr = MAIN_URI + ":" + monitoringPort + STATISTICS_MAIN_PATH + "/1"; + } + try { var client = HttpClient.newHttpClient(); var request = HttpRequest.newBuilder(URI.create(uriStr)) diff --git a/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerIntegrationCRUDTest.java b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerIntegrationCRUDTest.java index b70c9c11c2..65fee61438 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerIntegrationCRUDTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerIntegrationCRUDTest.java @@ -43,7 +43,7 @@ public class FederatedWorkerIntegrationCRUDTest extends FederatedMonitoringTestB @Test public void testWorkerAddedForMonitoring() { - var addedWorkers = addEntities(1); + var addedWorkers = addEntities(1, Entity.WORKER); var firstWorkerStatus = addedWorkers.get(0).statusCode(); Assert.assertEquals("Added worker status code", HttpStatus.SC_OK, firstWorkerStatus); @@ -52,10 +52,10 @@ public class FederatedWorkerIntegrationCRUDTest extends FederatedMonitoringTestB @Test @Ignore public void testWorkerRemovedFromMonitoring() { - addEntities(2); - var statusCode = removeEntity(1L).statusCode(); + addEntities(2, Entity.WORKER); + var statusCode = removeEntity(1L, Entity.WORKER).statusCode(); - var getAllWorkersResponse = getEntities(); + var getAllWorkersResponse = getEntities(Entity.WORKER); var numReturnedWorkers = StringUtils.countMatches(getAllWorkersResponse.body().toString(), "id"); Assert.assertEquals("Removed worker status code", HttpStatus.SC_OK, statusCode); @@ -65,12 +65,12 @@ public class FederatedWorkerIntegrationCRUDTest extends FederatedMonitoringTestB @Test @Ignore public void testWorkerDataUpdated() { - addEntities(3); + addEntities(3, Entity.WORKER); var newWorkerData = new WorkerModel(1L, "NonExistentName", "nonexistent.address"); - var editedWorker = updateEntity(newWorkerData); + var editedWorker = updateEntity(newWorkerData, Entity.WORKER); - var getAllWorkersResponse = getEntities(); + var getAllWorkersResponse = getEntities(Entity.WORKER); var numWorkersNewData = StringUtils.countMatches(getAllWorkersResponse.body().toString(), newWorkerData.name); Assert.assertEquals("Updated worker status code", HttpStatus.SC_OK, editedWorker.statusCode()); @@ -81,14 +81,14 @@ public class FederatedWorkerIntegrationCRUDTest extends FederatedMonitoringTestB @Ignore public void testCorrectAmountAddedWorkersForMonitoring() { int numWorkers = 3; - var addedWorkers = addEntities(numWorkers); + var addedWorkers = addEntities(numWorkers, Entity.WORKER); for (int i = 0; i < numWorkers; i++) { var workerStatus = addedWorkers.get(i).statusCode(); Assert.assertEquals("Added worker status code", HttpStatus.SC_OK, workerStatus); } - var getAllWorkersResponse = getEntities(); + var getAllWorkersResponse = getEntities(Entity.WORKER); var numReturnedWorkers = StringUtils.countMatches(getAllWorkersResponse.body().toString(), "id"); Assert.assertEquals("Amount of workers to get", numWorkers, numReturnedWorkers); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerStatisticsTest.java b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerStatisticsTest.java index 5d092daf8e..cb2e91d113 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerStatisticsTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/monitoring/FederatedWorkerStatisticsTest.java @@ -19,26 +19,49 @@ package org.apache.sysds.test.functions.federated.monitoring; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.DataObjectModel; import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.EventModel; import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.EventStageModel; +import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.RequestModel; import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.StatisticsModel; import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.StatisticsOptions; import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.WorkerModel; +import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.Constants; import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.DerbyRepository; +import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.IRepository; import org.apache.sysds.runtime.controlprogram.federated.monitoring.services.StatisticsService; import org.apache.sysds.runtime.controlprogram.federated.monitoring.services.WorkerService; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; import org.junit.Assert; +import org.junit.Ignore; import org.junit.Test; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + public class FederatedWorkerStatisticsTest extends FederatedMonitoringTestBase { + private static final Log LOG = LogFactory.getLog(FederatedWorkerStatisticsTest.class.getName()); + private final static String TEST_NAME = "FederatedWorkerStatisticsTest"; private final static String TEST_DIR = "functions/federated/monitoring/"; private static final String TEST_CLASS_DIR = TEST_DIR + FederatedWorkerStatisticsTest.class.getSimpleName() + "/"; + private static final String PERFORMANCE_FORMAT = "For %d number of workers, milliseconds elapsed %d."; + private static int[] workerPorts; + private final IRepository entityRepository = new DerbyRepository(); private final WorkerService workerMonitoringService = new WorkerService(); private final StatisticsService statisticsMonitoringService = new StatisticsService(); @@ -46,7 +69,7 @@ public class FederatedWorkerStatisticsTest extends FederatedMonitoringTestBase { public void setUp() { TestUtils.clearAssertionInformation(); addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"})); - workerPorts = startFedWorkers(3); + workerPorts = startFedWorkers(6); } @Test @@ -58,6 +81,55 @@ public class FederatedWorkerStatisticsTest extends FederatedMonitoringTestBase { Assert.assertNotEquals("Utilization stats parsed correctly", 0, model.utilization.size()); } + @Test + @Ignore + public void testWorkerStatisticsPerformance() throws InterruptedException { + ExecutorService executor = Executors.newFixedThreadPool(workerPorts.length); + + double meanExecTime = 0.f; + double numRepetitionsExperiment = 100.f; + + for (int j = -10; j < numRepetitionsExperiment; j++) { + + Collection<Callable<StatisticsModel>> collect = new ArrayList<>(); + Collection<Callable<Boolean>> parse = new ArrayList<>(); + + for (int i = 1; i <= workerPorts.length; i++) { + long id = i; + String address = "localhost:" + workerPorts[i - 1]; + workerMonitoringService.create(new WorkerModel(id, "Worker", address)); + collect.add(() -> StatisticsService.getWorkerStatistics(id, address)); + } + + long start = System.currentTimeMillis(); + + // Returns a list of Futures holding their status and results when all complete. + // Future.isDone() is true for each element of the returned list + // https://docs.oracle.com/javase/7/docs/api/java/util/concurrent/ExecutorService.html#invokeAll(java.util.Collection) + List<Future<StatisticsModel>> taskFutures = executor.invokeAll(collect); + + taskFutures.forEach(res -> parse.add(() -> syncWorkerStats(res.get(), res.get().traffic.get(0).workerId))); + + executor.invokeAll(parse); + + long finish = System.currentTimeMillis(); + long elapsedTime = (finish - start); + + if (j >= 0) { + meanExecTime += elapsedTime; + } + } + + executor.shutdown(); + + // Wait until all threads are finish + // Returns true if all tasks have completed following shut down. + // Note that isTerminated is never true unless either shutdown or shutdownNow was called first. + while (!executor.isTerminated()); + + LOG.info(String.format(PERFORMANCE_FORMAT, workerPorts.length, Math.round(meanExecTime / numRepetitionsExperiment))); + } + @Test public void testWorkerStatisticsReturnedForMonitoring() { workerMonitoringService.create(new WorkerModel(1L, "Worker", "localhost:" + workerPorts[0])); @@ -81,7 +153,6 @@ public class FederatedWorkerStatisticsTest extends FederatedMonitoringTestBase { new EventStageModel(); - workerMonitoringService.create(new WorkerModel(1L, "Worker", "localhost:8001")); var options = new StatisticsOptions(); options.utilization = true; @@ -90,4 +161,83 @@ public class FederatedWorkerStatisticsTest extends FederatedMonitoringTestBase { Assert.assertEquals("Utilization field of model contains worker statistics", 0, stats.utilization.size()); } + + private Boolean syncWorkerStats(StatisticsModel stats, Long id) { + CompletableFuture<Boolean> utilizationFuture = null; + CompletableFuture<Boolean> trafficFuture = null; + CompletableFuture<Boolean> eventsFuture = null; + CompletableFuture<Boolean> dataObjFuture = null; + CompletableFuture<Boolean> requestsFuture = null; + + if (stats != null) { + + if (stats.utilization != null) { + utilizationFuture = CompletableFuture.supplyAsync(() -> { + entityRepository.createEntity(stats.utilization.get(0)); + return true; + }); + } + if (stats.traffic != null) { + trafficFuture = CompletableFuture.supplyAsync(() -> { + for (var trafficEntity : stats.traffic) { + if (trafficEntity.coordinatorId > 0) { + entityRepository.createEntity(trafficEntity); + } + } + return true; + }); + } + if (stats.events != null) { + eventsFuture = CompletableFuture.supplyAsync(() -> { + for (var eventEntity: stats.events) { + if (eventEntity.coordinatorId > 0) { + var eventId = entityRepository.createEntity(eventEntity); + + for (var stageEntity : eventEntity.stages) { + stageEntity.eventId = eventId; + + entityRepository.createEntity(stageEntity); + } + } + } + return true; + }); + } + if (stats.dataObjects != null) { + dataObjFuture = CompletableFuture.supplyAsync(() -> { + entityRepository.removeAllEntitiesByField(Constants.ENTITY_WORKER_ID_COL, id, DataObjectModel.class); + + for (var dataObjectEntity : stats.dataObjects) { + entityRepository.createEntity(dataObjectEntity); + } + + return true; + }); + } + if (stats.requests != null) { + requestsFuture = CompletableFuture.supplyAsync(() -> { + entityRepository.removeAllEntitiesByField(Constants.ENTITY_WORKER_ID_COL, id, RequestModel.class); + + for (var requestEntity : stats.requests) { + if (requestEntity.coordinatorId > 0) { + entityRepository.createEntity(requestEntity); + } + } + + return true; + }); + } + } + List<CompletableFuture<Boolean>> completableFutures = Arrays.asList(utilizationFuture, trafficFuture, eventsFuture, dataObjFuture, requestsFuture); + + completableFutures.forEach(cf -> { + try { + cf.get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + }); + + return true; + } }
