This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 582b9c3 [SYSTEMDS-3039] Tracking and consolidation of federated
statistics
582b9c3 is described below
commit 582b9c3f622d87cd9d11b9dd01abcb0a6f179309
Author: ywcb00 <[email protected]>
AuthorDate: Sun Jul 4 19:52:56 2021 +0200
[SYSTEMDS-3039] Tracking and consolidation of federated statistics
Closes #1321.
---
src/main/java/org/apache/sysds/api/DMLOptions.java | 19 ++
src/main/java/org/apache/sysds/api/DMLScript.java | 72 ++---
.../federated/FederatedStatistics.java | 311 +++++++++++++++++++++
.../instructions/fed/InitFEDInstruction.java | 5 +
.../java/org/apache/sysds/utils/Statistics.java | 22 +-
.../primitives/FederatedStatisticsTest.java | 5 +-
6 files changed, 395 insertions(+), 39 deletions(-)
diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java
b/src/main/java/org/apache/sysds/api/DMLOptions.java
index 0949c06..fbdaa90 100644
--- a/src/main/java/org/apache/sysds/api/DMLOptions.java
+++ b/src/main/java/org/apache/sysds/api/DMLOptions.java
@@ -52,6 +52,8 @@ public class DMLOptions {
public boolean clean = false; //
Whether to clean up all SystemDS working directories (FS, DFS)
public boolean stats = false; //
Whether to record and print the statistics
public int statsCount = 10; //
Default statistics count
+ public boolean fedStats = false; //
Whether to record and print the federated statistics
+ public int fedStatsCount = 10; //
Default federated statistics count
public boolean memStats = false; // max
memory statistics
public Explain.ExplainType explainType = Explain.ExplainType.NONE;
// Whether to print the "Explain" and if so, what type
public ExecMode execMode =
OptimizerUtils.getDefaultExecutionMode(); // Execution mode standalone, MR,
Spark or a hybrid
@@ -85,6 +87,8 @@ public class DMLOptions {
", clean=" + clean +
", stats=" + stats +
", statsCount=" + statsCount +
+ ", fedStats=" + fedStats +
+ ", fedStatsCount=" + fedStatsCount +
", memStats=" + memStats +
", explainType=" + explainType +
", execMode=" + execMode +
@@ -193,6 +197,17 @@ public class DMLOptions {
}
}
}
+ dmlOptions.fedStats = line.hasOption("fedStats");
+ if (dmlOptions.fedStats) {
+ String fedStatsCount = line.getOptionValue("fedStats");
+ if(fedStatsCount != null) {
+ try {
+ dmlOptions.fedStatsCount =
Integer.parseInt(fedStatsCount);
+ } catch (NumberFormatException e) {
+ throw new
org.apache.commons.cli.ParseException("Invalid argument specified for -fedStats
option, must be a valid integer");
+ }
+ }
+ }
dmlOptions.memStats = line.hasOption("mem");
dmlOptions.clean = line.hasOption("clean");
@@ -265,6 +280,9 @@ public class DMLOptions {
Option statsOpt = OptionBuilder.withArgName("count")
.withDescription("monitors and reports summary
execution statistics; heavy hitter <count> is 10 unless overridden; default
off")
.hasOptionalArg().create("stats");
+ Option fedStatsOpt = OptionBuilder.withArgName("count")
+ .withDescription("monitors and reports summary
execution statistics of federated workers; heavy hitter <count> is 10 unless
overridden; default off")
+ .hasOptionalArg().create("fedStats");
Option memOpt = OptionBuilder.withDescription("monitors and
reports max memory consumption in CP; default off")
.create("mem");
Option explainOpt = OptionBuilder.withArgName("level")
@@ -299,6 +317,7 @@ public class DMLOptions {
options.addOption(configOpt);
options.addOption(cleanOpt);
options.addOption(statsOpt);
+ options.addOption(fedStatsOpt);
options.addOption(memOpt);
options.addOption(explainOpt);
options.addOption(execOpt);
diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java
b/src/main/java/org/apache/sysds/api/DMLScript.java
index 7d2bf16..e2e67a5 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -82,26 +82,28 @@ import org.apache.sysds.utils.Explain.ExplainType;
public class DMLScript
{
- private static ExecMode EXEC_MODE =
DMLOptions.defaultOptions.execMode; // the execution mode
- public static boolean STATISTICS =
DMLOptions.defaultOptions.stats; // whether to print statistics
- public static boolean JMLC_MEM_STATISTICS = false;
// whether to gather memory use stats in JMLC
- public static int STATISTICS_COUNT =
DMLOptions.defaultOptions.statsCount; // statistics maximum heavy hitter count
- public static int STATISTICS_MAX_WRAP_LEN = 30;
// statistics maximum wrap length
- public static ExplainType EXPLAIN =
DMLOptions.defaultOptions.explainType; // explain type
- public static String DML_FILE_PATH_ANTLR_PARSER =
DMLOptions.defaultOptions.filePath; // filename of dml/pydml script
- public static String FLOATING_POINT_PRECISION = "double";
// data type to use internally
- public static boolean PRINT_GPU_MEMORY_INFO = false;
// whether to print GPU memory-related information
- public static long EVICTION_SHADOW_BUFFER_MAX_BYTES = 0;
// maximum number of bytes to use for shadow buffer
- public static long EVICTION_SHADOW_BUFFER_CURR_BYTES = 0;
// number of bytes to use for shadow buffer
- public static double GPU_MEMORY_UTILIZATION_FACTOR = 0.9;
// fraction of available GPU memory to use
- public static String GPU_MEMORY_ALLOCATOR = "cuda";
// GPU memory allocator to use
- public static boolean LINEAGE = DMLOptions.defaultOptions.lineage;
// whether compute lineage trace
- public static boolean LINEAGE_DEDUP =
DMLOptions.defaultOptions.lineage_dedup; // whether deduplicate lineage
items
- public static ReuseCacheType LINEAGE_REUSE =
DMLOptions.defaultOptions.linReuseType; // whether lineage-based reuse
- public static LineageCachePolicy LINEAGE_POLICY =
DMLOptions.defaultOptions.linCachePolicy; // lineage cache eviction policy
- public static boolean LINEAGE_ESTIMATE =
DMLOptions.defaultOptions.lineage_estimate; // whether estimate reuse benefits
- public static boolean LINEAGE_DEBUGGER =
DMLOptions.defaultOptions.lineage_debugger; // whether enable lineage debugger
- public static boolean CHECK_PRIVACY =
DMLOptions.defaultOptions.checkPrivacy; // Check which privacy constraints
are loaded and checked during federated execution
+ private static ExecMode EXEC_MODE =
DMLOptions.defaultOptions.execMode; // the execution mode
+ public static boolean STATISTICS =
DMLOptions.defaultOptions.stats; // whether to print statistics
+ public static boolean JMLC_MEM_STATISTICS = false;
// whether to gather memory use stats in JMLC
+ public static int STATISTICS_COUNT =
DMLOptions.defaultOptions.statsCount; // statistics maximum heavy hitter
count
+ public static int STATISTICS_MAX_WRAP_LEN = 30;
// statistics maximum wrap length
+ public static boolean FED_STATISTICS =
DMLOptions.defaultOptions.fedStats; // whether to print federated
statistics
+ public static int FED_STATISTICS_COUNT =
DMLOptions.defaultOptions.fedStatsCount; // federated statistics maximum
heavy hitter count
+ public static ExplainType EXPLAIN =
DMLOptions.defaultOptions.explainType; // explain type
+ public static String DML_FILE_PATH_ANTLR_PARSER =
DMLOptions.defaultOptions.filePath; // filename of dml/pydml script
+ public static String FLOATING_POINT_PRECISION = "double";
// data type to use internally
+ public static boolean PRINT_GPU_MEMORY_INFO = false;
// whether to print GPU memory-related information
+ public static long EVICTION_SHADOW_BUFFER_MAX_BYTES = 0;
// maximum number of bytes to use for shadow buffer
+ public static long EVICTION_SHADOW_BUFFER_CURR_BYTES = 0;
// number of bytes to use for shadow buffer
+ public static double GPU_MEMORY_UTILIZATION_FACTOR = 0.9;
// fraction of available GPU memory to use
+ public static String GPU_MEMORY_ALLOCATOR = "cuda";
// GPU memory allocator to use
+ public static boolean LINEAGE = DMLOptions.defaultOptions.lineage;
// whether compute lineage trace
+ public static boolean LINEAGE_DEDUP =
DMLOptions.defaultOptions.lineage_dedup; // whether deduplicate
lineage items
+ public static ReuseCacheType LINEAGE_REUSE =
DMLOptions.defaultOptions.linReuseType; // whether lineage-based reuse
+ public static LineageCachePolicy LINEAGE_POLICY =
DMLOptions.defaultOptions.linCachePolicy; // lineage cache eviction policy
+ public static boolean LINEAGE_ESTIMATE =
DMLOptions.defaultOptions.lineage_estimate; // whether estimate reuse
benefits
+ public static boolean LINEAGE_DEBUGGER =
DMLOptions.defaultOptions.lineage_debugger; // whether enable lineage
debugger
+ public static boolean CHECK_PRIVACY =
DMLOptions.defaultOptions.checkPrivacy; // Check which privacy
constraints are loaded and checked during federated execution
public static boolean USE_ACCELERATOR =
DMLOptions.defaultOptions.gpu;
public static boolean FORCE_ACCELERATOR =
DMLOptions.defaultOptions.forceGPU;
@@ -212,20 +214,22 @@ public class DMLScript
try
{
- STATISTICS = dmlOptions.stats;
- STATISTICS_COUNT = dmlOptions.statsCount;
- JMLC_MEM_STATISTICS = dmlOptions.memStats;
- USE_ACCELERATOR = dmlOptions.gpu;
- FORCE_ACCELERATOR = dmlOptions.forceGPU;
- EXPLAIN = dmlOptions.explainType;
- EXEC_MODE = dmlOptions.execMode;
- LINEAGE = dmlOptions.lineage;
- LINEAGE_DEDUP = dmlOptions.lineage_dedup;
- LINEAGE_REUSE = dmlOptions.linReuseType;
- LINEAGE_POLICY = dmlOptions.linCachePolicy;
- LINEAGE_ESTIMATE = dmlOptions.lineage_estimate;
- CHECK_PRIVACY = dmlOptions.checkPrivacy;
- LINEAGE_DEBUGGER = dmlOptions.lineage_debugger;
+ STATISTICS = dmlOptions.stats;
+ STATISTICS_COUNT = dmlOptions.statsCount;
+ FED_STATISTICS = dmlOptions.fedStats;
+ FED_STATISTICS_COUNT = dmlOptions.fedStatsCount;
+ JMLC_MEM_STATISTICS = dmlOptions.memStats;
+ USE_ACCELERATOR = dmlOptions.gpu;
+ FORCE_ACCELERATOR = dmlOptions.forceGPU;
+ EXPLAIN = dmlOptions.explainType;
+ EXEC_MODE = dmlOptions.execMode;
+ LINEAGE = dmlOptions.lineage;
+ LINEAGE_DEDUP = dmlOptions.lineage_dedup;
+ LINEAGE_REUSE = dmlOptions.linReuseType;
+ LINEAGE_POLICY = dmlOptions.linCachePolicy;
+ LINEAGE_ESTIMATE = dmlOptions.lineage_estimate;
+ CHECK_PRIVACY = dmlOptions.checkPrivacy;
+ LINEAGE_DEBUGGER = dmlOptions.lineage_debugger;
String fnameOptConfig = dmlOptions.configFile;
boolean isFile = dmlOptions.filePath != null;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
new file mode 100644
index 0000000..14f29d9
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.federated;
+
+import java.io.Serializable;
+import java.net.InetSocketAddress;
+import java.text.DecimalFormat;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.concurrent.Future;
+import javax.net.ssl.SSLException;
+
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.runtime.controlprogram.caching.CacheStatistics;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.CacheStatsCollection;
+import
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.GCStatsCollection;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.utils.Statistics;
+
+public class FederatedStatistics {
+ private static Set<Pair<String, Integer>> _fedWorkerAddresses = new
HashSet<>();
+
+ public static void registerFedWorker(String host, int port) {
+ _fedWorkerAddresses.add(new ImmutablePair<>(host, new
Integer(port)));
+ }
+
+ public static String displayFedWorkers() {
+ StringBuilder sb = new StringBuilder();
+ sb.append("Federated Worker Addresses:\n");
+ for(Pair<String, Integer> fedAddr : _fedWorkerAddresses) {
+ sb.append(String.format(" %s:%d", fedAddr.getLeft(),
fedAddr.getRight().intValue()));
+ sb.append("\n");
+ }
+ return sb.toString();
+ }
+
+ public static String displayFedStatistics(int numHeavyHitters) {
+ StringBuilder sb = new StringBuilder();
+ FedStatsCollection fedStats = collectFedStats();
+ sb.append("SystemDS Federated Statistics:\n");
+ sb.append(displayCacheStats(fedStats.cacheStats));
+ sb.append(String.format("Total JIT compile time:\t\t%.3f
sec.\n", fedStats.jitCompileTime));
+ sb.append(displayGCStats(fedStats.gcStats));
+ sb.append(displayHeavyHitters(fedStats.heavyHitters,
numHeavyHitters));
+ return sb.toString();
+ }
+
+ public static String displayCacheStats(CacheStatsCollection csc) {
+ StringBuilder sb = new StringBuilder();
+ sb.append(String.format("Cache hits
(Mem/Li/WB/FS/HDFS):\t%d/%d/%d/%d/%d.\n",
+ csc.memHits, csc.linHits, csc.fsBuffHits, csc.fsHits,
csc.hdfsHits));
+ sb.append(String.format("Cache writes
(Li/WB/FS/HDFS):\t%d/%d/%d/%d.\n",
+ csc.linWrites, csc.fsBuffWrites, csc.fsWrites,
csc.hdfsWrites));
+ sb.append(String.format("Cache times (ACQr/m, RLS,
EXP):\t%.3f/%.3f/%.3f/%.3f sec.\n",
+ csc.acqRTime, csc.acqMTime, csc.rlsTime, csc.expTime));
+ return sb.toString();
+ }
+
+ public static String displayGCStats(GCStatsCollection gcsc) {
+ StringBuilder sb = new StringBuilder();
+ sb.append(String.format("Total JVM GC count:\t\t%d.\n",
gcsc.gcCount));
+ sb.append(String.format("Total JVM GC time:\t\t%.3f sec.\n",
gcsc.gcTime));
+ return sb.toString();
+ }
+
+ public static String displayHeavyHitters(HashMap<String, Pair<Long,
Double>> heavyHitters) {
+ return displayHeavyHitters(heavyHitters, 10);
+ }
+
+ public static String displayHeavyHitters(HashMap<String, Pair<Long,
Double>> heavyHitters, int num) {
+ StringBuilder sb = new StringBuilder();
+ @SuppressWarnings("unchecked")
+ Entry<String, Pair<Long, Double>>[] hhArr =
heavyHitters.entrySet().toArray(new Entry[0]);
+ Arrays.sort(hhArr, new Comparator<Entry<String, Pair<Long,
Double>>>() {
+ public int compare(Entry<String, Pair<Long, Double>>
e1, Entry<String, Pair<Long, Double>> e2) {
+ return
e1.getValue().getRight().compareTo(e2.getValue().getRight());
+ }
+ });
+
+ sb.append("Heavy hitter instructions:\n");
+ final String numCol = "#";
+ final String instCol = "Instruction";
+ final String timeSCol = "Time(s)";
+ final String countCol = "Count";
+ int numHittersToDisplay = Math.min(num, hhArr.length);
+ int maxNumLen = String.valueOf(numHittersToDisplay).length();
+ int maxInstLen = instCol.length();
+ int maxTimeSLen = timeSCol.length();
+ int maxCountLen = countCol.length();
+ DecimalFormat sFormat = new DecimalFormat("#,##0.000");
+ for (int counter = 0; counter < numHittersToDisplay; counter++)
{
+ Entry<String, Pair<Long, Double>> hh =
hhArr[hhArr.length - 1 - counter];
+ String instruction = hh.getKey();
+ maxInstLen = Math.max(maxInstLen, instruction.length());
+ String timeString =
sFormat.format(hh.getValue().getRight());
+ maxTimeSLen = Math.max(maxTimeSLen,
timeString.length());
+ maxCountLen = Math.max(maxCountLen,
String.valueOf(hh.getValue().getLeft()).length());
+ }
+ maxInstLen = Math.min(maxInstLen,
DMLScript.STATISTICS_MAX_WRAP_LEN);
+ sb.append(String.format( " %" + maxNumLen + "s %-" +
maxInstLen + "s %"
+ + maxTimeSLen + "s %" + maxCountLen + "s", numCol,
instCol, timeSCol, countCol));
+ sb.append("\n");
+
+ for (int counter = 0; counter < numHittersToDisplay; counter++)
{
+ String instruction = hhArr[hhArr.length - 1 -
counter].getKey();
+ String [] wrappedInstruction =
Statistics.wrap(instruction, maxInstLen);
+
+ String timeSString = sFormat.format(hhArr[hhArr.length
- 1 - counter].getValue().getRight());
+
+ long count = hhArr[hhArr.length - 1 -
counter].getValue().getLeft();
+ int numLines = wrappedInstruction.length;
+
+ for(int wrapIter = 0; wrapIter < numLines; wrapIter++) {
+ String instStr = (wrapIter <
wrappedInstruction.length) ? wrappedInstruction[wrapIter] : "";
+ if(wrapIter == 0) {
+ sb.append(String.format(
+ " %" + maxNumLen + "d %-" +
maxInstLen + "s %" + maxTimeSLen + "s %"
+ + maxCountLen + "d", (counter +
1), instStr, timeSString, count));
+ }
+ else {
+ sb.append(String.format(
+ " %" + maxNumLen + "s %-" +
maxInstLen + "s %" + maxTimeSLen + "s %"
+ + maxCountLen + "s", "",
instStr, "", ""));
+ }
+ sb.append("\n");
+ }
+ }
+
+ return sb.toString();
+ }
+
+ private static FedStatsCollection collectFedStats() {
+ Future<FederatedResponse>[] responses = getFederatedResponses();
+ FedStatsCollection aggFedStats = new FedStatsCollection();
+ for(Future<FederatedResponse> res : responses) {
+ try {
+ Object[] tmp = res.get().getData();
+ if(tmp[0] instanceof FedStatsCollection)
+
aggFedStats.aggregate((FedStatsCollection)tmp[0]);
+ } catch(Exception e) {
+ throw new DMLRuntimeException("Exception of
type " + e.getClass().toString()
+ + " thrown while " + "getting the
federated stats of the federated response: ", e);
+ }
+ }
+ return aggFedStats;
+ }
+
+ private static Future<FederatedResponse>[] getFederatedResponses() {
+ List<Future<FederatedResponse>> ret = new ArrayList<>();
+ for(Pair<String, Integer> fedAddr : _fedWorkerAddresses) {
+ InetSocketAddress isa = new
InetSocketAddress(fedAddr.getLeft(), fedAddr.getRight());
+ FederatedRequest frUDF = new
FederatedRequest(RequestType.EXEC_UDF, -1,
+ new FedStatsCollectFunction());
+ try {
+
ret.add(FederatedData.executeFederatedOperation(isa, frUDF));
+ } catch(SSLException ssle) {
+ throw new DMLRuntimeException("SSLException
while getting the federated stats from "
+ + isa.toString() + ": ", ssle);
+ } catch (Exception e) {
+ throw new DMLRuntimeException("Exeption of type
" + e.getClass().getName()
+ + " thrown while getting stats from
federated worker: ", e);
+ }
+ }
+ @SuppressWarnings("unchecked")
+ Future<FederatedResponse>[] retArr = ret.toArray(new Future[0]);
+ return retArr;
+ }
+
+ private static class FedStatsCollectFunction extends FederatedUDF {
+ private static final long serialVersionUID = 1L;
+
+ public FedStatsCollectFunction() {
+ super(new long[] { });
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ FedStatsCollection fedStats = new FedStatsCollection();
+ fedStats.collectStats();
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, fedStats);
+ }
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
+ }
+
+ protected static class FedStatsCollection implements Serializable {
+ private static final long serialVersionUID = 1L;
+
+ private void collectStats() {
+ cacheStats.collectStats();
+ jitCompileTime =
((double)Statistics.getJITCompileTime()) / 1000; // in sec
+ gcStats.collectStats();
+ heavyHitters = Statistics.getHeavyHittersHashMap();
+ }
+
+ private void aggregate(FedStatsCollection that) {
+ cacheStats.aggregate(that.cacheStats);
+ jitCompileTime += that.jitCompileTime;
+ gcStats.aggregate(that.gcStats);
+ that.heavyHitters.forEach(
+ (key, value) -> heavyHitters.merge(key, value,
(v1, v2) ->
+ new ImmutablePair<>(v1.getLeft() +
v2.getLeft(), v1.getRight() + v2.getRight()))
+ );
+ }
+
+ protected static class CacheStatsCollection implements
Serializable {
+ private static final long serialVersionUID = 1L;
+
+ private void collectStats() {
+ memHits = CacheStatistics.getMemHits();
+ linHits = CacheStatistics.getLinHits();
+ fsBuffHits = CacheStatistics.getFSBuffHits();
+ fsHits = CacheStatistics.getFSHits();
+ hdfsHits = CacheStatistics.getHDFSHits();
+ linWrites = CacheStatistics.getLinWrites();
+ fsBuffWrites =
CacheStatistics.getFSBuffWrites();
+ fsWrites = CacheStatistics.getFSWrites();
+ hdfsWrites = CacheStatistics.getHDFSWrites();
+ acqRTime =
((double)CacheStatistics.getAcquireRTime()) / 1000000000; // in sec
+ acqMTime =
((double)CacheStatistics.getAcquireMTime()) / 1000000000; // in sec
+ rlsTime =
((double)CacheStatistics.getReleaseTime()) / 1000000000; // in sec
+ expTime =
((double)CacheStatistics.getExportTime()) / 1000000000; // in sec
+ }
+
+ private void aggregate(CacheStatsCollection that) {
+ memHits += that.memHits;
+ linHits += that.linHits;
+ fsBuffHits += that.fsBuffHits;
+ fsHits += that.fsHits;
+ hdfsHits += that.hdfsHits;
+ linWrites += that.linWrites;
+ fsBuffWrites += that.fsBuffWrites;
+ fsWrites += that.fsWrites;
+ hdfsWrites += that.hdfsWrites;
+ acqRTime += that.acqRTime;
+ acqMTime += that.acqMTime;
+ rlsTime += that.rlsTime;
+ expTime += that.expTime;
+ }
+
+ private long memHits = 0;
+ private long linHits = 0;
+ private long fsBuffHits = 0;
+ private long fsHits = 0;
+ private long hdfsHits = 0;
+ private long linWrites = 0;
+ private long fsBuffWrites = 0;
+ private long fsWrites = 0;
+ private long hdfsWrites = 0;
+ private double acqRTime = 0;
+ private double acqMTime = 0;
+ private double rlsTime = 0;
+ private double expTime = 0;
+ }
+
+ protected static class GCStatsCollection implements
Serializable {
+ private static final long serialVersionUID = 1L;
+
+ private void collectStats() {
+ gcCount = Statistics.getJVMgcCount();
+ gcTime = ((double)Statistics.getJVMgcTime()) /
1000; // in sec
+ }
+
+ private void aggregate(GCStatsCollection that) {
+ gcCount += that.gcCount;
+ gcTime += that.gcTime;
+ }
+
+ private long gcCount = 0;
+ private double gcTime = 0;
+ }
+
+ private CacheStatsCollection cacheStats = new
CacheStatsCollection();
+ private double jitCompileTime = 0;
+ private GCStatsCollection gcStats = new GCStatsCollection();
+ private HashMap<String, Pair<Long, Double>> heavyHitters = new
HashMap<>();
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 9b6d3f0..b4d2e04 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -47,6 +47,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -120,6 +121,10 @@ public class InitFEDInstruction extends FEDInstruction
implements LineageTraceab
String host = parsedValues[0];
int port = Integer.parseInt(parsedValues[1]);
String filePath = parsedValues[2];
+
+ // register the federated worker for federated
statistics creation
+ FederatedStatistics.registerFedWorker(host,
port);
+
// get beginning and end of data ranges
List<Data> rangesData = ranges.getData();
Data beginData = rangesData.get(i * 2);
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java
b/src/main/java/org/apache/sysds/utils/Statistics.java
index d4247a7..dd8ddce 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -25,6 +25,7 @@ import java.lang.management.ManagementFactory;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.Comparator;
+import java.util.HashMap;
import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
@@ -32,12 +33,15 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.DoubleAdder;
import java.util.concurrent.atomic.LongAdder;
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.controlprogram.caching.CacheStatistics;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -732,6 +736,17 @@ public class Statistics
return (tmp != null) ? tmp.count.longValue() : 0;
}
+ public static HashMap<String, Pair<Long, Double>>
getHeavyHittersHashMap() {
+ HashMap<String, Pair<Long, Double>> heavyHitters = new
HashMap<>();
+ for(String opcode : _instStats.keySet()) {
+ InstStats val = _instStats.get(opcode);
+ long count = val.count.longValue();
+ double time = val.time.longValue() / 1000000000d; // in
sec
+ heavyHitters.put(opcode, new ImmutablePair<Long,
Double>(new Long(count), new Double(time)));
+ }
+ return heavyHitters;
+ }
+
/**
* Obtain a string tabular representation of the heavy hitter
instructions
* that displays the time, instruction count, and optionally GPU stats
about
@@ -956,7 +971,7 @@ public class Statistics
}
- private static String [] wrap(String str, int wrapLength) {
+ public static String [] wrap(String str, int wrapLength) {
int numLines = (int) Math.ceil( ((double)str.length()) /
wrapLength);
int len = str.length();
String [] ret = new String[numLines];
@@ -1105,6 +1120,11 @@ public class Statistics
if (DMLScript.CHECK_PRIVACY)
sb.append(CheckedConstraintsLog.display());
+ if(DMLScript.FED_STATISTICS) {
+ sb.append("\n");
+
sb.append(FederatedStatistics.displayFedStatistics(DMLScript.FED_STATISTICS_COUNT));
+ }
+
return sb.toString();
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
index 09ca19e..54d89e6 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
@@ -30,14 +30,12 @@ import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
-import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
-@Ignore
public class FederatedStatisticsTest extends AutomatedTestBase {
private final static String TEST_DIR = "functions/federated/";
@@ -105,7 +103,6 @@ public class FederatedStatisticsTest extends
AutomatedTestBase {
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
-
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {"-args", input("X1"), input("X2"),
input("Y"), expected("Z")};
@@ -113,7 +110,7 @@ public class FederatedStatisticsTest extends
AutomatedTestBase {
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-stats", "30", "-nvargs",
+ programArgs = new String[] {"-stats", "30", "-fedStats",
"-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2,
input("X2")), "rows=" + rows, "cols=" + cols,
"in_Y=" + input("Y"), "out=" + output("Z")};